1#![allow(clippy::doc_markdown)]
4
5use std::collections::HashMap;
42use std::sync::Arc;
43
44use crate::error::{Error, RepoError};
45use crate::id::NodeId;
46use crate::index::vector::VectorHit;
47use crate::objects::Node;
48use crate::prolly::Cursor;
49use crate::repo::readonly::decode_from_store;
50use crate::sparse::SparseEmbed;
51use crate::store::Blockstore;
52
53#[derive(Debug, Clone, Copy)]
55struct Posting {
56 node: NodeId,
57 weight: f32,
58}
59
60#[derive(Debug, Clone)]
70pub struct SparseInvertedIndex {
71 postings: HashMap<u32, Vec<Posting>>,
72 vocab_id: String,
73 doc_count: u32,
74}
75
76impl SparseInvertedIndex {
77 #[must_use]
82 pub fn new(vocab_id: impl Into<String>) -> Self {
83 Self {
84 postings: HashMap::new(),
85 vocab_id: vocab_id.into(),
86 doc_count: 0,
87 }
88 }
89
90 #[must_use]
92 pub fn vocab_id(&self) -> &str {
93 &self.vocab_id
94 }
95
96 #[must_use]
98 pub const fn doc_count(&self) -> u32 {
99 self.doc_count
100 }
101
102 pub fn add(&mut self, node: NodeId, embed: &SparseEmbed) {
106 if embed.vocab_id != self.vocab_id {
107 return;
108 }
109 if embed.indices.is_empty() {
110 return;
111 }
112 for (i, w) in embed.indices.iter().zip(embed.values.iter()) {
113 self.postings
114 .entry(*i)
115 .or_default()
116 .push(Posting { node, weight: *w });
117 }
118 self.doc_count = self.doc_count.saturating_add(1);
119 }
120
121 pub fn finalize(&mut self) {
125 for list in self.postings.values_mut() {
126 list.sort_by(|a, b| a.node.cmp(&b.node));
127 }
128 }
129
130 pub fn search(&self, query: &SparseEmbed, k: usize) -> Result<Vec<VectorHit>, Error> {
138 if query.vocab_id != self.vocab_id {
139 return Ok(Vec::new());
140 }
141 if query.indices.is_empty() || k == 0 {
142 return Ok(Vec::new());
143 }
144 let mut scores: HashMap<NodeId, f32> = HashMap::new();
145 for (tid, qw) in query.indices.iter().zip(query.values.iter()) {
146 let Some(list) = self.postings.get(tid) else {
147 continue;
148 };
149 for p in list {
150 let e = scores.entry(p.node).or_insert(0.0);
151 *e += qw * p.weight;
152 }
153 }
154 let mut ranked: Vec<(NodeId, f32)> = scores.into_iter().collect();
155 ranked.sort_by(|a, b| {
156 b.1.partial_cmp(&a.1)
157 .unwrap_or(std::cmp::Ordering::Equal)
158 .then_with(|| a.0.cmp(&b.0))
159 });
160 ranked.truncate(k);
161 Ok(ranked
162 .into_iter()
163 .map(|(node_id, score)| VectorHit { node_id, score })
164 .collect())
165 }
166
167 pub fn build_from_repo(
176 repo: &crate::repo::ReadonlyRepo,
177 vocab_id: impl Into<String>,
178 ) -> Result<Self, Error> {
179 let vocab_id = vocab_id.into();
180 let mut idx = Self::new(&vocab_id);
181 let bs: Arc<dyn Blockstore> = repo.blockstore().clone();
182 let Some(commit) = repo.head_commit() else {
183 return Err(RepoError::Uninitialized.into());
184 };
185 let cursor = Cursor::new(&*bs, &commit.nodes)?;
186 for entry in cursor {
187 let (_k, node_cid) = entry?;
188 let node: Node = decode_from_store(&*bs, &node_cid)?;
189 let Some(sparse) = &node.sparse_embed else {
190 continue;
191 };
192 if sparse.vocab_id == vocab_id {
193 idx.add(node.id, sparse);
194 }
195 }
196 idx.finalize();
197 Ok(idx)
198 }
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204 use crate::sparse::SparseEmbed;
205
206 fn nid(b: u8) -> NodeId {
207 NodeId::from_bytes_raw([b; 16])
208 }
209
210 fn emb(indices: Vec<u32>, values: Vec<f32>) -> SparseEmbed {
211 SparseEmbed::new(indices, values, "v0").unwrap()
212 }
213
214 #[test]
215 fn empty_index_returns_empty_results() {
216 let idx = SparseInvertedIndex::new("v0");
217 let hits = idx.search(&emb(vec![1], vec![1.0]), 10).unwrap();
218 assert!(hits.is_empty());
219 }
220
221 #[test]
222 fn add_and_search_single_doc() {
223 let mut idx = SparseInvertedIndex::new("v0");
224 idx.add(nid(1), &emb(vec![10, 20], vec![0.5, 0.5]));
225 idx.finalize();
226 let hits = idx.search(&emb(vec![10], vec![1.0]), 10).unwrap();
227 assert_eq!(hits.len(), 1);
228 assert!((hits[0].score - 0.5).abs() < 1e-6);
229 }
230
231 #[test]
232 fn search_ranks_by_dot_product_descending() {
233 let mut idx = SparseInvertedIndex::new("v0");
234 idx.add(nid(1), &emb(vec![10], vec![2.0]));
236 idx.add(nid(2), &emb(vec![10, 20], vec![0.1, 0.1]));
237 idx.add(nid(3), &emb(vec![99], vec![5.0])); idx.finalize();
239 let hits = idx.search(&emb(vec![10, 20], vec![1.0, 1.0]), 10).unwrap();
240 assert_eq!(hits.len(), 2, "doc3 has disjoint tokens; must not appear");
241 assert_eq!(hits[0].node_id, nid(1));
242 assert_eq!(hits[1].node_id, nid(2));
243 assert!(hits[0].score > hits[1].score);
244 }
245
246 #[test]
247 fn k_caps_result_count() {
248 let mut idx = SparseInvertedIndex::new("v0");
249 for i in 1..=5 {
250 idx.add(nid(i), &emb(vec![1], vec![f32::from(i)]));
251 }
252 idx.finalize();
253 let hits = idx.search(&emb(vec![1], vec![1.0]), 3).unwrap();
254 assert_eq!(hits.len(), 3);
255 }
256
257 #[test]
258 fn vocab_mismatch_returns_empty() {
259 let mut idx = SparseInvertedIndex::new("v0");
260 idx.add(nid(1), &emb(vec![1], vec![1.0]));
261 idx.finalize();
262 let other = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
263 let hits = idx.search(&other, 10).unwrap();
264 assert!(hits.is_empty());
265 }
266
267 #[test]
268 fn add_with_wrong_vocab_is_silently_skipped() {
269 let mut idx = SparseInvertedIndex::new("v0");
270 let foreign = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
271 idx.add(nid(1), &foreign);
272 assert_eq!(idx.doc_count(), 0);
273 }
274
275 #[test]
276 fn zero_k_returns_empty() {
277 let mut idx = SparseInvertedIndex::new("v0");
278 idx.add(nid(1), &emb(vec![1], vec![1.0]));
279 idx.finalize();
280 let hits = idx.search(&emb(vec![1], vec![1.0]), 0).unwrap();
281 assert!(hits.is_empty());
282 }
283
284 #[test]
285 fn tie_breaks_on_node_id_ascending() {
286 let mut idx = SparseInvertedIndex::new("v0");
287 idx.add(nid(5), &emb(vec![1], vec![1.0]));
288 idx.add(nid(2), &emb(vec![1], vec![1.0]));
289 idx.add(nid(9), &emb(vec![1], vec![1.0]));
290 idx.finalize();
291 let hits = idx.search(&emb(vec![1], vec![1.0]), 10).unwrap();
292 assert_eq!(hits.len(), 3);
294 assert_eq!(hits[0].node_id, nid(2));
295 assert_eq!(hits[1].node_id, nid(5));
296 assert_eq!(hits[2].node_id, nid(9));
297 }
298
299 #[test]
300 fn empty_query_returns_empty() {
301 let mut idx = SparseInvertedIndex::new("v0");
302 idx.add(nid(1), &emb(vec![1], vec![1.0]));
303 idx.finalize();
304 let q = SparseEmbed::new(vec![], vec![], "v0").unwrap();
305 let hits = idx.search(&q, 10).unwrap();
306 assert!(hits.is_empty());
307 }
308
309 #[test]
310 fn doc_count_tracks_adds() {
311 let mut idx = SparseInvertedIndex::new("v0");
312 assert_eq!(idx.doc_count(), 0);
313 idx.add(nid(1), &emb(vec![1], vec![1.0]));
314 assert_eq!(idx.doc_count(), 1);
315 idx.add(nid(2), &emb(vec![1], vec![1.0]));
316 assert_eq!(idx.doc_count(), 2);
317 }
318
319 #[test]
320 fn search_is_deterministic_across_build_orders() {
321 let mut idx1 = SparseInvertedIndex::new("v0");
322 idx1.add(nid(1), &emb(vec![1, 2], vec![1.0, 0.5]));
323 idx1.add(nid(2), &emb(vec![1, 3], vec![0.5, 1.0]));
324 idx1.finalize();
325
326 let mut idx2 = SparseInvertedIndex::new("v0");
327 idx2.add(nid(2), &emb(vec![1, 3], vec![0.5, 1.0]));
328 idx2.add(nid(1), &emb(vec![1, 2], vec![1.0, 0.5]));
329 idx2.finalize();
330
331 let q = emb(vec![1, 2, 3], vec![1.0, 1.0, 1.0]);
332 let h1 = idx1.search(&q, 10).unwrap();
333 let h2 = idx2.search(&q, 10).unwrap();
334 let ids1: Vec<NodeId> = h1.iter().map(|h| h.node_id).collect();
335 let ids2: Vec<NodeId> = h2.iter().map(|h| h.node_id).collect();
336 assert_eq!(ids1, ids2);
337 }
338}