Skip to main content

oxirs_vec/
embedding_similarity.rs

1//! Vector embedding similarity metrics and nearest-neighbour utilities.
2//!
3//! Provides `EmbeddingSimilarity` with multiple distance / similarity metrics
4//! and a `top_k` search over an in-memory corpus.
5
6// ── SimilarityMetric ──────────────────────────────────────────────────────────
7
8/// A distance or similarity measure between two embedding vectors.
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
10pub enum SimilarityMetric {
11    /// Cosine similarity (1 = identical direction, -1 = opposite).
12    Cosine,
13    /// Raw dot product.
14    DotProduct,
15    /// Euclidean (L2) distance — converted to a similarity score as `1 / (1 + d)`.
16    Euclidean,
17    /// Manhattan (L1) distance — similarity as `1 / (1 + d)`.
18    Manhattan,
19    /// Chebyshev (L∞) distance — similarity as `1 / (1 + d)`.
20    Chebyshev,
21}
22
23// ── SimilarityResult ──────────────────────────────────────────────────────────
24
25/// A single result from a nearest-neighbour search.
26#[derive(Debug, Clone, PartialEq)]
27pub struct SimilarityResult {
28    /// Index of the corpus vector.
29    pub index: usize,
30    /// Similarity score (higher = more similar for all metrics).
31    pub score: f64,
32    /// Optional human-readable label.
33    pub label: Option<String>,
34}
35
36// ── EmbeddingSimilarity ───────────────────────────────────────────────────────
37
38/// Utility functions for vector embedding similarity.
39pub struct EmbeddingSimilarity;
40
41impl EmbeddingSimilarity {
42    // ── Individual metrics ────────────────────────────────────────────────────
43
44    /// Cosine similarity between two vectors.
45    ///
46    /// Returns `0.0` if either vector is the zero vector.
47    pub fn cosine(a: &[f64], b: &[f64]) -> f64 {
48        let dot = Self::dot_product(a, b);
49        let norm_a = Self::l2_norm(a);
50        let norm_b = Self::l2_norm(b);
51        if norm_a == 0.0 || norm_b == 0.0 {
52            return 0.0;
53        }
54        dot / (norm_a * norm_b)
55    }
56
57    /// Raw dot product of two vectors.
58    pub fn dot_product(a: &[f64], b: &[f64]) -> f64 {
59        a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
60    }
61
62    /// Euclidean distance converted to a similarity score `1 / (1 + distance)`.
63    pub fn euclidean(a: &[f64], b: &[f64]) -> f64 {
64        let dist: f64 = a
65            .iter()
66            .zip(b.iter())
67            .map(|(x, y)| (x - y).powi(2))
68            .sum::<f64>()
69            .sqrt();
70        1.0 / (1.0 + dist)
71    }
72
73    /// Manhattan (L1) distance converted to a similarity score `1 / (1 + distance)`.
74    pub fn manhattan(a: &[f64], b: &[f64]) -> f64 {
75        let dist: f64 = a.iter().zip(b.iter()).map(|(x, y)| (x - y).abs()).sum();
76        1.0 / (1.0 + dist)
77    }
78
79    /// Chebyshev (L∞) distance converted to a similarity score `1 / (1 + distance)`.
80    pub fn chebyshev(a: &[f64], b: &[f64]) -> f64 {
81        let dist = a
82            .iter()
83            .zip(b.iter())
84            .map(|(x, y)| (x - y).abs())
85            .fold(0.0_f64, f64::max);
86        1.0 / (1.0 + dist)
87    }
88
89    // ── Dispatch ──────────────────────────────────────────────────────────────
90
91    /// Compute the similarity between `a` and `b` using the given `metric`.
92    pub fn compute(a: &[f64], b: &[f64], metric: SimilarityMetric) -> f64 {
93        match metric {
94            SimilarityMetric::Cosine => Self::cosine(a, b),
95            SimilarityMetric::DotProduct => Self::dot_product(a, b),
96            SimilarityMetric::Euclidean => Self::euclidean(a, b),
97            SimilarityMetric::Manhattan => Self::manhattan(a, b),
98            SimilarityMetric::Chebyshev => Self::chebyshev(a, b),
99        }
100    }
101
102    // ── Top-k search ──────────────────────────────────────────────────────────
103
104    /// Return the top-`k` most similar vectors from `corpus` relative to `query`.
105    ///
106    /// Results are sorted descending by score (most similar first).
107    /// If `k` exceeds `corpus.len()`, all entries are returned.
108    pub fn top_k(
109        query: &[f64],
110        corpus: &[Vec<f64>],
111        k: usize,
112        metric: SimilarityMetric,
113    ) -> Vec<SimilarityResult> {
114        let mut scored: Vec<SimilarityResult> = corpus
115            .iter()
116            .enumerate()
117            .map(|(i, v)| SimilarityResult {
118                index: i,
119                score: Self::compute(query, v, metric),
120                label: None,
121            })
122            .collect();
123
124        // Sort descending by score.
125        scored.sort_by(|a, b| {
126            b.score
127                .partial_cmp(&a.score)
128                .unwrap_or(std::cmp::Ordering::Equal)
129        });
130        scored.truncate(k);
131        scored
132    }
133
134    // ── Normalisation ─────────────────────────────────────────────────────────
135
136    /// L2-normalise a vector (return a unit vector).
137    ///
138    /// If the input is the zero vector, returns a zero vector.
139    pub fn normalize(v: &[f64]) -> Vec<f64> {
140        let norm = Self::l2_norm(v);
141        if norm == 0.0 {
142            return vec![0.0; v.len()];
143        }
144        v.iter().map(|x| x / norm).collect()
145    }
146
147    // ── Pairwise matrix ───────────────────────────────────────────────────────
148
149    /// Compute the N×N pairwise similarity matrix for `corpus`.
150    ///
151    /// `result[i][j]` is the similarity between `corpus[i]` and `corpus[j]`.
152    pub fn pairwise(corpus: &[Vec<f64>], metric: SimilarityMetric) -> Vec<Vec<f64>> {
153        let n = corpus.len();
154        let mut matrix = vec![vec![0.0_f64; n]; n];
155        for i in 0..n {
156            for j in 0..n {
157                matrix[i][j] = Self::compute(&corpus[i], &corpus[j], metric);
158            }
159        }
160        matrix
161    }
162
163    // ── Private ───────────────────────────────────────────────────────────────
164
165    fn l2_norm(v: &[f64]) -> f64 {
166        v.iter().map(|x| x * x).sum::<f64>().sqrt()
167    }
168}
169
170// ── Tests ─────────────────────────────────────────────────────────────────────
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    const EPS: f64 = 1e-9;
177
178    fn approx_eq(a: f64, b: f64) -> bool {
179        (a - b).abs() < 1e-6
180    }
181
182    // ── cosine ────────────────────────────────────────────────────────────────
183
184    #[test]
185    fn test_cosine_identical() {
186        let v = vec![1.0, 2.0, 3.0];
187        let s = EmbeddingSimilarity::cosine(&v, &v);
188        assert!(approx_eq(s, 1.0));
189    }
190
191    #[test]
192    fn test_cosine_opposite() {
193        let a = vec![1.0, 0.0];
194        let b = vec![-1.0, 0.0];
195        let s = EmbeddingSimilarity::cosine(&a, &b);
196        assert!(approx_eq(s, -1.0));
197    }
198
199    #[test]
200    fn test_cosine_orthogonal() {
201        let a = vec![1.0, 0.0];
202        let b = vec![0.0, 1.0];
203        let s = EmbeddingSimilarity::cosine(&a, &b);
204        assert!(approx_eq(s, 0.0));
205    }
206
207    #[test]
208    fn test_cosine_zero_vector() {
209        let a = vec![0.0, 0.0];
210        let b = vec![1.0, 2.0];
211        let s = EmbeddingSimilarity::cosine(&a, &b);
212        assert_eq!(s, 0.0);
213    }
214
215    #[test]
216    fn test_cosine_range() {
217        let a = vec![1.0, 2.0, 3.0];
218        let b = vec![4.0, 5.0, 6.0];
219        let s = EmbeddingSimilarity::cosine(&a, &b);
220        assert!((-1.0..=1.0).contains(&s));
221    }
222
223    // ── dot_product ───────────────────────────────────────────────────────────
224
225    #[test]
226    fn test_dot_product_basic() {
227        let a = vec![1.0, 2.0, 3.0];
228        let b = vec![4.0, 5.0, 6.0];
229        let d = EmbeddingSimilarity::dot_product(&a, &b);
230        assert!(approx_eq(d, 32.0));
231    }
232
233    #[test]
234    fn test_dot_product_zero() {
235        let a = vec![1.0, 0.0];
236        let b = vec![0.0, 1.0];
237        assert!(approx_eq(EmbeddingSimilarity::dot_product(&a, &b), 0.0));
238    }
239
240    #[test]
241    fn test_dot_product_negative() {
242        let a = vec![1.0, -1.0];
243        let b = vec![1.0, 1.0];
244        assert!(approx_eq(EmbeddingSimilarity::dot_product(&a, &b), 0.0));
245    }
246
247    // ── euclidean ─────────────────────────────────────────────────────────────
248
249    #[test]
250    fn test_euclidean_identical() {
251        let v = vec![1.0, 2.0, 3.0];
252        let s = EmbeddingSimilarity::euclidean(&v, &v);
253        assert!(approx_eq(s, 1.0)); // distance=0 → similarity=1
254    }
255
256    #[test]
257    fn test_euclidean_unit_apart() {
258        let a = vec![0.0];
259        let b = vec![1.0];
260        // distance = 1 → similarity = 0.5
261        let s = EmbeddingSimilarity::euclidean(&a, &b);
262        assert!(approx_eq(s, 0.5));
263    }
264
265    #[test]
266    fn test_euclidean_positive() {
267        let a = vec![1.0, 2.0];
268        let b = vec![4.0, 6.0];
269        let s = EmbeddingSimilarity::euclidean(&a, &b);
270        assert!(s > 0.0 && s < 1.0);
271    }
272
273    // ── manhattan ─────────────────────────────────────────────────────────────
274
275    #[test]
276    fn test_manhattan_identical() {
277        let v = vec![1.0, 2.0];
278        let s = EmbeddingSimilarity::manhattan(&v, &v);
279        assert!(approx_eq(s, 1.0));
280    }
281
282    #[test]
283    fn test_manhattan_unit_apart() {
284        let a = vec![0.0];
285        let b = vec![1.0];
286        assert!(approx_eq(EmbeddingSimilarity::manhattan(&a, &b), 0.5));
287    }
288
289    #[test]
290    fn test_manhattan_positive() {
291        let a = vec![0.0, 0.0];
292        let b = vec![3.0, 4.0];
293        // L1 distance = 7, similarity = 1/8
294        let s = EmbeddingSimilarity::manhattan(&a, &b);
295        assert!(approx_eq(s, 1.0 / 8.0));
296    }
297
298    // ── chebyshev ─────────────────────────────────────────────────────────────
299
300    #[test]
301    fn test_chebyshev_identical() {
302        let v = vec![1.0, 2.0, 3.0];
303        let s = EmbeddingSimilarity::chebyshev(&v, &v);
304        assert!(approx_eq(s, 1.0));
305    }
306
307    #[test]
308    fn test_chebyshev_picks_max() {
309        let a = vec![0.0, 0.0];
310        let b = vec![1.0, 5.0];
311        // Chebyshev distance = max(1, 5) = 5, similarity = 1/6
312        let s = EmbeddingSimilarity::chebyshev(&a, &b);
313        assert!(approx_eq(s, 1.0 / 6.0));
314    }
315
316    #[test]
317    fn test_chebyshev_positive() {
318        let a = vec![1.0, 2.0];
319        let b = vec![4.0, 3.0];
320        let s = EmbeddingSimilarity::chebyshev(&a, &b);
321        assert!(s > 0.0 && s < 1.0);
322    }
323
324    // ── compute (dispatch) ────────────────────────────────────────────────────
325
326    #[test]
327    fn test_compute_cosine() {
328        let a = vec![1.0, 0.0];
329        let b = vec![1.0, 0.0];
330        let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Cosine);
331        assert!(approx_eq(s, 1.0));
332    }
333
334    #[test]
335    fn test_compute_dot_product() {
336        let a = vec![2.0, 3.0];
337        let b = vec![4.0, 5.0];
338        let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::DotProduct);
339        assert!(approx_eq(s, 23.0));
340    }
341
342    #[test]
343    fn test_compute_euclidean() {
344        let a = vec![0.0];
345        let b = vec![1.0];
346        let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Euclidean);
347        assert!(approx_eq(s, 0.5));
348    }
349
350    #[test]
351    fn test_compute_manhattan() {
352        let a = vec![0.0];
353        let b = vec![1.0];
354        let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Manhattan);
355        assert!(approx_eq(s, 0.5));
356    }
357
358    #[test]
359    fn test_compute_chebyshev() {
360        let a = vec![0.0, 0.0];
361        let b = vec![2.0, 3.0];
362        let s = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Chebyshev);
363        assert!(approx_eq(s, 1.0 / 4.0)); // max dist = 3, sim = 1/4
364    }
365
366    // ── normalize ─────────────────────────────────────────────────────────────
367
368    #[test]
369    fn test_normalize_unit_length() {
370        let v = vec![3.0, 4.0];
371        let n = EmbeddingSimilarity::normalize(&v);
372        let norm: f64 = n.iter().map(|x| x * x).sum::<f64>().sqrt();
373        assert!(approx_eq(norm, 1.0));
374    }
375
376    #[test]
377    fn test_normalize_zero_vector() {
378        let v = vec![0.0, 0.0];
379        let n = EmbeddingSimilarity::normalize(&v);
380        assert!(n.iter().all(|&x| x == 0.0));
381    }
382
383    #[test]
384    fn test_normalize_already_unit() {
385        let v = vec![1.0, 0.0];
386        let n = EmbeddingSimilarity::normalize(&v);
387        assert!(approx_eq(n[0], 1.0));
388        assert!(approx_eq(n[1], 0.0));
389    }
390
391    #[test]
392    fn test_normalize_preserves_direction() {
393        let v = vec![1.0, 1.0];
394        let n = EmbeddingSimilarity::normalize(&v);
395        assert!(approx_eq(n[0], n[1]));
396    }
397
398    // ── top_k ─────────────────────────────────────────────────────────────────
399
400    #[test]
401    fn test_top_k_returns_k_results() {
402        let query = vec![1.0, 0.0];
403        let corpus = vec![
404            vec![1.0, 0.0],
405            vec![0.0, 1.0],
406            vec![-1.0, 0.0],
407            vec![0.5, 0.5],
408        ];
409        let results = EmbeddingSimilarity::top_k(&query, &corpus, 2, SimilarityMetric::Cosine);
410        assert_eq!(results.len(), 2);
411    }
412
413    #[test]
414    fn test_top_k_sorted_descending() {
415        let query = vec![1.0, 0.0];
416        let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![-1.0, 0.0]];
417        let results = EmbeddingSimilarity::top_k(&query, &corpus, 3, SimilarityMetric::Cosine);
418        for i in 0..results.len() - 1 {
419            assert!(results[i].score >= results[i + 1].score);
420        }
421    }
422
423    #[test]
424    fn test_top_k_best_is_identical() {
425        let query = vec![1.0, 2.0, 3.0];
426        let corpus = vec![vec![1.0, 2.0, 3.0], vec![0.0, 0.0, 1.0]];
427        let results = EmbeddingSimilarity::top_k(&query, &corpus, 1, SimilarityMetric::Cosine);
428        assert_eq!(results[0].index, 0);
429    }
430
431    #[test]
432    fn test_top_k_empty_corpus() {
433        let query = vec![1.0, 0.0];
434        let results = EmbeddingSimilarity::top_k(&query, &[], 5, SimilarityMetric::Euclidean);
435        assert!(results.is_empty());
436    }
437
438    #[test]
439    fn test_top_k_k_larger_than_corpus() {
440        let query = vec![1.0];
441        let corpus = vec![vec![1.0], vec![2.0]];
442        let results = EmbeddingSimilarity::top_k(&query, &corpus, 100, SimilarityMetric::Euclidean);
443        assert_eq!(results.len(), 2);
444    }
445
446    // ── pairwise ──────────────────────────────────────────────────────────────
447
448    #[test]
449    fn test_pairwise_dimensions() {
450        let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![1.0, 1.0]];
451        let m = EmbeddingSimilarity::pairwise(&corpus, SimilarityMetric::Cosine);
452        assert_eq!(m.len(), 3);
453        assert_eq!(m[0].len(), 3);
454    }
455
456    #[test]
457    fn test_pairwise_diagonal_is_max_cosine() {
458        let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
459        let m = EmbeddingSimilarity::pairwise(&corpus, SimilarityMetric::Cosine);
460        // Diagonal should be 1.0 (identical vectors)
461        assert!(approx_eq(m[0][0], 1.0));
462        assert!(approx_eq(m[1][1], 1.0));
463    }
464
465    #[test]
466    fn test_pairwise_symmetric_cosine() {
467        let corpus = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
468        let m = EmbeddingSimilarity::pairwise(&corpus, SimilarityMetric::Cosine);
469        assert!(approx_eq(m[0][1], m[1][0]));
470    }
471
472    #[test]
473    fn test_pairwise_empty_corpus() {
474        let m = EmbeddingSimilarity::pairwise(&[], SimilarityMetric::Cosine);
475        assert!(m.is_empty());
476    }
477
478    // ── SimilarityResult ──────────────────────────────────────────────────────
479
480    #[test]
481    fn test_similarity_result_fields() {
482        let r = SimilarityResult {
483            index: 5,
484            score: 0.95,
485            label: Some("example".to_string()),
486        };
487        assert_eq!(r.index, 5);
488        assert!((r.score - 0.95).abs() < EPS);
489        assert_eq!(r.label, Some("example".to_string()));
490    }
491
492    #[test]
493    fn test_similarity_result_no_label() {
494        let r = SimilarityResult {
495            index: 0,
496            score: 1.0,
497            label: None,
498        };
499        assert!(r.label.is_none());
500    }
501
502    #[test]
503    fn test_similarity_result_clone() {
504        let r = SimilarityResult {
505            index: 1,
506            score: 0.5,
507            label: None,
508        };
509        assert_eq!(r, r.clone());
510    }
511
512    // ── SimilarityMetric ──────────────────────────────────────────────────────
513
514    #[test]
515    fn test_metric_copy() {
516        let m = SimilarityMetric::Cosine;
517        let m2 = m;
518        assert_eq!(m, m2);
519    }
520
521    #[test]
522    fn test_metric_debug() {
523        let s = format!("{:?}", SimilarityMetric::DotProduct);
524        assert!(s.contains("DotProduct"));
525    }
526
527    #[test]
528    fn test_chebyshev_identical_vectors() {
529        let a = vec![1.0, 2.0, 3.0];
530        let sim = EmbeddingSimilarity::compute(&a, &a, SimilarityMetric::Chebyshev);
531        // Distance = 0 → similarity = 1/(1+0) = 1.0
532        assert!(approx_eq(sim, 1.0));
533    }
534
535    #[test]
536    fn test_manhattan_orthogonal() {
537        let a = vec![1.0, 0.0];
538        let b = vec![0.0, 1.0];
539        let sim = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::Manhattan);
540        // L1 distance = 2 → similarity = 1/3
541        assert!((sim - 1.0 / 3.0).abs() < EPS);
542    }
543
544    #[test]
545    fn test_dot_product_negative_components() {
546        let a = vec![-1.0, -1.0];
547        let b = vec![-1.0, -1.0];
548        let sim = EmbeddingSimilarity::compute(&a, &b, SimilarityMetric::DotProduct);
549        // dot(a, b) = 1+1 = 2
550        assert!(approx_eq(sim, 2.0));
551    }
552
553    #[test]
554    fn test_top_k_all_metrics() {
555        let query = vec![1.0, 0.0];
556        let corpus = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
557        for metric in [
558            SimilarityMetric::Cosine,
559            SimilarityMetric::Euclidean,
560            SimilarityMetric::Manhattan,
561            SimilarityMetric::Chebyshev,
562        ] {
563            let results = EmbeddingSimilarity::top_k(&query, &corpus, 2, metric);
564            assert_eq!(
565                results.len(),
566                2,
567                "metric {:?} should return 2 results",
568                metric
569            );
570        }
571    }
572
573    #[test]
574    fn test_similarity_result_debug() {
575        let r = SimilarityResult {
576            index: 0,
577            score: 0.9,
578            label: None,
579        };
580        let s = format!("{r:?}");
581        assert!(s.contains("SimilarityResult"));
582    }
583}