use std::collections::HashMap;
use std::collections::VecDeque;
use std::sync::Arc;
use std::time::{Duration, Instant};
use serde_json::Value;
use tokio::sync::Mutex;
use crate::state::State;
type MemoryStore = Arc<Mutex<(LruCache, Option<Duration>)>>;
#[derive(Debug, Clone)]
pub enum CacheBackend {
InMemory {
max_entries: usize,
},
#[cfg(feature = "redis-cache")]
Redis {
url: String,
},
}
impl Default for CacheBackend {
fn default() -> Self {
Self::InMemory { max_entries: 128 }
}
}
#[derive(Debug, Clone)]
pub struct NodeCachePolicy {
pub backend: CacheBackend,
pub ttl: Option<Duration>,
}
pub fn compute_cache_key(node_name: &str, input_state: &State) -> String {
let mut sorted_keys: Vec<&String> = input_state.keys().collect();
sorted_keys.sort();
let mut canonical = serde_json::Map::new();
for key in sorted_keys {
if let Some(value) = input_state.get(key) {
canonical.insert(key.clone(), value.clone());
}
}
let state_json = serde_json::to_string(&canonical).unwrap_or_default();
let input = format!("{node_name}{state_json}");
let hash = blake3::hash(input.as_bytes());
hash.to_hex().to_string()
}
#[derive(Debug)]
struct LruCache {
map: HashMap<String, (Value, Instant)>,
order: VecDeque<String>,
max_entries: usize,
}
impl LruCache {
fn new(max_entries: usize) -> Self {
Self {
map: HashMap::with_capacity(max_entries),
order: VecDeque::with_capacity(max_entries),
max_entries,
}
}
fn get(&mut self, key: &str, ttl: Option<Duration>) -> Option<Value> {
if let Some((value, inserted_at)) = self.map.get(key) {
if let Some(ttl) = ttl {
if inserted_at.elapsed() > ttl {
self.map.remove(key);
self.order.retain(|k| k != key);
return None;
}
}
let value = value.clone();
self.order.retain(|k| k != key);
self.order.push_back(key.to_string());
Some(value)
} else {
None
}
}
fn insert(&mut self, key: String, value: Value) {
if self.map.contains_key(&key) {
self.map.insert(key.clone(), (value, Instant::now()));
self.order.retain(|k| k != &key);
self.order.push_back(key);
} else {
if self.map.len() >= self.max_entries {
if let Some(evicted) = self.order.pop_front() {
self.map.remove(&evicted);
}
}
self.order.push_back(key.clone());
self.map.insert(key, (value, Instant::now()));
}
}
}
pub struct NodeCache {
memory: Option<MemoryStore>,
#[cfg(feature = "redis-cache")]
redis: Option<fred::clients::Client>,
}
impl NodeCache {
pub fn from_policy(policy: &NodeCachePolicy) -> Self {
match &policy.backend {
CacheBackend::InMemory { max_entries } => Self {
memory: Some(Arc::new(Mutex::new((LruCache::new(*max_entries), policy.ttl)))),
#[cfg(feature = "redis-cache")]
redis: None,
},
#[cfg(feature = "redis-cache")]
CacheBackend::Redis { url: _url } => Self {
memory: None,
redis: {
let config = fred::types::Config::from_url(_url).unwrap_or_default();
Some(fred::clients::Client::new(config, None, None, None))
},
},
}
}
pub async fn get(&self, key: &str) -> Option<Value> {
if let Some(memory) = &self.memory {
let mut guard = memory.lock().await;
let (lru, ttl) = &mut *guard;
return lru.get(key, *ttl);
}
#[cfg(feature = "redis-cache")]
if let Some(redis) = &self.redis {
use fred::interfaces::KeysInterface;
let result: Option<String> = redis.get(key).await.ok()?;
return result.and_then(|s| serde_json::from_str(&s).ok());
}
None
}
pub async fn set(&self, key: &str, value: Value, _ttl: Option<Duration>) {
if let Some(memory) = &self.memory {
let mut guard = memory.lock().await;
let (lru, _) = &mut *guard;
lru.insert(key.to_string(), value);
} else {
#[cfg(feature = "redis-cache")]
if let Some(redis) = &self.redis {
use fred::interfaces::KeysInterface;
let serialized = serde_json::to_string(&value).unwrap_or_default();
let expiration = _ttl.map(|d| fred::types::Expiration::EX(d.as_secs() as i64));
let _: std::result::Result<(), _> =
redis.set(key, serialized, expiration, None, false).await;
}
}
}
}
impl std::fmt::Debug for NodeCache {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NodeCache").field("has_memory", &self.memory.is_some()).finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_compute_cache_key_deterministic() {
let mut state = State::new();
state.insert("a".to_string(), json!(1));
state.insert("b".to_string(), json!("hello"));
let key1 = compute_cache_key("node1", &state);
let key2 = compute_cache_key("node1", &state);
assert_eq!(key1, key2);
assert_eq!(key1.len(), 64); }
#[test]
fn test_compute_cache_key_different_nodes() {
let state = State::new();
let key1 = compute_cache_key("node_a", &state);
let key2 = compute_cache_key("node_b", &state);
assert_ne!(key1, key2);
}
#[test]
fn test_compute_cache_key_different_state() {
let mut state1 = State::new();
state1.insert("x".to_string(), json!(1));
let mut state2 = State::new();
state2.insert("x".to_string(), json!(2));
let key1 = compute_cache_key("node", &state1);
let key2 = compute_cache_key("node", &state2);
assert_ne!(key1, key2);
}
#[test]
fn test_compute_cache_key_order_independent() {
let mut state1 = State::new();
state1.insert("a".to_string(), json!(1));
state1.insert("b".to_string(), json!(2));
let mut state2 = State::new();
state2.insert("b".to_string(), json!(2));
state2.insert("a".to_string(), json!(1));
let key1 = compute_cache_key("node", &state1);
let key2 = compute_cache_key("node", &state2);
assert_eq!(key1, key2);
}
#[tokio::test]
async fn test_node_cache_in_memory_basic() {
let policy =
NodeCachePolicy { backend: CacheBackend::InMemory { max_entries: 10 }, ttl: None };
let cache = NodeCache::from_policy(&policy);
assert!(cache.get("key1").await.is_none());
cache.set("key1", json!({"result": 42}), None).await;
assert_eq!(cache.get("key1").await, Some(json!({"result": 42})));
}
#[tokio::test]
async fn test_node_cache_lru_eviction() {
let policy =
NodeCachePolicy { backend: CacheBackend::InMemory { max_entries: 3 }, ttl: None };
let cache = NodeCache::from_policy(&policy);
cache.set("a", json!(1), None).await;
cache.set("b", json!(2), None).await;
cache.set("c", json!(3), None).await;
assert_eq!(cache.get("b").await, Some(json!(2)));
assert_eq!(cache.get("c").await, Some(json!(3)));
cache.set("d", json!(4), None).await;
assert!(cache.get("a").await.is_none());
assert_eq!(cache.get("b").await, Some(json!(2)));
assert_eq!(cache.get("c").await, Some(json!(3)));
assert_eq!(cache.get("d").await, Some(json!(4)));
}
#[tokio::test]
async fn test_node_cache_ttl_expiration() {
let policy = NodeCachePolicy {
backend: CacheBackend::InMemory { max_entries: 10 },
ttl: Some(Duration::from_millis(50)),
};
let cache = NodeCache::from_policy(&policy);
cache.set("key", json!("value"), Some(Duration::from_millis(50))).await;
assert_eq!(cache.get("key").await, Some(json!("value")));
tokio::time::sleep(Duration::from_millis(60)).await;
assert!(cache.get("key").await.is_none());
}
}