Skip to main content

nodedb_vector/
matryoshka.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Matryoshka adaptive-dim querying.
4//!
5//! For embedding models trained with Matryoshka Representation Learning
6//! (MRL), prefix-truncated vectors retain semantic structure. We exploit
7//! this for memory-bandwidth-bound HNSW: coarse pass at low dim → exact
8//! rerank at full dim.
9//!
10//! # Cache-efficiency
11//!
12//! A 256-dim coarse pass reads 256 × 4 = 1 KiB per vector vs 1536 × 4 = 6 KiB
13//! for full-dim traversal — a ≥6× reduction in L1/L2 memory bandwidth during
14//! the dominant HNSW graph traversal phase.
15//!
16//! # Supported models
17//!
18//! - `text-embedding-3` (OpenAI): `[256, 512, 1024, 1536]`
19//! - `Gemini Embedding`: `[256, 512, 768, 3072]`
20//! - `Nomic Embed`: `[64, 128, 256, 512, 768]`
21
22use std::collections::BinaryHeap;
23
24use crate::distance::distance;
25use nodedb_types::vector_distance::DistanceMetric;
26
27/// Per-collection Matryoshka configuration.
28#[derive(Debug, Clone)]
29pub struct MatryoshkaSpec {
30    /// Sorted ascending list of valid truncation dimensions.
31    /// E.g. for text-embedding-3 (1536-dim base): `[256, 512, 1024, 1536]`.
32    pub truncation_dims: Vec<u32>,
33}
34
35impl MatryoshkaSpec {
36    /// Create a new spec. Sorts the provided dims ascending.
37    pub fn new(mut truncation_dims: Vec<u32>) -> Self {
38        truncation_dims.sort_unstable();
39        truncation_dims.dedup();
40        Self { truncation_dims }
41    }
42
43    /// Pick the largest supported dim ≤ `requested`.
44    ///
45    /// Returns the full dim (last in the list) when `requested` is `None`.
46    /// Returns the smallest available dim when `requested` is smaller than
47    /// all supported dims.
48    pub fn pick(&self, requested: Option<u32>) -> u32 {
49        let Some(req) = requested else {
50            return *self.truncation_dims.last().copied().get_or_insert(0);
51        };
52        // Find largest dim that is ≤ req.
53        self.truncation_dims
54            .iter()
55            .rev()
56            .find(|&&d| d <= req)
57            .copied()
58            .unwrap_or_else(|| self.truncation_dims.first().copied().unwrap_or(req))
59    }
60
61    /// Return `true` if `dim` is one of the supported truncation dims.
62    pub fn is_valid(&self, dim: u32) -> bool {
63        self.truncation_dims.contains(&dim)
64    }
65}
66
67/// Stride-truncate a vector to its first `dim` components.
68///
69/// If `dim` exceeds `v.len()` the full slice is returned (no panic).
70#[inline]
71pub fn truncate(v: &[f32], dim: usize) -> &[f32] {
72    &v[..dim.min(v.len())]
73}
74
75/// Options for a two-stage Matryoshka search.
76pub struct MatryoshkaSearchOptions {
77    /// Dimensionality used for the coarse pass.
78    pub coarse_dim: u32,
79    /// Full dimensionality used for the rerank pass.
80    pub full_dim: u32,
81    /// Oversample factor: coarse pass collects `oversample × k` candidates.
82    /// Typical values: 3–5.
83    pub oversample: u8,
84    /// Final result count.
85    pub k: usize,
86}
87
88/// Entry in the coarse-pass max-heap (ordered by distance descending so we
89/// can evict the worst candidate when the heap overflows).
90#[derive(PartialEq)]
91struct HeapEntry {
92    /// Distance (higher = worse, evict first).
93    dist: f32,
94    id: u32,
95    /// Index into the collected vectors buffer for full-dim rerank.
96    vec_idx: usize,
97}
98
99impl Eq for HeapEntry {}
100
101impl PartialOrd for HeapEntry {
102    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
103        Some(self.cmp(other))
104    }
105}
106
107impl Ord for HeapEntry {
108    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
109        // Max-heap by distance (NaN treated as largest).
110        self.dist
111            .partial_cmp(&other.dist)
112            .unwrap_or(std::cmp::Ordering::Equal)
113    }
114}
115
116/// Two-stage Matryoshka search.
117///
118/// # Phase 1 — coarse pass
119/// Iterates `candidates`, computes distance on the first `options.coarse_dim`
120/// components, and keeps a bounded max-heap of size `oversample × k`.
121///
122/// # Phase 2 — rerank
123/// Recomputes full-dim distance for all coarse survivors and returns the top-k
124/// by ascending full-dim distance.
125///
126/// # Notes
127/// - `candidates` yields `(id, full_dim_vector)` pairs. Vectors shorter than
128///   `full_dim` are accepted; truncation clips to available length.
129/// - When `coarse_dim == full_dim` the method degenerates to a single-pass
130///   top-k scan with one distance call per candidate (no duplicated work).
131pub fn matryoshka_search<'a, I>(
132    candidates: I,
133    query: &[f32],
134    options: &MatryoshkaSearchOptions,
135    metric: DistanceMetric,
136) -> Vec<(u32, f32)>
137where
138    I: Iterator<Item = (u32, &'a [f32])>,
139{
140    let coarse = options.coarse_dim as usize;
141    let full = options.full_dim as usize;
142    let pool_size = (options.oversample as usize).max(1) * options.k.max(1);
143
144    let query_coarse = truncate(query, coarse);
145
146    // --- Phase 1: coarse pass ---
147    // We collect full-dim vectors for survivors so Phase 2 doesn't need the
148    // original iterator again.
149    let mut coarse_heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(pool_size + 1);
150    // Store owned copies of the surviving vectors for rerank.
151    // no-governor: search pool_size bounded by ef_search (small, typically ≤ 200); hot query path
152    let mut survivor_vecs: Vec<Vec<f32>> = Vec::with_capacity(pool_size);
153
154    for (id, vec) in candidates {
155        let vec_coarse = truncate(vec, coarse);
156        let d = distance(query_coarse, vec_coarse, metric);
157
158        // Maintain a max-heap of capacity `pool_size`.
159        let should_insert = coarse_heap.len() < pool_size
160            || coarse_heap
161                .peek()
162                .map(|worst| d < worst.dist)
163                .unwrap_or(true);
164
165        if should_insert {
166            let vec_idx = survivor_vecs.len();
167            survivor_vecs.push(vec[..full.min(vec.len())].to_vec());
168
169            coarse_heap.push(HeapEntry {
170                dist: d,
171                id,
172                vec_idx,
173            });
174
175            if coarse_heap.len() > pool_size {
176                // Evict worst; its survivor_vec slot becomes orphaned but that
177                // is acceptable — we only rerank heap survivors.
178                coarse_heap.pop();
179            }
180        }
181    }
182
183    // --- Phase 2: rerank ---
184    let query_full = truncate(query, full);
185
186    let mut reranked: Vec<(u32, f32)> = coarse_heap
187        .into_iter()
188        .map(|entry| {
189            let full_vec = &survivor_vecs[entry.vec_idx];
190            let d_full = distance(query_full, full_vec.as_slice(), metric);
191            (entry.id, d_full)
192        })
193        .collect();
194
195    reranked.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
196    reranked.truncate(options.k);
197    reranked
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203
204    // ── MatryoshkaSpec::pick ──────────────────────────────────────────────────
205
206    #[test]
207    fn pick_largest_leq_requested() {
208        let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
209        assert_eq!(spec.pick(Some(300)), 256);
210    }
211
212    #[test]
213    fn pick_exact_match() {
214        let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
215        assert_eq!(spec.pick(Some(512)), 512);
216    }
217
218    #[test]
219    fn pick_none_returns_full_dim() {
220        let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
221        assert_eq!(spec.pick(None), 1024);
222    }
223
224    #[test]
225    fn pick_smaller_than_all_returns_smallest() {
226        let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
227        assert_eq!(spec.pick(Some(10)), 256);
228    }
229
230    // ── MatryoshkaSpec::is_valid ──────────────────────────────────────────────
231
232    #[test]
233    fn is_valid_known_dim() {
234        let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
235        assert!(spec.is_valid(512));
236    }
237
238    #[test]
239    fn is_valid_unknown_dim() {
240        let spec = MatryoshkaSpec::new(vec![256, 512, 1024]);
241        assert!(!spec.is_valid(100));
242    }
243
244    // ── truncate ─────────────────────────────────────────────────────────────
245
246    #[test]
247    fn truncate_clips_to_requested_dim() {
248        let v: Vec<f32> = (0..1536).map(|i| i as f32).collect();
249        let t = truncate(&v, 256);
250        assert_eq!(t.len(), 256);
251        assert_eq!(t[0], 0.0);
252        assert_eq!(t[255], 255.0);
253    }
254
255    #[test]
256    fn truncate_does_not_exceed_vec_len() {
257        let v = vec![1.0f32; 10];
258        let t = truncate(&v, 9999);
259        assert_eq!(t.len(), 10);
260    }
261
262    // ── matryoshka_search ────────────────────────────────────────────────────
263
264    /// Build 100 random-ish vectors of dimension 128.
265    fn make_vecs(n: usize, dim: usize) -> Vec<Vec<f32>> {
266        (0..n)
267            .map(|i| {
268                (0..dim)
269                    .map(|j| ((i * dim + j) as f32 * 0.01).sin())
270                    .collect()
271            })
272            .collect()
273    }
274
275    #[test]
276    fn search_returns_k_results() {
277        let vecs = make_vecs(100, 128);
278        let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
279
280        let candidates = vecs
281            .iter()
282            .enumerate()
283            .map(|(i, v)| (i as u32, v.as_slice()));
284
285        let opts = MatryoshkaSearchOptions {
286            coarse_dim: 64,
287            full_dim: 128,
288            oversample: 3,
289            k: 10,
290        };
291
292        let results = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
293        assert_eq!(results.len(), 10, "expected exactly k=10 results");
294    }
295
296    #[test]
297    fn coarse_equal_to_full_matches_direct_search() {
298        let vecs = make_vecs(100, 128);
299        let query: Vec<f32> = (0..128).map(|i| (i as f32 * 0.007).cos()).collect();
300
301        // Direct top-10 search for reference.
302        let mut direct: Vec<(u32, f32)> = vecs
303            .iter()
304            .enumerate()
305            .map(|(i, v)| (i as u32, distance(&query, v.as_slice(), DistanceMetric::L2)))
306            .collect();
307        direct.sort_unstable_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
308        direct.truncate(10);
309
310        // Matryoshka with coarse_dim == full_dim — no truncation effect.
311        let candidates = vecs
312            .iter()
313            .enumerate()
314            .map(|(i, v)| (i as u32, v.as_slice()));
315        let opts = MatryoshkaSearchOptions {
316            coarse_dim: 128,
317            full_dim: 128,
318            oversample: 1,
319            k: 10,
320        };
321        let mrl = matryoshka_search(candidates, &query, &opts, DistanceMetric::L2);
322
323        // Same IDs in same order.
324        let direct_ids: Vec<u32> = direct.iter().map(|(id, _)| *id).collect();
325        let mrl_ids: Vec<u32> = mrl.iter().map(|(id, _)| *id).collect();
326        assert_eq!(
327            direct_ids, mrl_ids,
328            "coarse==full should equal direct search"
329        );
330    }
331}