1use std::sync::RwLock;
3
4use rayon::prelude::*;
5use tracing::instrument;
6
7use crate::{
8 config::VectorConfig,
9 error::{VectorError, VectorResult},
10 index::hnsw::HnswIndex,
11 types::DistanceMetric,
12};
13
14pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
18 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
19 let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
20 let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
21 if na == 0.0 || nb == 0.0 {
22 1.0
23 } else {
24 1.0 - dot / (na * nb)
25 }
26}
27
28pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
30 a.iter()
31 .zip(b.iter())
32 .map(|(x, y)| (x - y) * (x - y))
33 .sum::<f32>()
34 .sqrt()
35}
36
37pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
39 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
40}
41
42pub struct FlatIndex {
46 vectors: RwLock<Vec<(usize, Vec<f32>)>>,
47 pub dimensions: usize,
49 pub distance: DistanceMetric,
51}
52
53impl FlatIndex {
54 pub fn new(dimensions: usize, distance: DistanceMetric) -> Self {
56 FlatIndex {
57 vectors: RwLock::new(Vec::new()),
58 dimensions,
59 distance,
60 }
61 }
62
63 #[instrument(skip(self, vector))]
65 pub fn insert(&self, id: usize, vector: Vec<f32>) -> VectorResult<()> {
66 if vector.len() != self.dimensions {
67 return Err(VectorError::DimensionMismatch {
68 expected: self.dimensions,
69 got: vector.len(),
70 });
71 }
72 self.vectors
73 .write()
74 .map_err(|e| VectorError::Index(e.to_string()))?
75 .push((id, vector));
76 Ok(())
77 }
78
79 #[instrument(skip(self, items))]
81 pub fn insert_batch(&self, items: Vec<(usize, Vec<f32>)>) -> VectorResult<()> {
82 for (_, v) in &items {
83 if v.len() != self.dimensions {
84 return Err(VectorError::DimensionMismatch {
85 expected: self.dimensions,
86 got: v.len(),
87 });
88 }
89 }
90 self.vectors
91 .write()
92 .map_err(|e| VectorError::Index(e.to_string()))?
93 .extend(items);
94 Ok(())
95 }
96
97 #[instrument(skip(self, query))]
99 pub fn search(&self, query: &[f32], top_k: usize) -> VectorResult<Vec<(usize, f32)>> {
100 if query.len() != self.dimensions {
101 return Err(VectorError::DimensionMismatch {
102 expected: self.dimensions,
103 got: query.len(),
104 });
105 }
106 let vecs = self
107 .vectors
108 .read()
109 .map_err(|e| VectorError::Index(e.to_string()))?;
110 let dist = self.distance;
111 let mut scores: Vec<(usize, f32)> = vecs
112 .par_iter()
113 .map(|(id, v)| {
114 let d = match dist {
115 DistanceMetric::Cosine => cosine_similarity(query, v),
116 DistanceMetric::Euclidean => euclidean_distance(query, v),
117 DistanceMetric::DotProduct => dot_product(query, v),
118 };
119 (*id, d)
120 })
121 .collect();
122 scores.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
123 scores.truncate(top_k);
124 Ok(scores)
125 }
126
127 #[instrument(skip(self))]
129 pub fn delete(&self, id: usize) -> VectorResult<bool> {
130 let mut vecs = self
131 .vectors
132 .write()
133 .map_err(|e| VectorError::Index(e.to_string()))?;
134 let before = vecs.len();
135 vecs.retain(|(vid, _)| *vid != id);
136 Ok(vecs.len() < before)
137 }
138
139 pub fn len(&self) -> usize {
141 self.vectors.read().map(|v| v.len()).unwrap_or(0)
142 }
143
144 pub fn is_empty(&self) -> bool {
146 self.len() == 0
147 }
148
149 pub fn all_vectors(&self) -> VectorResult<Vec<(usize, Vec<f32>)>> {
151 Ok(self
152 .vectors
153 .read()
154 .map_err(|e| VectorError::Index(e.to_string()))?
155 .clone())
156 }
157
158 #[instrument(skip(self, config))]
160 pub fn to_hnsw(&self, config: &VectorConfig) -> VectorResult<HnswIndex> {
161 let hnsw = HnswIndex::new_with_dimensions(config, self.distance, self.dimensions)?;
162 let items = self.all_vectors()?;
163 hnsw.insert_batch(&items)?;
164 Ok(hnsw)
165 }
166}
167
168#[cfg(test)]
171mod tests {
172 use super::*;
173 use approx::assert_abs_diff_eq;
174
175 #[test]
176 fn cosine_orthogonal_vectors() {
177 let a = vec![1.0f32, 0.0];
178 let b = vec![0.0f32, 1.0];
179 assert_abs_diff_eq!(cosine_similarity(&a, &b), 1.0, epsilon = 1e-6);
180 }
181
182 #[test]
183 fn cosine_identical_vectors() {
184 let a = vec![1.0f32, 1.0, 1.0];
185 assert_abs_diff_eq!(cosine_similarity(&a, &a), 0.0, epsilon = 1e-6);
186 }
187
188 #[test]
189 fn euclidean_known_distance() {
190 let a = vec![0.0f32, 0.0, 0.0];
191 let b = vec![3.0f32, 4.0, 0.0];
192 assert_abs_diff_eq!(euclidean_distance(&a, &b), 5.0, epsilon = 1e-6);
193 }
194
195 #[test]
196 fn euclidean_same_point() {
197 let a = vec![1.0f32, 2.0, 3.0];
198 assert_abs_diff_eq!(euclidean_distance(&a, &a), 0.0, epsilon = 1e-6);
199 }
200
201 #[test]
202 fn dot_product_known() {
203 let a = vec![1.0f32, 2.0, 3.0];
204 let b = vec![4.0f32, 5.0, 6.0];
205 assert_abs_diff_eq!(dot_product(&a, &b), -32.0, epsilon = 1e-6);
206 }
207
208 #[test]
209 fn flat_index_insert_search() {
210 let idx = FlatIndex::new(2, DistanceMetric::Euclidean);
211 idx.insert(0, vec![0.0, 0.0]).unwrap();
212 idx.insert(1, vec![1.0, 1.0]).unwrap();
213 idx.insert(2, vec![10.0, 10.0]).unwrap();
214 let results = idx.search(&[0.1, 0.1], 2).unwrap();
215 assert_eq!(results.len(), 2);
216 assert_eq!(results[0].0, 0);
217 }
218
219 #[test]
220 fn flat_index_delete() {
221 let idx = FlatIndex::new(2, DistanceMetric::Euclidean);
222 idx.insert(42, vec![1.0, 1.0]).unwrap();
223 assert_eq!(idx.len(), 1);
224 assert!(idx.delete(42).unwrap());
225 assert_eq!(idx.len(), 0);
226 }
227
228 #[test]
229 fn dimension_mismatch_returns_error() {
230 let idx = FlatIndex::new(3, DistanceMetric::Euclidean);
231 let err = idx.insert(0, vec![1.0, 2.0]).unwrap_err();
232 assert!(err.is_dimension_mismatch());
233 }
234}