llm_memory_graph/storage/
cache.rs1use crate::types::{Edge, EdgeId, Node, NodeId};
7use moka::future::Cache;
8use std::time::Duration;
9
10#[derive(Clone)]
15pub struct StorageCache {
16 node_cache: Cache<NodeId, Node>,
18 edge_cache: Cache<EdgeId, Edge>,
20}
21
22impl StorageCache {
23 pub fn new() -> Self {
30 Self::with_capacity(10_000, 50_000)
31 }
32
33 pub fn with_capacity(node_capacity: u64, edge_capacity: u64) -> Self {
35 let node_cache = Cache::builder()
36 .max_capacity(node_capacity)
37 .time_to_live(Duration::from_secs(300)) .build();
39
40 let edge_cache = Cache::builder()
41 .max_capacity(edge_capacity)
42 .time_to_live(Duration::from_secs(300))
43 .build();
44
45 Self {
46 node_cache,
47 edge_cache,
48 }
49 }
50
51 pub fn with_ttl(ttl_secs: u64) -> Self {
53 let node_cache = Cache::builder()
54 .max_capacity(10_000)
55 .time_to_live(Duration::from_secs(ttl_secs))
56 .build();
57
58 let edge_cache = Cache::builder()
59 .max_capacity(50_000)
60 .time_to_live(Duration::from_secs(ttl_secs))
61 .build();
62
63 Self {
64 node_cache,
65 edge_cache,
66 }
67 }
68
69 pub async fn get_node(&self, id: &NodeId) -> Option<Node> {
71 self.node_cache.get(id).await
72 }
73
74 pub async fn insert_node(&self, id: NodeId, node: Node) {
76 self.node_cache.insert(id, node).await;
77 }
78
79 pub async fn invalidate_node(&self, id: &NodeId) {
81 self.node_cache.invalidate(id).await;
82 }
83
84 pub async fn get_edge(&self, id: &EdgeId) -> Option<Edge> {
86 self.edge_cache.get(id).await
87 }
88
89 pub async fn insert_edge(&self, id: EdgeId, edge: Edge) {
91 self.edge_cache.insert(id, edge).await;
92 }
93
94 pub async fn invalidate_edge(&self, id: &EdgeId) {
96 self.edge_cache.invalidate(id).await;
97 }
98
99 pub async fn stats(&self) -> CacheStats {
104 self.node_cache.run_pending_tasks().await;
106 self.edge_cache.run_pending_tasks().await;
107
108 CacheStats {
109 node_cache_size: self.node_cache.entry_count(),
110 edge_cache_size: self.edge_cache.entry_count(),
111 node_cache_hits: 0, node_cache_misses: 0,
113 edge_cache_hits: 0,
114 edge_cache_misses: 0,
115 }
116 }
117
118 pub fn clear(&self) {
120 self.node_cache.invalidate_all();
121 self.edge_cache.invalidate_all();
122 }
123}
124
125impl Default for StorageCache {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131#[derive(Debug, Clone)]
133pub struct CacheStats {
134 pub node_cache_size: u64,
136 pub edge_cache_size: u64,
138 pub node_cache_hits: u64,
140 pub node_cache_misses: u64,
142 pub edge_cache_hits: u64,
144 pub edge_cache_misses: u64,
146}
147
148impl CacheStats {
149 pub fn node_hit_rate(&self) -> f64 {
151 let total = self.node_cache_hits + self.node_cache_misses;
152 if total == 0 {
153 0.0
154 } else {
155 self.node_cache_hits as f64 / total as f64
156 }
157 }
158
159 pub fn edge_hit_rate(&self) -> f64 {
161 let total = self.edge_cache_hits + self.edge_cache_misses;
162 if total == 0 {
163 0.0
164 } else {
165 self.edge_cache_hits as f64 / total as f64
166 }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::*;
173 use crate::types::{ConversationSession, PromptNode, SessionId};
174
175 #[tokio::test]
176 async fn test_cache_creation() {
177 let cache = StorageCache::new();
178 let stats = cache.stats().await;
179
180 assert_eq!(stats.node_cache_size, 0);
181 assert_eq!(stats.edge_cache_size, 0);
182 }
183
184 #[tokio::test]
185 async fn test_node_cache() {
186 let cache = StorageCache::new();
187
188 let session = ConversationSession::new();
189 let node = Node::Session(session.clone());
190 let node_id = node.id();
191
192 assert!(cache.get_node(&node_id).await.is_none());
194
195 cache.insert_node(node_id, node.clone()).await;
197
198 let cached = cache.get_node(&node_id).await;
200 assert!(cached.is_some());
201 assert_eq!(cached.unwrap().id(), node_id);
202 }
203
204 #[tokio::test]
205 async fn test_node_cache_invalidation() {
206 let cache = StorageCache::new();
207
208 let session = ConversationSession::new();
209 let node = Node::Session(session);
210 let node_id = node.id();
211
212 cache.insert_node(node_id, node).await;
213 assert!(cache.get_node(&node_id).await.is_some());
214
215 cache.invalidate_node(&node_id).await;
216 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
218
219 }
222
223 #[tokio::test]
224 async fn test_cache_stats() {
225 let cache = StorageCache::new();
226
227 let session = ConversationSession::new();
228 let node = Node::Session(session);
229 let node_id = node.id();
230
231 let result = cache.get_node(&node_id).await;
233 assert!(result.is_none());
234
235 cache.insert_node(node_id, node.clone()).await;
237
238 let result = cache.get_node(&node_id).await;
240 assert!(result.is_some());
241
242 let stats = cache.stats().await;
244 assert_eq!(stats.node_cache_size, 1);
245 }
247
248 #[tokio::test]
249 async fn test_custom_capacity() {
250 let cache = StorageCache::with_capacity(100, 200);
251 let stats = cache.stats().await;
252
253 assert_eq!(stats.node_cache_size, 0);
255 }
256
257 #[tokio::test]
258 async fn test_concurrent_cache_access() {
259 let cache = StorageCache::new();
260 let cache_clone1 = cache.clone();
261 let cache_clone2 = cache.clone();
262
263 let session_id = SessionId::new();
264
265 let handle1 = tokio::spawn(async move {
266 for i in 0..50 {
267 let prompt = PromptNode::new(session_id, format!("Prompt {}", i));
268 let node = Node::Prompt(prompt.clone());
269 cache_clone1.insert_node(prompt.id, node).await;
270 }
271 });
272
273 let handle2 = tokio::spawn(async move {
274 for i in 50..100 {
275 let prompt = PromptNode::new(session_id, format!("Prompt {}", i));
276 let node = Node::Prompt(prompt.clone());
277 cache_clone2.insert_node(prompt.id, node).await;
278 }
279 });
280
281 handle1.await.unwrap();
282 handle2.await.unwrap();
283
284 let stats = cache.stats().await;
286 assert_eq!(stats.node_cache_size, 100);
287 }
288}