use crate::types::{Edge, EdgeId, Node, NodeId};
use moka::future::Cache;
use std::time::Duration;
#[derive(Clone)]
pub struct StorageCache {
node_cache: Cache<NodeId, Node>,
edge_cache: Cache<EdgeId, Edge>,
}
impl StorageCache {
pub fn new() -> Self {
Self::with_capacity(10_000, 50_000)
}
pub fn with_capacity(node_capacity: u64, edge_capacity: u64) -> Self {
let node_cache = Cache::builder()
.max_capacity(node_capacity)
.time_to_live(Duration::from_secs(300)) .build();
let edge_cache = Cache::builder()
.max_capacity(edge_capacity)
.time_to_live(Duration::from_secs(300))
.build();
Self {
node_cache,
edge_cache,
}
}
pub fn with_ttl(ttl_secs: u64) -> Self {
let node_cache = Cache::builder()
.max_capacity(10_000)
.time_to_live(Duration::from_secs(ttl_secs))
.build();
let edge_cache = Cache::builder()
.max_capacity(50_000)
.time_to_live(Duration::from_secs(ttl_secs))
.build();
Self {
node_cache,
edge_cache,
}
}
pub async fn get_node(&self, id: &NodeId) -> Option<Node> {
self.node_cache.get(id).await
}
pub async fn insert_node(&self, id: NodeId, node: Node) {
self.node_cache.insert(id, node).await;
}
pub async fn invalidate_node(&self, id: &NodeId) {
self.node_cache.invalidate(id).await;
}
pub async fn get_edge(&self, id: &EdgeId) -> Option<Edge> {
self.edge_cache.get(id).await
}
pub async fn insert_edge(&self, id: EdgeId, edge: Edge) {
self.edge_cache.insert(id, edge).await;
}
pub async fn invalidate_edge(&self, id: &EdgeId) {
self.edge_cache.invalidate(id).await;
}
pub async fn stats(&self) -> CacheStats {
self.node_cache.run_pending_tasks().await;
self.edge_cache.run_pending_tasks().await;
CacheStats {
node_cache_size: self.node_cache.entry_count(),
edge_cache_size: self.edge_cache.entry_count(),
node_cache_hits: 0, node_cache_misses: 0,
edge_cache_hits: 0,
edge_cache_misses: 0,
}
}
pub fn clear(&self) {
self.node_cache.invalidate_all();
self.edge_cache.invalidate_all();
}
}
impl Default for StorageCache {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub node_cache_size: u64,
pub edge_cache_size: u64,
pub node_cache_hits: u64,
pub node_cache_misses: u64,
pub edge_cache_hits: u64,
pub edge_cache_misses: u64,
}
impl CacheStats {
pub fn node_hit_rate(&self) -> f64 {
let total = self.node_cache_hits + self.node_cache_misses;
if total == 0 {
0.0
} else {
self.node_cache_hits as f64 / total as f64
}
}
pub fn edge_hit_rate(&self) -> f64 {
let total = self.edge_cache_hits + self.edge_cache_misses;
if total == 0 {
0.0
} else {
self.edge_cache_hits as f64 / total as f64
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{ConversationSession, PromptNode, SessionId};
#[tokio::test]
async fn test_cache_creation() {
let cache = StorageCache::new();
let stats = cache.stats().await;
assert_eq!(stats.node_cache_size, 0);
assert_eq!(stats.edge_cache_size, 0);
}
#[tokio::test]
async fn test_node_cache() {
let cache = StorageCache::new();
let session = ConversationSession::new();
let node = Node::Session(session.clone());
let node_id = node.id();
assert!(cache.get_node(&node_id).await.is_none());
cache.insert_node(node_id, node.clone()).await;
let cached = cache.get_node(&node_id).await;
assert!(cached.is_some());
assert_eq!(cached.unwrap().id(), node_id);
}
#[tokio::test]
async fn test_node_cache_invalidation() {
let cache = StorageCache::new();
let session = ConversationSession::new();
let node = Node::Session(session);
let node_id = node.id();
cache.insert_node(node_id, node).await;
assert!(cache.get_node(&node_id).await.is_some());
cache.invalidate_node(&node_id).await;
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
#[tokio::test]
async fn test_cache_stats() {
let cache = StorageCache::new();
let session = ConversationSession::new();
let node = Node::Session(session);
let node_id = node.id();
let result = cache.get_node(&node_id).await;
assert!(result.is_none());
cache.insert_node(node_id, node.clone()).await;
let result = cache.get_node(&node_id).await;
assert!(result.is_some());
let stats = cache.stats().await;
assert_eq!(stats.node_cache_size, 1);
}
#[tokio::test]
async fn test_custom_capacity() {
let cache = StorageCache::with_capacity(100, 200);
let stats = cache.stats().await;
assert_eq!(stats.node_cache_size, 0);
}
#[tokio::test]
async fn test_concurrent_cache_access() {
let cache = StorageCache::new();
let cache_clone1 = cache.clone();
let cache_clone2 = cache.clone();
let session_id = SessionId::new();
let handle1 = tokio::spawn(async move {
for i in 0..50 {
let prompt = PromptNode::new(session_id, format!("Prompt {}", i));
let node = Node::Prompt(prompt.clone());
cache_clone1.insert_node(prompt.id, node).await;
}
});
let handle2 = tokio::spawn(async move {
for i in 50..100 {
let prompt = PromptNode::new(session_id, format!("Prompt {}", i));
let node = Node::Prompt(prompt.clone());
cache_clone2.insert_node(prompt.id, node).await;
}
});
handle1.await.unwrap();
handle2.await.unwrap();
let stats = cache.stats().await;
assert_eq!(stats.node_cache_size, 100);
}
}