Skip to main content

nexus_memory_vectors/
database.rs

1//! Vector database implementation
2//!
3//! This module provides vector storage and retrieval with graph tree organization
4//! for efficient hierarchical memory management.
5//!
6//! ## Performance Targets
7//! - Search latency: <10ms for 1k vectors
8//! - Embedding dimension: 384 (all-MiniLM-L6-v2)
9//! - Cosine similarity for semantic search
10//!
11//! ## Status
12//! **This module is deprecated.** The live cognition path uses `SemanticSearch`
13//! (in `crate::search`) over storage-backed `VectorEntry` slices. `VectorDatabase`
14//! is an in-memory test/development abstraction not used by the shipped retrieval path.
15
16#![allow(deprecated)]
17
18use crate::graph::GraphTree;
19use crate::{SearchLatency, VectorEntry, EMBEDDING_DIMENSION};
20use serde::{Deserialize, Serialize};
21use std::collections::HashMap;
22use std::time::Instant;
23
24/// Default search limit
25pub const DEFAULT_SEARCH_LIMIT: usize = 10;
26
27/// Default similarity threshold
28pub const DEFAULT_SIMILARITY_THRESHOLD: f32 = 0.0;
29
30/// Vector database for storing and searching embeddings
31///
32/// **Deprecated**: `VectorDatabase` is an in-memory test/development abstraction.
33/// The live retrieval path uses `SemanticSearch` over storage-backed `VectorEntry`
34/// slices (see `crate::search::SemanticSearch`). `SemanticSearch` is the actual
35/// runtime retrieval path used by `RepresentationService`.
36#[deprecated(
37    since = "0.1.0",
38    note = "Use SemanticSearch over storage-backed VectorEntry slices for runtime retrieval. VectorDatabase is an internal/test abstraction."
39)]
40#[derive(Debug, Default)]
41pub struct VectorDatabase {
42    /// In-memory vector storage
43    vectors: HashMap<i64, VectorEntry>,
44
45    /// Graph tree for hierarchical organization
46    tree: GraphTree,
47
48    /// Embedding dimension
49    dimension: usize,
50
51    /// Namespace index for fast filtering
52    namespace_index: HashMap<i64, Vec<i64>>,
53
54    /// Category index for fast filtering
55    category_index: HashMap<String, Vec<i64>>,
56}
57
58/// Result of a vector search
59///
60/// **Deprecated**: This type is only used by the deprecated `VectorDatabase`.
61/// Runtime retrieval uses `SearchResult` from `crate::search`.
62#[deprecated(
63    since = "0.1.0",
64    note = "Use search::SearchResult instead. This type belongs to the deprecated VectorDatabase."
65)]
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct VectorSearchResult {
68    /// Memory ID
69    pub id: i64,
70
71    /// Similarity score (0.0 to 1.0)
72    pub similarity: f32,
73
74    /// Boosted score (after tree-based boosting)
75    pub boosted_score: f32,
76}
77
78impl VectorDatabase {
79    /// Create a new vector database
80    pub fn new() -> Self {
81        Self {
82            vectors: HashMap::new(),
83            tree: GraphTree::new(),
84            dimension: EMBEDDING_DIMENSION,
85            namespace_index: HashMap::new(),
86            category_index: HashMap::new(),
87        }
88    }
89
90    /// Create with custom dimension
91    pub fn with_dimension(dimension: usize) -> Self {
92        Self {
93            vectors: HashMap::new(),
94            tree: GraphTree::new(),
95            dimension,
96            namespace_index: HashMap::new(),
97            category_index: HashMap::new(),
98        }
99    }
100
101    /// Create a new in-memory database (async for API compatibility)
102    pub async fn in_memory() -> crate::Result<Self> {
103        Ok(Self::new())
104    }
105
106    /// Insert a vector with priority for graph tree
107    pub fn insert_with_priority(
108        &mut self,
109        entry: VectorEntry,
110        priority: Option<u8>,
111    ) -> crate::Result<()> {
112        if entry.embedding.len() != self.dimension {
113            return Err(nexus_core::NexusError::InvalidInput(format!(
114                "Vector dimension mismatch: expected {}, got {}",
115                self.dimension,
116                entry.embedding.len()
117            )));
118        }
119
120        let id = entry.id;
121        let namespace_id = entry.namespace_id;
122        let category = entry.category.clone();
123        let lane_type = entry.memory_lane_type.clone();
124
125        // Add to graph tree
126        self.tree
127            .add_memory(id, &category, lane_type.as_deref(), priority);
128
129        // Update indices
130        self.namespace_index
131            .entry(namespace_id)
132            .or_default()
133            .push(id);
134
135        self.category_index.entry(category).or_default().push(id);
136
137        // Store the entry
138        self.vectors.insert(id, entry);
139        Ok(())
140    }
141
142    /// Insert a vector (backward compatible)
143    pub fn insert(&mut self, entry: VectorEntry) -> crate::Result<()> {
144        self.insert_with_priority(entry, None)
145    }
146
147    /// Get a vector by ID
148    pub fn get(&self, id: i64) -> Option<&VectorEntry> {
149        self.vectors.get(&id)
150    }
151
152    /// Remove a vector by ID
153    pub fn remove(&mut self, id: i64) -> Option<VectorEntry> {
154        if let Some(entry) = self.vectors.remove(&id) {
155            // Remove from tree
156            self.tree.remove_memory(id);
157
158            // Remove from indices
159            if let Some(ns_vec) = self.namespace_index.get_mut(&entry.namespace_id) {
160                ns_vec.retain(|&i| i != id);
161            }
162
163            if let Some(cat_vec) = self.category_index.get_mut(&entry.category) {
164                cat_vec.retain(|&i| i != id);
165            }
166
167            Some(entry)
168        } else {
169            None
170        }
171    }
172
173    /// Get all vector IDs
174    pub fn ids(&self) -> Vec<i64> {
175        self.vectors.keys().copied().collect()
176    }
177
178    /// Get vector count
179    pub fn len(&self) -> usize {
180        self.vectors.len()
181    }
182
183    /// Check if empty
184    pub fn is_empty(&self) -> bool {
185        self.vectors.is_empty()
186    }
187
188    /// Get vectors by namespace
189    pub fn by_namespace(&self, namespace_id: i64) -> Vec<&VectorEntry> {
190        self.vectors
191            .values()
192            .filter(|v| v.namespace_id == namespace_id)
193            .collect()
194    }
195
196    /// Get vectors by category
197    pub fn by_category(&self, category: &str) -> Vec<&VectorEntry> {
198        self.vectors
199            .values()
200            .filter(|v| v.category == category)
201            .collect()
202    }
203
204    /// Get the embedding dimension
205    pub fn dimension(&self) -> usize {
206        self.dimension
207    }
208
209    /// Get reference to the graph tree
210    pub fn tree(&self) -> &GraphTree {
211        &self.tree
212    }
213
214    /// Get mutable reference to the graph tree
215    pub fn tree_mut(&mut self) -> &mut GraphTree {
216        &mut self.tree
217    }
218
219    /// Search for similar vectors
220    ///
221    /// Returns results sorted by boosted score (descending).
222    /// Target latency: <10ms
223    pub fn search(
224        &self,
225        query: &[f32],
226        namespace_id: i64,
227        limit: usize,
228        threshold: f32,
229    ) -> crate::Result<(Vec<VectorSearchResult>, SearchLatency)> {
230        let start = Instant::now();
231
232        // Validate query dimension
233        if query.len() != self.dimension {
234            return Err(nexus_core::NexusError::InvalidInput(format!(
235                "Query dimension mismatch: expected {}, got {}",
236                self.dimension,
237                query.len()
238            )));
239        }
240
241        // Get candidate IDs from namespace
242        let candidate_ids = self
243            .namespace_index
244            .get(&namespace_id)
245            .map(|v| v.as_slice())
246            .unwrap_or(&[]);
247
248        // Calculate similarities
249        let mut results: Vec<VectorSearchResult> = candidate_ids
250            .iter()
251            .filter_map(|&id| {
252                let entry = self.vectors.get(&id)?;
253                let similarity = cosine_similarity(query, &entry.embedding);
254
255                if similarity >= threshold {
256                    let boosted_score = self.tree.calculate_boosted_score(id, similarity);
257                    Some(VectorSearchResult {
258                        id,
259                        similarity,
260                        boosted_score,
261                    })
262                } else {
263                    None
264                }
265            })
266            .collect();
267
268        // Sort by boosted score (descending)
269        results.sort_by(|a, b| {
270            b.boosted_score
271                .partial_cmp(&a.boosted_score)
272                .unwrap_or(std::cmp::Ordering::Equal)
273        });
274
275        // Truncate to limit
276        results.truncate(limit);
277
278        let total_time = start.elapsed();
279
280        let latency = SearchLatency {
281            total_ms: total_time.as_millis() as u64,
282            vector_comparison_ms: total_time.as_millis() as u64,
283            graph_traversal_ms: None,
284        };
285
286        Ok((results, latency))
287    }
288
289    /// Search with category filter
290    pub fn search_by_category(
291        &self,
292        query: &[f32],
293        namespace_id: i64,
294        category: &str,
295        limit: usize,
296        threshold: f32,
297    ) -> crate::Result<(Vec<VectorSearchResult>, SearchLatency)> {
298        let start = Instant::now();
299
300        // Validate query dimension
301        if query.len() != self.dimension {
302            return Err(nexus_core::NexusError::InvalidInput(format!(
303                "Query dimension mismatch: expected {}, got {}",
304                self.dimension,
305                query.len()
306            )));
307        }
308
309        // Get candidates filtered by category
310        let category_ids: std::collections::HashSet<i64> = self
311            .category_index
312            .get(category)
313            .map(|v| v.iter().copied().collect())
314            .unwrap_or_default();
315
316        let namespace_ids: std::collections::HashSet<i64> = self
317            .namespace_index
318            .get(&namespace_id)
319            .map(|v| v.iter().copied().collect())
320            .unwrap_or_default();
321
322        // Intersect namespace and category
323        let candidate_ids: Vec<i64> = category_ids.intersection(&namespace_ids).copied().collect();
324
325        // Calculate similarities
326        let mut results: Vec<VectorSearchResult> = candidate_ids
327            .iter()
328            .filter_map(|&id| {
329                let entry = self.vectors.get(&id)?;
330                let similarity = cosine_similarity(query, &entry.embedding);
331
332                if similarity >= threshold {
333                    let boosted_score = self.tree.calculate_boosted_score(id, similarity);
334                    Some(VectorSearchResult {
335                        id,
336                        similarity,
337                        boosted_score,
338                    })
339                } else {
340                    None
341                }
342            })
343            .collect();
344
345        // Sort by boosted score (descending)
346        results.sort_by(|a, b| {
347            b.boosted_score
348                .partial_cmp(&a.boosted_score)
349                .unwrap_or(std::cmp::Ordering::Equal)
350        });
351
352        // Truncate to limit
353        results.truncate(limit);
354
355        let total_time = start.elapsed();
356
357        let latency = SearchLatency {
358            total_ms: total_time.as_millis() as u64,
359            vector_comparison_ms: total_time.as_millis() as u64,
360            graph_traversal_ms: None,
361        };
362
363        Ok((results, latency))
364    }
365
366    /// Batch insert multiple vectors
367    ///
368    /// More efficient than individual inserts for bulk loading.
369    pub fn insert_batch(&mut self, entries: Vec<VectorEntry>) -> crate::Result<usize> {
370        let mut success_count = 0;
371        for entry in entries {
372            match self.insert(entry) {
373                Ok(()) => success_count += 1,
374                Err(_) => continue, // Skip invalid entries
375            }
376        }
377        Ok(success_count)
378    }
379
380    /// Batch insert with priorities
381    pub fn insert_batch_with_priorities(
382        &mut self,
383        entries: Vec<(VectorEntry, Option<u8>)>,
384    ) -> crate::Result<usize> {
385        let mut success_count = 0;
386        for (entry, priority) in entries {
387            match self.insert_with_priority(entry, priority) {
388                Ok(()) => success_count += 1,
389                Err(_) => continue,
390            }
391        }
392        Ok(success_count)
393    }
394
395    /// Batch remove multiple vectors
396    pub fn remove_batch(&mut self, ids: &[i64]) -> Vec<Option<VectorEntry>> {
397        ids.iter().map(|&id| self.remove(id)).collect()
398    }
399
400    /// Search with multiple queries (for batch operations)
401    pub fn search_batch(
402        &self,
403        queries: &[Vec<f32>],
404        namespace_id: i64,
405        limit: usize,
406        threshold: f32,
407    ) -> crate::Result<Vec<(Vec<VectorSearchResult>, SearchLatency)>> {
408        let mut results = Vec::with_capacity(queries.len());
409        for query in queries {
410            results.push(self.search(query, namespace_id, limit, threshold)?);
411        }
412        Ok(results)
413    }
414
415    /// Find similar vectors to a stored vector by ID
416    pub fn find_similar(
417        &self,
418        memory_id: i64,
419        limit: usize,
420        threshold: f32,
421    ) -> crate::Result<(Vec<VectorSearchResult>, SearchLatency)> {
422        let start = Instant::now();
423
424        let entry = self
425            .vectors
426            .get(&memory_id)
427            .ok_or(nexus_core::NexusError::MemoryNotFound(memory_id))?;
428
429        let query = entry.embedding.clone();
430        let namespace_id = entry.namespace_id;
431
432        let (mut results, latency) = self.search(&query, namespace_id, limit + 1, threshold)?;
433
434        // Remove the query vector itself from results
435        results.retain(|r| r.id != memory_id);
436        results.truncate(limit);
437
438        let total_time = start.elapsed();
439        let adjusted_latency = SearchLatency {
440            total_ms: total_time.as_millis() as u64,
441            vector_comparison_ms: latency.vector_comparison_ms,
442            graph_traversal_ms: latency.graph_traversal_ms,
443        };
444
445        Ok((results, adjusted_latency))
446    }
447
448    /// Get statistics about the vector database
449    pub fn stats(&self) -> VectorDatabaseStats {
450        let mut category_counts = HashMap::new();
451        let mut namespace_counts = HashMap::new();
452
453        for entry in self.vectors.values() {
454            *category_counts.entry(entry.category.clone()).or_insert(0) += 1;
455            *namespace_counts.entry(entry.namespace_id).or_insert(0) += 1;
456        }
457
458        VectorDatabaseStats {
459            total_vectors: self.vectors.len(),
460            dimension: self.dimension,
461            category_counts,
462            namespace_counts,
463            tree_stats: self.tree.stats(),
464        }
465    }
466
467    /// Clear all vectors from the database
468    pub fn clear(&mut self) {
469        self.vectors.clear();
470        self.namespace_index.clear();
471        self.category_index.clear();
472        self.tree = GraphTree::new();
473    }
474
475    /// Check if a vector exists
476    pub fn contains(&self, id: i64) -> bool {
477        self.vectors.contains_key(&id)
478    }
479
480    /// Get all vectors as a slice (read-only access)
481    pub fn all_vectors(&self) -> Vec<&VectorEntry> {
482        self.vectors.values().collect()
483    }
484
485    /// Update an existing vector's embedding
486    pub fn update_embedding(&mut self, id: i64, new_embedding: Vec<f32>) -> crate::Result<()> {
487        if new_embedding.len() != self.dimension {
488            return Err(nexus_core::NexusError::InvalidInput(format!(
489                "Vector dimension mismatch: expected {}, got {}",
490                self.dimension,
491                new_embedding.len()
492            )));
493        }
494
495        let entry = self
496            .vectors
497            .get_mut(&id)
498            .ok_or(nexus_core::NexusError::MemoryNotFound(id))?;
499
500        entry.embedding = new_embedding;
501        entry.created_at = chrono::Utc::now();
502        Ok(())
503    }
504}
505
506/// Statistics about the vector database
507///
508/// **Deprecated**: This type belongs to the deprecated `VectorDatabase`.
509#[deprecated(since = "0.1.0", note = "Belongs to the deprecated VectorDatabase.")]
510#[derive(Debug, Clone, Serialize, Deserialize)]
511pub struct VectorDatabaseStats {
512    /// Total number of vectors
513    pub total_vectors: usize,
514    /// Embedding dimension
515    pub dimension: usize,
516    /// Count by category
517    pub category_counts: HashMap<String, usize>,
518    /// Count by namespace
519    pub namespace_counts: HashMap<i64, usize>,
520    /// Graph tree statistics
521    pub tree_stats: crate::graph::TreeStats,
522}
523
524/// Calculate cosine similarity between two vectors
525pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
526    if a.len() != b.len() || a.is_empty() {
527        return 0.0;
528    }
529
530    let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
531    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
532    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
533
534    if norm_a == 0.0 || norm_b == 0.0 {
535        return 0.0;
536    }
537
538    (dot_product / (norm_a * norm_b)).clamp(-1.0, 1.0)
539}
540
541/// Calculate euclidean distance between two vectors
542pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
543    if a.len() != b.len() {
544        return f32::MAX;
545    }
546
547    a.iter()
548        .zip(b.iter())
549        .map(|(x, y)| (x - y).powi(2))
550        .sum::<f32>()
551        .sqrt()
552}
553
554/// Calculate dot product between two vectors
555pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
556    if a.len() != b.len() {
557        return 0.0;
558    }
559
560    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
561}
562
563/// Normalize a vector in place
564pub fn normalize_vector(v: &mut [f32]) {
565    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
566    if norm > 0.0 {
567        for x in v.iter_mut() {
568            *x /= norm;
569        }
570    }
571}
572
573/// Batch similarity computation using SIMD-friendly approach
574pub fn batch_cosine_similarity(query: &[f32], vectors: &[&[f32]]) -> Vec<f32> {
575    vectors
576        .iter()
577        .map(|v| cosine_similarity(query, v))
578        .collect()
579}
580
581/// Find top-k similar vectors
582pub fn top_k_similar(
583    query: &[f32],
584    vectors: &[(i64, &[f32])],
585    k: usize,
586    threshold: f32,
587) -> Vec<(i64, f32)> {
588    let mut scored: Vec<(i64, f32)> = vectors
589        .iter()
590        .filter_map(|(id, vec)| {
591            let sim = cosine_similarity(query, vec);
592            if sim >= threshold {
593                Some((*id, sim))
594            } else {
595                None
596            }
597        })
598        .collect();
599
600    // Partial sort for top-k
601    scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
602    scored.truncate(k);
603    scored
604}
605
606#[cfg(test)]
607#[allow(deprecated)]
608mod tests {
609    use super::*;
610    use crate::VectorEntry;
611
612    fn create_test_entry(id: i64, namespace_id: i64) -> VectorEntry {
613        VectorEntry::new(
614            id,
615            vec![0.1; EMBEDDING_DIMENSION],
616            "general".to_string(),
617            namespace_id,
618        )
619    }
620
621    fn create_test_entry_with_embedding(id: i64, namespace_id: i64, value: f32) -> VectorEntry {
622        VectorEntry::new(
623            id,
624            vec![value; EMBEDDING_DIMENSION],
625            "general".to_string(),
626            namespace_id,
627        )
628    }
629
630    #[test]
631    fn test_insert_and_get() {
632        let mut db = VectorDatabase::new();
633        let entry = create_test_entry(1, 1);
634
635        db.insert(entry.clone()).unwrap();
636
637        assert!(db.get(1).is_some());
638        assert_eq!(db.len(), 1);
639    }
640
641    #[test]
642    fn test_remove() {
643        let mut db = VectorDatabase::new();
644        db.insert(create_test_entry(1, 1)).unwrap();
645
646        let removed = db.remove(1);
647
648        assert!(removed.is_some());
649        assert!(db.is_empty());
650    }
651
652    #[test]
653    fn test_dimension_mismatch() {
654        let mut db = VectorDatabase::new();
655        let bad_entry = VectorEntry::new(1, vec![0.1; 100], "general".to_string(), 1);
656
657        let result = db.insert(bad_entry);
658
659        assert!(result.is_err());
660    }
661
662    #[test]
663    fn test_by_namespace() {
664        let mut db = VectorDatabase::new();
665        db.insert(create_test_entry(1, 1)).unwrap();
666        db.insert(create_test_entry(2, 1)).unwrap();
667        db.insert(create_test_entry(3, 2)).unwrap();
668
669        let ns1 = db.by_namespace(1);
670        let ns2 = db.by_namespace(2);
671
672        assert_eq!(ns1.len(), 2);
673        assert_eq!(ns2.len(), 1);
674    }
675
676    #[test]
677    fn test_cosine_similarity_identical() {
678        let a = vec![0.5; EMBEDDING_DIMENSION];
679        let b = vec![0.5; EMBEDDING_DIMENSION];
680        let sim = cosine_similarity(&a, &b);
681        assert!((sim - 1.0).abs() < 0.001);
682    }
683
684    #[test]
685    fn test_cosine_similarity_orthogonal() {
686        let mut a = vec![0.0; EMBEDDING_DIMENSION];
687        let mut b = vec![0.0; EMBEDDING_DIMENSION];
688        for i in 0..EMBEDDING_DIMENSION {
689            if i < EMBEDDING_DIMENSION / 2 {
690                a[i] = 1.0;
691            } else {
692                b[i] = 1.0;
693            }
694        }
695        let sim = cosine_similarity(&a, &b);
696        assert!((sim - 0.0).abs() < 0.001);
697    }
698
699    #[test]
700    fn test_search_basic() {
701        let mut db = VectorDatabase::new();
702
703        // Store similar vectors
704        db.insert(create_test_entry_with_embedding(1, 1, 0.5))
705            .unwrap();
706        db.insert(create_test_entry_with_embedding(2, 1, 0.51))
707            .unwrap();
708        db.insert(create_test_entry_with_embedding(3, 1, 0.1))
709            .unwrap();
710
711        let query = vec![0.5; EMBEDDING_DIMENSION];
712        let (results, latency) = db.search(&query, 1, 10, 0.0).unwrap();
713
714        assert_eq!(results.len(), 3);
715        // Most similar should be first
716        assert!(results[0].similarity >= results[1].similarity);
717        println!("Search latency: {:?}", latency);
718    }
719
720    #[test]
721    fn test_search_with_threshold() {
722        let mut db = VectorDatabase::new();
723
724        // Create embeddings with different directions (not just different magnitudes)
725        // Uniform vectors always have cosine similarity of 1.0, so we need to vary the direction
726        let mut embedding1 = vec![0.5; EMBEDDING_DIMENSION];
727        embedding1[0] = 1.0; // Different direction
728
729        let mut embedding2 = vec![0.1; EMBEDDING_DIMENSION];
730        embedding2[0] = -1.0; // Opposite direction in first dimension
731
732        let entry1 = VectorEntry::new(1, embedding1.clone(), "general".to_string(), 1);
733        let entry2 = VectorEntry::new(2, embedding2, "general".to_string(), 1);
734
735        db.insert(entry1).unwrap();
736        db.insert(entry2).unwrap();
737
738        let query = embedding1.clone();
739        let (results, _) = db.search(&query, 1, 10, 0.9).unwrap();
740
741        // Only entry1 should match with high threshold since entry2 has opposite first dimension
742        assert_eq!(results.len(), 1);
743        assert_eq!(results[0].id, 1);
744    }
745
746    #[test]
747    fn test_search_by_category() {
748        let mut db = VectorDatabase::new();
749
750        let entry1 = VectorEntry::new(1, vec![0.5; EMBEDDING_DIMENSION], "general".to_string(), 1);
751        let entry2 = VectorEntry::new(2, vec![0.5; EMBEDDING_DIMENSION], "facts".to_string(), 1);
752
753        db.insert(entry1).unwrap();
754        db.insert(entry2).unwrap();
755
756        let query = vec![0.5; EMBEDDING_DIMENSION];
757        let (results, _) = db
758            .search_by_category(&query, 1, "general", 10, 0.0)
759            .unwrap();
760
761        assert_eq!(results.len(), 1);
762        assert_eq!(results[0].id, 1);
763    }
764
765    #[test]
766    fn test_search_latency_target() {
767        let mut db = VectorDatabase::new();
768
769        // Add many vectors
770        for i in 0..1000 {
771            db.insert(create_test_entry_with_embedding(i, 1, 0.5))
772                .unwrap();
773        }
774
775        let query = vec![0.5; EMBEDDING_DIMENSION];
776        let (_, latency) = db.search(&query, 1, 10, 0.0).unwrap();
777
778        // Should be very fast for in-memory search
779        println!("Search latency: {:?}", latency);
780        assert!(
781            latency.total_ms < 100,
782            "Search took {}ms, expected <100ms",
783            latency.total_ms
784        );
785    }
786
787    #[test]
788    fn test_insert_with_priority() {
789        let mut db = VectorDatabase::new();
790
791        let entry = create_test_entry(1, 1);
792        db.insert_with_priority(entry, Some(1)).unwrap(); // High priority
793
794        // Check tree has the node with boosted weight
795        let tree = db.tree();
796        let node = tree.get(1);
797        assert!(node.is_some());
798        let node = node.unwrap();
799        assert!((node.weight - 1.5).abs() < 0.01); // High priority weight
800    }
801
802    #[tokio::test]
803    async fn test_in_memory_creation() {
804        let db = VectorDatabase::in_memory().await.unwrap();
805        assert!(db.is_empty());
806    }
807
808    #[test]
809    fn test_batch_insert() {
810        let mut db = VectorDatabase::new();
811        let entries = vec![
812            create_test_entry(1, 1),
813            create_test_entry(2, 1),
814            create_test_entry(3, 1),
815        ];
816
817        let count = db.insert_batch(entries).unwrap();
818        assert_eq!(count, 3);
819        assert_eq!(db.len(), 3);
820    }
821
822    #[test]
823    fn test_batch_insert_with_invalid() {
824        let mut db = VectorDatabase::new();
825        let entries = vec![
826            create_test_entry(1, 1),
827            VectorEntry::new(2, vec![0.1; 100], "general".to_string(), 1), // Invalid dimension
828            create_test_entry(3, 1),
829        ];
830
831        let count = db.insert_batch(entries).unwrap();
832        assert_eq!(count, 2); // Only valid entries inserted
833        assert_eq!(db.len(), 2);
834    }
835
836    #[test]
837    fn test_batch_remove() {
838        let mut db = VectorDatabase::new();
839        db.insert(create_test_entry(1, 1)).unwrap();
840        db.insert(create_test_entry(2, 1)).unwrap();
841        db.insert(create_test_entry(3, 1)).unwrap();
842
843        let removed = db.remove_batch(&[1, 2, 999]);
844        assert_eq!(removed.len(), 3);
845        assert!(removed[0].is_some());
846        assert!(removed[1].is_some());
847        assert!(removed[2].is_none()); // 999 doesn't exist
848        assert_eq!(db.len(), 1);
849    }
850
851    #[test]
852    fn test_find_similar() {
853        let mut db = VectorDatabase::new();
854
855        // Create vectors with different similarities
856        let mut e1 = vec![0.5; EMBEDDING_DIMENSION];
857        e1[0] = 1.0;
858
859        let mut e2 = vec![0.5; EMBEDDING_DIMENSION];
860        e2[0] = 0.95;
861
862        let mut e3 = vec![0.1; EMBEDDING_DIMENSION];
863        e3[0] = -1.0;
864
865        db.insert(VectorEntry::new(1, e1.clone(), "general".to_string(), 1))
866            .unwrap();
867        db.insert(VectorEntry::new(2, e2, "general".to_string(), 1))
868            .unwrap();
869        db.insert(VectorEntry::new(3, e3, "general".to_string(), 1))
870            .unwrap();
871
872        let (results, _) = db.find_similar(1, 10, 0.0).unwrap();
873
874        // Should not include the query vector itself
875        assert!(!results.iter().any(|r| r.id == 1));
876        // Most similar should be first
877        assert_eq!(results[0].id, 2);
878    }
879
880    #[test]
881    fn test_stats() {
882        let mut db = VectorDatabase::new();
883        db.insert(VectorEntry::new(
884            1,
885            vec![0.1; EMBEDDING_DIMENSION],
886            "general".to_string(),
887            1,
888        ))
889        .unwrap();
890        db.insert(VectorEntry::new(
891            2,
892            vec![0.1; EMBEDDING_DIMENSION],
893            "general".to_string(),
894            1,
895        ))
896        .unwrap();
897        db.insert(VectorEntry::new(
898            3,
899            vec![0.1; EMBEDDING_DIMENSION],
900            "facts".to_string(),
901            2,
902        ))
903        .unwrap();
904
905        let stats = db.stats();
906        assert_eq!(stats.total_vectors, 3);
907        assert_eq!(stats.dimension, EMBEDDING_DIMENSION);
908        assert_eq!(*stats.category_counts.get("general").unwrap_or(&0), 2);
909        assert_eq!(*stats.category_counts.get("facts").unwrap_or(&0), 1);
910    }
911
912    #[test]
913    fn test_clear() {
914        let mut db = VectorDatabase::new();
915        db.insert(create_test_entry(1, 1)).unwrap();
916        db.insert(create_test_entry(2, 1)).unwrap();
917
918        db.clear();
919        assert!(db.is_empty());
920    }
921
922    #[test]
923    fn test_contains() {
924        let mut db = VectorDatabase::new();
925        db.insert(create_test_entry(1, 1)).unwrap();
926
927        assert!(db.contains(1));
928        assert!(!db.contains(2));
929    }
930
931    #[test]
932    fn test_update_embedding() {
933        let mut db = VectorDatabase::new();
934        db.insert(create_test_entry(1, 1)).unwrap();
935
936        let new_embedding = vec![0.9; EMBEDDING_DIMENSION];
937        db.update_embedding(1, new_embedding.clone()).unwrap();
938
939        let entry = db.get(1).unwrap();
940        assert_eq!(entry.embedding, new_embedding);
941    }
942
943    #[test]
944    fn test_update_embedding_nonexistent() {
945        let mut db = VectorDatabase::new();
946        let result = db.update_embedding(999, vec![0.1; EMBEDDING_DIMENSION]);
947        assert!(result.is_err());
948    }
949
950    #[test]
951    fn test_euclidean_distance() {
952        let a = vec![1.0, 0.0, 0.0];
953        let b = vec![0.0, 1.0, 0.0];
954        let dist = euclidean_distance(&a, &b);
955        assert!((dist - 2.0_f32.sqrt()).abs() < 0.001);
956    }
957
958    #[test]
959    fn test_dot_product() {
960        let a = vec![1.0, 2.0, 3.0];
961        let b = vec![4.0, 5.0, 6.0];
962        let prod = dot_product(&a, &b);
963        assert!((prod - 32.0).abs() < 0.001); // 1*4 + 2*5 + 3*6 = 32
964    }
965
966    #[test]
967    fn test_normalize_vector() {
968        let mut v = vec![3.0, 4.0];
969        normalize_vector(&mut v);
970        assert!((v[0] - 0.6).abs() < 0.001);
971        assert!((v[1] - 0.8).abs() < 0.001);
972    }
973
974    #[test]
975    fn test_top_k_similar() {
976        let query = vec![1.0, 0.0];
977        let vectors: Vec<(i64, &[f32])> = vec![
978            (1, &[1.0, 0.0]),     // similarity 1.0
979            (2, &[0.0, 1.0]),     // similarity 0.0
980            (3, &[0.707, 0.707]), // similarity ~0.707
981            (4, &[0.9, 0.1]),     // similarity ~0.9
982        ];
983
984        let top_k = top_k_similar(&query, &vectors, 2, 0.0);
985        assert_eq!(top_k.len(), 2);
986        assert_eq!(top_k[0].0, 1); // Most similar
987        assert_eq!(top_k[1].0, 4); // Second most similar
988    }
989}