rag_plusplus_core/index/
traits.rs

1//! Index Traits and Common Types
2//!
3//! Defines the core abstraction for vector indexes.
4
5use crate::error::Result;
6use std::fmt::Debug;
7
8/// Distance metric for similarity computation.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum DistanceType {
11    /// Euclidean (L2) distance - smaller is more similar
12    #[default]
13    L2,
14    /// Inner product - larger is more similar (use with normalized vectors)
15    InnerProduct,
16    /// Cosine similarity - larger is more similar
17    Cosine,
18}
19
20/// Index configuration.
21#[derive(Debug, Clone)]
22pub struct IndexConfig {
23    /// Vector dimension
24    pub dimension: usize,
25    /// Distance metric
26    pub distance_type: DistanceType,
27    /// Whether to normalize vectors before indexing
28    pub normalize: bool,
29}
30
31impl IndexConfig {
32    /// Create new index configuration.
33    #[must_use]
34    pub fn new(dimension: usize) -> Self {
35        Self {
36            dimension,
37            distance_type: DistanceType::L2,
38            normalize: false,
39        }
40    }
41
42    /// Set distance type.
43    #[must_use]
44    pub const fn with_distance(mut self, distance_type: DistanceType) -> Self {
45        self.distance_type = distance_type;
46        self
47    }
48
49    /// Enable vector normalization.
50    #[must_use]
51    pub const fn with_normalize(mut self, normalize: bool) -> Self {
52        self.normalize = normalize;
53        self
54    }
55}
56
57/// Result from a nearest neighbor search.
58#[derive(Debug, Clone)]
59pub struct SearchResult {
60    /// Record ID
61    pub id: String,
62    /// Distance from query (interpretation depends on metric)
63    pub distance: f32,
64    /// Similarity score (higher = more similar, normalized to 0-1)
65    pub score: f32,
66}
67
68impl SearchResult {
69    /// Create a new search result.
70    #[must_use]
71    pub fn new(id: String, distance: f32, distance_type: DistanceType) -> Self {
72        let score = Self::distance_to_score(distance, distance_type);
73        Self { id, distance, score }
74    }
75
76    /// Convert distance to similarity score (0-1, higher is better).
77    fn distance_to_score(distance: f32, distance_type: DistanceType) -> f32 {
78        match distance_type {
79            DistanceType::L2 => {
80                // Convert L2 distance to similarity: 1 / (1 + distance)
81                1.0 / (1.0 + distance)
82            }
83            DistanceType::InnerProduct | DistanceType::Cosine => {
84                // Inner product / cosine: already a similarity (may need clamping)
85                distance.clamp(0.0, 1.0)
86            }
87        }
88    }
89}
90
91/// Core trait for vector indexes.
92///
93/// All index implementations must implement this trait.
94pub trait VectorIndex: Send + Sync + Debug {
95    /// Add a vector to the index.
96    ///
97    /// # Arguments
98    ///
99    /// * `id` - Unique identifier for the vector
100    /// * `vector` - The vector to index
101    ///
102    /// # Errors
103    ///
104    /// Returns error if dimension mismatch or capacity exceeded.
105    fn add(&mut self, id: String, vector: &[f32]) -> Result<()>;
106
107    /// Add multiple vectors in batch.
108    ///
109    /// Default implementation calls `add` repeatedly.
110    fn add_batch(&mut self, ids: Vec<String>, vectors: &[Vec<f32>]) -> Result<()> {
111        for (id, vector) in ids.into_iter().zip(vectors.iter()) {
112            self.add(id, vector)?;
113        }
114        Ok(())
115    }
116
117    /// Search for k nearest neighbors.
118    ///
119    /// # Arguments
120    ///
121    /// * `query` - Query vector
122    /// * `k` - Number of neighbors to return
123    ///
124    /// # Returns
125    ///
126    /// Vector of search results, sorted by distance (ascending).
127    fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
128
129    /// Search with pre-filter (IDs to consider).
130    ///
131    /// Default implementation searches all, then filters.
132    fn search_with_ids(&self, query: &[f32], k: usize, ids: &[String]) -> Result<Vec<SearchResult>> {
133        let results = self.search(query, self.len().min(k * 10))?;
134        let id_set: std::collections::HashSet<_> = ids.iter().collect();
135        Ok(results
136            .into_iter()
137            .filter(|r| id_set.contains(&r.id))
138            .take(k)
139            .collect())
140    }
141
142    /// Remove a vector from the index.
143    ///
144    /// # Returns
145    ///
146    /// `true` if vector was found and removed, `false` otherwise.
147    fn remove(&mut self, id: &str) -> Result<bool>;
148
149    /// Check if ID exists in index.
150    fn contains(&self, id: &str) -> bool;
151
152    /// Number of vectors in the index.
153    fn len(&self) -> usize;
154
155    /// Whether the index is empty.
156    fn is_empty(&self) -> bool {
157        self.len() == 0
158    }
159
160    /// Vector dimension.
161    fn dimension(&self) -> usize;
162
163    /// Distance type used by this index.
164    fn distance_type(&self) -> DistanceType;
165
166    /// Clear all vectors from the index.
167    fn clear(&mut self);
168
169    /// Get memory usage estimate in bytes.
170    fn memory_usage(&self) -> usize;
171}
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    #[test]
178    fn test_l2_score() {
179        // Distance 0 -> score 1.0
180        let r = SearchResult::new("a".to_string(), 0.0, DistanceType::L2);
181        assert!((r.score - 1.0).abs() < 1e-6);
182
183        // Distance 1 -> score 0.5
184        let r = SearchResult::new("b".to_string(), 1.0, DistanceType::L2);
185        assert!((r.score - 0.5).abs() < 1e-6);
186    }
187
188    #[test]
189    fn test_config() {
190        let config = IndexConfig::new(256)
191            .with_distance(DistanceType::Cosine)
192            .with_normalize(true);
193
194        assert_eq!(config.dimension, 256);
195        assert_eq!(config.distance_type, DistanceType::Cosine);
196        assert!(config.normalize);
197    }
198}