Skip to main content

mnem_core/
sparse.rs

1// SPLADE, BGE-M3, BEIR, WordPiece, OpenSearch are well-known external
2// identifiers; backticking every mention in the module doc degrades
3// rendered rustdoc readability.
4#![allow(clippy::doc_markdown)]
5
6//! Sparse (learned) embedding primitives for SPLADE / BGE-M3-sparse
7//! integration .
8//!
9//! # Why
10//!
11//! Learned-sparse retrievers (SPLADE v3, opensearch-doc-v3-distill,
12//! BGE-M3-sparse, granite-embedding-30m-sparse) produce a sparse
13//! vector over a WordPiece vocabulary that can be scored via an
14//! inverted index with semantic term weights learned end-to-end.
15//! BEIR nDCG@10 on sparse neural retrievers lands around +3-5 points
16//! over classical lexical keyword scoring on zero-shot domains; this
17//! lane replaces that legacy lexical lane entirely .
18//!
19//! # What this module provides
20//!
21//! - [`SparseEmbed`] - canonical sparse-vector shape (ascending
22//!   `indices` + aligned `values`) with a `vocab_id` tag so two
23//!   models with different vocabularies never get mixed in one
24//!   posting list.
25//! - [`SparseEncoder`] trait - adapter-side hook for ONNX / candle
26//!   backends to implement. Mirrors the [`crate::rerank::Reranker`]
27//!   trait shape.
28//! - `MockSparseEncoder` - deterministic test-only encoder.
29//!
30//! The actual inverted-index over `SparseEmbed` values lives in
31//! [`crate::index::sparse`] so the index stays next to its sibling
32//! (brute-force vector index).
33//!
34//! Storage in [`crate::objects::Node`]: a future `Node.sparse_embed:
35//! Option<SparseEmbed>` field. Additive, so existing CIDs stay
36//! byte-identical because the serializer omits `None` via
37//! `skip_serializing_if`. CBOR canonicality is preserved because
38//! `indices` is sorted ascending at construction (checked by
39//! [`SparseEmbed::new`]).
40
41use std::fmt::Debug;
42
43use serde::{Deserialize, Serialize};
44use thiserror::Error;
45
46/// Error surface for sparse-encoder adapters. Same shape as
47/// [`crate::llm::LlmError`] and [`crate::rerank::RerankError`].
48#[derive(Debug, Error)]
49#[non_exhaustive]
50pub enum SparseError {
51    /// Network / transport failure when the adapter runs remotely
52    /// (sidecar) or fetches weights.
53    #[error("network error: {0}")]
54    Network(String),
55    /// Adapter config invalid (missing weights file, bad URL, etc.).
56    #[error("config error: {0}")]
57    Config(String),
58    /// Model / tokenizer returned an error.
59    #[error("inference error: {0}")]
60    Inference(String),
61    /// Caller attempted to encode empty text.
62    #[error("empty input")]
63    EmptyInput,
64}
65
66/// A sparse embedding over a fixed vocabulary.
67///
68/// `indices` MUST be strictly ascending; `values` MUST have the same
69/// length as `indices`. Both invariants are checked by [`Self::new`]
70/// and enforced on deserialise in a future CBOR round-trip test.
71/// `vocab_id` pins the model family so two adapters with different
72/// vocabs never fuse posting lists; compare as a string (e.g.
73/// `"bert-base-uncased@30522"`).
74#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
75pub struct SparseEmbed {
76    /// Token IDs in the vocabulary, strictly ascending.
77    pub indices: Vec<u32>,
78    /// Non-zero weights, aligned with `indices`.
79    pub values: Vec<f32>,
80    /// Vocabulary identifier.
81    pub vocab_id: String,
82}
83
84impl SparseEmbed {
85    /// Construct a [`SparseEmbed`]. Panics (debug) / errors (release)
86    /// if the invariants are violated. `indices` is taken as-is; if
87    /// the caller is unsure whether it is sorted, use
88    /// [`Self::from_unsorted`] instead.
89    ///
90    /// # Errors
91    ///
92    /// - [`SparseError::Config`] if `indices.len() != values.len()`
93    ///   or `indices` contains duplicates / non-ascending entries.
94    pub fn new(
95        indices: Vec<u32>,
96        values: Vec<f32>,
97        vocab_id: impl Into<String>,
98    ) -> Result<Self, SparseError> {
99        if indices.len() != values.len() {
100            return Err(SparseError::Config(format!(
101                "indices.len() {} != values.len() {}",
102                indices.len(),
103                values.len()
104            )));
105        }
106        for w in indices.windows(2) {
107            if w[0] >= w[1] {
108                return Err(SparseError::Config(format!(
109                    "indices must be strictly ascending; saw {} then {}",
110                    w[0], w[1]
111                )));
112            }
113        }
114        Ok(Self {
115            indices,
116            values,
117            vocab_id: vocab_id.into(),
118        })
119    }
120
121    /// Construct from an unsorted `(index, value)` pair list.
122    /// Duplicate indices are kept as the maximum value (SPLADE's
123    /// own pooling rule). Useful from ONNX-side decoders that
124    /// produce vectors in token-emission order.
125    pub fn from_unsorted(
126        pairs: impl IntoIterator<Item = (u32, f32)>,
127        vocab_id: impl Into<String>,
128    ) -> Self {
129        use std::collections::BTreeMap;
130        let mut bucket: BTreeMap<u32, f32> = BTreeMap::new();
131        for (i, v) in pairs {
132            let e = bucket.entry(i).or_insert(f32::NEG_INFINITY);
133            if v > *e {
134                *e = v;
135            }
136        }
137        let (indices, values): (Vec<_>, Vec<_>) =
138            bucket.into_iter().filter(|(_, v)| *v > 0.0).unzip();
139        Self {
140            indices,
141            values,
142            vocab_id: vocab_id.into(),
143        }
144    }
145
146    /// Validate the two invariants that [`SparseEmbed::new`] enforces at
147    /// construction time:
148    ///
149    /// 1. `indices.len() == values.len()` (aligned parallel arrays)
150    /// 2. `indices` is strictly ascending (no duplicates, no out-of-order)
151    ///
152    /// Call this after deserialising from untrusted bytes (e.g. IPLD round-
153    /// trips from a pre-G17 `extra["sparse_embed"]` node) to catch corrupt
154    /// data before it reaches the sparse sidecar.
155    ///
156    /// # Errors
157    ///
158    /// Returns a descriptive `String` (via `Err(String)`) if either
159    /// invariant is violated so callers can surface the message in a
160    /// warning without pulling in a new error type.
161    pub fn validate(&self) -> Result<(), String> {
162        if self.indices.len() != self.values.len() {
163            return Err(format!(
164                "SparseEmbed invariant violated: indices.len() {} != values.len() {}",
165                self.indices.len(),
166                self.values.len()
167            ));
168        }
169        for w in self.indices.windows(2) {
170            if w[0] >= w[1] {
171                return Err(format!(
172                    "SparseEmbed invariant violated: indices must be strictly ascending; \
173                     saw {} then {}",
174                    w[0], w[1]
175                ));
176            }
177        }
178        Ok(())
179    }
180
181    /// Number of non-zero entries.
182    #[must_use]
183    pub const fn nnz(&self) -> usize {
184        self.indices.len()
185    }
186
187    /// Dot product with another sparse embedding (must share vocab_id).
188    /// Returns `None` if the vocab_ids differ.
189    #[must_use]
190    pub fn dot(&self, other: &Self) -> Option<f32> {
191        if self.vocab_id != other.vocab_id {
192            return None;
193        }
194        let mut i = 0;
195        let mut j = 0;
196        let mut sum = 0.0f32;
197        while i < self.indices.len() && j < other.indices.len() {
198            use std::cmp::Ordering;
199            match self.indices[i].cmp(&other.indices[j]) {
200                Ordering::Less => i += 1,
201                Ordering::Greater => j += 1,
202                Ordering::Equal => {
203                    sum += self.values[i] * other.values[j];
204                    i += 1;
205                    j += 1;
206                }
207            }
208        }
209        Some(sum)
210    }
211}
212
213/// Learned-sparse encoder: given text, produce a [`SparseEmbed`]
214/// over a fixed vocabulary. Adapter crates implement this over
215/// SPLADE-ONNX, BGE-M3-sparse-ONNX, or a remote sidecar.
216pub trait SparseEncoder: Send + Sync + Debug {
217    /// Provider + model identifier. Lowercase, colon-separated by
218    /// convention (e.g. `"splade:opensearch-doc-v3-distill"`,
219    /// `"bgem3:sparse"`, `"mock:len-inverse"`).
220    fn model(&self) -> &str;
221
222    /// Vocabulary identifier. Passed through to
223    /// [`SparseEmbed::vocab_id`] on every emitted embedding.
224    fn vocab_id(&self) -> &str;
225
226    /// Encode a document-side text string into a sparse vector.
227    /// This is the path run at ingest time.
228    ///
229    /// # Errors
230    ///
231    /// Any [`SparseError`] the adapter surfaces. The caller fallback
232    /// policy matches the rerank / LLM pattern: on error, the sparse
233    /// lane is simply dropped from fusion and the hybrid still runs.
234    fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError>;
235
236    /// Encode a query-side text string into a sparse vector.
237    ///
238    /// Default implementation delegates to [`Self::encode`]. Adapters
239    /// with asymmetric inference (OpenSearch
240    /// `neural-sparse-encoding-doc-v3-distill` ships a distilled
241    /// `idf.json` table so the query side is tokenise + IDF-lookup
242    /// with zero neural compute) override this to skip the forward
243    /// pass. The overridden path keeps retrieval latency microsecond-
244    /// level even when documents use a 67M-parameter encoder.
245    fn encode_query(&self, text: &str) -> Result<SparseEmbed, SparseError> {
246        self.encode(text)
247    }
248}
249
250/// FNV-1a 32-bit offset basis. Standard value from the FNV
251/// specification (Fowler-Noll-Vo hash, Landon Curt Noll, 1991);
252/// see <http://www.isthe.com/chongo/tech/comp/fnv/>. Used as the
253/// seed state for the `MockSparseEncoder` token hash.
254const FNV_OFFSET_BASIS_32: u32 = 2_166_136_261;
255
256/// FNV-1a 32-bit prime. Standard value from the FNV specification
257/// (<http://www.isthe.com/chongo/tech/comp/fnv/>). Each byte of the
258/// input is XOR-then-multiplied by this prime to diffuse bits
259/// across the output word.
260const FNV_PRIME_32: u32 = 16_777_619;
261
262/// Mock vocabulary width. The token hash is reduced modulo this
263/// number of slots so the encoder emits indices in `0..1024`,
264/// matching the `"mock:1024"` `vocab_id` tag. Kept tiny so tests
265/// exercise collision handling cheaply.
266const MOCK_VOCAB_SIZE: u32 = 1024;
267
268/// Deterministic test-only encoder. Produces a `SparseEmbed` by
269/// hashing each whitespace-separated token into the first 1024
270/// vocabulary slots with a length-inverse weight
271/// (1.0 / (1.0 + token_len)).
272///
273/// Not a real SPLADE; do not use in benchmarks. Its purpose is to
274/// let `Retriever::with_sparse_ranker(...)` unit-test the fusion
275/// lane without pulling ONNX Runtime into `mnem-core`'s test deps.
276#[derive(Debug, Clone)]
277pub struct MockSparseEncoder {
278    vocab_id: String,
279}
280
281impl Default for MockSparseEncoder {
282    fn default() -> Self {
283        Self {
284            vocab_id: "mock:1024".into(),
285        }
286    }
287}
288
289impl SparseEncoder for MockSparseEncoder {
290    fn model(&self) -> &str {
291        "mock:len-inverse"
292    }
293
294    fn vocab_id(&self) -> &str {
295        &self.vocab_id
296    }
297
298    fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError> {
299        if text.trim().is_empty() {
300            return Err(SparseError::EmptyInput);
301        }
302        let pairs = text.split_whitespace().map(|tok| {
303            let h: u32 = tok.bytes().fold(FNV_OFFSET_BASIS_32, |acc, b| {
304                acc.wrapping_mul(FNV_PRIME_32).wrapping_add(u32::from(b))
305            });
306            let idx = h % MOCK_VOCAB_SIZE;
307            let weight = 1.0f32 / (1.0 + tok.len() as f32);
308            (idx, weight)
309        });
310        Ok(SparseEmbed::from_unsorted(pairs, &self.vocab_id))
311    }
312}
313
314#[cfg(test)]
315mod tests {
316    use super::*;
317
318    #[test]
319    fn sparse_embed_rejects_unsorted_indices() {
320        let e = SparseEmbed::new(vec![5, 3], vec![0.5, 0.5], "v0").unwrap_err();
321        assert!(matches!(e, SparseError::Config(_)));
322    }
323
324    #[test]
325    fn sparse_embed_rejects_length_mismatch() {
326        let e = SparseEmbed::new(vec![1, 2], vec![0.5], "v0").unwrap_err();
327        assert!(matches!(e, SparseError::Config(_)));
328    }
329
330    #[test]
331    fn from_unsorted_sorts_and_max_pools() {
332        let s = SparseEmbed::from_unsorted([(5, 0.1), (3, 0.9), (5, 0.3), (1, 0.2)], "v0");
333        assert_eq!(s.indices, vec![1, 3, 5]);
334        assert!(
335            (s.values[2] - 0.3).abs() < 1e-6,
336            "max-pool should keep 0.3 for index 5"
337        );
338    }
339
340    #[test]
341    fn from_unsorted_drops_zero_weights() {
342        let s = SparseEmbed::from_unsorted([(1, 0.0), (2, 0.5), (3, -0.1)], "v0");
343        assert_eq!(s.indices, vec![2]);
344    }
345
346    #[test]
347    fn dot_product_on_disjoint_is_zero() {
348        let a = SparseEmbed::new(vec![1, 2], vec![1.0, 1.0], "v").unwrap();
349        let b = SparseEmbed::new(vec![3, 4], vec![1.0, 1.0], "v").unwrap();
350        assert_eq!(a.dot(&b), Some(0.0));
351    }
352
353    #[test]
354    fn dot_product_on_overlap() {
355        let a = SparseEmbed::new(vec![1, 2, 5], vec![0.5, 0.5, 0.2], "v").unwrap();
356        let b = SparseEmbed::new(vec![2, 5, 9], vec![0.4, 0.3, 0.1], "v").unwrap();
357        // Overlap at 2 (0.5*0.4=0.2) and 5 (0.2*0.3=0.06) -> 0.26.
358        let d = a.dot(&b).unwrap();
359        assert!((d - 0.26).abs() < 1e-6, "got {d}");
360    }
361
362    #[test]
363    fn dot_product_different_vocabs_is_none() {
364        let a = SparseEmbed::new(vec![1], vec![1.0], "v0").unwrap();
365        let b = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
366        assert_eq!(a.dot(&b), None);
367    }
368
369    #[test]
370    fn mock_encoder_is_deterministic() {
371        let e = MockSparseEncoder::default();
372        let a = e.encode("hello world").unwrap();
373        let b = e.encode("hello world").unwrap();
374        assert_eq!(a, b);
375    }
376
377    #[test]
378    fn mock_encoder_empty_input_errors() {
379        let e = MockSparseEncoder::default();
380        assert!(matches!(
381            e.encode("   ").unwrap_err(),
382            SparseError::EmptyInput
383        ));
384    }
385
386    #[test]
387    fn mock_encoder_vocab_id_carries_through() {
388        let e = MockSparseEncoder::default();
389        let emb = e.encode("hello").unwrap();
390        assert_eq!(emb.vocab_id, e.vocab_id());
391    }
392
393    // --- SparseEmbed::validate() unit tests ---
394
395    #[test]
396    fn validate_ok_on_valid_sparse_embed() {
397        let se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
398        assert!(se.validate().is_ok());
399    }
400
401    #[test]
402    fn validate_rejects_non_ascending_indices() {
403        // Construct via new() with valid data, then corrupt indices in place.
404        // SparseEmbed fields are pub so direct assignment compiles fine.
405        let mut se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
406        se.indices = vec![99, 10, 42]; // non-ascending -- invalid
407        assert!(
408            se.validate().is_err(),
409            "validate() must reject non-ascending indices"
410        );
411    }
412
413    #[test]
414    fn validate_rejects_duplicate_indices() {
415        let mut se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
416        se.indices = vec![10, 10, 99]; // duplicate 10 -- not strictly ascending
417        assert!(
418            se.validate().is_err(),
419            "validate() must reject duplicate indices"
420        );
421    }
422
423    #[test]
424    fn validate_rejects_mismatched_lengths() {
425        let mut se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
426        se.values.push(0.1); // one extra value -- length mismatch
427        assert!(
428            se.validate().is_err(),
429            "validate() must reject mismatched indices/values lengths"
430        );
431    }
432}