Skip to main content

mnem_core/
sparse.rs

1// SPLADE, BGE-M3, BEIR, WordPiece, OpenSearch are well-known external
2// identifiers; backticking every mention in the module doc degrades
3// rendered rustdoc readability.
4#![allow(clippy::doc_markdown)]
5
6//! Sparse (learned) embedding primitives for SPLADE / BGE-M3-sparse
7//! integration .
8//!
9//! # Why
10//!
11//! Learned-sparse retrievers (SPLADE v3, opensearch-doc-v3-distill,
12//! BGE-M3-sparse, granite-embedding-30m-sparse) produce a sparse
13//! vector over a WordPiece vocabulary that can be scored via an
14//! inverted index with semantic term weights learned end-to-end.
15//! BEIR nDCG@10 on sparse neural retrievers lands around +3-5 points
16//! over classical lexical keyword scoring on zero-shot domains; this
17//! lane replaces that legacy lexical lane entirely .
18//!
19//! # What this module provides
20//!
21//! - [`SparseEmbed`] - canonical sparse-vector shape (ascending
22//!   `indices` + aligned `values`) with a `vocab_id` tag so two
23//!   models with different vocabularies never get mixed in one
24//!   posting list.
25//! - [`SparseEncoder`] trait - adapter-side hook for ONNX / candle
26//!   backends to implement. Mirrors the [`crate::rerank::Reranker`]
27//!   trait shape.
28//! - `MockSparseEncoder` - deterministic test-only encoder.
29//!
30//! The actual inverted-index over `SparseEmbed` values lives in
31//! [`crate::index::sparse`] so the index stays next to its sibling
32//! (brute-force vector index).
33//!
34//! Storage in [`crate::objects::Node`]: a future `Node.sparse_embed:
35//! Option<SparseEmbed>` field. Additive, so existing CIDs stay
36//! byte-identical because the serializer omits `None` via
37//! `skip_serializing_if`. CBOR canonicality is preserved because
38//! `indices` is sorted ascending at construction (checked by
39//! [`SparseEmbed::new`]).
40
41use std::fmt::Debug;
42
43use serde::{Deserialize, Serialize};
44use thiserror::Error;
45
46/// Error surface for sparse-encoder adapters. Same shape as
47/// [`crate::llm::LlmError`] and [`crate::rerank::RerankError`].
48#[derive(Debug, Error)]
49#[non_exhaustive]
50pub enum SparseError {
51    /// Network / transport failure when the adapter runs remotely
52    /// (sidecar) or fetches weights.
53    #[error("network error: {0}")]
54    Network(String),
55    /// Adapter config invalid (missing weights file, bad URL, etc.).
56    #[error("config error: {0}")]
57    Config(String),
58    /// Model / tokenizer returned an error.
59    #[error("inference error: {0}")]
60    Inference(String),
61    /// Caller attempted to encode empty text.
62    #[error("empty input")]
63    EmptyInput,
64}
65
66/// A sparse embedding over a fixed vocabulary.
67///
68/// `indices` MUST be strictly ascending; `values` MUST have the same
69/// length as `indices`. Both invariants are checked by [`Self::new`]
70/// and enforced on deserialise in a future CBOR round-trip test.
71/// `vocab_id` pins the model family so two adapters with different
72/// vocabs never fuse posting lists; compare as a string (e.g.
73/// `"bert-base-uncased@30522"`).
74#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
75pub struct SparseEmbed {
76    /// Token IDs in the vocabulary, strictly ascending.
77    pub indices: Vec<u32>,
78    /// Non-zero weights, aligned with `indices`.
79    pub values: Vec<f32>,
80    /// Vocabulary identifier.
81    pub vocab_id: String,
82}
83
84impl SparseEmbed {
85    /// Construct a [`SparseEmbed`]. Panics (debug) / errors (release)
86    /// if the invariants are violated. `indices` is taken as-is; if
87    /// the caller is unsure whether it is sorted, use
88    /// [`Self::from_unsorted`] instead.
89    ///
90    /// # Errors
91    ///
92    /// - [`SparseError::Config`] if `indices.len() != values.len()`
93    ///   or `indices` contains duplicates / non-ascending entries.
94    pub fn new(
95        indices: Vec<u32>,
96        values: Vec<f32>,
97        vocab_id: impl Into<String>,
98    ) -> Result<Self, SparseError> {
99        if indices.len() != values.len() {
100            return Err(SparseError::Config(format!(
101                "indices.len() {} != values.len() {}",
102                indices.len(),
103                values.len()
104            )));
105        }
106        for w in indices.windows(2) {
107            if w[0] >= w[1] {
108                return Err(SparseError::Config(format!(
109                    "indices must be strictly ascending; saw {} then {}",
110                    w[0], w[1]
111                )));
112            }
113        }
114        Ok(Self {
115            indices,
116            values,
117            vocab_id: vocab_id.into(),
118        })
119    }
120
121    /// Construct from an unsorted `(index, value)` pair list.
122    /// Duplicate indices are kept as the maximum value (SPLADE's
123    /// own pooling rule). Useful from ONNX-side decoders that
124    /// produce vectors in token-emission order.
125    pub fn from_unsorted(
126        pairs: impl IntoIterator<Item = (u32, f32)>,
127        vocab_id: impl Into<String>,
128    ) -> Self {
129        use std::collections::BTreeMap;
130        let mut bucket: BTreeMap<u32, f32> = BTreeMap::new();
131        for (i, v) in pairs {
132            let e = bucket.entry(i).or_insert(f32::NEG_INFINITY);
133            if v > *e {
134                *e = v;
135            }
136        }
137        let (indices, values): (Vec<_>, Vec<_>) =
138            bucket.into_iter().filter(|(_, v)| *v > 0.0).unzip();
139        Self {
140            indices,
141            values,
142            vocab_id: vocab_id.into(),
143        }
144    }
145
146    /// Number of non-zero entries.
147    #[must_use]
148    pub const fn nnz(&self) -> usize {
149        self.indices.len()
150    }
151
152    /// Dot product with another sparse embedding (must share vocab_id).
153    /// Returns `None` if the vocab_ids differ.
154    #[must_use]
155    pub fn dot(&self, other: &Self) -> Option<f32> {
156        if self.vocab_id != other.vocab_id {
157            return None;
158        }
159        let mut i = 0;
160        let mut j = 0;
161        let mut sum = 0.0f32;
162        while i < self.indices.len() && j < other.indices.len() {
163            use std::cmp::Ordering;
164            match self.indices[i].cmp(&other.indices[j]) {
165                Ordering::Less => i += 1,
166                Ordering::Greater => j += 1,
167                Ordering::Equal => {
168                    sum += self.values[i] * other.values[j];
169                    i += 1;
170                    j += 1;
171                }
172            }
173        }
174        Some(sum)
175    }
176}
177
178/// Learned-sparse encoder: given text, produce a [`SparseEmbed`]
179/// over a fixed vocabulary. Adapter crates implement this over
180/// SPLADE-ONNX, BGE-M3-sparse-ONNX, or a remote sidecar.
181pub trait SparseEncoder: Send + Sync + Debug {
182    /// Provider + model identifier. Lowercase, colon-separated by
183    /// convention (e.g. `"splade:opensearch-doc-v3-distill"`,
184    /// `"bgem3:sparse"`, `"mock:len-inverse"`).
185    fn model(&self) -> &str;
186
187    /// Vocabulary identifier. Passed through to
188    /// [`SparseEmbed::vocab_id`] on every emitted embedding.
189    fn vocab_id(&self) -> &str;
190
191    /// Encode a document-side text string into a sparse vector.
192    /// This is the path run at ingest time.
193    ///
194    /// # Errors
195    ///
196    /// Any [`SparseError`] the adapter surfaces. The caller fallback
197    /// policy matches the rerank / LLM pattern: on error, the sparse
198    /// lane is simply dropped from fusion and the hybrid still runs.
199    fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError>;
200
201    /// Encode a query-side text string into a sparse vector.
202    ///
203    /// Default implementation delegates to [`Self::encode`]. Adapters
204    /// with asymmetric inference (OpenSearch
205    /// `neural-sparse-encoding-doc-v3-distill` ships a distilled
206    /// `idf.json` table so the query side is tokenise + IDF-lookup
207    /// with zero neural compute) override this to skip the forward
208    /// pass. The overridden path keeps retrieval latency microsecond-
209    /// level even when documents use a 67M-parameter encoder.
210    fn encode_query(&self, text: &str) -> Result<SparseEmbed, SparseError> {
211        self.encode(text)
212    }
213}
214
215/// FNV-1a 32-bit offset basis. Standard value from the FNV
216/// specification (Fowler-Noll-Vo hash, Landon Curt Noll, 1991);
217/// see <http://www.isthe.com/chongo/tech/comp/fnv/>. Used as the
218/// seed state for the `MockSparseEncoder` token hash.
219const FNV_OFFSET_BASIS_32: u32 = 2_166_136_261;
220
221/// FNV-1a 32-bit prime. Standard value from the FNV specification
222/// (<http://www.isthe.com/chongo/tech/comp/fnv/>). Each byte of the
223/// input is XOR-then-multiplied by this prime to diffuse bits
224/// across the output word.
225const FNV_PRIME_32: u32 = 16_777_619;
226
227/// Mock vocabulary width. The token hash is reduced modulo this
228/// number of slots so the encoder emits indices in `0..1024`,
229/// matching the `"mock:1024"` `vocab_id` tag. Kept tiny so tests
230/// exercise collision handling cheaply.
231const MOCK_VOCAB_SIZE: u32 = 1024;
232
233/// Deterministic test-only encoder. Produces a `SparseEmbed` by
234/// hashing each whitespace-separated token into the first 1024
235/// vocabulary slots with a length-inverse weight
236/// (1.0 / (1.0 + token_len)).
237///
238/// Not a real SPLADE; do not use in benchmarks. Its purpose is to
239/// let `Retriever::with_sparse_ranker(...)` unit-test the fusion
240/// lane without pulling ONNX Runtime into `mnem-core`'s test deps.
241#[derive(Debug, Clone)]
242pub struct MockSparseEncoder {
243    vocab_id: String,
244}
245
246impl Default for MockSparseEncoder {
247    fn default() -> Self {
248        Self {
249            vocab_id: "mock:1024".into(),
250        }
251    }
252}
253
254impl SparseEncoder for MockSparseEncoder {
255    fn model(&self) -> &str {
256        "mock:len-inverse"
257    }
258
259    fn vocab_id(&self) -> &str {
260        &self.vocab_id
261    }
262
263    fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError> {
264        if text.trim().is_empty() {
265            return Err(SparseError::EmptyInput);
266        }
267        let pairs = text.split_whitespace().map(|tok| {
268            let h: u32 = tok.bytes().fold(FNV_OFFSET_BASIS_32, |acc, b| {
269                acc.wrapping_mul(FNV_PRIME_32).wrapping_add(u32::from(b))
270            });
271            let idx = h % MOCK_VOCAB_SIZE;
272            let weight = 1.0f32 / (1.0 + tok.len() as f32);
273            (idx, weight)
274        });
275        Ok(SparseEmbed::from_unsorted(pairs, &self.vocab_id))
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn sparse_embed_rejects_unsorted_indices() {
285        let e = SparseEmbed::new(vec![5, 3], vec![0.5, 0.5], "v0").unwrap_err();
286        assert!(matches!(e, SparseError::Config(_)));
287    }
288
289    #[test]
290    fn sparse_embed_rejects_length_mismatch() {
291        let e = SparseEmbed::new(vec![1, 2], vec![0.5], "v0").unwrap_err();
292        assert!(matches!(e, SparseError::Config(_)));
293    }
294
295    #[test]
296    fn from_unsorted_sorts_and_max_pools() {
297        let s = SparseEmbed::from_unsorted([(5, 0.1), (3, 0.9), (5, 0.3), (1, 0.2)], "v0");
298        assert_eq!(s.indices, vec![1, 3, 5]);
299        assert!(
300            (s.values[2] - 0.3).abs() < 1e-6,
301            "max-pool should keep 0.3 for index 5"
302        );
303    }
304
305    #[test]
306    fn from_unsorted_drops_zero_weights() {
307        let s = SparseEmbed::from_unsorted([(1, 0.0), (2, 0.5), (3, -0.1)], "v0");
308        assert_eq!(s.indices, vec![2]);
309    }
310
311    #[test]
312    fn dot_product_on_disjoint_is_zero() {
313        let a = SparseEmbed::new(vec![1, 2], vec![1.0, 1.0], "v").unwrap();
314        let b = SparseEmbed::new(vec![3, 4], vec![1.0, 1.0], "v").unwrap();
315        assert_eq!(a.dot(&b), Some(0.0));
316    }
317
318    #[test]
319    fn dot_product_on_overlap() {
320        let a = SparseEmbed::new(vec![1, 2, 5], vec![0.5, 0.5, 0.2], "v").unwrap();
321        let b = SparseEmbed::new(vec![2, 5, 9], vec![0.4, 0.3, 0.1], "v").unwrap();
322        // Overlap at 2 (0.5*0.4=0.2) and 5 (0.2*0.3=0.06) -> 0.26.
323        let d = a.dot(&b).unwrap();
324        assert!((d - 0.26).abs() < 1e-6, "got {d}");
325    }
326
327    #[test]
328    fn dot_product_different_vocabs_is_none() {
329        let a = SparseEmbed::new(vec![1], vec![1.0], "v0").unwrap();
330        let b = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
331        assert_eq!(a.dot(&b), None);
332    }
333
334    #[test]
335    fn mock_encoder_is_deterministic() {
336        let e = MockSparseEncoder::default();
337        let a = e.encode("hello world").unwrap();
338        let b = e.encode("hello world").unwrap();
339        assert_eq!(a, b);
340    }
341
342    #[test]
343    fn mock_encoder_empty_input_errors() {
344        let e = MockSparseEncoder::default();
345        assert!(matches!(
346            e.encode("   ").unwrap_err(),
347            SparseError::EmptyInput
348        ));
349    }
350
351    #[test]
352    fn mock_encoder_vocab_id_carries_through() {
353        let e = MockSparseEncoder::default();
354        let emb = e.encode("hello").unwrap();
355        assert_eq!(emb.vocab_id, e.vocab_id());
356    }
357}