1use nodedb_types::vector_distance::DistanceMetric;
16
17use super::plaid::PlaidPruner;
18use super::scoring::budgeted_maxsim;
19use super::storage::MultiVectorStore;
20
21pub fn meta_embed_search(
35 store: &MultiVectorStore,
36 plaid: Option<&PlaidPruner>,
37 query: &[Vec<f32>],
38 budget: u8,
39 k: usize,
40 metric: DistanceMetric,
41) -> Vec<(u32, f32)> {
42 if k == 0 || query.is_empty() {
43 return Vec::new();
44 }
45
46 let effective_budget = if budget == 0 {
48 query.len() as u8
49 } else {
50 budget
51 };
52
53 let candidate_ids: Vec<u32> = match plaid {
55 Some(pruner) => pruner.candidates(query),
56 None => store.iter().map(|doc| doc.doc_id).collect(),
57 };
58
59 let mut scored: Vec<(u32, f32)> = candidate_ids
61 .into_iter()
62 .filter_map(|doc_id| {
63 store.get(doc_id).map(|doc| {
64 let score = budgeted_maxsim(query, &doc.vectors, effective_budget, metric);
65 (doc_id, score)
66 })
67 })
68 .collect();
69
70 scored.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
72 scored.truncate(k);
73 scored
74}
75
76#[cfg(test)]
81mod tests {
82 use super::*;
83 use crate::multivec::storage::{MultiVecMode, MultiVectorDoc, MultiVectorStore};
84
85 fn build_store(n: u32, dim: usize, k: u8) -> MultiVectorStore {
88 let mut store = MultiVectorStore::new(dim, MultiVecMode::MetaToken { k });
89 for i in 0..n {
90 let mut vecs: Vec<Vec<f32>> = Vec::new();
91 for j in 0..k as usize {
92 let mut v = vec![0.0f32; dim];
93 v[(i as usize + j) % dim] = 1.0;
95 vecs.push(v);
96 }
97 store
98 .insert(MultiVectorDoc {
99 doc_id: i,
100 vectors: vecs,
101 })
102 .unwrap();
103 }
104 store
105 }
106
107 #[test]
108 fn search_returns_at_most_k_results() {
109 let store = build_store(10, 4, 2);
110 let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
111 let results = meta_embed_search(&store, None, &query, 2, 3, DistanceMetric::Cosine);
112 assert!(results.len() <= 3);
113 }
114
115 #[test]
116 fn search_results_sorted_descending() {
117 let store = build_store(8, 4, 2);
118 let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
119 let results = meta_embed_search(&store, None, &query, 2, 8, DistanceMetric::Cosine);
120 for w in results.windows(2) {
121 assert!(w[0].1 >= w[1].1, "not sorted: {:?}", results);
122 }
123 }
124
125 #[test]
126 fn plaid_filtered_results_are_subset_of_unfiltered() {
127 let store = build_store(9, 2, 2);
128
129 let pruner = PlaidPruner::train(&store, 3, 10, 99);
131
132 let query = vec![vec![1.0f32, 0.0f32]];
133 let unfiltered = meta_embed_search(&store, None, &query, 2, 9, DistanceMetric::Cosine);
134 let filtered =
135 meta_embed_search(&store, Some(&pruner), &query, 2, 9, DistanceMetric::Cosine);
136
137 let unfiltered_ids: std::collections::HashSet<u32> =
138 unfiltered.iter().map(|(id, _)| *id).collect();
139
140 for (id, _) in &filtered {
141 assert!(
142 unfiltered_ids.contains(id),
143 "filtered result {id} not in unfiltered set"
144 );
145 }
146 }
147
148 #[test]
149 fn search_empty_query_returns_empty() {
150 let store = build_store(5, 4, 2);
151 let results = meta_embed_search(&store, None, &[], 2, 5, DistanceMetric::Cosine);
152 assert!(results.is_empty());
153 }
154
155 #[test]
156 fn search_k_zero_returns_empty() {
157 let store = build_store(5, 4, 2);
158 let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
159 let results = meta_embed_search(&store, None, &query, 2, 0, DistanceMetric::Cosine);
160 assert!(results.is_empty());
161 }
162
163 #[test]
164 fn top_result_is_best_matching_doc() {
165 let store = build_store(4, 4, 1);
168 let query = vec![vec![1.0f32, 0.0, 0.0, 0.0]];
169 let results = meta_embed_search(&store, None, &query, 1, 1, DistanceMetric::Cosine);
170 assert_eq!(results.len(), 1);
171 assert_eq!(results[0].0, 0, "expected doc_id=0 to be top result");
172 }
173}