Skip to main content

oxirs_vec/
product_search.rs

1//! Multi-vector product search combining multiple embedding sub-vectors.
2//!
3//! A `ProductSearchIndex` stores `MultiVector` items (each item has multiple
4//! sub-vectors of potentially different dimensions) and provides search
5//! functionality that computes per-sub-vector scores and combines them.
6
7// ── Types ─────────────────────────────────────────────────────────────────────
8
9/// An item in the index.  `vectors[i]` is the i-th sub-vector.
10#[derive(Debug, Clone)]
11pub struct MultiVector {
12    pub id: usize,
13    pub vectors: Vec<Vec<f32>>,
14}
15
16/// The distance metric used for scoring.
17#[derive(Debug, Clone, PartialEq, Eq)]
18pub enum DistanceMetric {
19    /// Squared Euclidean distance (lower is better → negated for combined score).
20    L2,
21    /// Cosine similarity (higher is better).
22    Cosine,
23    /// Raw dot-product (higher is better).
24    DotProduct,
25}
26
27/// Configuration for the index.
28#[derive(Debug, Clone)]
29pub struct ProductSearchConfig {
30    /// Number of sub-vectors per item.
31    pub sub_dimensions: usize,
32    /// Distance metric used during search.
33    pub distance_metric: DistanceMetric,
34}
35
36/// A single search result candidate.
37#[derive(Debug, Clone)]
38pub struct SearchCandidate {
39    /// The item's id.
40    pub id: usize,
41    /// One score per sub-vector.
42    pub scores: Vec<f32>,
43    /// The arithmetic mean of `scores`.
44    pub combined_score: f32,
45}
46
47// ── ProductSearchIndex ────────────────────────────────────────────────────────
48
49/// Multi-vector product search index.
50pub struct ProductSearchIndex {
51    config: ProductSearchConfig,
52    items: Vec<MultiVector>,
53}
54
55impl ProductSearchIndex {
56    /// Create an empty index.
57    pub fn new(config: ProductSearchConfig) -> Self {
58        Self {
59            config,
60            items: Vec::new(),
61        }
62    }
63
64    /// Insert an item into the index.
65    pub fn insert(&mut self, item: MultiVector) {
66        self.items.push(item);
67    }
68
69    /// Search for the `k` nearest items to `query` across all sub-vectors.
70    ///
71    /// The combined score is the mean of per-sub-vector scores.  Items are
72    /// returned sorted by combined score (descending for similarity metrics).
73    pub fn search(&self, query: &MultiVector, k: usize) -> Vec<SearchCandidate> {
74        let mut candidates: Vec<SearchCandidate> = self
75            .items
76            .iter()
77            .filter_map(|item| self.score_all(query, item))
78            .collect();
79
80        // Sort: higher combined score = better (for L2 we negate so still descending)
81        candidates.sort_by(|a, b| {
82            b.combined_score
83                .partial_cmp(&a.combined_score)
84                .unwrap_or(std::cmp::Ordering::Equal)
85        });
86        candidates.truncate(k);
87        candidates
88    }
89
90    /// Search using only the sub-vector at index `sub_idx`.
91    pub fn search_sub(&self, query_sub: &[f32], sub_idx: usize, k: usize) -> Vec<SearchCandidate> {
92        let mut candidates: Vec<SearchCandidate> = self
93            .items
94            .iter()
95            .filter_map(|item| {
96                let item_sub = item.vectors.get(sub_idx)?;
97                if item_sub.len() != query_sub.len() {
98                    return None;
99                }
100                let score = self.compute_score(query_sub, item_sub);
101                Some(SearchCandidate {
102                    id: item.id,
103                    scores: vec![score],
104                    combined_score: score,
105                })
106            })
107            .collect();
108
109        candidates.sort_by(|a, b| {
110            b.combined_score
111                .partial_cmp(&a.combined_score)
112                .unwrap_or(std::cmp::Ordering::Equal)
113        });
114        candidates.truncate(k);
115        candidates
116    }
117
118    /// Number of items in the index.
119    pub fn item_count(&self) -> usize {
120        self.items.len()
121    }
122
123    /// The configured number of sub-dimensions.
124    pub fn sub_dimension_count(&self) -> usize {
125        self.config.sub_dimensions
126    }
127
128    /// Remove an item by id.  Returns `true` if the item existed.
129    pub fn remove(&mut self, id: usize) -> bool {
130        let before = self.items.len();
131        self.items.retain(|item| item.id != id);
132        self.items.len() < before
133    }
134
135    // ── Private helpers ────────────────────────────────────────────────────────
136
137    /// Score `query` against `item` across all matching sub-vectors.
138    fn score_all(&self, query: &MultiVector, item: &MultiVector) -> Option<SearchCandidate> {
139        let n_subs = query.vectors.len().min(item.vectors.len());
140        if n_subs == 0 {
141            return None;
142        }
143        let mut scores: Vec<f32> = Vec::with_capacity(n_subs);
144        for i in 0..n_subs {
145            let qv = &query.vectors[i];
146            let iv = &item.vectors[i];
147            if qv.len() != iv.len() {
148                return None;
149            }
150            scores.push(self.compute_score(qv, iv));
151        }
152        let combined_score = scores.iter().sum::<f32>() / scores.len() as f32;
153        Some(SearchCandidate {
154            id: item.id,
155            scores,
156            combined_score,
157        })
158    }
159
160    /// Compute a single sub-vector score according to the configured metric.
161    fn compute_score(&self, a: &[f32], b: &[f32]) -> f32 {
162        match &self.config.distance_metric {
163            DistanceMetric::L2 => -l2_distance(a, b),
164            DistanceMetric::Cosine => cosine_sim(a, b),
165            DistanceMetric::DotProduct => dot_product(a, b),
166        }
167    }
168}
169
170// ── Distance/similarity functions ────────────────────────────────────────────
171
172/// Euclidean L2 distance.
173pub fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
174    a.iter()
175        .zip(b.iter())
176        .map(|(x, y)| (x - y).powi(2))
177        .sum::<f32>()
178        .sqrt()
179}
180
181/// Cosine similarity in [−1, 1].
182pub fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
183    let dot = dot_product(a, b);
184    let norm_a = a.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
185    let norm_b = b.iter().map(|x| x.powi(2)).sum::<f32>().sqrt();
186    if norm_a == 0.0 || norm_b == 0.0 {
187        0.0
188    } else {
189        dot / (norm_a * norm_b)
190    }
191}
192
193/// Raw dot product.
194pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
195    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
196}
197
198// ── Tests ─────────────────────────────────────────────────────────────────────
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    fn vec1(v: &[f32]) -> Vec<Vec<f32>> {
205        vec![v.to_vec()]
206    }
207
208    fn vec2(v1: &[f32], v2: &[f32]) -> Vec<Vec<f32>> {
209        vec![v1.to_vec(), v2.to_vec()]
210    }
211
212    fn cfg(metric: DistanceMetric) -> ProductSearchConfig {
213        ProductSearchConfig {
214            sub_dimensions: 1,
215            distance_metric: metric,
216        }
217    }
218
219    fn mv(id: usize, vecs: Vec<Vec<f32>>) -> MultiVector {
220        MultiVector { id, vectors: vecs }
221    }
222
223    // ── l2_distance ───────────────────────────────────────────────────────────
224
225    #[test]
226    fn test_l2_distance_zero() {
227        assert!((l2_distance(&[1.0, 2.0], &[1.0, 2.0])).abs() < 1e-6);
228    }
229
230    #[test]
231    fn test_l2_distance_known() {
232        // 3-4-5 triangle
233        assert!((l2_distance(&[0.0, 0.0], &[3.0, 4.0]) - 5.0).abs() < 1e-5);
234    }
235
236    // ── cosine_sim ────────────────────────────────────────────────────────────
237
238    #[test]
239    fn test_cosine_sim_identical() {
240        let v = [1.0f32, 0.0, 0.0];
241        assert!((cosine_sim(&v, &v) - 1.0).abs() < 1e-6);
242    }
243
244    #[test]
245    fn test_cosine_sim_orthogonal() {
246        assert!((cosine_sim(&[1.0, 0.0], &[0.0, 1.0])).abs() < 1e-6);
247    }
248
249    #[test]
250    fn test_cosine_sim_opposite() {
251        assert!((cosine_sim(&[1.0, 0.0], &[-1.0, 0.0]) + 1.0).abs() < 1e-6);
252    }
253
254    #[test]
255    fn test_cosine_sim_zero_vector() {
256        assert_eq!(cosine_sim(&[0.0, 0.0], &[1.0, 0.0]), 0.0);
257    }
258
259    // ── dot_product ───────────────────────────────────────────────────────────
260
261    #[test]
262    fn test_dot_product_basic() {
263        assert!((dot_product(&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]) - 32.0).abs() < 1e-6);
264    }
265
266    #[test]
267    fn test_dot_product_zero() {
268        assert_eq!(dot_product(&[1.0, 0.0], &[0.0, 1.0]), 0.0);
269    }
270
271    // ── insert / item_count ───────────────────────────────────────────────────
272
273    #[test]
274    fn test_insert_increments_count() {
275        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
276        idx.insert(mv(1, vec1(&[1.0])));
277        assert_eq!(idx.item_count(), 1);
278    }
279
280    #[test]
281    fn test_insert_multiple() {
282        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
283        idx.insert(mv(1, vec1(&[1.0])));
284        idx.insert(mv(2, vec1(&[2.0])));
285        assert_eq!(idx.item_count(), 2);
286    }
287
288    // ── sub_dimension_count ───────────────────────────────────────────────────
289
290    #[test]
291    fn test_sub_dimension_count() {
292        let idx = ProductSearchIndex::new(ProductSearchConfig {
293            sub_dimensions: 3,
294            distance_metric: DistanceMetric::Cosine,
295        });
296        assert_eq!(idx.sub_dimension_count(), 3);
297    }
298
299    // ── search L2 ─────────────────────────────────────────────────────────────
300
301    #[test]
302    fn test_search_l2_nearest_neighbor() {
303        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
304        idx.insert(mv(1, vec1(&[0.0])));
305        idx.insert(mv(2, vec1(&[10.0])));
306        let q = mv(0, vec1(&[0.5]));
307        let results = idx.search(&q, 1);
308        assert_eq!(results.len(), 1);
309        assert_eq!(results[0].id, 1); // closer to 0.0 than 10.0
310    }
311
312    #[test]
313    fn test_search_l2_same_vector_best_score() {
314        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
315        idx.insert(mv(1, vec1(&[1.0, 2.0, 3.0])));
316        idx.insert(mv(2, vec1(&[10.0, 10.0, 10.0])));
317        let q = mv(0, vec1(&[1.0, 2.0, 3.0]));
318        let results = idx.search(&q, 2);
319        assert_eq!(results[0].id, 1); // exact match
320    }
321
322    #[test]
323    fn test_search_l2_k_limit() {
324        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
325        for i in 0..10usize {
326            idx.insert(mv(i, vec1(&[i as f32])));
327        }
328        let q = mv(99, vec1(&[0.0]));
329        let results = idx.search(&q, 3);
330        assert_eq!(results.len(), 3);
331    }
332
333    // ── search Cosine ─────────────────────────────────────────────────────────
334
335    #[test]
336    fn test_search_cosine_identical_is_top() {
337        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::Cosine));
338        idx.insert(mv(1, vec1(&[1.0, 0.0])));
339        idx.insert(mv(2, vec1(&[0.0, 1.0])));
340        let q = mv(0, vec1(&[1.0, 0.0]));
341        let results = idx.search(&q, 2);
342        assert_eq!(results[0].id, 1);
343    }
344
345    // ── search DotProduct ─────────────────────────────────────────────────────
346
347    #[test]
348    fn test_search_dot_product() {
349        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::DotProduct));
350        idx.insert(mv(1, vec1(&[1.0, 2.0])));
351        idx.insert(mv(2, vec1(&[3.0, 4.0])));
352        let q = mv(0, vec1(&[1.0, 1.0]));
353        let results = idx.search(&q, 2);
354        // dot([1,1],[3,4])=7 > dot([1,1],[1,2])=3 → item 2 first
355        assert_eq!(results[0].id, 2);
356    }
357
358    // ── multi-vector combination ──────────────────────────────────────────────
359
360    #[test]
361    fn test_search_multi_vector_combination() {
362        let mut idx = ProductSearchIndex::new(ProductSearchConfig {
363            sub_dimensions: 2,
364            distance_metric: DistanceMetric::Cosine,
365        });
366        // Item 1: very similar in sub0, very dissimilar in sub1
367        idx.insert(mv(1, vec2(&[1.0, 0.0], &[0.0, 1.0])));
368        // Item 2: similar in both sub0 and sub1
369        idx.insert(mv(2, vec2(&[1.0, 0.0], &[1.0, 0.0])));
370        let q = mv(0, vec2(&[1.0, 0.0], &[1.0, 0.0]));
371        let results = idx.search(&q, 2);
372        // Item 2 should have higher combined score
373        assert_eq!(results[0].id, 2);
374    }
375
376    #[test]
377    fn test_search_candidate_scores_count_equals_sub_vectors() {
378        let mut idx = ProductSearchIndex::new(ProductSearchConfig {
379            sub_dimensions: 3,
380            distance_metric: DistanceMetric::Cosine,
381        });
382        idx.insert(mv(1, vec![vec![1.0], vec![1.0], vec![1.0]]));
383        let q = mv(0, vec![vec![1.0], vec![1.0], vec![1.0]]);
384        let results = idx.search(&q, 1);
385        assert_eq!(results[0].scores.len(), 3);
386    }
387
388    // ── search_sub ────────────────────────────────────────────────────────────
389
390    #[test]
391    fn test_search_sub_single_dimension() {
392        let mut idx = ProductSearchIndex::new(ProductSearchConfig {
393            sub_dimensions: 2,
394            distance_metric: DistanceMetric::L2,
395        });
396        idx.insert(mv(1, vec2(&[0.0], &[10.0])));
397        idx.insert(mv(2, vec2(&[5.0], &[10.0])));
398        let results = idx.search_sub(&[0.0], 0, 1);
399        assert_eq!(results.len(), 1);
400        assert_eq!(results[0].id, 1); // item 1 has [0.0] in sub0
401    }
402
403    #[test]
404    fn test_search_sub_k_limit() {
405        let mut idx = ProductSearchIndex::new(ProductSearchConfig {
406            sub_dimensions: 1,
407            distance_metric: DistanceMetric::Cosine,
408        });
409        for i in 0..5usize {
410            idx.insert(mv(i, vec1(&[i as f32 + 1.0])));
411        }
412        let results = idx.search_sub(&[1.0], 0, 2);
413        assert_eq!(results.len(), 2);
414    }
415
416    // ── remove ────────────────────────────────────────────────────────────────
417
418    #[test]
419    fn test_remove_existing_item() {
420        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
421        idx.insert(mv(42, vec1(&[1.0])));
422        assert!(idx.remove(42));
423        assert_eq!(idx.item_count(), 0);
424    }
425
426    #[test]
427    fn test_remove_nonexistent_item() {
428        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
429        assert!(!idx.remove(99));
430    }
431
432    #[test]
433    fn test_remove_does_not_affect_other_items() {
434        let mut idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
435        idx.insert(mv(1, vec1(&[1.0])));
436        idx.insert(mv(2, vec1(&[2.0])));
437        idx.remove(1);
438        assert_eq!(idx.item_count(), 1);
439        let q = mv(0, vec1(&[2.0]));
440        let results = idx.search(&q, 1);
441        assert_eq!(results[0].id, 2);
442    }
443
444    // ── empty index ───────────────────────────────────────────────────────────
445
446    #[test]
447    fn test_search_empty_index() {
448        let idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
449        let q = mv(0, vec1(&[1.0]));
450        let results = idx.search(&q, 5);
451        assert!(results.is_empty());
452    }
453
454    #[test]
455    fn test_search_sub_empty_index() {
456        let idx = ProductSearchIndex::new(cfg(DistanceMetric::L2));
457        let results = idx.search_sub(&[1.0], 0, 5);
458        assert!(results.is_empty());
459    }
460
461    // ── combined_score correctness ────────────────────────────────────────────
462
463    #[test]
464    fn test_combined_score_is_mean_of_scores() {
465        let mut idx = ProductSearchIndex::new(ProductSearchConfig {
466            sub_dimensions: 2,
467            distance_metric: DistanceMetric::Cosine,
468        });
469        idx.insert(mv(1, vec2(&[1.0, 0.0], &[1.0, 0.0])));
470        let q = mv(0, vec2(&[1.0, 0.0], &[1.0, 0.0]));
471        let results = idx.search(&q, 1);
472        let c = &results[0];
473        let expected = c.scores.iter().sum::<f32>() / c.scores.len() as f32;
474        assert!((c.combined_score - expected).abs() < 1e-5);
475    }
476}