use std::{
collections::BTreeMap,
sync::{Arc, RwLock},
};
use async_trait::async_trait;
use sqlx::Database;
#[cfg(feature = "postgres")]
pub mod postgres;
#[derive(sqlx::FromRow, Debug, Clone)]
pub struct PromadRow {
pub(crate) name: String,
pub(crate) ordering_key: i64,
pub(crate) created_at: chrono::DateTime<chrono::Utc>,
}
#[async_trait]
pub trait PromadRepo<DB: Database>: Send + Sync {
fn new() -> Self
where
Self: Sized;
async fn init<'a>(
&self,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()>;
async fn set_read_only<'a>(
&self,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()>;
async fn get_all<'a>(
&self,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<Vec<PromadRow>>;
async fn get<'a>(
&self,
name: &str,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<Option<PromadRow>>;
async fn insert<'a>(
&self,
row: &PromadRow,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()>;
async fn delete<'a>(
&self,
row: &'static str,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()>;
}
pub struct CachedPromadRepo<DB: Database, N: PromadRepo<DB>> {
inner: Box<dyn PromadRepo<DB>>,
cache: Arc<RwLock<BTreeMap<i64, PromadRow>>>,
is_db_loaded: Arc<RwLock<bool>>,
_marker: std::marker::PhantomData<N>,
}
#[async_trait]
impl<DB: Database, N: PromadRepo<DB> + 'static> PromadRepo<DB> for CachedPromadRepo<DB, N> {
fn new() -> Self {
Self {
inner: Box::new(N::new()),
cache: Arc::new(RwLock::new(BTreeMap::new())),
is_db_loaded: Arc::new(RwLock::new(false)),
_marker: Default::default(),
}
}
async fn init<'a>(
&self,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()> {
self.inner.init(conn).await
}
async fn set_read_only<'a>(
&self,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()> {
self.inner.set_read_only(conn).await
}
async fn get_all<'a>(
&self,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<Vec<PromadRow>> {
{
let is_db_loaded = self.is_db_loaded.read()?;
if *is_db_loaded {
let cache = self.cache.read()?;
return Ok(cache.values().cloned().collect());
}
}
let rows = {
let rows = self.inner.get_all(conn).await?;
let mut cache = self.cache.write()?;
for row in &rows {
cache.insert(row.ordering_key, row.clone());
}
rows
};
let mut is_db_loaded = self.is_db_loaded.write()?;
*is_db_loaded = true;
Ok(rows)
}
async fn get<'a>(
&self,
name: &str,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<Option<PromadRow>> {
{
let is_db_loaded = self.is_db_loaded.read()?;
if *is_db_loaded {
let cache = self.cache.read()?;
return Ok(cache.values().find(|&row| row.name == name).cloned());
}
}
let row = self.inner.get(name, conn).await?;
if let Some(ref r) = row {
let mut cache = self.cache.write()?;
cache.insert(r.ordering_key, r.clone());
}
Ok(row)
}
async fn insert<'a>(
&self,
row: &PromadRow,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()> {
self.inner.insert(row, conn).await?;
let mut cache = self.cache.write()?;
cache.insert(row.ordering_key, row.clone());
Ok(())
}
async fn delete<'a>(
&self,
name: &'static str,
conn: &'a mut <DB as Database>::Connection,
) -> crate::error::Result<()> {
self.inner.delete(name, conn).await?;
let mut cache = self.cache.write()?;
cache.retain(|_, row| row.name != name);
Ok(())
}
}