Skip to main content

nodedb_vector/multivec/
scoring.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! MaxSim aggregation for multi-vector late interaction (ColBERT / MetaEmbed).
4//!
5//! `score(Q, D) = Σᵢ maxⱼ sim(qᵢ, dⱼ)`
6//!
7//! `budgeted_maxsim` restricts the query side to the first `budget` vectors,
8//! enabling Matryoshka ordering for latency-elastic late interaction.
9
10use nodedb_types::vector_distance::DistanceMetric;
11
12use crate::distance::scalar::scalar_distance;
13
14// ---------------------------------------------------------------------------
15// Similarity helpers
16// ---------------------------------------------------------------------------
17
18/// Convert a distance value produced by `scalar_distance` into a similarity
19/// score where higher is better.
20///
21/// | Metric         | Conversion          |
22/// |----------------|---------------------|
23/// | L2             | `−distance`         |
24/// | Cosine         | `1 − distance`      |
25/// | InnerProduct   | `−distance` (scalar_distance returns neg-IP) |
26/// | All others     | `−distance`         |
27fn dist_to_sim(d: f32, metric: DistanceMetric) -> f32 {
28    match metric {
29        DistanceMetric::Cosine => 1.0 - d,
30        // All other metrics (and unknown future metrics) use negated distance.
31        _ => -d,
32    }
33}
34
35// ---------------------------------------------------------------------------
36// Public API
37// ---------------------------------------------------------------------------
38
39/// Compute the MaxSim score between a query multi-vector and a document
40/// multi-vector.
41///
42/// `score(Q, D) = Σᵢ maxⱼ sim(qᵢ, dⱼ)`
43///
44/// Returns 0.0 if either side is empty.
45pub fn maxsim(query: &[Vec<f32>], doc: &[Vec<f32>], metric: DistanceMetric) -> f32 {
46    if query.is_empty() || doc.is_empty() {
47        return 0.0;
48    }
49    query
50        .iter()
51        .map(|q| {
52            doc.iter()
53                .map(|d| dist_to_sim(scalar_distance(q, d, metric), metric))
54                .fold(f32::NEG_INFINITY, f32::max)
55        })
56        .sum()
57}
58
59/// Budgeted MaxSim: only uses the first `budget` query vectors (Matryoshka
60/// ordering).  When `budget` equals or exceeds `query.len()` this is
61/// equivalent to `maxsim`.
62///
63/// Returns 0.0 if either side is empty or budget is 0.
64pub fn budgeted_maxsim(
65    query: &[Vec<f32>],
66    doc: &[Vec<f32>],
67    budget: u8,
68    metric: DistanceMetric,
69) -> f32 {
70    let effective = (budget as usize).min(query.len());
71    maxsim(&query[..effective], doc, metric)
72}
73
74// ---------------------------------------------------------------------------
75// Tests
76// ---------------------------------------------------------------------------
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    /// Build a unit vector of dimension `dim` with a single non-zero entry at
83    /// position `pos`.
84    fn unit_vec(dim: usize, pos: usize) -> Vec<f32> {
85        let mut v = vec![0.0f32; dim];
86        v[pos] = 1.0;
87        v
88    }
89
90    #[test]
91    fn maxsim_identical_query_doc_l2() {
92        // When every query vector equals the corresponding doc vector the
93        // MaxSim for each query vector is dist_to_sim(0.0, L2) = 0.0.
94        // Sum across 3 query vectors = 0.0.
95        let q = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
96        let d = q.clone();
97        let score = maxsim(&q, &d, DistanceMetric::L2);
98        // L2 squared of identical vectors is 0; sim = -0 = 0.
99        assert!((score - 0.0).abs() < 1e-6, "score={score}");
100    }
101
102    #[test]
103    fn maxsim_identical_query_doc_cosine() {
104        // Cosine of identical unit vectors = 0 distance → sim = 1.
105        // 3 query vectors → sum = 3.
106        let q = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
107        let d = q.clone();
108        let score = maxsim(&q, &d, DistanceMetric::Cosine);
109        assert!((score - 3.0).abs() < 1e-5, "score={score}");
110    }
111
112    #[test]
113    fn maxsim_empty_query_returns_zero() {
114        let d = vec![unit_vec(4, 0)];
115        assert_eq!(maxsim(&[], &d, DistanceMetric::Cosine), 0.0);
116    }
117
118    #[test]
119    fn maxsim_empty_doc_returns_zero() {
120        let q = vec![unit_vec(4, 0)];
121        assert_eq!(maxsim(&q, &[], DistanceMetric::Cosine), 0.0);
122    }
123
124    #[test]
125    fn budgeted_vs_full_differ_when_extra_tokens_are_relevant() {
126        // 4 query vectors: first 2 match doc perfectly (cosine dist ≈ 0 → sim = 1),
127        // last 2 are orthogonal to everything in doc (sim < 1).
128        let q = vec![
129            unit_vec(4, 0),
130            unit_vec(4, 1),
131            unit_vec(4, 2),
132            unit_vec(4, 3),
133        ];
134        let d = vec![unit_vec(4, 0), unit_vec(4, 1)];
135
136        let full = maxsim(&q, &d, DistanceMetric::Cosine);
137        let budgeted = budgeted_maxsim(&q, &d, 2, DistanceMetric::Cosine);
138
139        // Full uses all 4 query vectors; budgeted uses only the first 2.
140        // The third and fourth query vectors are orthogonal to d[0] and d[1]
141        // so their max-sim is 0 (cosine dist = 1 → sim = 0).
142        // Therefore full = budgeted = 2.0 in this specific case.
143        // To see a difference, add a doc vector that matches q[2].
144        let d2 = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
145        let full2 = maxsim(&q, &d2, DistanceMetric::Cosine);
146        let budgeted2 = budgeted_maxsim(&q, &d2, 2, DistanceMetric::Cosine);
147
148        // full2 uses q[2] which now matches d2[2] → full2 = 3.0 (q[3] still 0).
149        // budgeted2 only uses q[0..2] → 2.0.
150        assert!(
151            (full - budgeted).abs() < 1e-5,
152            "first case should be equal: full={full} budgeted={budgeted}"
153        );
154        assert!(
155            full2 > budgeted2,
156            "full2 should exceed budgeted2: full2={full2} budgeted2={budgeted2}"
157        );
158    }
159
160    #[test]
161    fn budgeted_zero_budget_returns_zero() {
162        let q = vec![unit_vec(4, 0)];
163        let d = vec![unit_vec(4, 0)];
164        assert_eq!(budgeted_maxsim(&q, &d, 0, DistanceMetric::Cosine), 0.0);
165    }
166
167    #[test]
168    fn budgeted_exceeds_query_length_equals_full() {
169        let q = vec![unit_vec(4, 0), unit_vec(4, 1)];
170        let d = q.clone();
171        let full = maxsim(&q, &d, DistanceMetric::Cosine);
172        let budgeted = budgeted_maxsim(&q, &d, 255, DistanceMetric::Cosine);
173        assert!((full - budgeted).abs() < 1e-6);
174    }
175}