use std::{marker::PhantomData, ops::DerefMut, sync::Arc};
use anyhow::{Context, Result};
use async_trait::async_trait;
use sqlx::SqliteConnection;
use tokio::{
select,
sync::{watch, Mutex},
};
use tokio_util::sync::CancellationToken;
use super::sqlite::SqliteStore;
#[async_trait]
pub trait Storable: Sized {
fn id(&self) -> i32;
async fn store(&self, conn: &mut SqliteConnection) -> Result<i32>;
async fn load(conn: &mut SqliteConnection, minimum_id: i32) -> Result<Option<Self>>;
async fn remove(conn: &mut SqliteConnection, id: i32) -> Result<()>;
async fn count(conn: &mut SqliteConnection) -> Result<usize>;
}
pub fn channel<T: Storable>(store: SqliteStore) -> (Sender<T>, Receiver<T>) {
let (watch_tx, watch_rx) = watch::channel(None);
(
Sender {
store: store.clone(),
last_saved: Arc::new(Mutex::new(watch_tx)),
phantom: PhantomData,
},
Receiver {
store: store.clone(),
last_saved: watch_rx,
last_received: None,
phantom: PhantomData,
},
)
}
#[derive(Debug, Clone)]
pub struct Sender<T> {
store: SqliteStore,
last_saved: Arc<Mutex<watch::Sender<Option<i32>>>>,
phantom: PhantomData<T>,
}
#[derive(Debug)]
pub struct Receiver<T> {
store: SqliteStore,
last_saved: watch::Receiver<Option<i32>>,
last_received: Option<i32>,
phantom: PhantomData<T>,
}
impl<T: Storable> Sender<T> {
pub async fn send(&self, obj: &T) -> Result<()> {
let mut conn = self.store.connection().await;
let id = obj.store(conn.deref_mut()).await?;
{
let last_saved = self.last_saved.lock().await;
let last_id = last_saved.send_replace(Some(id));
if let Some(last_id) = last_id {
if last_id > id {
last_saved.send_replace(Some(last_id));
}
}
}
Ok(())
}
pub async fn count(&self) -> Result<usize> {
let mut conn = self.store.connection().await;
T::count(&mut conn).await
}
}
impl<T: Storable> Receiver<T> {
pub async fn recv(&mut self, cancellation: &Option<CancellationToken>) -> Result<T> {
let last_inserted = self.wait_new(cancellation).await?;
let mut conn = self.store.connection().await;
let obj = T::load(conn.deref_mut(), self.last_received.unwrap_or(i32::MIN))
.await?
.ok_or_else(|| {
anyhow::anyhow!(
"Unable to retrieve object with ID {:?} that should have already been stored.",
last_inserted
)
})?;
self.last_received = Some(obj.id());
Ok(obj)
}
async fn wait_new(&mut self, cancellation: &Option<CancellationToken>) -> Result<i32> {
let mut last_inserted = *self.last_saved.borrow_and_update();
while last_inserted <= self.last_received {
let change_task = self.last_saved.changed();
let result = if let Some(cancellation) = cancellation {
select! {
result = change_task => result,
_ = cancellation.cancelled() => anyhow::bail!("Task cancelled."),
}
} else {
change_task.await
};
result.context("No more messages will be received in this run")?;
last_inserted = *self.last_saved.borrow_and_update();
}
Ok(last_inserted.expect("Last inserted cannot be None."))
}
pub async fn ack(&self, obj: &T) -> Result<()> {
let mut conn = self.store.connection().await;
T::remove(conn.deref_mut(), obj.id()).await
}
pub async fn count(&self) -> Result<usize> {
let mut conn = self.store.connection().await;
T::count(&mut conn).await
}
}