Skip to main content

ripvec_core/
similarity.rs

1//! Cosine similarity computation and ranking.
2//!
3//! Since all embeddings are L2-normalized, cosine similarity equals
4//! the dot product — no square roots needed at query time.
5
6/// Cosine similarity between two L2-normalized vectors (= dot product).
7///
8/// Both slices must have the same length. A `debug_assert` fires on mismatch
9/// in debug builds; in release builds, mismatched lengths silently produce a
10/// truncated (and therefore incorrect) result.
11#[must_use]
12pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
13    debug_assert_eq!(a.len(), b.len(), "dot_product: vector length mismatch");
14    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
15}
16
17/// Rank all chunks against a query embedding using matrix-vector multiply.
18///
19/// Computes `embeddings @ query` where `embeddings` is `[num_chunks, hidden_dim]`
20/// and `query` is `[hidden_dim]`. Returns similarity scores in chunk order.
21///
22/// Uses ndarray's optimized matmul (SIMD-accelerated via `matrixmultiply` crate).
23#[must_use]
24pub fn rank_all(embeddings: &ndarray::Array2<f32>, query: &ndarray::Array1<f32>) -> Vec<f32> {
25    embeddings.dot(query).to_vec()
26}
27
28#[cfg(test)]
29mod tests {
30    use super::*;
31    use ndarray::{Array1, Array2};
32
33    #[test]
34    fn identical_normalized_vectors() {
35        let v = vec![0.5773, 0.5773, 0.5773];
36        let sim = dot_product(&v, &v);
37        assert!((sim - 1.0).abs() < 0.01);
38    }
39
40    #[test]
41    fn orthogonal_vectors() {
42        let a = vec![1.0, 0.0, 0.0];
43        let b = vec![0.0, 1.0, 0.0];
44        let sim = dot_product(&a, &b);
45        assert!((sim).abs() < 1e-6);
46    }
47
48    #[test]
49    fn opposite_vectors() {
50        let a = vec![1.0, 0.0];
51        let b = vec![-1.0, 0.0];
52        let sim = dot_product(&a, &b);
53        assert!((sim + 1.0).abs() < 1e-6);
54    }
55
56    #[test]
57    fn rank_all_matches_scalar_dot_product() {
58        // 4 chunks, 3-dimensional embeddings
59        let data = vec![
60            1.0, 0.0, 0.0, // chunk 0
61            0.0, 1.0, 0.0, // chunk 1
62            0.5773, 0.5773, 0.5773, // chunk 2
63            -1.0, 0.0, 0.0, // chunk 3
64        ];
65        let embeddings = Array2::from_shape_vec((4, 3), data.clone()).unwrap();
66        let query = Array1::from_vec(vec![1.0, 0.0, 0.0]);
67
68        let scores = rank_all(&embeddings, &query);
69
70        // Compare against scalar dot_product over each row
71        for (i, score) in scores.iter().enumerate() {
72            let row = &data[i * 3..(i + 1) * 3];
73            let expected = dot_product(row, query.as_slice().unwrap());
74            assert!(
75                (score - expected).abs() < 1e-6,
76                "mismatch at chunk {i}: rank_all={score}, dot_product={expected}"
77            );
78        }
79    }
80
81    #[test]
82    fn rank_all_empty_matrix() {
83        let embeddings = Array2::from_shape_vec((0, 384), vec![]).unwrap();
84        let query = Array1::zeros(384);
85        let scores = rank_all(&embeddings, &query);
86        assert!(scores.is_empty());
87    }
88
89    #[test]
90    fn rank_all_known_values() {
91        // 2x2 matrix: [[1, 2], [3, 4]] dot [1, 0] = [1, 3]
92        let embeddings = Array2::from_shape_vec((2, 2), vec![1.0, 2.0, 3.0, 4.0]).unwrap();
93        let query = Array1::from_vec(vec![1.0, 0.0]);
94        let scores = rank_all(&embeddings, &query);
95        assert!((scores[0] - 1.0).abs() < 1e-6);
96        assert!((scores[1] - 3.0).abs() < 1e-6);
97
98        // Same matrix dot [0, 1] = [2, 4]
99        let query2 = Array1::from_vec(vec![0.0, 1.0]);
100        let scores2 = rank_all(&embeddings, &query2);
101        assert!((scores2[0] - 2.0).abs() < 1e-6);
102        assert!((scores2[1] - 4.0).abs() < 1e-6);
103    }
104}