nodedb_vector/multivec/scoring.rs
1// SPDX-License-Identifier: Apache-2.0
2
3//! MaxSim aggregation for multi-vector late interaction (ColBERT / MetaEmbed).
4//!
5//! `score(Q, D) = Σᵢ maxⱼ sim(qᵢ, dⱼ)`
6//!
7//! `budgeted_maxsim` restricts the query side to the first `budget` vectors,
8//! enabling Matryoshka ordering for latency-elastic late interaction.
9
10use nodedb_types::vector_distance::DistanceMetric;
11
12use crate::distance::scalar::scalar_distance;
13
14// ---------------------------------------------------------------------------
15// Similarity helpers
16// ---------------------------------------------------------------------------
17
18/// Convert a distance value produced by `scalar_distance` into a similarity
19/// score where higher is better.
20///
21/// | Metric | Conversion |
22/// |----------------|---------------------|
23/// | L2 | `−distance` |
24/// | Cosine | `1 − distance` |
25/// | InnerProduct | `−distance` (scalar_distance returns neg-IP) |
26/// | All others | `−distance` |
27fn dist_to_sim(d: f32, metric: DistanceMetric) -> f32 {
28 match metric {
29 DistanceMetric::Cosine => 1.0 - d,
30 // All other metrics (and unknown future metrics) use negated distance.
31 _ => -d,
32 }
33}
34
35// ---------------------------------------------------------------------------
36// Public API
37// ---------------------------------------------------------------------------
38
39/// Compute the MaxSim score between a query multi-vector and a document
40/// multi-vector.
41///
42/// `score(Q, D) = Σᵢ maxⱼ sim(qᵢ, dⱼ)`
43///
44/// Returns 0.0 if either side is empty.
45pub fn maxsim(query: &[Vec<f32>], doc: &[Vec<f32>], metric: DistanceMetric) -> f32 {
46 if query.is_empty() || doc.is_empty() {
47 return 0.0;
48 }
49 query
50 .iter()
51 .map(|q| {
52 doc.iter()
53 .map(|d| dist_to_sim(scalar_distance(q, d, metric), metric))
54 .fold(f32::NEG_INFINITY, f32::max)
55 })
56 .sum()
57}
58
59/// Budgeted MaxSim: only uses the first `budget` query vectors (Matryoshka
60/// ordering). When `budget` equals or exceeds `query.len()` this is
61/// equivalent to `maxsim`.
62///
63/// Returns 0.0 if either side is empty or budget is 0.
64pub fn budgeted_maxsim(
65 query: &[Vec<f32>],
66 doc: &[Vec<f32>],
67 budget: u8,
68 metric: DistanceMetric,
69) -> f32 {
70 let effective = (budget as usize).min(query.len());
71 maxsim(&query[..effective], doc, metric)
72}
73
74// ---------------------------------------------------------------------------
75// Tests
76// ---------------------------------------------------------------------------
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 /// Build a unit vector of dimension `dim` with a single non-zero entry at
83 /// position `pos`.
84 fn unit_vec(dim: usize, pos: usize) -> Vec<f32> {
85 let mut v = vec![0.0f32; dim];
86 v[pos] = 1.0;
87 v
88 }
89
90 #[test]
91 fn maxsim_identical_query_doc_l2() {
92 // When every query vector equals the corresponding doc vector the
93 // MaxSim for each query vector is dist_to_sim(0.0, L2) = 0.0.
94 // Sum across 3 query vectors = 0.0.
95 let q = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
96 let d = q.clone();
97 let score = maxsim(&q, &d, DistanceMetric::L2);
98 // L2 squared of identical vectors is 0; sim = -0 = 0.
99 assert!((score - 0.0).abs() < 1e-6, "score={score}");
100 }
101
102 #[test]
103 fn maxsim_identical_query_doc_cosine() {
104 // Cosine of identical unit vectors = 0 distance → sim = 1.
105 // 3 query vectors → sum = 3.
106 let q = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
107 let d = q.clone();
108 let score = maxsim(&q, &d, DistanceMetric::Cosine);
109 assert!((score - 3.0).abs() < 1e-5, "score={score}");
110 }
111
112 #[test]
113 fn maxsim_empty_query_returns_zero() {
114 let d = vec![unit_vec(4, 0)];
115 assert_eq!(maxsim(&[], &d, DistanceMetric::Cosine), 0.0);
116 }
117
118 #[test]
119 fn maxsim_empty_doc_returns_zero() {
120 let q = vec![unit_vec(4, 0)];
121 assert_eq!(maxsim(&q, &[], DistanceMetric::Cosine), 0.0);
122 }
123
124 #[test]
125 fn budgeted_vs_full_differ_when_extra_tokens_are_relevant() {
126 // 4 query vectors: first 2 match doc perfectly (cosine dist ≈ 0 → sim = 1),
127 // last 2 are orthogonal to everything in doc (sim < 1).
128 let q = vec![
129 unit_vec(4, 0),
130 unit_vec(4, 1),
131 unit_vec(4, 2),
132 unit_vec(4, 3),
133 ];
134 let d = vec![unit_vec(4, 0), unit_vec(4, 1)];
135
136 let full = maxsim(&q, &d, DistanceMetric::Cosine);
137 let budgeted = budgeted_maxsim(&q, &d, 2, DistanceMetric::Cosine);
138
139 // Full uses all 4 query vectors; budgeted uses only the first 2.
140 // The third and fourth query vectors are orthogonal to d[0] and d[1]
141 // so their max-sim is 0 (cosine dist = 1 → sim = 0).
142 // Therefore full = budgeted = 2.0 in this specific case.
143 // To see a difference, add a doc vector that matches q[2].
144 let d2 = vec![unit_vec(4, 0), unit_vec(4, 1), unit_vec(4, 2)];
145 let full2 = maxsim(&q, &d2, DistanceMetric::Cosine);
146 let budgeted2 = budgeted_maxsim(&q, &d2, 2, DistanceMetric::Cosine);
147
148 // full2 uses q[2] which now matches d2[2] → full2 = 3.0 (q[3] still 0).
149 // budgeted2 only uses q[0..2] → 2.0.
150 assert!(
151 (full - budgeted).abs() < 1e-5,
152 "first case should be equal: full={full} budgeted={budgeted}"
153 );
154 assert!(
155 full2 > budgeted2,
156 "full2 should exceed budgeted2: full2={full2} budgeted2={budgeted2}"
157 );
158 }
159
160 #[test]
161 fn budgeted_zero_budget_returns_zero() {
162 let q = vec![unit_vec(4, 0)];
163 let d = vec![unit_vec(4, 0)];
164 assert_eq!(budgeted_maxsim(&q, &d, 0, DistanceMetric::Cosine), 0.0);
165 }
166
167 #[test]
168 fn budgeted_exceeds_query_length_equals_full() {
169 let q = vec![unit_vec(4, 0), unit_vec(4, 1)];
170 let d = q.clone();
171 let full = maxsim(&q, &d, DistanceMetric::Cosine);
172 let budgeted = budgeted_maxsim(&q, &d, 255, DistanceMetric::Cosine);
173 assert!((full - budgeted).abs() < 1e-6);
174 }
175}