1use crate::vectordb::{SearchResult, VectorDbClient, VectorDbError, VectorPoint, WriteConsistency};
2use std::collections::HashMap;
3
4#[derive(Default)]
5pub struct MockVectorDbClient {
7 collections: std::sync::RwLock<HashMap<String, MockCollection>>,
8}
9
10#[derive(Default, Clone)]
11struct MockCollection {
12 vector_size: u64,
13 points: HashMap<u64, MockStoredPoint>,
14}
15
16#[derive(Clone)]
17struct MockStoredPoint {
18 vector: Vec<f32>,
19 tenant_id: u64,
20 context_hash: u64,
21 timestamp: i64,
22 storage_key: Option<String>,
23}
24
25impl MockVectorDbClient {
26 pub fn new() -> Self {
28 Self::default()
29 }
30
31 pub fn point_count(&self, collection: &str) -> Option<usize> {
33 self.collections
34 .read()
35 .ok()?
36 .get(collection)
37 .map(|c| c.points.len())
38 }
39
40 #[cfg(test)]
43 pub fn poison_lock(&self) {
44 use std::thread;
45
46 let collections_ptr = &self.collections as *const _ as usize;
47 let handle = thread::spawn(move || {
48 let collections: &std::sync::RwLock<HashMap<String, MockCollection>> =
50 unsafe { &*(collections_ptr as *const _) };
51 let _guard = collections.write().unwrap();
52 panic!("Intentional panic to poison lock for testing");
53 });
54 let _ = handle.join();
56 }
57}
58
59impl VectorDbClient for MockVectorDbClient {
61 async fn ensure_collection(&self, name: &str, vector_size: u64) -> Result<(), VectorDbError> {
62 let mut collections =
63 self.collections
64 .write()
65 .map_err(|_| VectorDbError::CreateCollectionFailed {
66 collection: name.to_string(),
67 message: "lock poisoned".to_string(),
68 })?;
69
70 collections
71 .entry(name.to_string())
72 .or_insert(MockCollection {
73 vector_size,
74 points: HashMap::new(),
75 });
76
77 Ok(())
78 }
79
80 async fn upsert_points(
81 &self,
82 collection: &str,
83 points: Vec<VectorPoint>,
84 _consistency: WriteConsistency,
85 ) -> Result<(), VectorDbError> {
86 let mut collections =
87 self.collections
88 .write()
89 .map_err(|_| VectorDbError::UpsertFailed {
90 collection: collection.to_string(),
91 message: "lock poisoned".to_string(),
92 })?;
93
94 let coll =
95 collections
96 .get_mut(collection)
97 .ok_or_else(|| VectorDbError::CollectionNotFound {
98 collection: collection.to_string(),
99 })?;
100
101 for point in points {
102 if point.vector.len() as u64 != coll.vector_size {
103 return Err(VectorDbError::InvalidDimension {
104 expected: coll.vector_size as usize,
105 actual: point.vector.len(),
106 });
107 }
108
109 coll.points.insert(
110 point.id,
111 MockStoredPoint {
112 vector: point.vector,
113 tenant_id: point.tenant_id,
114 context_hash: point.context_hash,
115 timestamp: point.timestamp,
116 storage_key: point.storage_key,
117 },
118 );
119 }
120
121 Ok(())
122 }
123
124 async fn search(
125 &self,
126 collection: &str,
127 query: Vec<f32>,
128 limit: u64,
129 tenant_filter: Option<u64>,
130 ) -> Result<Vec<SearchResult>, VectorDbError> {
131 let collections = self
132 .collections
133 .read()
134 .map_err(|_| VectorDbError::SearchFailed {
135 collection: collection.to_string(),
136 message: "lock poisoned".to_string(),
137 })?;
138
139 let coll =
140 collections
141 .get(collection)
142 .ok_or_else(|| VectorDbError::CollectionNotFound {
143 collection: collection.to_string(),
144 })?;
145
146 let mut results: Vec<SearchResult> = coll
147 .points
148 .iter()
149 .filter(|(_, p)| tenant_filter.is_none() || tenant_filter == Some(p.tenant_id))
150 .map(|(&id, p)| {
151 let score = cosine_similarity(&query, &p.vector);
152 SearchResult {
153 id,
154 score,
155 tenant_id: p.tenant_id,
156 context_hash: p.context_hash,
157 timestamp: p.timestamp,
158 storage_key: p.storage_key.clone(),
159 }
160 })
161 .collect();
162
163 results.sort_by(|a, b| {
164 b.score
165 .partial_cmp(&a.score)
166 .unwrap_or(std::cmp::Ordering::Equal)
167 });
168
169 results.truncate(limit as usize);
170 Ok(results)
171 }
172
173 async fn delete_points(&self, collection: &str, ids: Vec<u64>) -> Result<(), VectorDbError> {
174 let mut collections =
175 self.collections
176 .write()
177 .map_err(|_| VectorDbError::DeleteFailed {
178 collection: collection.to_string(),
179 message: "lock poisoned".to_string(),
180 })?;
181
182 let coll =
183 collections
184 .get_mut(collection)
185 .ok_or_else(|| VectorDbError::CollectionNotFound {
186 collection: collection.to_string(),
187 })?;
188
189 for id in ids {
190 coll.points.remove(&id);
191 }
192
193 Ok(())
194 }
195}
196
197pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
199 if a.len() != b.len() || a.is_empty() {
200 return 0.0;
201 }
202
203 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
204 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
205 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
206
207 if norm_a == 0.0 || norm_b == 0.0 {
208 0.0
209 } else {
210 dot_product / (norm_a * norm_b)
211 }
212}