1use std::collections::BinaryHeap;
23
24use crate::distance::distance;
25use nodedb_types::vector_distance::DistanceMetric;
26
27#[derive(Debug, Clone)]
29pub struct MatryoshkaSpec {
30 pub truncation_dims: Vec<u32>,
33}
34
35impl MatryoshkaSpec {
36 pub fn new(mut truncation_dims: Vec<u32>) -> Self {
38 truncation_dims.sort_unstable();
39 truncation_dims.dedup();
40 Self { truncation_dims }
41 }
42
43 pub fn pick(&self, requested: Option<u32>) -> u32 {
49 let Some(req) = requested else {
50 return *self.truncation_dims.last().copied().get_or_insert(0);
51 };
52 self.truncation_dims
54 .iter()
55 .rev()
56 .find(|&&d| d <= req)
57 .copied()
58 .unwrap_or_else(|| self.truncation_dims.first().copied().unwrap_or(req))
59 }
60
61 pub fn is_valid(&self, dim: u32) -> bool {
63 self.truncation_dims.contains(&dim)
64 }
65}
66
67#[inline]
71pub fn truncate(v: &[f32], dim: usize) -> &[f32] {
72 &v[..dim.min(v.len())]
73}
74
75pub struct MatryoshkaSearchOptions {
77 pub coarse_dim: u32,
79 pub full_dim: u32,
81 pub oversample: u8,
84 pub k: usize,
86}
87
88#[derive(PartialEq)]
91struct HeapEntry {
92 dist: f32,
94 id: u32,
95 vec_idx: usize,
97}
98
99impl Eq for HeapEntry {}
100
101impl PartialOrd for HeapEntry {
102 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
103 Some(self.cmp(other))
104 }
105}
106
107impl Ord for HeapEntry {
108 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
109 self.dist
111 .partial_cmp(&other.dist)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 }
114}
115
116pub fn matryoshka_search<'a, I>(
132 candidates: I,
133 query: &[f32],
134 options: &MatryoshkaSearchOptions,
135 metric: DistanceMetric,
136) -> Vec<(u32, f32)>
137where
138 I: Iterator<Item = (u32, &'a [f32])>,
139{
140 let coarse = options.coarse_dim as usize;
141 let full = options.full_dim as usize;
142 let pool_size = (options.oversample as usize).max(1) * options.k.max(1);
143
144 let query_coarse = truncate(query, coarse);
145
146 let mut coarse_heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(pool_size + 1);
150 let mut survivor_vecs: Vec<Vec<f32>> = Vec::with_capacity(pool_size);
153
154 for (id, vec) in candidates {
155 let vec_coarse = truncate(vec, coarse);
156 let d = distance(query_coarse, vec_coarse, metric);
157
158 let should_insert = coarse_heap.len() < pool_size
160 || coarse_heap
161 .peek()
162 .map(|worst| d < worst.dist)
163 .unwrap_or(true);
164
165 if should_insert {
166 let vec_idx = survivor_vecs.len();
167 survivor_vecs.push(vec[..full.min(vec.len())].to_vec());
168
169 coarse_heap.push(HeapEntry {
170 dist: d,
171 id,
172 vec_idx,
173 });
174
175 if coarse_heap.len() > pool_size {
176 coarse_heap.pop();
179 }
180 }
181 }
182
183 let query_full = truncate(query, full);
185
186 let mut reranked: Vec<(u32, f32)> = coarse_heap
187 .into_iter()
188 .map(|entry| {
189 let full_vec = &survivor_vecs[entry.vec_idx];
190 let d_full = distance(query_full, full_vec.as_slice(), metric);
191 (entry.id, d_full)
192 })
193 .collect();
194
195 reranked.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
196 reranked.truncate(options.k);
197 reranked
198}
199
200#[cfg(test)]
201mod tests {
202 use super::*;
203
204 #[test]
207 fn pick_largest_leq_requested() {
208 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
209 assert_eq!(spec.pick(Some(300)), 256);
210 }
211
212 #[test]
213 fn pick_exact_match() {
214 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
215 assert_eq!(spec.pick(Some(512)), 512);
216 }
217
218 #[test]
219 fn pick_none_returns_full_dim() {
220 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
221 assert_eq!(spec.pick(None), 1024);
222 }
223
224 #[test]
225 fn pick_smaller_than_all_returns_smallest() {
226 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
227 assert_eq!(spec.pick(Some(10)), 256);
228 }
229
230 #[test]
233 fn is_valid_known_dim() {
234 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
235 assert!(spec.is_valid(512));
236 }
237
238 #[test]
239 fn is_valid_unknown_dim() {
240 let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
241 assert!(!spec.is_valid(100));
242 }
243
244 #[test]
247 fn truncate_clips_to_requested_dim() {
248 let v: Vec<f32> = (0..1536).map(|i| i as f32).collect();
249 let t = truncate(&v, 256);
250 assert_eq!(t.len(), 256);
251 assert_eq!(t[0], 0.0);
252 assert_eq!(t[255], 255.0);
253 }
254
255 #[test]
256 fn truncate_does_not_exceed_vec_len() {
257 let v = vec![1.0f32; 10];
258 let t = truncate(&v, 9999);
259 assert_eq!(t.len(), 10);
260 }
261
262 fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
266 (0..n)
267 .map(|i| {
268 (0..dim)
269 .map(|j| ((i * dim + j) as f32 * 0.01).sin())
270 .collect()
271 })
272 .collect()
273 }
274
275 #[test]
276 fn search_returns_k_results() {
277 let vecs = make_vecs(100, 128);
278 let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
279
280 let candidates = vecs
281 .iter()
282 .enumerate()
283 .map(|(i, v)| (i as u32, v.as_slice()));
284
285 let opts = MatryoshkaSearchOptions {
286 coarse_dim: 64,
287 full_dim: 128,
288 oversample: 3,
289 k: 10,
290 };
291
292 let results = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
293 assert_eq!(results.len(), 10, "expected exactly k=10 results");
294 }
295
296 #[test]
297 fn coarse_equal_to_full_matches_direct_search() {
298 let vecs = make_vecs(100, 128);
299 let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
300
301 let mut direct: Vec<(u32, f32)> = vecs
303 .iter()
304 .enumerate()
305 .map(|(i, v)| (i as u32, distance(&query, v.as_slice(), DistanceMetric::L2)))
306 .collect();
307 direct.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
308 direct.truncate(10);
309
310 let candidates = vecs
312 .iter()
313 .enumerate()
314 .map(|(i, v)| (i as u32, v.as_slice()));
315 let opts = MatryoshkaSearchOptions {
316 coarse_dim: 128,
317 full_dim: 128,
318 oversample: 1,
319 k: 10,
320 };
321 let mrl = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
322
323 let direct_ids: Vec<u32> = direct.iter().map(|(id, _)| *id).collect();
325 let mrl_ids: Vec<u32> = mrl.iter().map(|(id, _)| *id).collect();
326 assert_eq!(
327 direct_ids, mrl_ids,
328 "coarse==full should equal direct search"
329 );
330 }
331}