mnem_core/sparse.rs
1// SPLADE, BGE-M3, BEIR, WordPiece, OpenSearch are well-known external
2// identifiers; backticking every mention in the module doc degrades
3// rendered rustdoc readability.
4#![allow(clippy::doc_markdown)]
5
6//! Sparse (learned) embedding primitives for SPLADE / BGE-M3-sparse
7//! integration .
8//!
9//! # Why
10//!
11//! Learned-sparse retrievers (SPLADE v3, opensearch-doc-v3-distill,
12//! BGE-M3-sparse, granite-embedding-30m-sparse) produce a sparse
13//! vector over a WordPiece vocabulary that can be scored via an
14//! inverted index with semantic term weights learned end-to-end.
15//! BEIR nDCG@10 on sparse neural retrievers lands around +3-5 points
16//! over classical lexical keyword scoring on zero-shot domains; this
17//! lane replaces that legacy lexical lane entirely .
18//!
19//! # What this module provides
20//!
21//! - [`SparseEmbed`] - canonical sparse-vector shape (ascending
22//! `indices` + aligned `values`) with a `vocab_id` tag so two
23//! models with different vocabularies never get mixed in one
24//! posting list.
25//! - [`SparseEncoder`] trait - adapter-side hook for ONNX / candle
26//! backends to implement. Mirrors the [`crate::rerank::Reranker`]
27//! trait shape.
28//! - `MockSparseEncoder` - deterministic test-only encoder.
29//!
30//! The actual inverted-index over `SparseEmbed` values lives in
31//! [`crate::index::sparse`] so the index stays next to its sibling
32//! (brute-force vector index).
33//!
34//! Storage in [`crate::objects::Node`]: a future `Node.sparse_embed:
35//! Option<SparseEmbed>` field. Additive, so existing CIDs stay
36//! byte-identical because the serializer omits `None` via
37//! `skip_serializing_if`. CBOR canonicality is preserved because
38//! `indices` is sorted ascending at construction (checked by
39//! [`SparseEmbed::new`]).
40
41use std::fmt::Debug;
42
43use serde::{Deserialize, Serialize};
44use thiserror::Error;
45
46/// Error surface for sparse-encoder adapters. Same shape as
47/// [`crate::llm::LlmError`] and [`crate::rerank::RerankError`].
48#[derive(Debug, Error)]
49#[non_exhaustive]
50pub enum SparseError {
51 /// Network / transport failure when the adapter runs remotely
52 /// (sidecar) or fetches weights.
53 #[error("network error: {0}")]
54 Network(String),
55 /// Adapter config invalid (missing weights file, bad URL, etc.).
56 #[error("config error: {0}")]
57 Config(String),
58 /// Model / tokenizer returned an error.
59 #[error("inference error: {0}")]
60 Inference(String),
61 /// Caller attempted to encode empty text.
62 #[error("empty input")]
63 EmptyInput,
64}
65
66/// A sparse embedding over a fixed vocabulary.
67///
68/// `indices` MUST be strictly ascending; `values` MUST have the same
69/// length as `indices`. Both invariants are checked by [`Self::new`]
70/// and enforced on deserialise in a future CBOR round-trip test.
71/// `vocab_id` pins the model family so two adapters with different
72/// vocabs never fuse posting lists; compare as a string (e.g.
73/// `"bert-base-uncased@30522"`).
74#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
75pub struct SparseEmbed {
76 /// Token IDs in the vocabulary, strictly ascending.
77 pub indices: Vec<u32>,
78 /// Non-zero weights, aligned with `indices`.
79 pub values: Vec<f32>,
80 /// Vocabulary identifier.
81 pub vocab_id: String,
82}
83
84impl SparseEmbed {
85 /// Construct a [`SparseEmbed`]. Panics (debug) / errors (release)
86 /// if the invariants are violated. `indices` is taken as-is; if
87 /// the caller is unsure whether it is sorted, use
88 /// [`Self::from_unsorted`] instead.
89 ///
90 /// # Errors
91 ///
92 /// - [`SparseError::Config`] if `indices.len() != values.len()`
93 /// or `indices` contains duplicates / non-ascending entries.
94 pub fn new(
95 indices: Vec<u32>,
96 values: Vec<f32>,
97 vocab_id: impl Into<String>,
98 ) -> Result<Self, SparseError> {
99 if indices.len() != values.len() {
100 return Err(SparseError::Config(format!(
101 "indices.len() {} != values.len() {}",
102 indices.len(),
103 values.len()
104 )));
105 }
106 for w in indices.windows(2) {
107 if w[0] >= w[1] {
108 return Err(SparseError::Config(format!(
109 "indices must be strictly ascending; saw {} then {}",
110 w[0], w[1]
111 )));
112 }
113 }
114 Ok(Self {
115 indices,
116 values,
117 vocab_id: vocab_id.into(),
118 })
119 }
120
121 /// Construct from an unsorted `(index, value)` pair list.
122 /// Duplicate indices are kept as the maximum value (SPLADE's
123 /// own pooling rule). Useful from ONNX-side decoders that
124 /// produce vectors in token-emission order.
125 pub fn from_unsorted(
126 pairs: impl IntoIterator<Item = (u32, f32)>,
127 vocab_id: impl Into<String>,
128 ) -> Self {
129 use std::collections::BTreeMap;
130 let mut bucket: BTreeMap<u32, f32> = BTreeMap::new();
131 for (i, v) in pairs {
132 let e = bucket.entry(i).or_insert(f32::NEG_INFINITY);
133 if v > *e {
134 *e = v;
135 }
136 }
137 let (indices, values): (Vec<_>, Vec<_>) =
138 bucket.into_iter().filter(|(_, v)| *v > 0.0).unzip();
139 Self {
140 indices,
141 values,
142 vocab_id: vocab_id.into(),
143 }
144 }
145
146 /// Number of non-zero entries.
147 #[must_use]
148 pub const fn nnz(&self) -> usize {
149 self.indices.len()
150 }
151
152 /// Dot product with another sparse embedding (must share vocab_id).
153 /// Returns `None` if the vocab_ids differ.
154 #[must_use]
155 pub fn dot(&self, other: &Self) -> Option<f32> {
156 if self.vocab_id != other.vocab_id {
157 return None;
158 }
159 let mut i = 0;
160 let mut j = 0;
161 let mut sum = 0.0f32;
162 while i < self.indices.len() && j < other.indices.len() {
163 use std::cmp::Ordering;
164 match self.indices[i].cmp(&other.indices[j]) {
165 Ordering::Less => i += 1,
166 Ordering::Greater => j += 1,
167 Ordering::Equal => {
168 sum += self.values[i] * other.values[j];
169 i += 1;
170 j += 1;
171 }
172 }
173 }
174 Some(sum)
175 }
176}
177
178/// Learned-sparse encoder: given text, produce a [`SparseEmbed`]
179/// over a fixed vocabulary. Adapter crates implement this over
180/// SPLADE-ONNX, BGE-M3-sparse-ONNX, or a remote sidecar.
181pub trait SparseEncoder: Send + Sync + Debug {
182 /// Provider + model identifier. Lowercase, colon-separated by
183 /// convention (e.g. `"splade:opensearch-doc-v3-distill"`,
184 /// `"bgem3:sparse"`, `"mock:len-inverse"`).
185 fn model(&self) -> &str;
186
187 /// Vocabulary identifier. Passed through to
188 /// [`SparseEmbed::vocab_id`] on every emitted embedding.
189 fn vocab_id(&self) -> &str;
190
191 /// Encode a document-side text string into a sparse vector.
192 /// This is the path run at ingest time.
193 ///
194 /// # Errors
195 ///
196 /// Any [`SparseError`] the adapter surfaces. The caller fallback
197 /// policy matches the rerank / LLM pattern: on error, the sparse
198 /// lane is simply dropped from fusion and the hybrid still runs.
199 fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError>;
200
201 /// Encode a query-side text string into a sparse vector.
202 ///
203 /// Default implementation delegates to [`Self::encode`]. Adapters
204 /// with asymmetric inference (OpenSearch
205 /// `neural-sparse-encoding-doc-v3-distill` ships a distilled
206 /// `idf.json` table so the query side is tokenise + IDF-lookup
207 /// with zero neural compute) override this to skip the forward
208 /// pass. The overridden path keeps retrieval latency microsecond-
209 /// level even when documents use a 67M-parameter encoder.
210 fn encode_query(&self, text: &str) -> Result<SparseEmbed, SparseError> {
211 self.encode(text)
212 }
213}
214
215/// FNV-1a 32-bit offset basis. Standard value from the FNV
216/// specification (Fowler-Noll-Vo hash, Landon Curt Noll, 1991);
217/// see <http://www.isthe.com/chongo/tech/comp/fnv/>. Used as the
218/// seed state for the `MockSparseEncoder` token hash.
219const FNV_OFFSET_BASIS_32: u32 = 2_166_136_261;
220
221/// FNV-1a 32-bit prime. Standard value from the FNV specification
222/// (<http://www.isthe.com/chongo/tech/comp/fnv/>). Each byte of the
223/// input is XOR-then-multiplied by this prime to diffuse bits
224/// across the output word.
225const FNV_PRIME_32: u32 = 16_777_619;
226
227/// Mock vocabulary width. The token hash is reduced modulo this
228/// number of slots so the encoder emits indices in `0..1024`,
229/// matching the `"mock:1024"` `vocab_id` tag. Kept tiny so tests
230/// exercise collision handling cheaply.
231const MOCK_VOCAB_SIZE: u32 = 1024;
232
233/// Deterministic test-only encoder. Produces a `SparseEmbed` by
234/// hashing each whitespace-separated token into the first 1024
235/// vocabulary slots with a length-inverse weight
236/// (1.0 / (1.0 + token_len)).
237///
238/// Not a real SPLADE; do not use in benchmarks. Its purpose is to
239/// let `Retriever::with_sparse_ranker(...)` unit-test the fusion
240/// lane without pulling ONNX Runtime into `mnem-core`'s test deps.
241#[derive(Debug, Clone)]
242pub struct MockSparseEncoder {
243 vocab_id: String,
244}
245
246impl Default for MockSparseEncoder {
247 fn default() -> Self {
248 Self {
249 vocab_id: "mock:1024".into(),
250 }
251 }
252}
253
254impl SparseEncoder for MockSparseEncoder {
255 fn model(&self) -> &str {
256 "mock:len-inverse"
257 }
258
259 fn vocab_id(&self) -> &str {
260 &self.vocab_id
261 }
262
263 fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError> {
264 if text.trim().is_empty() {
265 return Err(SparseError::EmptyInput);
266 }
267 let pairs = text.split_whitespace().map(|tok| {
268 let h: u32 = tok.bytes().fold(FNV_OFFSET_BASIS_32, |acc, b| {
269 acc.wrapping_mul(FNV_PRIME_32).wrapping_add(u32::from(b))
270 });
271 let idx = h % MOCK_VOCAB_SIZE;
272 let weight = 1.0f32 / (1.0 + tok.len() as f32);
273 (idx, weight)
274 });
275 Ok(SparseEmbed::from_unsorted(pairs, &self.vocab_id))
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn sparse_embed_rejects_unsorted_indices() {
285 let e = SparseEmbed::new(vec![5, 3], vec![0.5, 0.5], "v0").unwrap_err();
286 assert!(matches!(e, SparseError::Config(_)));
287 }
288
289 #[test]
290 fn sparse_embed_rejects_length_mismatch() {
291 let e = SparseEmbed::new(vec![1, 2], vec![0.5], "v0").unwrap_err();
292 assert!(matches!(e, SparseError::Config(_)));
293 }
294
295 #[test]
296 fn from_unsorted_sorts_and_max_pools() {
297 let s = SparseEmbed::from_unsorted([(5, 0.1), (3, 0.9), (5, 0.3), (1, 0.2)], "v0");
298 assert_eq!(s.indices, vec![1, 3, 5]);
299 assert!(
300 (s.values[2] - 0.3).abs() < 1e-6,
301 "max-pool should keep 0.3 for index 5"
302 );
303 }
304
305 #[test]
306 fn from_unsorted_drops_zero_weights() {
307 let s = SparseEmbed::from_unsorted([(1, 0.0), (2, 0.5), (3, -0.1)], "v0");
308 assert_eq!(s.indices, vec![2]);
309 }
310
311 #[test]
312 fn dot_product_on_disjoint_is_zero() {
313 let a = SparseEmbed::new(vec![1, 2], vec![1.0, 1.0], "v").unwrap();
314 let b = SparseEmbed::new(vec![3, 4], vec![1.0, 1.0], "v").unwrap();
315 assert_eq!(a.dot(&b), Some(0.0));
316 }
317
318 #[test]
319 fn dot_product_on_overlap() {
320 let a = SparseEmbed::new(vec![1, 2, 5], vec![0.5, 0.5, 0.2], "v").unwrap();
321 let b = SparseEmbed::new(vec![2, 5, 9], vec![0.4, 0.3, 0.1], "v").unwrap();
322 // Overlap at 2 (0.5*0.4=0.2) and 5 (0.2*0.3=0.06) -> 0.26.
323 let d = a.dot(&b).unwrap();
324 assert!((d - 0.26).abs() < 1e-6, "got {d}");
325 }
326
327 #[test]
328 fn dot_product_different_vocabs_is_none() {
329 let a = SparseEmbed::new(vec![1], vec![1.0], "v0").unwrap();
330 let b = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
331 assert_eq!(a.dot(&b), None);
332 }
333
334 #[test]
335 fn mock_encoder_is_deterministic() {
336 let e = MockSparseEncoder::default();
337 let a = e.encode("hello world").unwrap();
338 let b = e.encode("hello world").unwrap();
339 assert_eq!(a, b);
340 }
341
342 #[test]
343 fn mock_encoder_empty_input_errors() {
344 let e = MockSparseEncoder::default();
345 assert!(matches!(
346 e.encode(" ").unwrap_err(),
347 SparseError::EmptyInput
348 ));
349 }
350
351 #[test]
352 fn mock_encoder_vocab_id_carries_through() {
353 let e = MockSparseEncoder::default();
354 let emb = e.encode("hello").unwrap();
355 assert_eq!(emb.vocab_id, e.vocab_id());
356 }
357}