use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use async_trait::async_trait;
use uuid::Uuid;
use crate::store::{RoutingStore, RoutingStoreError};
use crate::{Route, RoutingEngine};
pub const DEFAULT_TTL: Duration = Duration::from_secs(60);
#[derive(Debug)]
struct Cached {
engine: Arc<RoutingEngine>,
expires_at: Instant,
}
#[derive(Debug)]
pub struct CachingRoutingStore {
inner: Arc<dyn RoutingStore>,
ttl: Duration,
cache: tokio::sync::RwLock<HashMap<Uuid, Cached>>,
}
impl CachingRoutingStore {
pub fn new(inner: Arc<dyn RoutingStore>) -> Self {
Self::with_ttl(inner, DEFAULT_TTL)
}
pub fn with_ttl(inner: Arc<dyn RoutingStore>, ttl: Duration) -> Self {
Self {
inner,
ttl,
cache: tokio::sync::RwLock::new(HashMap::new()),
}
}
pub async fn engine_for(&self, org_id: Uuid) -> Result<Arc<RoutingEngine>, RoutingStoreError> {
{
let g = self.cache.read().await;
if let Some(entry) = g.get(&org_id) {
if entry.expires_at > Instant::now() {
return Ok(Arc::clone(&entry.engine));
}
}
}
let routes = self.inner.list_for_org(org_id).await?;
let engine = Arc::new(RoutingEngine::with_routes(routes));
let mut g = self.cache.write().await;
g.insert(
org_id,
Cached {
engine: Arc::clone(&engine),
expires_at: Instant::now() + self.ttl,
},
);
Ok(engine)
}
pub async fn invalidate(&self, org_id: Uuid) {
let mut g = self.cache.write().await;
g.remove(&org_id);
}
}
#[async_trait]
impl RoutingStore for CachingRoutingStore {
async fn list_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
let engine = self.engine_for(org_id).await?;
Ok(engine.routes().to_vec())
}
async fn list_all_for_org(&self, org_id: Uuid) -> Result<Vec<Route>, RoutingStoreError> {
self.inner.list_all_for_org(org_id).await
}
async fn create_route(
&self,
org_id: Uuid,
spec: crate::store::NewRoute,
) -> Result<Route, RoutingStoreError> {
let created = self.inner.create_route(org_id, spec).await?;
self.invalidate(org_id).await;
Ok(created)
}
async fn get_route(&self, org_id: Uuid, id: Uuid) -> Result<Option<Route>, RoutingStoreError> {
self.inner.get_route(org_id, id).await
}
async fn delete_route(&self, org_id: Uuid, id: Uuid) -> Result<bool, RoutingStoreError> {
let removed = self.inner.delete_route(org_id, id).await?;
if removed {
self.invalidate(org_id).await;
}
Ok(removed)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::store::InMemoryRoutingStore;
use crate::{RouteAction, RouteConditions};
fn route(name: &str, target: &str) -> Route {
Route {
id: Uuid::now_v7(),
name: name.into(),
priority: 10,
enabled: true,
when: RouteConditions::default(),
then: RouteAction {
target_model: target.into(),
fallbacks: Vec::new(),
disable_cache: false,
max_cost_usd: None,
},
}
}
#[tokio::test]
async fn caches_within_ttl() {
let backing = Arc::new(InMemoryRoutingStore::new());
let org = Uuid::now_v7();
backing.set_routes(org, vec![route("a", "m1")]);
let cache = CachingRoutingStore::with_ttl(
backing.clone() as Arc<dyn RoutingStore>,
Duration::from_secs(60),
);
let e1 = cache.engine_for(org).await.unwrap();
backing.set_routes(org, vec![route("b", "m2"), route("c", "m3")]);
let e2 = cache.engine_for(org).await.unwrap();
assert!(Arc::ptr_eq(&e1, &e2));
assert_eq!(e2.routes().len(), 1);
}
#[tokio::test]
async fn refreshes_after_ttl_expires() {
let backing = Arc::new(InMemoryRoutingStore::new());
let org = Uuid::now_v7();
backing.set_routes(org, vec![route("a", "m1")]);
let cache = CachingRoutingStore::with_ttl(
backing.clone() as Arc<dyn RoutingStore>,
Duration::from_millis(50),
);
let e1 = cache.engine_for(org).await.unwrap();
assert_eq!(e1.routes().len(), 1);
backing.set_routes(org, vec![route("b", "m2"), route("c", "m3")]);
tokio::time::sleep(Duration::from_millis(80)).await;
let e2 = cache.engine_for(org).await.unwrap();
assert_eq!(e2.routes().len(), 2);
}
#[tokio::test]
async fn invalidate_forces_refresh() {
let backing = Arc::new(InMemoryRoutingStore::new());
let org = Uuid::now_v7();
backing.set_routes(org, vec![route("a", "m1")]);
let cache = CachingRoutingStore::with_ttl(
backing.clone() as Arc<dyn RoutingStore>,
Duration::from_secs(3600),
);
let _ = cache.engine_for(org).await.unwrap();
backing.set_routes(org, vec![route("b", "m2")]);
cache.invalidate(org).await;
let e = cache.engine_for(org).await.unwrap();
assert_eq!(e.routes()[0].name, "b");
}
#[tokio::test]
async fn empty_org_caches_too() {
let backing = Arc::new(InMemoryRoutingStore::new());
let cache = CachingRoutingStore::with_ttl(
backing as Arc<dyn RoutingStore>,
Duration::from_secs(60),
);
let e = cache.engine_for(Uuid::now_v7()).await.unwrap();
assert!(e.routes().is_empty());
}
#[tokio::test]
async fn create_invalidates_so_engine_sees_it() {
let backing = Arc::new(InMemoryRoutingStore::new());
let org = Uuid::now_v7();
let cache = CachingRoutingStore::with_ttl(
backing as Arc<dyn RoutingStore>,
Duration::from_secs(3600), );
assert_eq!(cache.engine_for(org).await.unwrap().routes().len(), 0);
cache
.create_route(
org,
crate::store::NewRoute {
name: "x".into(),
priority: 10,
enabled: true,
when: RouteConditions::default(),
then: RouteAction {
target_model: "m".into(),
fallbacks: vec![],
disable_cache: false,
max_cost_usd: None,
},
},
)
.await
.unwrap();
assert_eq!(cache.engine_for(org).await.unwrap().routes().len(), 1);
}
}