1use std::collections::HashMap;
7use std::sync::Mutex;
8use std::time::{Duration, Instant};
9
10use uuid::Uuid;
11
12use crate::model::memory::MemoryRecord;
13
14pub struct MemoryCache {
16 entries: Mutex<HashMap<Uuid, CacheEntry>>,
17 ttl: Duration,
18 max_entries: usize,
19}
20
21struct CacheEntry {
22 record: MemoryRecord,
23 inserted_at: Instant,
24}
25
26impl MemoryCache {
27 pub fn new(ttl_seconds: u64, max_entries: usize) -> Self {
29 Self {
30 entries: Mutex::new(HashMap::new()),
31 ttl: Duration::from_secs(ttl_seconds),
32 max_entries,
33 }
34 }
35
36 pub fn get(&self, id: Uuid) -> Option<MemoryRecord> {
38 let mut entries = self.entries.lock().unwrap_or_else(|e| e.into_inner());
39 if let Some(entry) = entries.get(&id) {
40 if entry.inserted_at.elapsed() < self.ttl {
41 return Some(entry.record.clone());
42 }
43 entries.remove(&id);
45 }
46 None
47 }
48
49 pub fn put(&self, record: MemoryRecord) {
51 let mut entries = self.entries.lock().unwrap_or_else(|e| e.into_inner());
52
53 if entries.len() >= self.max_entries {
55 let now = Instant::now();
56 entries.retain(|_, e| now.duration_since(e.inserted_at) < self.ttl);
57 }
58
59 if entries.len() >= self.max_entries
61 && let Some(&oldest_id) = entries
62 .iter()
63 .min_by_key(|(_, e)| e.inserted_at)
64 .map(|(id, _)| id)
65 {
66 entries.remove(&oldest_id);
67 }
68
69 if entries.len() >= self.max_entries && !entries.contains_key(&record.id) {
71 return;
72 }
73
74 entries.insert(
75 record.id,
76 CacheEntry {
77 record,
78 inserted_at: Instant::now(),
79 },
80 );
81 }
82
83 pub fn invalidate(&self, id: Uuid) {
85 self.entries
86 .lock()
87 .unwrap_or_else(|e| e.into_inner())
88 .remove(&id);
89 }
90
91 pub fn clear(&self) {
93 self.entries
94 .lock()
95 .unwrap_or_else(|e| e.into_inner())
96 .clear();
97 }
98
99 pub fn len(&self) -> usize {
101 self.entries.lock().unwrap_or_else(|e| e.into_inner()).len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.entries
107 .lock()
108 .unwrap_or_else(|e| e.into_inner())
109 .is_empty()
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 fn make_record(id: Uuid) -> MemoryRecord {
118 MemoryRecord {
119 id,
120 agent_id: "test".to_string(),
121 content: format!("content-{id}"),
122 memory_type: crate::model::memory::MemoryType::Episodic,
123 scope: crate::model::memory::Scope::Private,
124 importance: 0.5,
125 tags: vec![],
126 embedding: None,
127 metadata: serde_json::Value::Null,
128 source_type: crate::model::memory::SourceType::Agent,
129 source_id: None,
130 consolidation_state: crate::model::memory::ConsolidationState::Raw,
131 access_count: 0,
132 org_id: None,
133 thread_id: None,
134 content_hash: vec![],
135 prev_hash: None,
136 created_at: String::new(),
137 updated_at: String::new(),
138 last_accessed_at: None,
139 expires_at: None,
140 deleted_at: None,
141 decay_rate: None,
142 created_by: None,
143 version: 1,
144 prev_version_id: None,
145 quarantined: false,
146 quarantine_reason: None,
147 decay_function: None,
148 }
149 }
150
151 #[test]
152 fn test_cache_put_and_get() {
153 let cache = MemoryCache::new(60, 100);
154 let id = Uuid::now_v7();
155 let record = make_record(id);
156
157 cache.put(record.clone());
158 let cached = cache.get(id).unwrap();
159 assert_eq!(cached.id, id);
160 assert_eq!(cached.content, record.content);
161 }
162
163 #[test]
164 fn test_cache_miss() {
165 let cache = MemoryCache::new(60, 100);
166 assert!(cache.get(Uuid::now_v7()).is_none());
167 }
168
169 #[test]
170 fn test_cache_invalidate() {
171 let cache = MemoryCache::new(60, 100);
172 let id = Uuid::now_v7();
173 cache.put(make_record(id));
174 assert!(cache.get(id).is_some());
175
176 cache.invalidate(id);
177 assert!(cache.get(id).is_none());
178 }
179
180 #[test]
181 fn test_cache_max_entries() {
182 let cache = MemoryCache::new(60, 2);
183
184 let id1 = Uuid::now_v7();
185 let id2 = Uuid::now_v7();
186 let id3 = Uuid::now_v7();
187
188 cache.put(make_record(id1));
189 cache.put(make_record(id2));
190 assert_eq!(cache.len(), 2);
191
192 cache.put(make_record(id3));
193 assert_eq!(cache.len(), 2);
195 }
196}