use std::time::{Duration, SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use super::{Cache, CacheError};
use crate::core::SqlValue;
use crate::sql::{raw_execute_pool, raw_query_pool, Pool};
#[derive(Clone)]
pub struct DatabaseCache {
pool: Pool,
table: String,
}
impl DatabaseCache {
#[must_use]
pub fn new(pool: Pool, table: impl Into<String>) -> Self {
Self {
pool,
table: table.into(),
}
}
#[must_use]
pub fn table(&self) -> &str {
&self.table
}
pub async fn ensure_table(&self) -> Result<(), CacheError> {
let dialect = self.pool.dialect();
let table = dialect.quote_ident(&self.table);
let sql = match dialect.name() {
"postgres" => format!(
"CREATE TABLE IF NOT EXISTS {table} (\
cache_key TEXT PRIMARY KEY, \
value TEXT NOT NULL, \
expires BIGINT NOT NULL DEFAULT 0\
)"
),
"mysql" => format!(
"CREATE TABLE IF NOT EXISTS {table} (\
cache_key VARCHAR(255) PRIMARY KEY, \
value LONGTEXT NOT NULL, \
expires BIGINT NOT NULL DEFAULT 0\
)"
),
_ => format!(
"CREATE TABLE IF NOT EXISTS {table} (\
cache_key TEXT PRIMARY KEY, \
value TEXT NOT NULL, \
expires INTEGER NOT NULL DEFAULT 0\
)"
),
};
raw_execute_pool(&self.pool, &sql, vec![])
.await
.map_err(|e| CacheError::Connection(format!("ensure_table: {e}")))?;
Ok(())
}
pub async fn drop_table(&self) -> Result<(), CacheError> {
let table = self.pool.dialect().quote_ident(&self.table);
let sql = format!("DROP TABLE IF EXISTS {table}");
raw_execute_pool(&self.pool, &sql, vec![])
.await
.map_err(|e| CacheError::Connection(format!("drop_table: {e}")))?;
Ok(())
}
pub async fn purge_expired(&self) -> Result<u64, CacheError> {
let dialect = self.pool.dialect();
let table = dialect.quote_ident(&self.table);
let p1 = dialect.placeholder(1);
let sql = format!("DELETE FROM {table} WHERE expires != 0 AND expires < {p1}");
let now = Self::now_unix_ms();
raw_execute_pool(&self.pool, &sql, vec![SqlValue::I64(now)])
.await
.map_err(|e| CacheError::Connection(format!("purge_expired: {e}")))
}
fn now_unix_ms() -> i64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.map(|d| i64::try_from(d.as_millis()).unwrap_or(i64::MAX))
.unwrap_or(0)
}
fn expires_for(ttl: Option<Duration>) -> i64 {
ttl.map(|d| {
let ms = i64::try_from(d.as_millis()).unwrap_or(i64::MAX);
Self::now_unix_ms().saturating_add(ms)
})
.unwrap_or(0)
}
}
#[async_trait]
impl Cache for DatabaseCache {
async fn get(&self, key: &str) -> Result<Option<String>, CacheError> {
let dialect = self.pool.dialect();
let table = dialect.quote_ident(&self.table);
let p1 = dialect.placeholder(1);
let sql = format!("SELECT value, expires FROM {table} WHERE cache_key = {p1} LIMIT 1");
let rows: Vec<(String, i64)> =
raw_query_pool(&sql, vec![SqlValue::String(key.to_owned())], &self.pool)
.await
.map_err(|e| CacheError::Connection(format!("get: {e}")))?;
let Some((value, expires)) = rows.into_iter().next() else {
return Ok(None);
};
if expires != 0 && Self::now_unix_ms() >= expires {
let _ = self.delete(key).await;
return Ok(None);
}
Ok(Some(value))
}
async fn set(&self, key: &str, value: &str, ttl: Option<Duration>) -> Result<(), CacheError> {
let dialect = self.pool.dialect();
let table = dialect.quote_ident(&self.table);
let p1 = dialect.placeholder(1);
let p2 = dialect.placeholder(2);
let p3 = dialect.placeholder(3);
let expires = Self::expires_for(ttl);
let sql = match dialect.name() {
"mysql" => format!(
"INSERT INTO {table} (cache_key, value, expires) \
VALUES ({p1}, {p2}, {p3}) \
ON DUPLICATE KEY UPDATE value = VALUES(value), expires = VALUES(expires)"
),
_ => format!(
"INSERT INTO {table} (cache_key, value, expires) \
VALUES ({p1}, {p2}, {p3}) \
ON CONFLICT (cache_key) DO UPDATE SET value = EXCLUDED.value, expires = EXCLUDED.expires"
),
};
raw_execute_pool(
&self.pool,
&sql,
vec![
SqlValue::String(key.to_owned()),
SqlValue::String(value.to_owned()),
SqlValue::I64(expires),
],
)
.await
.map_err(|e| CacheError::Connection(format!("set: {e}")))?;
Ok(())
}
async fn delete(&self, key: &str) -> Result<(), CacheError> {
let dialect = self.pool.dialect();
let table = dialect.quote_ident(&self.table);
let p1 = dialect.placeholder(1);
let sql = format!("DELETE FROM {table} WHERE cache_key = {p1}");
raw_execute_pool(&self.pool, &sql, vec![SqlValue::String(key.to_owned())])
.await
.map_err(|e| CacheError::Connection(format!("delete: {e}")))?;
Ok(())
}
async fn exists(&self, key: &str) -> Result<bool, CacheError> {
Ok(self.get(key).await?.is_some())
}
async fn clear(&self) -> Result<(), CacheError> {
let dialect = self.pool.dialect();
let table = dialect.quote_ident(&self.table);
let sql = format!("DELETE FROM {table}");
raw_execute_pool(&self.pool, &sql, vec![])
.await
.map_err(|e| CacheError::Connection(format!("clear: {e}")))?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn expires_for_zero_when_no_ttl() {
assert_eq!(DatabaseCache::expires_for(None), 0);
}
#[test]
fn expires_for_offsets_from_now() {
let before = DatabaseCache::now_unix_ms();
let ts = DatabaseCache::expires_for(Some(Duration::from_secs(60)));
let after = DatabaseCache::now_unix_ms();
assert!(
ts >= before + 60_000 && ts <= after + 60_000,
"expected expires in [{before}+60_000, {after}+60_000], got {ts}"
);
}
}