rag_plusplus_core/index/
traits.rs1use crate::error::Result;
6use std::fmt::Debug;
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
10pub enum DistanceType {
11 #[default]
13 L2,
14 InnerProduct,
16 Cosine,
18}
19
20#[derive(Debug, Clone)]
22pub struct IndexConfig {
23 pub dimension: usize,
25 pub distance_type: DistanceType,
27 pub normalize: bool,
29}
30
31impl IndexConfig {
32 #[must_use]
34 pub fn new(dimension: usize) -> Self {
35 Self {
36 dimension,
37 distance_type: DistanceType::L2,
38 normalize: false,
39 }
40 }
41
42 #[must_use]
44 pub const fn with_distance(mut self, distance_type: DistanceType) -> Self {
45 self.distance_type = distance_type;
46 self
47 }
48
49 #[must_use]
51 pub const fn with_normalize(mut self, normalize: bool) -> Self {
52 self.normalize = normalize;
53 self
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct SearchResult {
60 pub id: String,
62 pub distance: f32,
64 pub score: f32,
66}
67
68impl SearchResult {
69 #[must_use]
71 pub fn new(id: String, distance: f32, distance_type: DistanceType) -> Self {
72 let score = Self::distance_to_score(distance, distance_type);
73 Self { id, distance, score }
74 }
75
76 fn distance_to_score(distance: f32, distance_type: DistanceType) -> f32 {
78 match distance_type {
79 DistanceType::L2 => {
80 1.0 / (1.0 + distance)
82 }
83 DistanceType::InnerProduct | DistanceType::Cosine => {
84 distance.clamp(0.0, 1.0)
86 }
87 }
88 }
89}
90
91pub trait VectorIndex: Send + Sync + Debug {
95 fn add(&mut self, id: String, vector: &[f32]) -> Result<()>;
106
107 fn add_batch(&mut self, ids: Vec<String>, vectors: &[Vec<f32>]) -> Result<()> {
111 for (id, vector) in ids.into_iter().zip(vectors.iter()) {
112 self.add(id, vector)?;
113 }
114 Ok(())
115 }
116
117 fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>>;
128
129 fn search_with_ids(&self, query: &[f32], k: usize, ids: &[String]) -> Result<Vec<SearchResult>> {
133 let results = self.search(query, self.len().min(k * 10))?;
134 let id_set: std::collections::HashSet<_> = ids.iter().collect();
135 Ok(results
136 .into_iter()
137 .filter(|r| id_set.contains(&r.id))
138 .take(k)
139 .collect())
140 }
141
142 fn remove(&mut self, id: &str) -> Result<bool>;
148
149 fn contains(&self, id: &str) -> bool;
151
152 fn len(&self) -> usize;
154
155 fn is_empty(&self) -> bool {
157 self.len() == 0
158 }
159
160 fn dimension(&self) -> usize;
162
163 fn distance_type(&self) -> DistanceType;
165
166 fn clear(&mut self);
168
169 fn memory_usage(&self) -> usize;
171}
172
173#[cfg(test)]
174mod tests {
175 use super::*;
176
177 #[test]
178 fn test_l2_score() {
179 let r = SearchResult::new("a".to_string(), 0.0, DistanceType::L2);
181 assert!((r.score - 1.0).abs() < 1e-6);
182
183 let r = SearchResult::new("b".to_string(), 1.0, DistanceType::L2);
185 assert!((r.score - 0.5).abs() < 1e-6);
186 }
187
188 #[test]
189 fn test_config() {
190 let config = IndexConfig::new(256)
191 .with_distance(DistanceType::Cosine)
192 .with_normalize(true);
193
194 assert_eq!(config.dimension, 256);
195 assert_eq!(config.distance_type, DistanceType::Cosine);
196 assert!(config.normalize);
197 }
198}