1#[derive(Debug, Clone)]
11pub struct MultiVector {
12 pub id: usize,
13 pub vectors: Vec<Vec<f32>>,
14}
15
16#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum DistanceMetric {
19 L2,
21 Cosine,
23 DotProduct,
25}
26
27#[derive(Debug, Clone)]
29pub struct ProductSearchConfig {
30 pub sub_dimensions: usize,
32 pub distance_metric: DistanceMetric,
34}
35
36#[derive(Debug, Clone)]
38pub struct SearchCandidate {
39 pub id: usize,
41 pub scores: Vec<f32>,
43 pub combined_score: f32,
45}
46
47pub struct ProductSearchIndex {
51 config: ProductSearchConfig,
52 items: Vec<MultiVector>,
53}
54
55impl ProductSearchIndex {
56 pub fn new(config: ProductSearchConfig) -> Self {
58 Self {
59 config,
60 items: Vec::new(),
61 }
62 }
63
64 pub fn insert(&mut self, item: MultiVector) {
66 self.items.push(item);
67 }
68
69 pub fn search(&self, query: &MultiVector, k: usize) -> Vec<SearchCandidate> {
74 let mut candidates: Vec<SearchCandidate> = self
75 .items
76 .iter()
77 .filter_map(|item| self.score_all(query, item))
78 .collect();
79
80 candidates.sort_by(|a, b| {
82 b.combined_score
83 .partial_cmp(&a.combined_score)
84 .unwrap_or(std::cmp::Ordering::Equal)
85 });
86 candidates.truncate(k);
87 candidates
88 }
89
90 pub fn search_sub(&self, query_sub: &[f32], sub_idx: usize, k: usize) -> Vec<SearchCandidate> {
92 let mut candidates: Vec<SearchCandidate> = self
93 .items
94 .iter()
95 .filter_map(|item| {
96 let item_sub = item.vectors.get(sub_idx)?;
97 if item_sub.len() != query_sub.len() {
98 return None;
99 }
100 let score = self.compute_score(query_sub, item_sub);
101 Some(SearchCandidate {
102 id: item.id,
103 scores: vec![score],
104 combined_score: score,
105 })
106 })
107 .collect();
108
109 candidates.sort_by(|a, b| {
110 b.combined_score
111 .partial_cmp(&a.combined_score)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 });
114 candidates.truncate(k);
115 candidates
116 }
117
118 pub fn item_count(&self) -> usize {
120 self.items.len()
121 }
122
123 pub fn sub_dimension_count(&self) -> usize {
125 self.config.sub_dimensions
126 }
127
128 pub fn remove(&mut self, id: usize) -> bool {
130 let before = self.items.len();
131 self.items.retain(|item| item.id != id);
132 self.items.len() < before
133 }
134
135 fn score_all(&self, query: &MultiVector, item: &MultiVector) -> Option<SearchCandidate> {
139 let n_subs = query.vectors.len().min(item.vectors.len());
140 if n_subs == 0 {
141 return None;
142 }
143 let mut scores: Vec<f32> = Vec::with_capacity(n_subs);
144 for i in 0..n_subs {
145 let qv = &query.vectors[i];
146 let iv = &item.vectors[i];
147 if qv.len() != iv.len() {
148 return None;
149 }
150 scores.push(self.compute_score(qv, iv));
151 }
152 let combined_score = scores.iter().sum::<f32>() / scores.len() as f32;
153 Some(SearchCandidate {
154 id: item.id,
155 scores,
156 combined_score,
157 })
158 }
159
160 fn compute_score(&self, a: &[f32], b: &[f32]) -> f32 {
162 match &self.config.distance_metric {
163 DistanceMetric::L2 => -l2_distance(a, b),
164 DistanceMetric::Cosine => cosine_sim(a, b),
165 DistanceMetric::DotProduct => dot_product(a, b),
166 }
167 }
168}
169
170pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
174 a.iter()
175 .zip(b.iter())
176 .map(|(x, y)| (x - y).powi(2))
177 .sum::<f32>()
178 .sqrt()
179}
180
181pub fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
183 let dot = dot_product(a, b);
184 let norm_a = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
185 let norm_b = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
186 if norm_a == 0.0 || norm_b == 0.0 {
187 0.0
188 } else {
189 dot / (norm_a * norm_b)
190 }
191}
192
193pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
195 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
196}
197
198#[cfg(test)]
201mod tests {
202 use super::*;
203
204 fn vec1(v: &[f32]) -> Vec<Vec<f32>> {
205 vec![v.to_vec()]
206 }
207
208 fn vec2(v1: &[f32], v2: &[f32]) -> Vec<Vec<f32>> {
209 vec![v1.to_vec(), v2.to_vec()]
210 }
211
212 fn cfg(metric: DistanceMetric) -> ProductSearchConfig {
213 ProductSearchConfig {
214 sub_dimensions: 1,
215 distance_metric: metric,
216 }
217 }
218
219 fn mv(id: usize, vecs: Vec<Vec<f32>>) -> MultiVector {
220 MultiVector { id, vectors: vecs }
221 }
222
223 #[test]
226 fn test_l2_distance_zero() {
227 assert!((l2_distance(&[1.0, 2.0], &[1.0, 2.0])).abs() < 1e-6);
228 }
229
230 #[test]
231 fn test_l2_distance_known() {
232 assert!((l2_distance(&[0.0, 0.0], &[3.0, 4.0]) - 5.0).abs() < 1e-5);
234 }
235
236 #[test]
239 fn test_cosine_sim_identical() {
240 let v = [1.0f32, 0.0, 0.0];
241 assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-6);
242 }
243
244 #[test]
245 fn test_cosine_sim_orthogonal() {
246 assert!((cosine_sim(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-6);
247 }
248
249 #[test]
250 fn test_cosine_sim_opposite() {
251 assert!((cosine_sim(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-6);
252 }
253
254 #[test]
255 fn test_cosine_sim_zero_vector() {
256 assert_eq!(cosine_sim(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
257 }
258
259 #[test]
262 fn test_dot_product_basic() {
263 assert!((dot_product(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]) - 32.0).abs() < 1e-6);
264 }
265
266 #[test]
267 fn test_dot_product_zero() {
268 assert_eq!(dot_product(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
269 }
270
271 #[test]
274 fn test_insert_increments_count() {
275 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
276 idx.insert(mv(1, vec1(&[1.0])));
277 assert_eq!(idx.item_count(), 1);
278 }
279
280 #[test]
281 fn test_insert_multiple() {
282 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
283 idx.insert(mv(1, vec1(&[1.0])));
284 idx.insert(mv(2, vec1(&[2.0])));
285 assert_eq!(idx.item_count(), 2);
286 }
287
288 #[test]
291 fn test_sub_dimension_count() {
292 let idx = ProductSearchIndex::new(ProductSearchConfig {
293 sub_dimensions: 3,
294 distance_metric: DistanceMetric::Cosine,
295 });
296 assert_eq!(idx.sub_dimension_count(), 3);
297 }
298
299 #[test]
302 fn test_search_l2_nearest_neighbor() {
303 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
304 idx.insert(mv(1, vec1(&[0.0])));
305 idx.insert(mv(2, vec1(&[10.0])));
306 let q = mv(0, vec1(&[0.5]));
307 let results = idx.search(&q, 1);
308 assert_eq!(results.len(), 1);
309 assert_eq!(results[0].id, 1); }
311
312 #[test]
313 fn test_search_l2_same_vector_best_score() {
314 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
315 idx.insert(mv(1, vec1(&[1.0, 2.0, 3.0])));
316 idx.insert(mv(2, vec1(&[10.0, 10.0, 10.0])));
317 let q = mv(0, vec1(&[1.0, 2.0, 3.0]));
318 let results = idx.search(&q, 2);
319 assert_eq!(results[0].id, 1); }
321
322 #[test]
323 fn test_search_l2_k_limit() {
324 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
325 for i in 0..10usize {
326 idx.insert(mv(i, vec1(&[i as f32])));
327 }
328 let q = mv(99, vec1(&[0.0]));
329 let results = idx.search(&q, 3);
330 assert_eq!(results.len(), 3);
331 }
332
333 #[test]
336 fn test_search_cosine_identical_is_top() {
337 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::Cosine));
338 idx.insert(mv(1, vec1(&[1.0, 0.0])));
339 idx.insert(mv(2, vec1(&[0.0, 1.0])));
340 let q = mv(0, vec1(&[1.0, 0.0]));
341 let results = idx.search(&q, 2);
342 assert_eq!(results[0].id, 1);
343 }
344
345 #[test]
348 fn test_search_dot_product() {
349 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::DotProduct));
350 idx.insert(mv(1, vec1(&[1.0, 2.0])));
351 idx.insert(mv(2, vec1(&[3.0, 4.0])));
352 let q = mv(0, vec1(&[1.0, 1.0]));
353 let results = idx.search(&q, 2);
354 assert_eq!(results[0].id, 2);
356 }
357
358 #[test]
361 fn test_search_multi_vector_combination() {
362 let mut idx = ProductSearchIndex::new(ProductSearchConfig {
363 sub_dimensions: 2,
364 distance_metric: DistanceMetric::Cosine,
365 });
366 idx.insert(mv(1, vec2(&[1.0, 0.0], &[0.0, 1.0])));
368 idx.insert(mv(2, vec2(&[1.0, 0.0], &[1.0, 0.0])));
370 let q = mv(0, vec2(&[1.0, 0.0], &[1.0, 0.0]));
371 let results = idx.search(&q, 2);
372 assert_eq!(results[0].id, 2);
374 }
375
376 #[test]
377 fn test_search_candidate_scores_count_equals_sub_vectors() {
378 let mut idx = ProductSearchIndex::new(ProductSearchConfig {
379 sub_dimensions: 3,
380 distance_metric: DistanceMetric::Cosine,
381 });
382 idx.insert(mv(1, vec![vec![1.0], vec![1.0], vec![1.0]]));
383 let q = mv(0, vec![vec![1.0], vec![1.0], vec![1.0]]);
384 let results = idx.search(&q, 1);
385 assert_eq!(results[0].scores.len(), 3);
386 }
387
388 #[test]
391 fn test_search_sub_single_dimension() {
392 let mut idx = ProductSearchIndex::new(ProductSearchConfig {
393 sub_dimensions: 2,
394 distance_metric: DistanceMetric::L2,
395 });
396 idx.insert(mv(1, vec2(&[0.0], &[10.0])));
397 idx.insert(mv(2, vec2(&[5.0], &[10.0])));
398 let results = idx.search_sub(&[0.0], 0, 1);
399 assert_eq!(results.len(), 1);
400 assert_eq!(results[0].id, 1); }
402
403 #[test]
404 fn test_search_sub_k_limit() {
405 let mut idx = ProductSearchIndex::new(ProductSearchConfig {
406 sub_dimensions: 1,
407 distance_metric: DistanceMetric::Cosine,
408 });
409 for i in 0..5usize {
410 idx.insert(mv(i, vec1(&[i as f32 + 1.0])));
411 }
412 let results = idx.search_sub(&[1.0], 0, 2);
413 assert_eq!(results.len(), 2);
414 }
415
416 #[test]
419 fn test_remove_existing_item() {
420 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
421 idx.insert(mv(42, vec1(&[1.0])));
422 assert!(idx.remove(42));
423 assert_eq!(idx.item_count(), 0);
424 }
425
426 #[test]
427 fn test_remove_nonexistent_item() {
428 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
429 assert!(!idx.remove(99));
430 }
431
432 #[test]
433 fn test_remove_does_not_affect_other_items() {
434 let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
435 idx.insert(mv(1, vec1(&[1.0])));
436 idx.insert(mv(2, vec1(&[2.0])));
437 idx.remove(1);
438 assert_eq!(idx.item_count(), 1);
439 let q = mv(0, vec1(&[2.0]));
440 let results = idx.search(&q, 1);
441 assert_eq!(results[0].id, 2);
442 }
443
444 #[test]
447 fn test_search_empty_index() {
448 let idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
449 let q = mv(0, vec1(&[1.0]));
450 let results = idx.search(&q, 5);
451 assert!(results.is_empty());
452 }
453
454 #[test]
455 fn test_search_sub_empty_index() {
456 let idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
457 let results = idx.search_sub(&[1.0], 0, 5);
458 assert!(results.is_empty());
459 }
460
461 #[test]
464 fn test_combined_score_is_mean_of_scores() {
465 let mut idx = ProductSearchIndex::new(ProductSearchConfig {
466 sub_dimensions: 2,
467 distance_metric: DistanceMetric::Cosine,
468 });
469 idx.insert(mv(1, vec2(&[1.0, 0.0], &[1.0, 0.0])));
470 let q = mv(0, vec2(&[1.0, 0.0], &[1.0, 0.0]));
471 let results = idx.search(&q, 1);
472 let c = &results[0];
473 let expected = c.scores.iter().sum::<f32>() / c.scores.len() as f32;
474 assert!((c.combined_score - expected).abs() < 1e-5);
475 }
476}