Skip to main content

nodedb_vector/multivec/
meta_embed.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! MetaEmbed-specific search helpers.
4//!
5//! MetaEmbed (ICLR 2026) stores K pre-computed Meta Token vectors per document
6//! instead of one vector per token.  At query time the Matryoshka ordering of
7//! the query Meta Tokens allows a `budget` parameter to trade recall for
8//! latency: `budget=k` is full accuracy; `budget=1` is fastest.
9//!
10//! `meta_embed_search` wires together:
11//! 1. Optional PLAID candidate pruning.
12//! 2. Budgeted MaxSim scoring over the remaining candidates.
13//! 3. Top-k selection.
14
15use nodedb_types::vector_distance::DistanceMetric;
16
17use super::plaid::PlaidPruner;
18use super::scoring::budgeted_maxsim;
19use super::storage::MultiVectorStore;
20
21/// Search a `MultiVectorStore` using budgeted MaxSim with optional PLAID
22/// candidate pruning.
23///
24/// # Parameters
25/// * `store`  — the document collection.
26/// * `plaid`  — optional PLAID pruner (pass `None` to scan all docs).
27/// * `query`  — query Meta Token vectors (Matryoshka ordering).
28/// * `budget` — number of leading query tokens to use; 0 falls back to all.
29/// * `k`      — number of top documents to return.
30/// * `metric` — distance metric (Cosine recommended for MetaEmbed).
31///
32/// # Returns
33/// A `Vec<(doc_id, score)>` sorted descending by score, length ≤ `k`.
34pub fn meta_embed_search(
35    store: &MultiVectorStore,
36    plaid: Option<&PlaidPruner>,
37    query: &[Vec<f32>],
38    budget: u8,
39    k: usize,
40    metric: DistanceMetric,
41) -> Vec<(u32, f32)> {
42    if k == 0 || query.is_empty() {
43        return Vec::new();
44    }
45
46    // Effective budget: 0 means use all query vectors.
47    let effective_budget = if budget == 0 {
48        query.len() as u8
49    } else {
50        budget
51    };
52
53    // Determine candidate set.
54    let candidate_ids: Vec<u32> = match plaid {
55        Some(pruner) => pruner.candidates(query),
56        None => store.iter().map(|doc| doc.doc_id).collect(),
57    };
58
59    // Score each candidate.
60    let mut scored: Vec<(u32, f32)> = candidate_ids
61        .into_iter()
62        .filter_map(|doc_id| {
63            store.get(doc_id).map(|doc| {
64                let score = budgeted_maxsim(query, &doc.vectors, effective_budget, metric);
65                (doc_id, score)
66            })
67        })
68        .collect();
69
70    // Sort descending by score.
71    scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
72    scored.truncate(k);
73    scored
74}
75
76// ---------------------------------------------------------------------------
77// Tests
78// ---------------------------------------------------------------------------
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::multivec::storage::{MultiVecMode, MultiVectorDoc, MultiVectorStore};
84
85    /// Build a store with `n` documents, each with one unit-vector at
86    /// dimension `i % dim`.
87    fn build_store(n: u32, dim: usize, k: u8) -> MultiVectorStore {
88        let mut store = MultiVectorStore::new(dim, MultiVecMode::MetaToken { k });
89        for i in 0..n {
90            let mut vecs: Vec<Vec<f32>> = Vec::new();
91            for j in 0..k as usize {
92                let mut v = vec![0.0f32; dim];
93                // Each Meta Token of doc i points in direction (i + j) % dim.
94                v[(i as usize + j) % dim] = 1.0;
95                vecs.push(v);
96            }
97            store
98                .insert(MultiVectorDoc {
99                    doc_id: i,
100                    vectors: vecs,
101                })
102                .unwrap();
103        }
104        store
105    }
106
107    #[test]
108    fn search_returns_at_most_k_results() {
109        let store = build_store(10, 4, 2);
110        let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
111        let results = meta_embed_search(&store, None, &query, 2, 3, DistanceMetric::Cosine);
112        assert!(results.len() <= 3);
113    }
114
115    #[test]
116    fn search_results_sorted_descending() {
117        let store = build_store(8, 4, 2);
118        let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
119        let results = meta_embed_search(&store, None, &query, 2, 8, DistanceMetric::Cosine);
120        for w in results.windows(2) {
121            assert!(w[0].1 >= w[1].1, "not sorted: {:?}", results);
122        }
123    }
124
125    #[test]
126    fn plaid_filtered_results_are_subset_of_unfiltered() {
127        let store = build_store(9, 2, 2);
128
129        // Train PLAID on the store.
130        let pruner = PlaidPruner::train(&store, 3, 10, 99);
131
132        let query = vec![vec![1.0f32, 0.0f32]];
133        let unfiltered = meta_embed_search(&store, None, &query, 2, 9, DistanceMetric::Cosine);
134        let filtered =
135            meta_embed_search(&store, Some(&pruner), &query, 2, 9, DistanceMetric::Cosine);
136
137        let unfiltered_ids: std::collections::HashSet<u32> =
138            unfiltered.iter().map(|(id, _)| *id).collect();
139
140        for (id, _) in &filtered {
141            assert!(
142                unfiltered_ids.contains(id),
143                "filtered result {id} not in unfiltered set"
144            );
145        }
146    }
147
148    #[test]
149    fn search_empty_query_returns_empty() {
150        let store = build_store(5, 4, 2);
151        let results = meta_embed_search(&store, None, &[], 2, 5, DistanceMetric::Cosine);
152        assert!(results.is_empty());
153    }
154
155    #[test]
156    fn search_k_zero_returns_empty() {
157        let store = build_store(5, 4, 2);
158        let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
159        let results = meta_embed_search(&store, None, &query, 2, 0, DistanceMetric::Cosine);
160        assert!(results.is_empty());
161    }
162
163    #[test]
164    fn top_result_is_best_matching_doc() {
165        // Doc 0 has its first Meta Token at direction 0 (index 0).
166        // Query is also in direction 0 — doc 0 should rank first.
167        let store = build_store(4, 4, 1);
168        let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
169        let results = meta_embed_search(&store, None, &query, 1, 1, DistanceMetric::Cosine);
170        assert_eq!(results.len(), 1);
171        assert_eq!(results[0].0, 0, "expected doc_id=0 to be top result");
172    }
173}