Skip to main content

mnem_core/index/
sparse.rs

1// SPLADE, BGE-M3, WordPiece, OpenSearch are proper-noun external
2// identifiers; per-mention backticking adds no signal here.
3#![allow(clippy::doc_markdown)]
4
5//! Sparse-retrieval index for learned-sparse encoders (SPLADE,
6//! BGE-M3-sparse, opensearch-doc-v3-distill). Pair with
7//! [`crate::sparse::SparseEncoder`] adapters .
8//!
9//! # Why an inverted index (not brute force)
10//!
11//! SPLADE vectors are ~100-300 non-zero entries over a 30K-ish
12//! WordPiece vocabulary. A brute-force `O(N_docs * nnz)` walk is
13//! tolerable at a few thousand docs but collapses past 100K. The
14//! inverted index turns query-time scoring into: for each non-zero
15//! query token, look up the posting list of (doc_id, weight) pairs
16//! and accumulate `query_weight * doc_weight` into a per-doc score
17//! map. Total work is `O(sum(nnz(doc_i)))` summed over docs that
18//! share at least one token with the query - typically far less than
19//! `O(N * nnz)`.
20//!
21//! # Canonicality
22//!
23//! Posting lists sort by `(NodeId ASC)` at build time so the search
24//! result's tie-break order is deterministic across runs (matches the
25//! pattern used by [`crate::index::vector::BruteForceVectorIndex`]).
26//!
27//! # Model scoping
28//!
29//! Every index binds to a single `vocab_id` string. A query sparse
30//! vector whose `vocab_id` differs from the index's returns an empty
31//! result (and a debug log in a future instrumentation pass). This
32//! prevents accidentally fusing incompatible models under RRF.
33//!
34//! # Future work
35//!
36//! WAND / MaxScore pruning, block-max posting-list skipping, and
37//! disk-persisted postings. The current in-memory implementation is
38//! the correctness baseline; optimisations are opt-in features
39//! deferred to a follow-up.
40
41use std::collections::HashMap;
42use std::sync::Arc;
43
44use crate::error::{Error, RepoError};
45use crate::id::NodeId;
46use crate::index::vector::VectorHit;
47use crate::objects::Node;
48use crate::prolly::Cursor;
49use crate::repo::readonly::decode_from_store;
50use crate::sparse::SparseEmbed;
51use crate::store::Blockstore;
52
53/// One posting list entry: `(NodeId, weight)`.
54#[derive(Debug, Clone, Copy)]
55struct Posting {
56    node: NodeId,
57    weight: f32,
58}
59
60/// A sparse inverted index over [`SparseEmbed`] values.
61///
62/// Build incrementally via [`Self::new`] + [`Self::add`], or in bulk
63/// via [`Self::build_from_repo`]. Query via [`Self::search`].
64///
65/// Posting lists are stored as `HashMap<u32 token_id, Vec<Posting>>`,
66/// where every `Vec<Posting>` is sorted by `NodeId ASC` for
67/// deterministic tie-break behaviour matching the rest of mnem-core's
68/// indexes.
69#[derive(Debug, Clone)]
70pub struct SparseInvertedIndex {
71    postings: HashMap<u32, Vec<Posting>>,
72    vocab_id: String,
73    doc_count: u32,
74}
75
76impl SparseInvertedIndex {
77    /// Construct an empty index bound to `vocab_id`. Nodes added
78    /// via [`Self::add`] whose own `vocab_id` disagrees are silently
79    /// skipped - mirrors [`BruteForceVectorIndex`][crate::index::vector::BruteForceVectorIndex]
80    /// behaviour for cross-model documents.
81    #[must_use]
82    pub fn new(vocab_id: impl Into<String>) -> Self {
83        Self {
84            postings: HashMap::new(),
85            vocab_id: vocab_id.into(),
86            doc_count: 0,
87        }
88    }
89
90    /// Vocabulary identifier this index is bound to.
91    #[must_use]
92    pub fn vocab_id(&self) -> &str {
93        &self.vocab_id
94    }
95
96    /// Number of documents indexed.
97    #[must_use]
98    pub const fn doc_count(&self) -> u32 {
99        self.doc_count
100    }
101
102    /// Feed one (node, sparse_embed) pair. Silently skips when the
103    /// embed's `vocab_id` disagrees with the index's or when the
104    /// embed has zero non-zero entries.
105    pub fn add(&mut self, node: NodeId, embed: &SparseEmbed) {
106        if embed.vocab_id != self.vocab_id {
107            return;
108        }
109        if embed.indices.is_empty() {
110            return;
111        }
112        for (i, w) in embed.indices.iter().zip(embed.values.iter()) {
113            self.postings
114                .entry(*i)
115                .or_default()
116                .push(Posting { node, weight: *w });
117        }
118        self.doc_count = self.doc_count.saturating_add(1);
119    }
120
121    /// Finalise the index: sort each posting list by `NodeId ASC` so
122    /// search results tie-break deterministically. Call once after
123    /// all `add()` calls; idempotent.
124    pub fn finalize(&mut self) {
125        for list in self.postings.values_mut() {
126            list.sort_by(|a, b| a.node.cmp(&b.node));
127        }
128    }
129
130    /// Search the index for the top-`k` documents by sparse-dot-product
131    /// score against `query`. Returns [`VectorHit`] (same shape as the
132    /// dense index so callers can fuse results without a custom type).
133    ///
134    /// On `vocab_id` mismatch returns an empty vec - the caller
135    /// receives no scores to fuse, same semantics as a disjoint
136    /// vocabulary.
137    pub fn search(&self, query: &SparseEmbed, k: usize) -> Result<Vec<VectorHit>, Error> {
138        if query.vocab_id != self.vocab_id {
139            return Ok(Vec::new());
140        }
141        if query.indices.is_empty() || k == 0 {
142            return Ok(Vec::new());
143        }
144        let mut scores: HashMap<NodeId, f32> = HashMap::new();
145        for (tid, qw) in query.indices.iter().zip(query.values.iter()) {
146            let Some(list) = self.postings.get(tid) else {
147                continue;
148            };
149            for p in list {
150                let e = scores.entry(p.node).or_insert(0.0);
151                *e += qw * p.weight;
152            }
153        }
154        let mut ranked: Vec<(NodeId, f32)> = scores.into_iter().collect();
155        ranked.sort_by(|a, b| {
156            b.1.partial_cmp(&a.1)
157                .unwrap_or(std::cmp::Ordering::Equal)
158                .then_with(|| a.0.cmp(&b.0))
159        });
160        ranked.truncate(k);
161        Ok(ranked
162            .into_iter()
163            .map(|(node_id, score)| VectorHit { node_id, score })
164            .collect())
165    }
166
167    /// Build an index from all nodes in the current commit whose
168    /// `sparse_embed` field matches `vocab_id`. Requires the nodes to
169    /// have been indexed by an adapter at write time.
170    ///
171    /// # Errors
172    ///
173    /// - [`RepoError::Uninitialized`] if the repo has no head commit.
174    /// - Store / codec errors while walking the Prolly tree.
175    pub fn build_from_repo(
176        repo: &crate::repo::ReadonlyRepo,
177        vocab_id: impl Into<String>,
178    ) -> Result<Self, Error> {
179        let vocab_id = vocab_id.into();
180        let mut idx = Self::new(&vocab_id);
181        let bs: Arc<dyn Blockstore> = repo.blockstore().clone();
182        let Some(commit) = repo.head_commit() else {
183            return Err(RepoError::Uninitialized.into());
184        };
185        let cursor = Cursor::new(&*bs, &commit.nodes)?;
186        for entry in cursor {
187            let (_k, node_cid) = entry?;
188            let node: Node = decode_from_store(&*bs, &node_cid)?;
189            let Some(sparse) = &node.sparse_embed else {
190                continue;
191            };
192            if sparse.vocab_id == vocab_id {
193                idx.add(node.id, sparse);
194            }
195        }
196        idx.finalize();
197        Ok(idx)
198    }
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204    use crate::sparse::SparseEmbed;
205
206    fn nid(b: u8) -> NodeId {
207        NodeId::from_bytes_raw([b; 16])
208    }
209
210    fn emb(indices: Vec<u32>, values: Vec<f32>) -> SparseEmbed {
211        SparseEmbed::new(indices, values, "v0").unwrap()
212    }
213
214    #[test]
215    fn empty_index_returns_empty_results() {
216        let idx = SparseInvertedIndex::new("v0");
217        let hits = idx.search(&emb(vec![1], vec![1.0]), 10).unwrap();
218        assert!(hits.is_empty());
219    }
220
221    #[test]
222    fn add_and_search_single_doc() {
223        let mut idx = SparseInvertedIndex::new("v0");
224        idx.add(nid(1), &emb(vec![10, 20], vec![0.5, 0.5]));
225        idx.finalize();
226        let hits = idx.search(&emb(vec![10], vec![1.0]), 10).unwrap();
227        assert_eq!(hits.len(), 1);
228        assert!((hits[0].score - 0.5).abs() < 1e-6);
229    }
230
231    #[test]
232    fn search_ranks_by_dot_product_descending() {
233        let mut idx = SparseInvertedIndex::new("v0");
234        // doc1 shares token 10 strongly; doc2 shares tokens 10 + 20 but weakly.
235        idx.add(nid(1), &emb(vec![10], vec![2.0]));
236        idx.add(nid(2), &emb(vec![10, 20], vec![0.1, 0.1]));
237        idx.add(nid(3), &emb(vec![99], vec![5.0])); // disjoint
238        idx.finalize();
239        let hits = idx.search(&emb(vec![10, 20], vec![1.0, 1.0]), 10).unwrap();
240        assert_eq!(hits.len(), 2, "doc3 has disjoint tokens; must not appear");
241        assert_eq!(hits[0].node_id, nid(1));
242        assert_eq!(hits[1].node_id, nid(2));
243        assert!(hits[0].score > hits[1].score);
244    }
245
246    #[test]
247    fn k_caps_result_count() {
248        let mut idx = SparseInvertedIndex::new("v0");
249        for i in 1..=5 {
250            idx.add(nid(i), &emb(vec![1], vec![f32::from(i)]));
251        }
252        idx.finalize();
253        let hits = idx.search(&emb(vec![1], vec![1.0]), 3).unwrap();
254        assert_eq!(hits.len(), 3);
255    }
256
257    #[test]
258    fn vocab_mismatch_returns_empty() {
259        let mut idx = SparseInvertedIndex::new("v0");
260        idx.add(nid(1), &emb(vec![1], vec![1.0]));
261        idx.finalize();
262        let other = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
263        let hits = idx.search(&other, 10).unwrap();
264        assert!(hits.is_empty());
265    }
266
267    #[test]
268    fn add_with_wrong_vocab_is_silently_skipped() {
269        let mut idx = SparseInvertedIndex::new("v0");
270        let foreign = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
271        idx.add(nid(1), &foreign);
272        assert_eq!(idx.doc_count(), 0);
273    }
274
275    #[test]
276    fn zero_k_returns_empty() {
277        let mut idx = SparseInvertedIndex::new("v0");
278        idx.add(nid(1), &emb(vec![1], vec![1.0]));
279        idx.finalize();
280        let hits = idx.search(&emb(vec![1], vec![1.0]), 0).unwrap();
281        assert!(hits.is_empty());
282    }
283
284    #[test]
285    fn tie_breaks_on_node_id_ascending() {
286        let mut idx = SparseInvertedIndex::new("v0");
287        idx.add(nid(5), &emb(vec![1], vec![1.0]));
288        idx.add(nid(2), &emb(vec![1], vec![1.0]));
289        idx.add(nid(9), &emb(vec![1], vec![1.0]));
290        idx.finalize();
291        let hits = idx.search(&emb(vec![1], vec![1.0]), 10).unwrap();
292        // All scores equal 1.0; tie-break should be NodeId ASC.
293        assert_eq!(hits.len(), 3);
294        assert_eq!(hits[0].node_id, nid(2));
295        assert_eq!(hits[1].node_id, nid(5));
296        assert_eq!(hits[2].node_id, nid(9));
297    }
298
299    #[test]
300    fn empty_query_returns_empty() {
301        let mut idx = SparseInvertedIndex::new("v0");
302        idx.add(nid(1), &emb(vec![1], vec![1.0]));
303        idx.finalize();
304        let q = SparseEmbed::new(vec![], vec![], "v0").unwrap();
305        let hits = idx.search(&q, 10).unwrap();
306        assert!(hits.is_empty());
307    }
308
309    #[test]
310    fn doc_count_tracks_adds() {
311        let mut idx = SparseInvertedIndex::new("v0");
312        assert_eq!(idx.doc_count(), 0);
313        idx.add(nid(1), &emb(vec![1], vec![1.0]));
314        assert_eq!(idx.doc_count(), 1);
315        idx.add(nid(2), &emb(vec![1], vec![1.0]));
316        assert_eq!(idx.doc_count(), 2);
317    }
318
319    #[test]
320    fn search_is_deterministic_across_build_orders() {
321        let mut idx1 = SparseInvertedIndex::new("v0");
322        idx1.add(nid(1), &emb(vec![1, 2], vec![1.0, 0.5]));
323        idx1.add(nid(2), &emb(vec![1, 3], vec![0.5, 1.0]));
324        idx1.finalize();
325
326        let mut idx2 = SparseInvertedIndex::new("v0");
327        idx2.add(nid(2), &emb(vec![1, 3], vec![0.5, 1.0]));
328        idx2.add(nid(1), &emb(vec![1, 2], vec![1.0, 0.5]));
329        idx2.finalize();
330
331        let q = emb(vec![1, 2, 3], vec![1.0, 1.0, 1.0]);
332        let h1 = idx1.search(&q, 10).unwrap();
333        let h2 = idx2.search(&q, 10).unwrap();
334        let ids1: Vec<NodeId> = h1.iter().map(|h| h.node_id).collect();
335        let ids2: Vec<NodeId> = h2.iter().map(|h| h.node_id).collect();
336        assert_eq!(ids1, ids2);
337    }
338}