use super::{Driver, Error as CacheError};
use async_trait::async_trait;
use ensemble::{types::DateTime, Model};
use std::time::Duration;
#[derive(Debug, Model)]
#[ensemble(table = "cache")]
struct CacheEntry {
#[model(primary)]
pub key: String,
pub value: Vec<u8>,
pub expiration: Option<DateTime>,
}
#[allow(clippy::module_name_repetitions)]
#[derive(Debug, Default)]
pub struct DatabaseDriver;
impl DatabaseDriver {
#[must_use]
pub const fn new() -> Self {
Self
}
}
fn map_db_err(e: ensemble::Error) -> CacheError {
CacheError::Other(Box::new(e))
}
#[async_trait]
impl Driver for DatabaseDriver {
async fn get(&self, key: &str) -> Result<Option<Vec<u8>>, CacheError> {
let Some(entry) = CacheEntry::query()
.r#where("key", '=', key)
.where_group(|query| {
query
.where_null("expiration")
.or_where("expiration", '>', DateTime::now())
})
.first::<CacheEntry>()
.await
.map_err(map_db_err)?
else {
return Ok(None);
};
Ok(Some(entry.value))
}
async fn has(&self, key: &str) -> Result<bool, CacheError> {
let count = CacheEntry::query()
.r#where("key", '=', key)
.where_null("expiration")
.or_where("expiration", '>', DateTime::now())
.count()
.await
.map_err(map_db_err)?;
Ok(count != 0)
}
async fn put(
&self,
key: &str,
value: Vec<u8>,
duration: Option<Duration>,
) -> Result<(), CacheError> {
let expiration = duration.map(|duration| DateTime::now() + duration);
CacheEntry::query()
.r#where("key", '=', key)
.delete()
.await
.map_err(map_db_err)?;
CacheEntry::create(CacheEntry {
value,
expiration,
key: key.to_string(),
})
.await
.map_err(map_db_err)?;
Ok(())
}
async fn forget(&self, key: &str) -> Result<(), CacheError> {
CacheEntry::query()
.r#where("key", '=', key)
.delete()
.await
.map_err(map_db_err)?;
Ok(())
}
async fn flush(&self) -> Result<(), CacheError> {
CacheEntry::query().delete().await.map_err(map_db_err)?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use std::env;
use super::*;
use crate::Cache;
#[tokio::test]
async fn test_database_driver() {
ensemble::setup(&env::var("DATABASE_URL").expect("DATABASE_URL not set")).unwrap();
let cache = Cache::new(DatabaseDriver::new());
assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
assert!(!cache.has("foo").await.unwrap());
cache
.put("foo", &"bar".to_string(), Duration::from_secs(10))
.await
.unwrap();
assert_eq!(cache.get("foo").await.unwrap(), Some("bar".to_string()));
assert!(cache.has("foo").await.unwrap());
cache.forget("foo").await.unwrap();
assert_eq!(cache.get::<String>("foo").await.unwrap(), None);
assert!(!cache.has("foo").await.unwrap());
cache
.put("foo", &"bar".to_string(), Duration::from_secs(1))
.await
.unwrap();
}
}