1use std::collections::HashMap;
24use std::sync::Arc;
25
26use async_trait::async_trait;
27use tokio::sync::RwLock;
28use uuid::Uuid;
29
30use crate::error::{VectorDBError, VectorDBResult};
31use crate::models::{SearchResult, VectorPoint};
32use crate::vector_db_trait::VectorDB;
33
34#[derive(Debug)]
35struct Collection {
36 dimension: usize,
37 points: Vec<VectorPoint>,
38}
39
40#[derive(Debug, Clone, Default)]
46pub struct BruteForceVectorDB {
47 collections: Arc<RwLock<HashMap<String, Collection>>>,
48}
49
50impl BruteForceVectorDB {
51 pub fn new() -> Self {
53 Self {
54 collections: Arc::new(RwLock::new(HashMap::new())),
55 }
56 }
57
58 fn key(data_type: &str, field_name: &str) -> String {
60 format!("{data_type}_{field_name}")
61 }
62}
63
64fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
69 debug_assert_eq!(a.len(), b.len(), "cosine_similarity inputs must match");
70 let mut dot = 0.0f32;
71 let mut na = 0.0f32;
72 let mut nb = 0.0f32;
73 for i in 0..a.len() {
74 dot += a[i] * b[i];
75 na += a[i] * a[i];
76 nb += b[i] * b[i];
77 }
78 let denom = (na.sqrt() * nb.sqrt()).max(f32::EPSILON);
79 dot / denom
80}
81
82#[async_trait]
83impl VectorDB for BruteForceVectorDB {
84 async fn create_collection(
85 &self,
86 data_type: &str,
87 field_name: &str,
88 dimension: usize,
89 ) -> VectorDBResult<()> {
90 let key = Self::key(data_type, field_name);
91 let mut g = self.collections.write().await;
92 if g.contains_key(&key) {
93 return Err(VectorDBError::CollectionExists(key));
94 }
95 g.insert(
96 key,
97 Collection {
98 dimension,
99 points: Vec::new(),
100 },
101 );
102 Ok(())
103 }
104
105 async fn has_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<bool> {
106 let g = self.collections.read().await;
107 Ok(g.contains_key(&Self::key(data_type, field_name)))
108 }
109
110 async fn index_points(
111 &self,
112 data_type: &str,
113 field_name: &str,
114 points: &[VectorPoint],
115 ) -> VectorDBResult<()> {
116 if points.is_empty() {
117 return Ok(());
118 }
119 let key = Self::key(data_type, field_name);
120 let mut g = self.collections.write().await;
121 let coll = g
122 .get_mut(&key)
123 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
124
125 for p in points {
127 if p.vector.len() != coll.dimension {
128 return Err(VectorDBError::DimensionMismatch {
129 collection: key.clone(),
130 expected: coll.dimension,
131 actual: p.vector.len(),
132 });
133 }
134 }
135
136 for p in points {
138 if let Some(existing) = coll.points.iter_mut().find(|x| x.id == p.id) {
139 *existing = p.clone();
140 } else {
141 coll.points.push(p.clone());
142 }
143 }
144 Ok(())
145 }
146
147 async fn search_similar(
148 &self,
149 data_type: &str,
150 field_name: &str,
151 query_vector: &[f32],
152 top_k: usize,
153 ) -> VectorDBResult<Vec<SearchResult>> {
154 let key = Self::key(data_type, field_name);
155 let mut scored: Vec<(Uuid, f32, HashMap<String, serde_json::Value>)> = {
159 let g = self.collections.read().await;
160 let coll = g
161 .get(&key)
162 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
163 if query_vector.len() != coll.dimension {
164 return Err(VectorDBError::DimensionMismatch {
165 collection: key.clone(),
166 expected: coll.dimension,
167 actual: query_vector.len(),
168 });
169 }
170 coll.points
171 .iter()
172 .map(|p| {
173 (
174 p.id,
175 cosine_similarity(&p.vector, query_vector),
176 p.metadata.clone(),
177 )
178 })
179 .collect()
180 };
181
182 scored.sort_by(|a, b| b.1.total_cmp(&a.1));
184 scored.truncate(top_k);
185 Ok(scored
186 .into_iter()
187 .map(|(id, score, metadata)| SearchResult {
188 id,
189 score,
190 metadata,
191 })
192 .collect())
193 }
194
195 async fn delete_collection(&self, data_type: &str, field_name: &str) -> VectorDBResult<()> {
196 let mut g = self.collections.write().await;
197 g.remove(&Self::key(data_type, field_name));
198 Ok(())
199 }
200
201 async fn delete_points(
202 &self,
203 data_type: &str,
204 field_name: &str,
205 point_ids: &[Uuid],
206 ) -> VectorDBResult<()> {
207 let key = Self::key(data_type, field_name);
208 let mut g = self.collections.write().await;
209 let coll = g
210 .get_mut(&key)
211 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
212 coll.points.retain(|p| !point_ids.contains(&p.id));
213 Ok(())
214 }
215
216 async fn collection_size(&self, data_type: &str, field_name: &str) -> VectorDBResult<usize> {
217 let key = Self::key(data_type, field_name);
218 let g = self.collections.read().await;
219 let coll = g
220 .get(&key)
221 .ok_or_else(|| VectorDBError::CollectionNotFound(key.clone()))?;
222 Ok(coll.points.len())
223 }
224
225 async fn list_collections(&self) -> VectorDBResult<Vec<(String, String)>> {
226 let g = self.collections.read().await;
227 let mut out: Vec<(String, String)> = g
228 .keys()
229 .filter_map(|k| {
230 k.split_once('_')
231 .map(|(a, b)| (a.to_string(), b.to_string()))
232 })
233 .collect();
234 out.sort();
235 Ok(out)
236 }
237}
238
239#[cfg(test)]
240#[allow(
241 clippy::unwrap_used,
242 clippy::expect_used,
243 reason = "test code — panics are acceptable"
244)]
245mod tests {
246 use super::*;
247 use std::collections::HashMap as Hm;
248 use uuid::Uuid;
249
250 fn point(id_seed: u128, v: Vec<f32>) -> VectorPoint {
251 VectorPoint {
252 id: Uuid::from_u128(id_seed),
253 vector: v,
254 metadata: Hm::new(),
255 }
256 }
257
258 #[tokio::test]
259 async fn create_then_has_collection() {
260 let db = BruteForceVectorDB::new();
261 assert!(!db.has_collection("T", "f").await.unwrap());
262 db.create_collection("T", "f", 4).await.unwrap();
263 assert!(db.has_collection("T", "f").await.unwrap());
264 }
265
266 #[tokio::test]
267 async fn create_duplicate_returns_exists() {
268 let db = BruteForceVectorDB::new();
269 db.create_collection("T", "f", 4).await.unwrap();
270 let err = db.create_collection("T", "f", 4).await.unwrap_err();
271 assert!(
272 matches!(err, VectorDBError::CollectionExists(ref k) if k == "T_f"),
273 "expected CollectionExists, got {err:?}",
274 );
275 }
276
277 #[tokio::test]
278 async fn index_dim_mismatch_returns_error() {
279 let db = BruteForceVectorDB::new();
280 db.create_collection("T", "f", 3).await.unwrap();
281 let p = point(1, vec![1.0, 2.0]); let err = db.index_points("T", "f", &[p]).await.unwrap_err();
283 assert!(
284 matches!(
285 err,
286 VectorDBError::DimensionMismatch {
287 expected: 3,
288 actual: 2,
289 ..
290 }
291 ),
292 "expected DimensionMismatch 3 vs 2, got {err:?}",
293 );
294 }
295
296 #[tokio::test]
297 async fn index_replaces_by_id() {
298 let db = BruteForceVectorDB::new();
299 db.create_collection("T", "f", 2).await.unwrap();
300 let p_v1 = point(1, vec![1.0, 0.0]);
301 let p_v2 = point(1, vec![0.0, 1.0]); db.index_points("T", "f", &[p_v1]).await.unwrap();
303 db.index_points("T", "f", &[p_v2]).await.unwrap();
304 assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
305
306 let results = db.search_similar("T", "f", &[0.0, 1.0], 1).await.unwrap();
308 assert_eq!(results.len(), 1);
309 assert!(
310 (results[0].score - 1.0).abs() < 1e-5,
311 "upserted vector should score 1.0, got {}",
312 results[0].score
313 );
314 }
315
316 #[tokio::test]
317 async fn search_ranks_descending() {
318 let db = BruteForceVectorDB::new();
319 db.create_collection("T", "f", 3).await.unwrap();
320 let a = point(1, vec![1.0, 0.0, 0.0]);
321 let b = point(2, vec![0.0, 1.0, 0.0]);
322 let c = point(3, vec![0.0, 0.0, 1.0]);
323 db.index_points("T", "f", &[a, b, c]).await.unwrap();
324
325 let results = db
326 .search_similar("T", "f", &[1.0, 0.0, 0.0], 3)
327 .await
328 .unwrap();
329 assert_eq!(results.len(), 3);
330 assert_eq!(results[0].id, Uuid::from_u128(1), "A should rank first");
331 assert!(results[0].score >= results[1].score);
334 assert!(results[1].score >= results[2].score);
335 assert!(
336 (results[0].score - 1.0).abs() < 1e-5,
337 "self-similarity should be ~1.0, got {}",
338 results[0].score
339 );
340 }
341
342 #[tokio::test]
343 async fn search_empty_collection_returns_empty() {
344 let db = BruteForceVectorDB::new();
345 db.create_collection("T", "f", 3).await.unwrap();
346 let results = db
347 .search_similar("T", "f", &[1.0, 0.0, 0.0], 5)
348 .await
349 .unwrap();
350 assert!(results.is_empty());
351 }
352
353 #[tokio::test]
354 async fn search_query_dim_mismatch_returns_error() {
355 let db = BruteForceVectorDB::new();
356 db.create_collection("T", "f", 3).await.unwrap();
357 let err = db
358 .search_similar("T", "f", &[1.0, 0.0], 5)
359 .await
360 .unwrap_err();
361 assert!(
362 matches!(
363 err,
364 VectorDBError::DimensionMismatch {
365 expected: 3,
366 actual: 2,
367 ..
368 }
369 ),
370 "expected DimensionMismatch, got {err:?}",
371 );
372 }
373
374 #[tokio::test]
375 async fn delete_points_removes_matching_ids() {
376 let db = BruteForceVectorDB::new();
377 db.create_collection("T", "f", 2).await.unwrap();
378 let a = point(1, vec![1.0, 0.0]);
379 let b = point(2, vec![0.0, 1.0]);
380 let c = point(3, vec![1.0, 1.0]);
381 db.index_points("T", "f", &[a, b, c]).await.unwrap();
382 db.delete_points("T", "f", &[Uuid::from_u128(1), Uuid::from_u128(3)])
383 .await
384 .unwrap();
385 assert_eq!(db.collection_size("T", "f").await.unwrap(), 1);
386 }
387
388 #[tokio::test]
389 async fn delete_collection_is_idempotent() {
390 let db = BruteForceVectorDB::new();
391 db.create_collection("T", "f", 2).await.unwrap();
392 db.delete_collection("T", "f").await.unwrap();
393 db.delete_collection("T", "f").await.unwrap();
395 assert!(!db.has_collection("T", "f").await.unwrap());
396 }
397
398 #[tokio::test]
399 async fn list_collections_returns_pairs() {
400 let db = BruteForceVectorDB::new();
401 let empty = db.list_collections().await.unwrap();
402 assert!(empty.is_empty());
403
404 db.create_collection("DocumentChunk", "text", 3)
405 .await
406 .unwrap();
407 db.create_collection("Entity", "name", 3).await.unwrap();
408
409 let pairs = db.list_collections().await.unwrap();
410 assert_eq!(pairs.len(), 2);
411 assert!(pairs.contains(&("DocumentChunk".to_string(), "text".to_string())));
412 assert!(pairs.contains(&("Entity".to_string(), "name".to_string())));
413 }
414
415 #[tokio::test]
416 async fn collection_size_after_upsert() {
417 let db = BruteForceVectorDB::new();
418 db.create_collection("T", "f", 2).await.unwrap();
419 assert_eq!(db.collection_size("T", "f").await.unwrap(), 0);
420 db.index_points(
421 "T",
422 "f",
423 &[point(1, vec![1.0, 0.0]), point(2, vec![0.0, 1.0])],
424 )
425 .await
426 .unwrap();
427 assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
428 db.index_points("T", "f", &[point(1, vec![0.5, 0.5])])
430 .await
431 .unwrap();
432 assert_eq!(db.collection_size("T", "f").await.unwrap(), 2);
433 }
434
435 #[tokio::test]
436 async fn collection_size_unknown_collection_errors() {
437 let db = BruteForceVectorDB::new();
438 let err = db.collection_size("T", "f").await.unwrap_err();
439 assert!(
440 matches!(err, VectorDBError::CollectionNotFound(_)),
441 "expected CollectionNotFound, got {err:?}",
442 );
443 }
444}