use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use dashmap::DashMap;
use crate::connector::CacheConnectorConfig;
use crate::errors::OrionError;
#[async_trait]
pub trait CacheBackend: Send + Sync {
async fn get(&self, key: &str) -> Result<Option<String>, OrionError>;
async fn set(&self, key: &str, value: &str) -> Result<(), OrionError>;
async fn set_ex(&self, key: &str, value: &str, ttl_secs: u64) -> Result<(), OrionError>;
async fn check_and_insert(&self, key: &str, window_secs: u64) -> Result<bool, OrionError>;
}
struct MemoryEntry {
value: String,
expires_at: Option<Instant>,
}
pub struct MemoryCacheBackend {
entries: DashMap<String, MemoryEntry>,
}
impl MemoryCacheBackend {
pub fn new(cleanup_interval_secs: u64) -> Arc<Self> {
let store = Arc::new(Self {
entries: DashMap::new(),
});
let weak = Arc::downgrade(&store);
tokio::spawn(async move {
let interval = Duration::from_secs(cleanup_interval_secs.max(1));
loop {
tokio::time::sleep(interval).await;
let Some(store) = weak.upgrade() else {
break;
};
store.purge_expired();
}
});
store
}
fn purge_expired(&self) {
let now = Instant::now();
self.entries
.retain(|_, entry| entry.expires_at.is_none_or(|exp| exp > now));
}
}
#[async_trait]
impl CacheBackend for MemoryCacheBackend {
async fn get(&self, key: &str) -> Result<Option<String>, OrionError> {
let Some(entry) = self.entries.get(key) else {
return Ok(None);
};
if let Some(exp) = entry.expires_at
&& Instant::now() >= exp
{
drop(entry); self.entries.remove(key);
return Ok(None);
}
Ok(Some(entry.value.clone()))
}
async fn set(&self, key: &str, value: &str) -> Result<(), OrionError> {
self.entries.insert(
key.to_string(),
MemoryEntry {
value: value.to_string(),
expires_at: None,
},
);
Ok(())
}
async fn set_ex(&self, key: &str, value: &str, ttl_secs: u64) -> Result<(), OrionError> {
self.entries.insert(
key.to_string(),
MemoryEntry {
value: value.to_string(),
expires_at: Some(Instant::now() + Duration::from_secs(ttl_secs)),
},
);
Ok(())
}
async fn check_and_insert(&self, key: &str, window_secs: u64) -> Result<bool, OrionError> {
use dashmap::mapref::entry::Entry;
let now = Instant::now();
let expires_at = now + Duration::from_secs(window_secs);
match self.entries.entry(key.to_string()) {
Entry::Vacant(vacant) => {
vacant.insert(MemoryEntry {
value: "1".to_string(),
expires_at: Some(expires_at),
});
Ok(true) }
Entry::Occupied(mut occupied) => {
if let Some(exp) = occupied.get().expires_at
&& now >= exp
{
occupied.insert(MemoryEntry {
value: "1".to_string(),
expires_at: Some(expires_at),
});
return Ok(true);
}
Ok(false) }
}
}
}
pub struct RedisCacheBackend {
conn: redis::aio::MultiplexedConnection,
}
impl RedisCacheBackend {
pub fn new(conn: redis::aio::MultiplexedConnection) -> Self {
Self { conn }
}
}
#[async_trait]
impl CacheBackend for RedisCacheBackend {
async fn get(&self, key: &str) -> Result<Option<String>, OrionError> {
use redis::AsyncCommands;
let mut conn = self.conn.clone();
conn.get(key).await.map_err(|e| OrionError::InternalSource {
context: format!("Redis GET failed for key '{key}'"),
source: Box::new(e),
})
}
async fn set(&self, key: &str, value: &str) -> Result<(), OrionError> {
use redis::AsyncCommands;
let mut conn = self.conn.clone();
conn.set::<_, _, ()>(key, value)
.await
.map_err(|e| OrionError::InternalSource {
context: format!("Redis SET failed for key '{key}'"),
source: Box::new(e),
})
}
async fn set_ex(&self, key: &str, value: &str, ttl_secs: u64) -> Result<(), OrionError> {
use redis::AsyncCommands;
let mut conn = self.conn.clone();
conn.set_ex::<_, _, ()>(key, value, ttl_secs)
.await
.map_err(|e| OrionError::InternalSource {
context: format!("Redis SETEX failed for key '{key}'"),
source: Box::new(e),
})
}
async fn check_and_insert(&self, key: &str, window_secs: u64) -> Result<bool, OrionError> {
let mut conn = self.conn.clone();
let result: Option<String> = redis::cmd("SET")
.arg(key)
.arg("1")
.arg("NX")
.arg("EX")
.arg(window_secs)
.query_async(&mut conn)
.await
.map_err(|e| OrionError::InternalSource {
context: format!("Redis SET NX EX failed for key '{key}'"),
source: Box::new(e),
})?;
Ok(result.is_some())
}
}
pub struct CachePool {
memory: Arc<MemoryCacheBackend>,
redis: Arc<super::redis_pool::RedisPoolCache>,
}
impl CachePool {
pub fn new(max_redis_pool_entries: usize, cleanup_interval_secs: u64) -> Self {
Self {
memory: MemoryCacheBackend::new(cleanup_interval_secs),
redis: Arc::new(super::redis_pool::RedisPoolCache::new(
max_redis_pool_entries,
)),
}
}
pub async fn get_backend(
&self,
connector_name: &str,
config: &CacheConnectorConfig,
) -> Result<Arc<dyn CacheBackend>, OrionError> {
match config.backend.as_str() {
"memory" => Ok(self.memory.clone() as Arc<dyn CacheBackend>),
"redis" => {
let conn = self.redis.get_conn(connector_name, config).await?;
Ok(Arc::new(RedisCacheBackend::new(conn)))
}
other => Err(OrionError::BadRequest(format!(
"Unknown cache backend '{other}'. Must be 'redis' or 'memory'"
))),
}
}
pub fn memory(&self) -> Arc<dyn CacheBackend> {
self.memory.clone() as Arc<dyn CacheBackend>
}
pub async fn evict_pool(&self, connector_name: &str) {
self.redis.evict(connector_name).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_get_set() {
let backend = MemoryCacheBackend::new(60);
assert!(backend.get("k1").await.expect("test").is_none());
backend.set("k1", "v1").await.expect("test");
assert_eq!(
backend.get("k1").await.expect("test"),
Some("v1".to_string())
);
}
#[tokio::test]
async fn test_memory_set_ex_expires() {
let backend = MemoryCacheBackend::new(60);
backend.set_ex("k1", "v1", 1).await.expect("test");
assert_eq!(
backend.get("k1").await.expect("test"),
Some("v1".to_string())
);
tokio::time::sleep(Duration::from_secs(2)).await;
assert!(backend.get("k1").await.expect("test").is_none());
}
#[tokio::test]
async fn test_memory_check_and_insert_new() {
let backend = MemoryCacheBackend::new(60);
assert!(
backend
.check_and_insert("dedup-1", 300)
.await
.expect("test")
);
}
#[tokio::test]
async fn test_memory_check_and_insert_duplicate() {
let backend = MemoryCacheBackend::new(60);
assert!(
backend
.check_and_insert("dedup-1", 300)
.await
.expect("test")
);
assert!(
!backend
.check_and_insert("dedup-1", 300)
.await
.expect("test")
);
}
#[tokio::test]
async fn test_memory_check_and_insert_expired() {
let backend = MemoryCacheBackend::new(60);
assert!(backend.check_and_insert("k", 1).await.expect("test"));
tokio::time::sleep(Duration::from_secs(2)).await;
assert!(backend.check_and_insert("k", 1).await.expect("test"));
}
#[tokio::test]
async fn test_memory_purge_expired() {
let backend = MemoryCacheBackend::new(60);
backend.set_ex("keep", "val", 3600).await.expect("test");
backend.set_ex("expire", "val", 1).await.expect("test");
tokio::time::sleep(Duration::from_secs(2)).await;
backend.purge_expired();
assert!(backend.get("keep").await.expect("test").is_some());
assert!(backend.get("expire").await.expect("test").is_none());
}
#[tokio::test]
async fn test_memory_set_overwrites() {
let backend = MemoryCacheBackend::new(60);
backend.set("k", "v1").await.expect("test");
backend.set("k", "v2").await.expect("test");
assert_eq!(
backend.get("k").await.expect("test"),
Some("v2".to_string())
);
}
}