1use std::{collections::HashMap, path::Path, sync::Mutex};
2
3use crate::{
4 error::MemoryError,
5 index::VectorStore,
6 types::{Scope, ScopeFilter},
7};
8
9struct InMemoryState {
14 entries: HashMap<String, (Scope, Vec<f32>)>,
16 key_counter: u64,
18 key_map: HashMap<String, u64>,
20 commit_sha: Option<String>,
22}
23
24#[non_exhaustive]
33pub struct InMemoryStore {
34 state: Mutex<InMemoryState>,
35 dimensions: usize,
36 ready: bool,
37}
38
39impl InMemoryStore {
40 pub fn new(dimensions: usize) -> Self {
45 Self {
46 state: Mutex::new(InMemoryState {
47 entries: HashMap::new(),
48 key_counter: 0,
49 key_map: HashMap::new(),
50 commit_sha: None,
51 }),
52 dimensions,
53 ready: true,
54 }
55 }
56
57 pub fn with_ready(mut self, ready: bool) -> Self {
59 self.ready = ready;
60 self
61 }
62}
63
64fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
68 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
69 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
70 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
71 if norm_a == 0.0 || norm_b == 0.0 {
72 0.0
73 } else {
74 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0)
75 }
76}
77
78fn similarity_to_distance(sim: f32) -> f32 {
81 1.0 - sim
82}
83
84impl crate::index::sealed::Sealed for InMemoryStore {}
85
86impl VectorStore for InMemoryStore {
87 fn add(
92 &self,
93 scope: &Scope,
94 vector: &[f32],
95 qualified_name: String,
96 ) -> Result<u64, MemoryError> {
97 if vector.len() != self.dimensions {
98 return Err(MemoryError::InvalidInput {
99 reason: format!(
100 "expected {} dimensions, got {}",
101 self.dimensions,
102 vector.len()
103 ),
104 });
105 }
106 let mut state = self
107 .state
108 .lock()
109 .expect("lock poisoned — prior panic corrupted state");
110 let key = state.key_counter;
111 state.key_counter = state
112 .key_counter
113 .checked_add(1)
114 .expect("key space exhausted");
115 state
116 .entries
117 .insert(qualified_name.clone(), (scope.clone(), vector.to_vec()));
118 state.key_map.insert(qualified_name, key);
119 Ok(key)
120 }
121
122 fn remove(&self, _scope: &Scope, qualified_name: &str) -> Result<(), MemoryError> {
123 let mut state = self
124 .state
125 .lock()
126 .expect("lock poisoned — prior panic corrupted state");
127 state.entries.remove(qualified_name);
128 state.key_map.remove(qualified_name);
129 Ok(())
130 }
131
132 fn search(
133 &self,
134 filter: &ScopeFilter,
135 query: &[f32],
136 limit: usize,
137 ) -> Result<Vec<(u64, String, f32)>, MemoryError> {
138 if query.len() != self.dimensions {
139 return Err(MemoryError::InvalidInput {
140 reason: format!(
141 "expected {} dimensions, got {}",
142 self.dimensions,
143 query.len()
144 ),
145 });
146 }
147 let state = self
148 .state
149 .lock()
150 .expect("lock poisoned — prior panic corrupted state");
151
152 let mut candidates: Vec<(u64, String, f32)> = state
153 .entries
154 .iter()
155 .filter(|(_, (scope, _))| scope_matches(filter, scope))
156 .map(|(name, (_, vec))| {
157 let key = state
158 .key_map
159 .get(name)
160 .copied()
161 .expect("invariant: every entry has a key_map entry");
162 let sim = cosine_similarity(query, vec);
163 let dist = similarity_to_distance(sim);
164 (key, name.clone(), dist)
165 })
166 .collect();
167
168 candidates.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(std::cmp::Ordering::Equal));
170 candidates.truncate(limit);
171 Ok(candidates)
172 }
173
174 fn find_by_name(&self, qualified_name: &str) -> Option<u64> {
175 let state = self
176 .state
177 .lock()
178 .expect("lock poisoned — prior panic corrupted state");
179 state.key_map.get(qualified_name).copied()
180 }
181
182 fn save(&self, _dir: &Path) -> Result<(), MemoryError> {
183 Ok(())
185 }
186
187 fn is_ready(&self) -> bool {
188 self.ready
189 }
190
191 fn dimensions(&self) -> usize {
192 self.dimensions
193 }
194
195 fn commit_sha(&self) -> Option<String> {
196 let state = self
197 .state
198 .lock()
199 .expect("lock poisoned — prior panic corrupted state");
200 state.commit_sha.clone()
201 }
202
203 fn set_commit_sha(&self, sha: Option<&str>) {
204 let mut state = self
205 .state
206 .lock()
207 .expect("lock poisoned — prior panic corrupted state");
208 state.commit_sha = sha.map(|s| s.to_owned());
209 }
210}
211
212fn scope_matches(filter: &ScopeFilter, scope: &Scope) -> bool {
214 match filter {
215 ScopeFilter::All => true,
216 ScopeFilter::GlobalOnly => matches!(scope, Scope::Global),
217 ScopeFilter::ProjectAndGlobal(project_name) => match scope {
218 Scope::Global => true,
219 Scope::Project(p) => p == project_name,
220 },
221 }
222}
223
224#[cfg(test)]
229mod tests {
230 use super::*;
231 use crate::index::VectorStore;
232
233 fn make_store() -> InMemoryStore {
234 InMemoryStore::new(4)
235 }
236
237 fn vec_a() -> Vec<f32> {
238 vec![1.0, 0.0, 0.0, 0.0]
239 }
240
241 fn vec_b() -> Vec<f32> {
242 vec![0.0, 1.0, 0.0, 0.0]
243 }
244
245 fn vec_c() -> Vec<f32> {
246 vec![0.0, 0.0, 1.0, 0.0]
247 }
248
249 #[test]
251 fn tc02a_add_and_find_by_name() {
252 let store: &dyn VectorStore = &make_store();
253 store
254 .add(&Scope::Global, &vec_a(), "global/mem1".to_string())
255 .expect("add failed");
256 assert!(
257 store.find_by_name("global/mem1").is_some(),
258 "TC-02a: find_by_name should return Some after add"
259 );
260 }
261
262 #[test]
264 fn tc02b_remove_clears_entry() {
265 let store: &dyn VectorStore = &make_store();
266 store
267 .add(&Scope::Global, &vec_a(), "global/mem2".to_string())
268 .expect("add failed");
269 store
270 .remove(&Scope::Global, "global/mem2")
271 .expect("remove failed");
272 assert!(
273 store.find_by_name("global/mem2").is_none(),
274 "TC-02b: find_by_name should return None after remove"
275 );
276 }
277
278 #[test]
280 fn tc02c_search_global_only() {
281 let store: &dyn VectorStore = &make_store();
282 let proj = Scope::Project("p".to_string());
283
284 store
285 .add(&Scope::Global, &vec_a(), "global/g1".to_string())
286 .expect("add global");
287 store
288 .add(&proj, &vec_b(), "projects/p/p1".to_string())
289 .expect("add project");
290
291 let results = store
292 .search(&ScopeFilter::GlobalOnly, &vec_a(), 10)
293 .expect("search failed");
294 let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
295 assert!(names.contains(&"global/g1"), "should contain global");
296 assert!(
297 !names.contains(&"projects/p/p1"),
298 "should NOT contain project"
299 );
300 }
301
302 #[test]
304 fn tc02d_search_project_and_global() {
305 let store: &dyn VectorStore = &make_store();
306 let proj_a = Scope::Project("alpha".to_string());
307 let proj_b = Scope::Project("beta".to_string());
308
309 store
310 .add(&Scope::Global, &vec_a(), "global/g1".to_string())
311 .expect("add global");
312 store
313 .add(&proj_a, &vec_b(), "projects/alpha/a1".to_string())
314 .expect("add alpha");
315 store
316 .add(&proj_b, &vec_c(), "projects/beta/b1".to_string())
317 .expect("add beta");
318
319 let results = store
320 .search(
321 &ScopeFilter::ProjectAndGlobal("alpha".to_string()),
322 &vec_a(),
323 10,
324 )
325 .expect("search failed");
326 let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
327 assert!(names.contains(&"global/g1"), "should contain global");
328 assert!(names.contains(&"projects/alpha/a1"), "should contain alpha");
329 assert!(
330 !names.contains(&"projects/beta/b1"),
331 "should NOT contain beta"
332 );
333 }
334
335 #[test]
337 fn tc02e_search_all() {
338 let store: &dyn VectorStore = &make_store();
339 let proj = Scope::Project("foo".to_string());
340
341 store
342 .add(&Scope::Global, &vec_a(), "global/x".to_string())
343 .expect("add global");
344 store
345 .add(&proj, &vec_b(), "projects/foo/y".to_string())
346 .expect("add project");
347
348 let results = store
349 .search(&ScopeFilter::All, &vec_a(), 10)
350 .expect("search failed");
351 let names: Vec<&str> = results.iter().map(|(_, n, _)| n.as_str()).collect();
352 assert!(names.contains(&"global/x"), "all should include global");
353 assert!(
354 names.contains(&"projects/foo/y"),
355 "all should include project"
356 );
357 }
358
359 #[test]
363 fn tc05c_in_memory_store_returns_ok_variants() {
364 let store: &dyn VectorStore = &make_store();
365 let result = store.add(&Scope::Global, &vec_a(), "global/tc05c".to_string());
366 assert!(
367 result.is_ok(),
368 "TC-05c: add should return Ok, got: {:?}",
369 result
370 );
371 let result = store.search(&ScopeFilter::All, &vec_a(), 5);
372 assert!(result.is_ok(), "TC-05c: search should return Ok");
373 let result = store.remove(&Scope::Global, "global/tc05c");
374 assert!(result.is_ok(), "TC-05c: remove should return Ok");
375 }
376
377 #[test]
379 fn tc06b_in_memory_store_is_ready_default_true() {
380 let store = InMemoryStore::new(4);
381 assert!(
382 store.is_ready(),
383 "TC-06b: InMemoryStore::is_ready() should return true by default"
384 );
385 }
386
387 #[test]
388 fn tc06b_in_memory_store_is_ready_configured_false() {
389 let store = InMemoryStore::new(4).with_ready(false);
390 assert!(
391 !store.is_ready(),
392 "TC-06b: InMemoryStore::is_ready() should return false when configured so"
393 );
394 }
395
396 #[test]
397 fn in_memory_store_dimensions() {
398 let store = InMemoryStore::new(128);
399 assert_eq!(store.dimensions(), 128);
400 }
401
402 #[test]
403 fn in_memory_store_commit_sha_round_trip() {
404 let store: &dyn VectorStore = &InMemoryStore::new(4);
405 assert!(store.commit_sha().is_none());
406 store.set_commit_sha(Some("abc123"));
407 assert_eq!(store.commit_sha(), Some("abc123".to_string()));
408 store.set_commit_sha(None);
409 assert!(store.commit_sha().is_none());
410 }
411
412 #[test]
413 fn in_memory_store_save_is_noop() {
414 let store: &dyn VectorStore = &make_store();
415 let dir = tempfile::tempdir().expect("tempdir");
416 store
417 .add(&Scope::Global, &vec_a(), "global/save-test".to_string())
418 .expect("add");
419 let result = store.save(dir.path());
420 assert!(result.is_ok(), "save should be a no-op Ok");
421 }
422
423 #[test]
424 fn in_memory_store_search_results_sorted_by_distance() {
425 let store: &dyn VectorStore = &make_store();
426 store
428 .add(&Scope::Global, &vec_a(), "global/closest".to_string())
429 .expect("add a");
430 store
431 .add(&Scope::Global, &vec_b(), "global/farther".to_string())
432 .expect("add b");
433
434 let results = store
435 .search(&ScopeFilter::All, &vec_a(), 10)
436 .expect("search");
437 assert_eq!(results.len(), 2);
438 assert!(
440 results[0].2 <= results[1].2,
441 "results should be sorted by ascending distance"
442 );
443 assert_eq!(results[0].1, "global/closest");
444 }
445
446 #[test]
447 fn tc05c_in_memory_store_dimension_mismatch_returns_invalid_input() {
448 let store = InMemoryStore::new(4);
449 let wrong_dims = vec![1.0_f32, 0.0]; let err = store
451 .add(&Scope::Global, &wrong_dims, "global/bad-dims".to_string())
452 .unwrap_err();
453 assert!(
454 matches!(err, MemoryError::InvalidInput { .. }),
455 "TC-05c: dimension mismatch should return InvalidInput, got: {:?}",
456 err
457 );
458 }
459
460 #[test]
461 fn in_memory_store_upsert_overwrites() {
462 let store: &dyn VectorStore = &make_store();
463 let name = "global/upsert-me".to_string();
464 let key1 = store
465 .add(&Scope::Global, &vec_a(), name.clone())
466 .expect("first add");
467 let key2 = store
468 .add(&Scope::Global, &vec_b(), name.clone())
469 .expect("second add");
470 assert_ne!(key1, key2);
472 assert_eq!(store.find_by_name(&name), Some(key2));
474 }
475}