1#![allow(deprecated)]
17
18use crate::graph::GraphTree;
19use crate::{SearchLatency, VectorEntry, EMBEDDING_DIMENSION};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::time::Instant;
23
24pub const DEFAULT_SEARCH_LIMIT: usize = 10;
26
27pub const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.0;
29
30#[deprecated(
37 since = "0.1.0",
38 note = "Use SemanticSearch over storage-backed VectorEntry slices for runtime retrieval. VectorDatabase is an internal/test abstraction."
39)]
40#[derive(Debug, Default)]
41pub struct VectorDatabase {
42 vectors: HashMap<i64, VectorEntry>,
44
45 tree: GraphTree,
47
48 dimension: usize,
50
51 namespace_index: HashMap<i64, Vec<i64>>,
53
54 category_index: HashMap<String, Vec<i64>>,
56}
57
58#[deprecated(
63 since = "0.1.0",
64 note = "Use search::SearchResult instead. This type belongs to the deprecated VectorDatabase."
65)]
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct VectorSearchResult {
68 pub id: i64,
70
71 pub similarity: f32,
73
74 pub boosted_score: f32,
76}
77
78impl VectorDatabase {
79 pub fn new() -> Self {
81 Self {
82 vectors: HashMap::new(),
83 tree: GraphTree::new(),
84 dimension: EMBEDDING_DIMENSION,
85 namespace_index: HashMap::new(),
86 category_index: HashMap::new(),
87 }
88 }
89
90 pub fn with_dimension(dimension: usize) -> Self {
92 Self {
93 vectors: HashMap::new(),
94 tree: GraphTree::new(),
95 dimension,
96 namespace_index: HashMap::new(),
97 category_index: HashMap::new(),
98 }
99 }
100
101 pub async fn in_memory() -> crate::Result<Self> {
103 Ok(Self::new())
104 }
105
106 pub fn insert_with_priority(
108 &mut self,
109 entry: VectorEntry,
110 priority: Option<u8>,
111 ) -> crate::Result<()> {
112 if entry.embedding.len() != self.dimension {
113 return Err(nexus_core::NexusError::InvalidInput(format!(
114 "Vector dimension mismatch: expected {}, got {}",
115 self.dimension,
116 entry.embedding.len()
117 )));
118 }
119
120 let id = entry.id;
121 let namespace_id = entry.namespace_id;
122 let category = entry.category.clone();
123 let lane_type = entry.memory_lane_type.clone();
124
125 self.tree
127 .add_memory(id, &category, lane_type.as_deref(), priority);
128
129 self.namespace_index
131 .entry(namespace_id)
132 .or_default()
133 .push(id);
134
135 self.category_index.entry(category).or_default().push(id);
136
137 self.vectors.insert(id, entry);
139 Ok(())
140 }
141
142 pub fn insert(&mut self, entry: VectorEntry) -> crate::Result<()> {
144 self.insert_with_priority(entry, None)
145 }
146
147 pub fn get(&self, id: i64) -> Option<&VectorEntry> {
149 self.vectors.get(&id)
150 }
151
152 pub fn remove(&mut self, id: i64) -> Option<VectorEntry> {
154 if let Some(entry) = self.vectors.remove(&id) {
155 self.tree.remove_memory(id);
157
158 if let Some(ns_vec) = self.namespace_index.get_mut(&entry.namespace_id) {
160 ns_vec.retain(|&i| i != id);
161 }
162
163 if let Some(cat_vec) = self.category_index.get_mut(&entry.category) {
164 cat_vec.retain(|&i| i != id);
165 }
166
167 Some(entry)
168 } else {
169 None
170 }
171 }
172
173 pub fn ids(&self) -> Vec<i64> {
175 self.vectors.keys().copied().collect()
176 }
177
178 pub fn len(&self) -> usize {
180 self.vectors.len()
181 }
182
183 pub fn is_empty(&self) -> bool {
185 self.vectors.is_empty()
186 }
187
188 pub fn by_namespace(&self, namespace_id: i64) -> Vec<&VectorEntry> {
190 self.vectors
191 .values()
192 .filter(|v| v.namespace_id == namespace_id)
193 .collect()
194 }
195
196 pub fn by_category(&self, category: &str) -> Vec<&VectorEntry> {
198 self.vectors
199 .values()
200 .filter(|v| v.category == category)
201 .collect()
202 }
203
204 pub fn dimension(&self) -> usize {
206 self.dimension
207 }
208
209 pub fn tree(&self) -> &GraphTree {
211 &self.tree
212 }
213
214 pub fn tree_mut(&mut self) -> &mut GraphTree {
216 &mut self.tree
217 }
218
219 pub fn search(
224 &self,
225 query: &[f32],
226 namespace_id: i64,
227 limit: usize,
228 threshold: f32,
229 ) -> crate::Result<(Vec<VectorSearchResult>, SearchLatency)> {
230 let start = Instant::now();
231
232 if query.len() != self.dimension {
234 return Err(nexus_core::NexusError::InvalidInput(format!(
235 "Query dimension mismatch: expected {}, got {}",
236 self.dimension,
237 query.len()
238 )));
239 }
240
241 let candidate_ids = self
243 .namespace_index
244 .get(&namespace_id)
245 .map(|v| v.as_slice())
246 .unwrap_or(&[]);
247
248 let mut results: Vec<VectorSearchResult> = candidate_ids
250 .iter()
251 .filter_map(|&id| {
252 let entry = self.vectors.get(&id)?;
253 let similarity = cosine_similarity(query, &entry.embedding);
254
255 if similarity >= threshold {
256 let boosted_score = self.tree.calculate_boosted_score(id, similarity);
257 Some(VectorSearchResult {
258 id,
259 similarity,
260 boosted_score,
261 })
262 } else {
263 None
264 }
265 })
266 .collect();
267
268 results.sort_by(|a, b| {
270 b.boosted_score
271 .partial_cmp(&a.boosted_score)
272 .unwrap_or(std::cmp::Ordering::Equal)
273 });
274
275 results.truncate(limit);
277
278 let total_time = start.elapsed();
279
280 let latency = SearchLatency {
281 total_ms: total_time.as_millis() as u64,
282 vector_comparison_ms: total_time.as_millis() as u64,
283 graph_traversal_ms: None,
284 };
285
286 Ok((results, latency))
287 }
288
289 pub fn search_by_category(
291 &self,
292 query: &[f32],
293 namespace_id: i64,
294 category: &str,
295 limit: usize,
296 threshold: f32,
297 ) -> crate::Result<(Vec<VectorSearchResult>, SearchLatency)> {
298 let start = Instant::now();
299
300 if query.len() != self.dimension {
302 return Err(nexus_core::NexusError::InvalidInput(format!(
303 "Query dimension mismatch: expected {}, got {}",
304 self.dimension,
305 query.len()
306 )));
307 }
308
309 let category_ids: std::collections::HashSet<i64> = self
311 .category_index
312 .get(category)
313 .map(|v| v.iter().copied().collect())
314 .unwrap_or_default();
315
316 let namespace_ids: std::collections::HashSet<i64> = self
317 .namespace_index
318 .get(&namespace_id)
319 .map(|v| v.iter().copied().collect())
320 .unwrap_or_default();
321
322 let candidate_ids: Vec<i64> = category_ids.intersection(&namespace_ids).copied().collect();
324
325 let mut results: Vec<VectorSearchResult> = candidate_ids
327 .iter()
328 .filter_map(|&id| {
329 let entry = self.vectors.get(&id)?;
330 let similarity = cosine_similarity(query, &entry.embedding);
331
332 if similarity >= threshold {
333 let boosted_score = self.tree.calculate_boosted_score(id, similarity);
334 Some(VectorSearchResult {
335 id,
336 similarity,
337 boosted_score,
338 })
339 } else {
340 None
341 }
342 })
343 .collect();
344
345 results.sort_by(|a, b| {
347 b.boosted_score
348 .partial_cmp(&a.boosted_score)
349 .unwrap_or(std::cmp::Ordering::Equal)
350 });
351
352 results.truncate(limit);
354
355 let total_time = start.elapsed();
356
357 let latency = SearchLatency {
358 total_ms: total_time.as_millis() as u64,
359 vector_comparison_ms: total_time.as_millis() as u64,
360 graph_traversal_ms: None,
361 };
362
363 Ok((results, latency))
364 }
365
366 pub fn insert_batch(&mut self, entries: Vec<VectorEntry>) -> crate::Result<usize> {
370 let mut success_count = 0;
371 for entry in entries {
372 match self.insert(entry) {
373 Ok(()) => success_count += 1,
374 Err(_) => continue, }
376 }
377 Ok(success_count)
378 }
379
380 pub fn insert_batch_with_priorities(
382 &mut self,
383 entries: Vec<(VectorEntry, Option<u8>)>,
384 ) -> crate::Result<usize> {
385 let mut success_count = 0;
386 for (entry, priority) in entries {
387 match self.insert_with_priority(entry, priority) {
388 Ok(()) => success_count += 1,
389 Err(_) => continue,
390 }
391 }
392 Ok(success_count)
393 }
394
395 pub fn remove_batch(&mut self, ids: &[i64]) -> Vec<Option<VectorEntry>> {
397 ids.iter().map(|&id| self.remove(id)).collect()
398 }
399
400 pub fn search_batch(
402 &self,
403 queries: &[Vec<f32>],
404 namespace_id: i64,
405 limit: usize,
406 threshold: f32,
407 ) -> crate::Result<Vec<(Vec<VectorSearchResult>, SearchLatency)>> {
408 let mut results = Vec::with_capacity(queries.len());
409 for query in queries {
410 results.push(self.search(query, namespace_id, limit, threshold)?);
411 }
412 Ok(results)
413 }
414
415 pub fn find_similar(
417 &self,
418 memory_id: i64,
419 limit: usize,
420 threshold: f32,
421 ) -> crate::Result<(Vec<VectorSearchResult>, SearchLatency)> {
422 let start = Instant::now();
423
424 let entry = self
425 .vectors
426 .get(&memory_id)
427 .ok_or(nexus_core::NexusError::MemoryNotFound(memory_id))?;
428
429 let query = entry.embedding.clone();
430 let namespace_id = entry.namespace_id;
431
432 let (mut results, latency) = self.search(&query, namespace_id, limit + 1, threshold)?;
433
434 results.retain(|r| r.id != memory_id);
436 results.truncate(limit);
437
438 let total_time = start.elapsed();
439 let adjusted_latency = SearchLatency {
440 total_ms: total_time.as_millis() as u64,
441 vector_comparison_ms: latency.vector_comparison_ms,
442 graph_traversal_ms: latency.graph_traversal_ms,
443 };
444
445 Ok((results, adjusted_latency))
446 }
447
448 pub fn stats(&self) -> VectorDatabaseStats {
450 let mut category_counts = HashMap::new();
451 let mut namespace_counts = HashMap::new();
452
453 for entry in self.vectors.values() {
454 *category_counts.entry(entry.category.clone()).or_insert(0) += 1;
455 *namespace_counts.entry(entry.namespace_id).or_insert(0) += 1;
456 }
457
458 VectorDatabaseStats {
459 total_vectors: self.vectors.len(),
460 dimension: self.dimension,
461 category_counts,
462 namespace_counts,
463 tree_stats: self.tree.stats(),
464 }
465 }
466
467 pub fn clear(&mut self) {
469 self.vectors.clear();
470 self.namespace_index.clear();
471 self.category_index.clear();
472 self.tree = GraphTree::new();
473 }
474
475 pub fn contains(&self, id: i64) -> bool {
477 self.vectors.contains_key(&id)
478 }
479
480 pub fn all_vectors(&self) -> Vec<&VectorEntry> {
482 self.vectors.values().collect()
483 }
484
485 pub fn update_embedding(&mut self, id: i64, new_embedding: Vec<f32>) -> crate::Result<()> {
487 if new_embedding.len() != self.dimension {
488 return Err(nexus_core::NexusError::InvalidInput(format!(
489 "Vector dimension mismatch: expected {}, got {}",
490 self.dimension,
491 new_embedding.len()
492 )));
493 }
494
495 let entry = self
496 .vectors
497 .get_mut(&id)
498 .ok_or(nexus_core::NexusError::MemoryNotFound(id))?;
499
500 entry.embedding = new_embedding;
501 entry.created_at = chrono::Utc::now();
502 Ok(())
503 }
504}
505
506#[deprecated(since = "0.1.0", note = "Belongs to the deprecated VectorDatabase.")]
510#[derive(Debug, Clone, Serialize, Deserialize)]
511pub struct VectorDatabaseStats {
512 pub total_vectors: usize,
514 pub dimension: usize,
516 pub category_counts: HashMap<String, usize>,
518 pub namespace_counts: HashMap<i64, usize>,
520 pub tree_stats: crate::graph::TreeStats,
522}
523
524pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
526 if a.len() != b.len() || a.is_empty() {
527 return 0.0;
528 }
529
530 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
531 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
532 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
533
534 if norm_a == 0.0 || norm_b == 0.0 {
535 return 0.0;
536 }
537
538 (dot_product / (norm_a * norm_b)).clamp(-1.0, 1.0)
539}
540
541pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
543 if a.len() != b.len() {
544 return f32::MAX;
545 }
546
547 a.iter()
548 .zip(b.iter())
549 .map(|(x, y)| (x - y).powi(2))
550 .sum::<f32>()
551 .sqrt()
552}
553
554pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
556 if a.len() != b.len() {
557 return 0.0;
558 }
559
560 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
561}
562
563pub fn normalize_vector(v: &mut [f32]) {
565 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
566 if norm > 0.0 {
567 for x in v.iter_mut() {
568 *x /= norm;
569 }
570 }
571}
572
573pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]]) -> Vec<f32> {
575 vectors
576 .iter()
577 .map(|v| cosine_similarity(query, v))
578 .collect()
579}
580
581pub fn top_k_similar(
583 query: &[f32],
584 vectors: &[(i64, &[f32])],
585 k: usize,
586 threshold: f32,
587) -> Vec<(i64, f32)> {
588 let mut scored: Vec<(i64, f32)> = vectors
589 .iter()
590 .filter_map(|(id, vec)| {
591 let sim = cosine_similarity(query, vec);
592 if sim >= threshold {
593 Some((*id, sim))
594 } else {
595 None
596 }
597 })
598 .collect();
599
600 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
602 scored.truncate(k);
603 scored
604}
605
606#[cfg(test)]
607#[allow(deprecated)]
608mod tests {
609 use super::*;
610 use crate::VectorEntry;
611
612 fn create_test_entry(id: i64, namespace_id: i64) -> VectorEntry {
613 VectorEntry::new(
614 id,
615 vec![0.1; EMBEDDING_DIMENSION],
616 "general".to_string(),
617 namespace_id,
618 )
619 }
620
621 fn create_test_entry_with_embedding(id: i64, namespace_id: i64, value: f32) -> VectorEntry {
622 VectorEntry::new(
623 id,
624 vec![value; EMBEDDING_DIMENSION],
625 "general".to_string(),
626 namespace_id,
627 )
628 }
629
630 #[test]
631 fn test_insert_and_get() {
632 let mut db = VectorDatabase::new();
633 let entry = create_test_entry(1, 1);
634
635 db.insert(entry.clone()).unwrap();
636
637 assert!(db.get(1).is_some());
638 assert_eq!(db.len(), 1);
639 }
640
641 #[test]
642 fn test_remove() {
643 let mut db = VectorDatabase::new();
644 db.insert(create_test_entry(1, 1)).unwrap();
645
646 let removed = db.remove(1);
647
648 assert!(removed.is_some());
649 assert!(db.is_empty());
650 }
651
652 #[test]
653 fn test_dimension_mismatch() {
654 let mut db = VectorDatabase::new();
655 let bad_entry = VectorEntry::new(1, vec![0.1; 100], "general".to_string(), 1);
656
657 let result = db.insert(bad_entry);
658
659 assert!(result.is_err());
660 }
661
662 #[test]
663 fn test_by_namespace() {
664 let mut db = VectorDatabase::new();
665 db.insert(create_test_entry(1, 1)).unwrap();
666 db.insert(create_test_entry(2, 1)).unwrap();
667 db.insert(create_test_entry(3, 2)).unwrap();
668
669 let ns1 = db.by_namespace(1);
670 let ns2 = db.by_namespace(2);
671
672 assert_eq!(ns1.len(), 2);
673 assert_eq!(ns2.len(), 1);
674 }
675
676 #[test]
677 fn test_cosine_similarity_identical() {
678 let a = vec![0.5; EMBEDDING_DIMENSION];
679 let b = vec![0.5; EMBEDDING_DIMENSION];
680 let sim = cosine_similarity(&a, &b);
681 assert!((sim - 1.0).abs() < 0.001);
682 }
683
684 #[test]
685 fn test_cosine_similarity_orthogonal() {
686 let mut a = vec![0.0; EMBEDDING_DIMENSION];
687 let mut b = vec![0.0; EMBEDDING_DIMENSION];
688 for i in 0..EMBEDDING_DIMENSION {
689 if i < EMBEDDING_DIMENSION / 2 {
690 a[i] = 1.0;
691 } else {
692 b[i] = 1.0;
693 }
694 }
695 let sim = cosine_similarity(&a, &b);
696 assert!((sim - 0.0).abs() < 0.001);
697 }
698
699 #[test]
700 fn test_search_basic() {
701 let mut db = VectorDatabase::new();
702
703 db.insert(create_test_entry_with_embedding(1, 1, 0.5))
705 .unwrap();
706 db.insert(create_test_entry_with_embedding(2, 1, 0.51))
707 .unwrap();
708 db.insert(create_test_entry_with_embedding(3, 1, 0.1))
709 .unwrap();
710
711 let query = vec![0.5; EMBEDDING_DIMENSION];
712 let (results, latency) = db.search(&query, 1, 10, 0.0).unwrap();
713
714 assert_eq!(results.len(), 3);
715 assert!(results[0].similarity >= results[1].similarity);
717 println!("Search latency: {:?}", latency);
718 }
719
720 #[test]
721 fn test_search_with_threshold() {
722 let mut db = VectorDatabase::new();
723
724 let mut embedding1 = vec![0.5; EMBEDDING_DIMENSION];
727 embedding1[0] = 1.0; let mut embedding2 = vec![0.1; EMBEDDING_DIMENSION];
730 embedding2[0] = -1.0; let entry1 = VectorEntry::new(1, embedding1.clone(), "general".to_string(), 1);
733 let entry2 = VectorEntry::new(2, embedding2, "general".to_string(), 1);
734
735 db.insert(entry1).unwrap();
736 db.insert(entry2).unwrap();
737
738 let query = embedding1.clone();
739 let (results, _) = db.search(&query, 1, 10, 0.9).unwrap();
740
741 assert_eq!(results.len(), 1);
743 assert_eq!(results[0].id, 1);
744 }
745
746 #[test]
747 fn test_search_by_category() {
748 let mut db = VectorDatabase::new();
749
750 let entry1 = VectorEntry::new(1, vec![0.5; EMBEDDING_DIMENSION], "general".to_string(), 1);
751 let entry2 = VectorEntry::new(2, vec![0.5; EMBEDDING_DIMENSION], "facts".to_string(), 1);
752
753 db.insert(entry1).unwrap();
754 db.insert(entry2).unwrap();
755
756 let query = vec![0.5; EMBEDDING_DIMENSION];
757 let (results, _) = db
758 .search_by_category(&query, 1, "general", 10, 0.0)
759 .unwrap();
760
761 assert_eq!(results.len(), 1);
762 assert_eq!(results[0].id, 1);
763 }
764
765 #[test]
766 fn test_search_latency_target() {
767 let mut db = VectorDatabase::new();
768
769 for i in 0..1000 {
771 db.insert(create_test_entry_with_embedding(i, 1, 0.5))
772 .unwrap();
773 }
774
775 let query = vec![0.5; EMBEDDING_DIMENSION];
776 let (_, latency) = db.search(&query, 1, 10, 0.0).unwrap();
777
778 println!("Search latency: {:?}", latency);
780 assert!(
781 latency.total_ms < 100,
782 "Search took {}ms, expected <100ms",
783 latency.total_ms
784 );
785 }
786
787 #[test]
788 fn test_insert_with_priority() {
789 let mut db = VectorDatabase::new();
790
791 let entry = create_test_entry(1, 1);
792 db.insert_with_priority(entry, Some(1)).unwrap(); let tree = db.tree();
796 let node = tree.get(1);
797 assert!(node.is_some());
798 let node = node.unwrap();
799 assert!((node.weight - 1.5).abs() < 0.01); }
801
802 #[tokio::test]
803 async fn test_in_memory_creation() {
804 let db = VectorDatabase::in_memory().await.unwrap();
805 assert!(db.is_empty());
806 }
807
808 #[test]
809 fn test_batch_insert() {
810 let mut db = VectorDatabase::new();
811 let entries = vec![
812 create_test_entry(1, 1),
813 create_test_entry(2, 1),
814 create_test_entry(3, 1),
815 ];
816
817 let count = db.insert_batch(entries).unwrap();
818 assert_eq!(count, 3);
819 assert_eq!(db.len(), 3);
820 }
821
822 #[test]
823 fn test_batch_insert_with_invalid() {
824 let mut db = VectorDatabase::new();
825 let entries = vec![
826 create_test_entry(1, 1),
827 VectorEntry::new(2, vec![0.1; 100], "general".to_string(), 1), create_test_entry(3, 1),
829 ];
830
831 let count = db.insert_batch(entries).unwrap();
832 assert_eq!(count, 2); assert_eq!(db.len(), 2);
834 }
835
836 #[test]
837 fn test_batch_remove() {
838 let mut db = VectorDatabase::new();
839 db.insert(create_test_entry(1, 1)).unwrap();
840 db.insert(create_test_entry(2, 1)).unwrap();
841 db.insert(create_test_entry(3, 1)).unwrap();
842
843 let removed = db.remove_batch(&[1, 2, 999]);
844 assert_eq!(removed.len(), 3);
845 assert!(removed[0].is_some());
846 assert!(removed[1].is_some());
847 assert!(removed[2].is_none()); assert_eq!(db.len(), 1);
849 }
850
851 #[test]
852 fn test_find_similar() {
853 let mut db = VectorDatabase::new();
854
855 let mut e1 = vec![0.5; EMBEDDING_DIMENSION];
857 e1[0] = 1.0;
858
859 let mut e2 = vec![0.5; EMBEDDING_DIMENSION];
860 e2[0] = 0.95;
861
862 let mut e3 = vec![0.1; EMBEDDING_DIMENSION];
863 e3[0] = -1.0;
864
865 db.insert(VectorEntry::new(1, e1.clone(), "general".to_string(), 1))
866 .unwrap();
867 db.insert(VectorEntry::new(2, e2, "general".to_string(), 1))
868 .unwrap();
869 db.insert(VectorEntry::new(3, e3, "general".to_string(), 1))
870 .unwrap();
871
872 let (results, _) = db.find_similar(1, 10, 0.0).unwrap();
873
874 assert!(!results.iter().any(|r| r.id == 1));
876 assert_eq!(results[0].id, 2);
878 }
879
880 #[test]
881 fn test_stats() {
882 let mut db = VectorDatabase::new();
883 db.insert(VectorEntry::new(
884 1,
885 vec![0.1; EMBEDDING_DIMENSION],
886 "general".to_string(),
887 1,
888 ))
889 .unwrap();
890 db.insert(VectorEntry::new(
891 2,
892 vec![0.1; EMBEDDING_DIMENSION],
893 "general".to_string(),
894 1,
895 ))
896 .unwrap();
897 db.insert(VectorEntry::new(
898 3,
899 vec![0.1; EMBEDDING_DIMENSION],
900 "facts".to_string(),
901 2,
902 ))
903 .unwrap();
904
905 let stats = db.stats();
906 assert_eq!(stats.total_vectors, 3);
907 assert_eq!(stats.dimension, EMBEDDING_DIMENSION);
908 assert_eq!(*stats.category_counts.get("general").unwrap_or(&0), 2);
909 assert_eq!(*stats.category_counts.get("facts").unwrap_or(&0), 1);
910 }
911
912 #[test]
913 fn test_clear() {
914 let mut db = VectorDatabase::new();
915 db.insert(create_test_entry(1, 1)).unwrap();
916 db.insert(create_test_entry(2, 1)).unwrap();
917
918 db.clear();
919 assert!(db.is_empty());
920 }
921
922 #[test]
923 fn test_contains() {
924 let mut db = VectorDatabase::new();
925 db.insert(create_test_entry(1, 1)).unwrap();
926
927 assert!(db.contains(1));
928 assert!(!db.contains(2));
929 }
930
931 #[test]
932 fn test_update_embedding() {
933 let mut db = VectorDatabase::new();
934 db.insert(create_test_entry(1, 1)).unwrap();
935
936 let new_embedding = vec![0.9; EMBEDDING_DIMENSION];
937 db.update_embedding(1, new_embedding.clone()).unwrap();
938
939 let entry = db.get(1).unwrap();
940 assert_eq!(entry.embedding, new_embedding);
941 }
942
943 #[test]
944 fn test_update_embedding_nonexistent() {
945 let mut db = VectorDatabase::new();
946 let result = db.update_embedding(999, vec![0.1; EMBEDDING_DIMENSION]);
947 assert!(result.is_err());
948 }
949
950 #[test]
951 fn test_euclidean_distance() {
952 let a = vec![1.0, 0.0, 0.0];
953 let b = vec![0.0, 1.0, 0.0];
954 let dist = euclidean_distance(&a, &b);
955 assert!((dist - 2.0_f32.sqrt()).abs() < 0.001);
956 }
957
958 #[test]
959 fn test_dot_product() {
960 let a = vec![1.0, 2.0, 3.0];
961 let b = vec![4.0, 5.0, 6.0];
962 let prod = dot_product(&a, &b);
963 assert!((prod - 32.0).abs() < 0.001); }
965
966 #[test]
967 fn test_normalize_vector() {
968 let mut v = vec![3.0, 4.0];
969 normalize_vector(&mut v);
970 assert!((v[0] - 0.6).abs() < 0.001);
971 assert!((v[1] - 0.8).abs() < 0.001);
972 }
973
974 #[test]
975 fn test_top_k_similar() {
976 let query = vec![1.0, 0.0];
977 let vectors: Vec<(i64, &[f32])> = vec![
978 (1, &[1.0, 0.0]), (2, &[0.0, 1.0]), (3, &[0.707, 0.707]), (4, &[0.9, 0.1]), ];
983
984 let top_k = top_k_similar(&query, &vectors, 2, 0.0);
985 assert_eq!(top_k.len(), 2);
986 assert_eq!(top_k[0].0, 1); assert_eq!(top_k[1].0, 4); }
989}