1use std::collections::HashMap;
24use std::sync::Arc;
25
26use async_trait::async_trait;
27use tokio::sync::RwLock;
28use uuid::Uuid;
29
30use crate::error::{VectorDBError, VectorDBResult};
31use crate::models::{SearchResult, VectorPoint};
32use crate::vector_db_trait::VectorDB;
33
34#[derive(Debug)]
35struct Collection {
36 dimension: usize,
37 points: Vec<VectorPoint>,
38}
39
40#[derive(Debug, Clone, Default)]
46pub struct BruteForceVectorDB {
47 collections: Arc<RwLock<HashMap<String, Collection>>>,
48}
49
50impl BruteForceVectorDB {
51 pub fn new() -> Self {
53 Self {
54 collections: Arc::new(RwLock::new(HashMap::new())),
55 }
56 }
57
58 fn key(data_type: &str, field_name: &str) -> String {
60 format!("{data_type}_{field_name}")
61 }
62}
63
64fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
69 debug_assert_eq!(a.len(), b.len(), "cosine_similarity inputs must match");
70 let mut dot = 0.0f32;
71 let mut na = 0.0f32;
72 let mut nb = 0.0f32;
73 for i in 0..a.len() {
74 dot += a[i] * b[i];
75 na += a[i] * a[i];
76 nb += b[i] * b[i];
77 }
78 let denom = (na.sqrt() * nb.sqrt()).max(f32::EPSILON);
79 dot / denom
80}
81
82#[async_trait]
83impl VectorDB for BruteForceVectorDB {
84 async fn create_collection(
85 &self,
86 data_type: &str,
87 field_name: &str,
88 dimension: usize,
89 ) -> VectorDBResult<()> {
90 let key = Self::key(data_type, field_name);
91 let mut g = self.collections.write().await;
92 if g.contains_key(&key) {
93 return Err(VectorDBError::CollectionExists(key));
94 }
95 g.insert(
96 key,
97 Collection {
98 dimension,
99 points: Vec::new(),
100 },
101 );
102 Ok(())
103 }
104
105 async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
106 let g = self.collections.read().await;
107 Ok(g.contains_key(&Self::key(data_type, field_name)))
108 }
109
110 async fn index_points(
111 &self,
112 data_type: &str,
113 field_name: &str,
114 points: &[VectorPoint],
115 ) -> VectorDBResult<()> {
116 if points.is_empty() {
117 return Ok(());
118 }
119 let key = Self::key(data_type, field_name);
120 let mut g = self.collections.write().await;
121 let coll = g
122 .get_mut(&key)
123 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
124
125 for p in points {
127 if p.vector.len() != coll.dimension {
128 return Err(VectorDBError::DimensionMismatch {
129 collection: key.clone(),
130 expected: coll.dimension,
131 actual: p.vector.len(),
132 });
133 }
134 }
135
136 for p in points {
140 if let Some(existing) = coll.points.iter_mut().find(|x| x.id == p.id) {
141 let mut merged = p.clone();
142 merged.merge_dataset_membership(existing);
143 *existing = merged;
144 } else {
145 coll.points.push(p.clone());
146 }
147 }
148 Ok(())
149 }
150
151 async fn search_similar(
152 &self,
153 data_type: &str,
154 field_name: &str,
155 query_vector: &[f32],
156 top_k: usize,
157 ) -> VectorDBResult<Vec<SearchResult>> {
158 let key = Self::key(data_type, field_name);
159 let mut scored: Vec<(Uuid, f32, HashMap<String, serde_json::Value>)> = {
163 let g = self.collections.read().await;
164 let coll = g
165 .get(&key)
166 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
167 if query_vector.len() != coll.dimension {
168 return Err(VectorDBError::DimensionMismatch {
169 collection: key.clone(),
170 expected: coll.dimension,
171 actual: query_vector.len(),
172 });
173 }
174 coll.points
175 .iter()
176 .map(|p| {
177 (
178 p.id,
179 cosine_similarity(&p.vector, query_vector),
180 p.metadata.clone(),
181 )
182 })
183 .collect()
184 };
185
186 scored.sort_by(|a, b| b.1.total_cmp(&a.1));
188 scored.truncate(top_k);
189 Ok(scored
190 .into_iter()
191 .map(|(id, score, metadata)| SearchResult {
192 id,
193 score,
194 metadata,
195 })
196 .collect())
197 }
198
199 async fn delete_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<()> {
200 let mut g = self.collections.write().await;
201 g.remove(&Self::key(data_type, field_name));
202 Ok(())
203 }
204
205 async fn delete_points(
206 &self,
207 data_type: &str,
208 field_name: &str,
209 point_ids: &[Uuid],
210 ) -> VectorDBResult<()> {
211 let key = Self::key(data_type, field_name);
212 let mut g = self.collections.write().await;
213 let coll = g
214 .get_mut(&key)
215 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
216 coll.points.retain(|p| !point_ids.contains(&p.id));
217 Ok(())
218 }
219
220 async fn collection_size(&self, data_type: &str, field_name: &str) -> VectorDBResult<usize> {
221 let key = Self::key(data_type, field_name);
222 let g = self.collections.read().await;
223 let coll = g
224 .get(&key)
225 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
226 Ok(coll.points.len())
227 }
228
229 async fn list_collections(&self) -> VectorDBResult<Vec<(String, String)>> {
230 let g = self.collections.read().await;
231 let mut out: Vec<(String, String)> = g
232 .keys()
233 .filter_map(|k| {
234 k.split_once('_')
235 .map(|(a, b)| (a.to_string(), b.to_string()))
236 })
237 .collect();
238 out.sort();
239 Ok(out)
240 }
241}
242
243#[cfg(test)]
244#[allow(
245 clippy::unwrap_used,
246 clippy::expect_used,
247 reason = "test code — panics are acceptable"
248)]
249mod tests {
250 use super::*;
251 use std::collections::HashMap as Hm;
252 use uuid::Uuid;
253
254 fn point(id_seed: u128, v: Vec<f32>) -> VectorPoint {
255 VectorPoint {
256 id: Uuid::from_u128(id_seed),
257 vector: v,
258 metadata: Hm::new(),
259 }
260 }
261
262 #[tokio::test]
263 async fn create_then_has_collection() {
264 let db = BruteForceVectorDB::new();
265 assert!(!db.has_collection("T", "f").await.unwrap());
266 db.create_collection("T", "f", 4).await.unwrap();
267 assert!(db.has_collection("T", "f").await.unwrap());
268 }
269
270 #[tokio::test]
271 async fn create_duplicate_returns_exists() {
272 let db = BruteForceVectorDB::new();
273 db.create_collection("T", "f", 4).await.unwrap();
274 let err = db.create_collection("T", "f", 4).await.unwrap_err();
275 assert!(
276 matches!(err, VectorDBError::CollectionExists(ref k) if k == "T_f"),
277 "expected CollectionExists, got {err:?}",
278 );
279 }
280
281 #[tokio::test]
282 async fn index_dim_mismatch_returns_error() {
283 let db = BruteForceVectorDB::new();
284 db.create_collection("T", "f", 3).await.unwrap();
285 let p = point(1, vec![1.0, 2.0]); let err = db.index_points("T", "f", &[p]).await.unwrap_err();
287 assert!(
288 matches!(
289 err,
290 VectorDBError::DimensionMismatch {
291 expected: 3,
292 actual: 2,
293 ..
294 }
295 ),
296 "expected DimensionMismatch 3 vs 2, got {err:?}",
297 );
298 }
299
300 #[tokio::test]
301 async fn index_replaces_by_id() {
302 let db = BruteForceVectorDB::new();
303 db.create_collection("T", "f", 2).await.unwrap();
304 let p_v1 = point(1, vec![1.0, 0.0]);
305 let p_v2 = point(1, vec![0.0, 1.0]); db.index_points("T", "f", &[p_v1]).await.unwrap();
307 db.index_points("T", "f", &[p_v2]).await.unwrap();
308 assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
309
310 let results = db.search_similar("T", "f", &[0.0, 1.0], 1).await.unwrap();
312 assert_eq!(results.len(), 1);
313 assert!(
314 (results[0].score - 1.0).abs() < 1e-5,
315 "upserted vector should score 1.0, got {}",
316 results[0].score
317 );
318 }
319
320 #[tokio::test]
321 async fn search_ranks_descending() {
322 let db = BruteForceVectorDB::new();
323 db.create_collection("T", "f", 3).await.unwrap();
324 let a = point(1, vec![1.0, 0.0, 0.0]);
325 let b = point(2, vec![0.0, 1.0, 0.0]);
326 let c = point(3, vec![0.0, 0.0, 1.0]);
327 db.index_points("T", "f", &[a, b, c]).await.unwrap();
328
329 let results = db
330 .search_similar("T", "f", &[1.0, 0.0, 0.0], 3)
331 .await
332 .unwrap();
333 assert_eq!(results.len(), 3);
334 assert_eq!(results[0].id, Uuid::from_u128(1), "A should rank first");
335 assert!(results[0].score >= results[1].score);
338 assert!(results[1].score >= results[2].score);
339 assert!(
340 (results[0].score - 1.0).abs() < 1e-5,
341 "self-similarity should be ~1.0, got {}",
342 results[0].score
343 );
344 }
345
346 #[tokio::test]
347 async fn search_empty_collection_returns_empty() {
348 let db = BruteForceVectorDB::new();
349 db.create_collection("T", "f", 3).await.unwrap();
350 let results = db
351 .search_similar("T", "f", &[1.0, 0.0, 0.0], 5)
352 .await
353 .unwrap();
354 assert!(results.is_empty());
355 }
356
357 #[tokio::test]
358 async fn search_query_dim_mismatch_returns_error() {
359 let db = BruteForceVectorDB::new();
360 db.create_collection("T", "f", 3).await.unwrap();
361 let err = db
362 .search_similar("T", "f", &[1.0, 0.0], 5)
363 .await
364 .unwrap_err();
365 assert!(
366 matches!(
367 err,
368 VectorDBError::DimensionMismatch {
369 expected: 3,
370 actual: 2,
371 ..
372 }
373 ),
374 "expected DimensionMismatch, got {err:?}",
375 );
376 }
377
378 #[tokio::test]
379 async fn delete_points_removes_matching_ids() {
380 let db = BruteForceVectorDB::new();
381 db.create_collection("T", "f", 2).await.unwrap();
382 let a = point(1, vec![1.0, 0.0]);
383 let b = point(2, vec![0.0, 1.0]);
384 let c = point(3, vec![1.0, 1.0]);
385 db.index_points("T", "f", &[a, b, c]).await.unwrap();
386 db.delete_points("T", "f", &[Uuid::from_u128(1), Uuid::from_u128(3)])
387 .await
388 .unwrap();
389 assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
390 }
391
392 #[tokio::test]
393 async fn delete_collection_is_idempotent() {
394 let db = BruteForceVectorDB::new();
395 db.create_collection("T", "f", 2).await.unwrap();
396 db.delete_collection("T", "f").await.unwrap();
397 db.delete_collection("T", "f").await.unwrap();
399 assert!(!db.has_collection("T", "f").await.unwrap());
400 }
401
402 #[tokio::test]
403 async fn list_collections_returns_pairs() {
404 let db = BruteForceVectorDB::new();
405 let empty = db.list_collections().await.unwrap();
406 assert!(empty.is_empty());
407
408 db.create_collection("DocumentChunk", "text", 3)
409 .await
410 .unwrap();
411 db.create_collection("Entity", "name", 3).await.unwrap();
412
413 let pairs = db.list_collections().await.unwrap();
414 assert_eq!(pairs.len(), 2);
415 assert!(pairs.contains(&("DocumentChunk".to_string(), "text".to_string())));
416 assert!(pairs.contains(&("Entity".to_string(), "name".to_string())));
417 }
418
419 #[tokio::test]
420 async fn collection_size_after_upsert() {
421 let db = BruteForceVectorDB::new();
422 db.create_collection("T", "f", 2).await.unwrap();
423 assert_eq!(db.collection_size("T", "f").await.unwrap(), 0);
424 db.index_points(
425 "T",
426 "f",
427 &[point(1, vec![1.0, 0.0]), point(2, vec![0.0, 1.0])],
428 )
429 .await
430 .unwrap();
431 assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
432 db.index_points("T", "f", &[point(1, vec![0.5, 0.5])])
434 .await
435 .unwrap();
436 assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
437 }
438
439 #[tokio::test]
440 async fn collection_size_unknown_collection_errors() {
441 let db = BruteForceVectorDB::new();
442 let err = db.collection_size("T", "f").await.unwrap_err();
443 assert!(
444 matches!(err, VectorDBError::CollectionNotFound(_)),
445 "expected CollectionNotFound, got {err:?}",
446 );
447 }
448}