1use std::collections::BinaryHeap;
23
24use crate::distance::distance;
25use nodedb_types::vector_distance::DistanceMetric;
26
27#[derive(Debug, Clone)]
29pub struct MatryoshkaSpec {
30 pub truncation_dims: Vec<u32>,
33}
34
35impl MatryoshkaSpec {
36 pub fn new(mut truncation_dims: Vec<u32>) -> Self {
38 truncation_dims.sort_unstable();
39 truncation_dims.dedup();
40 Self { truncation_dims }
41 }
42
43 pub fn pick(&self, requested: Option<u32>) -> u32 {
49 let Some(req) = requested else {
50 return *self.truncation_dims.last().copied().get_or_insert(0);
51 };
52 self.truncation_dims
54 .iter()
55 .rev()
56 .find(|&&d| d <= req)
57 .copied()
58 .unwrap_or_else(|| self.truncation_dims.first().copied().unwrap_or(req))
59 }
60
61 pub fn is_valid(&self, dim: u32) -> bool {
63 self.truncation_dims.contains(&dim)
64 }
65}
66
67#[inline]
71pub fn truncate(v: &[f32], dim: usize) -> &[f32] {
72 &v[..dim.min(v.len())]
73}
74
75pub struct MatryoshkaSearchOptions {
77 pub coarse_dim: u32,
79 pub full_dim: u32,
81 pub oversample: u8,
84 pub k: usize,
86}
87
88#[derive(PartialEq)]
91struct HeapEntry {
92 dist: f32,
94 id: u32,
95 vec_idx: usize,
97}
98
99impl Eq for HeapEntry {}
100
101impl PartialOrd for HeapEntry {
102 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
103 Some(self.cmp(other))
104 }
105}
106
107impl Ord for HeapEntry {
108 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
109 self.dist
111 .partial_cmp(&other.dist)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 }
114}
115
116pub fn matryoshka_search<'a, I>(
132 candidates: I,
133 query: &[f32],
134 options: &MatryoshkaSearchOptions,
135 metric: DistanceMetric,
136) -> Vec<(u32, f32)>
137where
138 I: Iterator<Item = (u32, &'a [f32])>,
139{
140 let coarse = options.coarse_dim as usize;
141 let full = options.full_dim as usize;
142 let pool_size = (options.oversample as usize).max(1) * options.k.max(1);
143
144 let query_coarse = truncate(query, coarse);
145
146 let mut coarse_heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(pool_size + 1);
150 let mut survivor_vecs: Vec<Vec<f32>> = Vec::with_capacity(pool_size);
152
153 for (id, vec) in candidates {
154 let vec_coarse = truncate(vec, coarse);
155 let d = distance(query_coarse, vec_coarse, metric);
156
157 let should_insert = coarse_heap.len() < pool_size
159 || coarse_heap
160 .peek()
161 .map(|worst| d < worst.dist)
162 .unwrap_or(true);
163
164 if should_insert {
165 let vec_idx = survivor_vecs.len();
166 survivor_vecs.push(vec[..full.min(vec.len())].to_vec());
167
168 coarse_heap.push(HeapEntry {
169 dist: d,
170 id,
171 vec_idx,
172 });
173
174 if coarse_heap.len() > pool_size {
175 coarse_heap.pop();
178 }
179 }
180 }
181
182 let query_full = truncate(query, full);
184
185 let mut reranked: Vec<(u32, f32)> = coarse_heap
186 .into_iter()
187 .map(|entry| {
188 let full_vec = &survivor_vecs[entry.vec_idx];
189 let d_full = distance(query_full, full_vec.as_slice(), metric);
190 (entry.id, d_full)
191 })
192 .collect();
193
194 reranked.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
195 reranked.truncate(options.k);
196 reranked
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[test]
206 fn pick_largest_leq_requested() {
207 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
208 assert_eq!(spec.pick(Some(300)), 256);
209 }
210
211 #[test]
212 fn pick_exact_match() {
213 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
214 assert_eq!(spec.pick(Some(512)), 512);
215 }
216
217 #[test]
218 fn pick_none_returns_full_dim() {
219 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
220 assert_eq!(spec.pick(None), 1024);
221 }
222
223 #[test]
224 fn pick_smaller_than_all_returns_smallest() {
225 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
226 assert_eq!(spec.pick(Some(10)), 256);
227 }
228
229 #[test]
232 fn is_valid_known_dim() {
233 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
234 assert!(spec.is_valid(512));
235 }
236
237 #[test]
238 fn is_valid_unknown_dim() {
239 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
240 assert!(!spec.is_valid(100));
241 }
242
243 #[test]
246 fn truncate_clips_to_requested_dim() {
247 let v: Vec<f32> = (0..1536).map(|i| i as f32).collect();
248 let t = truncate(&v, 256);
249 assert_eq!(t.len(), 256);
250 assert_eq!(t[0], 0.0);
251 assert_eq!(t[255], 255.0);
252 }
253
254 #[test]
255 fn truncate_does_not_exceed_vec_len() {
256 let v = vec![1.0f32; 10];
257 let t = truncate(&v, 9999);
258 assert_eq!(t.len(), 10);
259 }
260
261 fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
265 (0..n)
266 .map(|i| {
267 (0..dim)
268 .map(|j| ((i * dim + j) as f32 * 0.01).sin())
269 .collect()
270 })
271 .collect()
272 }
273
274 #[test]
275 fn search_returns_k_results() {
276 let vecs = make_vecs(100, 128);
277 let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
278
279 let candidates = vecs
280 .iter()
281 .enumerate()
282 .map(|(i, v)| (i as u32, v.as_slice()));
283
284 let opts = MatryoshkaSearchOptions {
285 coarse_dim: 64,
286 full_dim: 128,
287 oversample: 3,
288 k: 10,
289 };
290
291 let results = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
292 assert_eq!(results.len(), 10, "expected exactly k=10 results");
293 }
294
295 #[test]
296 fn coarse_equal_to_full_matches_direct_search() {
297 let vecs = make_vecs(100, 128);
298 let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
299
300 let mut direct: Vec<(u32, f32)> = vecs
302 .iter()
303 .enumerate()
304 .map(|(i, v)| (i as u32, distance(&query, v.as_slice(), DistanceMetric::L2)))
305 .collect();
306 direct.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
307 direct.truncate(10);
308
309 let candidates = vecs
311 .iter()
312 .enumerate()
313 .map(|(i, v)| (i as u32, v.as_slice()));
314 let opts = MatryoshkaSearchOptions {
315 coarse_dim: 128,
316 full_dim: 128,
317 oversample: 1,
318 k: 10,
319 };
320 let mrl = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
321
322 let direct_ids: Vec<u32> = direct.iter().map(|(id, _)| *id).collect();
324 let mrl_ids: Vec<u32> = mrl.iter().map(|(id, _)| *id).collect();
325 assert_eq!(
326 direct_ids, mrl_ids,
327 "coarse==full should equal direct search"
328 );
329 }
330}