rag_plusplus_core/store/
memory.rs1use crate::error::{Error, Result};
18use crate::store::traits::RecordStore;
19use crate::types::{MemoryRecord, RecordId, RecordStatus};
20use ahash::AHashMap;
21
22#[derive(Debug, Default)]
40pub struct InMemoryStore {
41 records: AHashMap<RecordId, MemoryRecord>,
43 estimated_bytes: usize,
45}
46
47impl InMemoryStore {
48 #[must_use]
50 pub fn new() -> Self {
51 Self {
52 records: AHashMap::new(),
53 estimated_bytes: 0,
54 }
55 }
56
57 #[must_use]
59 pub fn with_capacity(capacity: usize) -> Self {
60 Self {
61 records: AHashMap::with_capacity(capacity),
62 estimated_bytes: 0,
63 }
64 }
65
66 fn estimate_record_size(record: &MemoryRecord) -> usize {
68 let base = std::mem::size_of::<MemoryRecord>();
70
71 let id_size = record.id.capacity();
73 let context_size = record.context.capacity();
74
75 let embedding_size = record.embedding.capacity() * std::mem::size_of::<f32>();
77
78 let metadata_size = record.metadata.len() * 64; base + id_size + context_size + embedding_size + metadata_size
82 }
83
84 #[must_use]
86 pub fn stats(&self) -> StoreStats {
87 let active_count = self
88 .records
89 .values()
90 .filter(|r| r.status == RecordStatus::Active)
91 .count();
92
93 let total_outcomes: u64 = self.records.values().map(|r| r.stats.count()).sum();
94
95 StoreStats {
96 total_records: self.records.len(),
97 active_records: active_count,
98 total_outcome_updates: total_outcomes,
99 memory_bytes: self.estimated_bytes,
100 }
101 }
102
103 pub fn iter(&self) -> impl Iterator<Item = &MemoryRecord> {
105 self.records.values()
106 }
107
108 pub fn iter_active(&self) -> impl Iterator<Item = &MemoryRecord> {
110 self.records
111 .values()
112 .filter(|r| r.status == RecordStatus::Active)
113 }
114
115 fn get_mut(&mut self, id: &RecordId) -> Option<&mut MemoryRecord> {
117 self.records.get_mut(id)
118 }
119
120 #[must_use]
122 pub fn memory_bytes(&self) -> usize {
123 self.estimated_bytes
124 }
125}
126
127impl RecordStore for InMemoryStore {
128 fn insert(&mut self, record: MemoryRecord) -> Result<RecordId> {
129 let id = record.id.clone();
130
131 if self.records.contains_key(&id) {
132 return Err(Error::DuplicateRecord {
133 record_id: id.to_string(),
134 });
135 }
136
137 let size = Self::estimate_record_size(&record);
138 self.records.insert(id.clone(), record);
139 self.estimated_bytes += size;
140
141 Ok(id)
142 }
143
144 fn get(&self, id: &RecordId) -> Option<MemoryRecord> {
145 self.records.get(id).cloned()
146 }
147
148 fn contains(&self, id: &RecordId) -> bool {
149 self.records.contains_key(id)
150 }
151
152 fn update_stats(&mut self, id: &RecordId, outcome: f64) -> Result<()> {
153 let record = self.get_mut(id).ok_or_else(|| Error::RecordNotFound {
154 record_id: id.to_string(),
155 })?;
156
157 record.stats.update_scalar(outcome);
159
160 Ok(())
161 }
162
163 fn remove(&mut self, id: &RecordId) -> Result<bool> {
164 if let Some(record) = self.records.remove(id) {
165 let size = Self::estimate_record_size(&record);
166 self.estimated_bytes = self.estimated_bytes.saturating_sub(size);
167 Ok(true)
168 } else {
169 Ok(false)
170 }
171 }
172
173 fn len(&self) -> usize {
174 self.records.len()
175 }
176
177 fn clear(&mut self) {
178 self.records.clear();
179 self.estimated_bytes = 0;
180 }
181
182 fn ids(&self) -> Vec<RecordId> {
183 self.records.keys().cloned().collect()
184 }
185
186 fn memory_usage(&self) -> usize {
187 self.estimated_bytes
188 }
189}
190
191#[derive(Debug, Clone, Default)]
193pub struct StoreStats {
194 pub total_records: usize,
196 pub active_records: usize,
198 pub total_outcome_updates: u64,
200 pub memory_bytes: usize,
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207 use crate::store::traits::tests::*;
208 use crate::OutcomeStats;
209
210 fn create_test_record(id: &str) -> MemoryRecord {
211 MemoryRecord {
212 id: id.into(),
213 embedding: vec![1.0, 2.0, 3.0],
214 context: format!("Context for {id}"),
215 outcome: 0.5,
216 metadata: Default::default(),
217 created_at: 1234567890,
218 status: RecordStatus::Active,
219 stats: OutcomeStats::new(1),
220 }
221 }
222
223 #[test]
224 fn test_new_store() {
225 let store = InMemoryStore::new();
226 assert!(store.is_empty());
227 assert_eq!(store.len(), 0);
228 }
229
230 #[test]
231 fn test_with_capacity() {
232 let store = InMemoryStore::with_capacity(100);
233 assert!(store.is_empty());
234 }
235
236 #[test]
237 fn test_insert_and_get() {
238 let mut store = InMemoryStore::new();
239 let record = create_test_record("test-1");
240
241 let id = store.insert(record).unwrap();
242 assert_eq!(id.as_str(), "test-1");
243
244 let retrieved = store.get(&id).unwrap();
245 assert_eq!(retrieved.id.as_str(), "test-1");
246 assert_eq!(retrieved.context, "Context for test-1");
247 }
248
249 #[test]
250 fn test_duplicate_insert_error() {
251 let mut store = InMemoryStore::new();
252 let record = create_test_record("dup");
253
254 store.insert(record.clone()).unwrap();
255 let result = store.insert(record);
256
257 assert!(result.is_err());
258 }
259
260 #[test]
261 fn test_update_stats() {
262 let mut store = InMemoryStore::new();
263 store.insert(create_test_record("stats-test")).unwrap();
264
265 let id: RecordId = "stats-test".into();
266
267 store.update_stats(&id, 0.8).unwrap();
268 store.update_stats(&id, 0.9).unwrap();
269 store.update_stats(&id, 0.7).unwrap();
270
271 let record = store.get(&id).unwrap();
272 assert_eq!(record.stats.count(), 3);
273 assert!((record.stats.mean_scalar().unwrap() - 0.8).abs() < 0.01);
274 }
275
276 #[test]
277 fn test_update_stats_not_found() {
278 let mut store = InMemoryStore::new();
279 let result = store.update_stats(&"nonexistent".into(), 0.5);
280 assert!(result.is_err());
281 }
282
283 #[test]
284 fn test_remove() {
285 let mut store = InMemoryStore::new();
286 store.insert(create_test_record("to-remove")).unwrap();
287
288 assert_eq!(store.len(), 1);
289
290 let removed = store.remove(&"to-remove".into()).unwrap();
291 assert!(removed);
292 assert_eq!(store.len(), 0);
293
294 let removed_again = store.remove(&"to-remove".into()).unwrap();
295 assert!(!removed_again);
296 }
297
298 #[test]
299 fn test_iter() {
300 let mut store = InMemoryStore::new();
301
302 for i in 0..5 {
303 store.insert(create_test_record(&format!("iter-{i}"))).unwrap();
304 }
305
306 let count = store.iter().count();
307 assert_eq!(count, 5);
308 }
309
310 #[test]
311 fn test_iter_active() {
312 let mut store = InMemoryStore::new();
313
314 for i in 0..5 {
315 let mut record = create_test_record(&format!("active-{i}"));
316 if i % 2 == 0 {
317 record.status = RecordStatus::Archived;
318 }
319 store.insert(record).unwrap();
320 }
321
322 let active_count = store.iter_active().count();
323 assert_eq!(active_count, 2); }
325
326 #[test]
327 fn test_stats() {
328 let mut store = InMemoryStore::new();
329
330 for i in 0..10 {
331 store.insert(create_test_record(&format!("stat-{i}"))).unwrap();
332 }
333
334 store.update_stats(&"stat-0".into(), 0.5).unwrap();
335 store.update_stats(&"stat-0".into(), 0.6).unwrap();
336 store.update_stats(&"stat-1".into(), 0.7).unwrap();
337
338 let stats = store.stats();
339 assert_eq!(stats.total_records, 10);
340 assert_eq!(stats.active_records, 10);
341 assert_eq!(stats.total_outcome_updates, 3);
342 assert!(stats.memory_bytes > 0);
343 }
344
345 #[test]
346 fn test_memory_tracking() {
347 let mut store = InMemoryStore::new();
348
349 let initial = store.memory_usage();
350 assert_eq!(initial, 0);
351
352 store.insert(create_test_record("mem-1")).unwrap();
353 let after_one = store.memory_usage();
354 assert!(after_one > 0);
355
356 store.insert(create_test_record("mem-2")).unwrap();
357 let after_two = store.memory_usage();
358 assert!(after_two > after_one);
359
360 store.remove(&"mem-1".into()).unwrap();
361 let after_remove = store.memory_usage();
362 assert!(after_remove < after_two);
363 }
364
365 #[test]
366 fn test_clear() {
367 let mut store = InMemoryStore::new();
368
369 for i in 0..10 {
370 store.insert(create_test_record(&format!("clear-{i}"))).unwrap();
371 }
372
373 assert_eq!(store.len(), 10);
374 assert!(store.memory_usage() > 0);
375
376 store.clear();
377
378 assert!(store.is_empty());
379 assert_eq!(store.memory_usage(), 0);
380 }
381
382 #[test]
383 fn test_ids() {
384 let mut store = InMemoryStore::new();
385
386 for i in 0..5 {
387 store.insert(create_test_record(&format!("id-{i}"))).unwrap();
388 }
389
390 let ids = store.ids();
391 assert_eq!(ids.len(), 5);
392 }
393
394 #[test]
396 fn test_trait_basic_crud() {
397 let mut store = InMemoryStore::new();
398 test_basic_crud(&mut store);
399 }
400
401 #[test]
402 fn test_trait_batch_operations() {
403 let mut store = InMemoryStore::new();
404 test_batch_operations(&mut store);
405 }
406
407 #[test]
408 fn test_trait_duplicate_insert() {
409 let mut store = InMemoryStore::new();
410 test_duplicate_insert(&mut store);
411 }
412}