mnem_core/sparse.rs
1// SPLADE, BGE-M3, BEIR, WordPiece, OpenSearch are well-known external
2// identifiers; backticking every mention in the module doc degrades
3// rendered rustdoc readability.
4#![allow(clippy::doc_markdown)]
5
6//! Sparse (learned) embedding primitives for SPLADE / BGE-M3-sparse
7//! integration .
8//!
9//! # Why
10//!
11//! Learned-sparse retrievers (SPLADE v3, opensearch-doc-v3-distill,
12//! BGE-M3-sparse, granite-embedding-30m-sparse) produce a sparse
13//! vector over a WordPiece vocabulary that can be scored via an
14//! inverted index with semantic term weights learned end-to-end.
15//! BEIR nDCG@10 on sparse neural retrievers lands around +3-5 points
16//! over classical lexical keyword scoring on zero-shot domains; this
17//! lane replaces that legacy lexical lane entirely .
18//!
19//! # What this module provides
20//!
21//! - [`SparseEmbed`] - canonical sparse-vector shape (ascending
22//! `indices` + aligned `values`) with a `vocab_id` tag so two
23//! models with different vocabularies never get mixed in one
24//! posting list.
25//! - [`SparseEncoder`] trait - adapter-side hook for ONNX / candle
26//! backends to implement. Mirrors the [`crate::rerank::Reranker`]
27//! trait shape.
28//! - `MockSparseEncoder` - deterministic test-only encoder.
29//!
30//! The actual inverted-index over `SparseEmbed` values lives in
31//! [`crate::index::sparse`] so the index stays next to its sibling
32//! (brute-force vector index).
33//!
34//! Storage in [`crate::objects::Node`]: a future `Node.sparse_embed:
35//! Option<SparseEmbed>` field. Additive, so existing CIDs stay
36//! byte-identical because the serializer omits `None` via
37//! `skip_serializing_if`. CBOR canonicality is preserved because
38//! `indices` is sorted ascending at construction (checked by
39//! [`SparseEmbed::new`]).
40
41use std::fmt::Debug;
42
43use serde::{Deserialize, Serialize};
44use thiserror::Error;
45
46/// Error surface for sparse-encoder adapters. Same shape as
47/// [`crate::llm::LlmError`] and [`crate::rerank::RerankError`].
48#[derive(Debug, Error)]
49#[non_exhaustive]
50pub enum SparseError {
51 /// Network / transport failure when the adapter runs remotely
52 /// (sidecar) or fetches weights.
53 #[error("network error: {0}")]
54 Network(String),
55 /// Adapter config invalid (missing weights file, bad URL, etc.).
56 #[error("config error: {0}")]
57 Config(String),
58 /// Model / tokenizer returned an error.
59 #[error("inference error: {0}")]
60 Inference(String),
61 /// Caller attempted to encode empty text.
62 #[error("empty input")]
63 EmptyInput,
64}
65
66/// A sparse embedding over a fixed vocabulary.
67///
68/// `indices` MUST be strictly ascending; `values` MUST have the same
69/// length as `indices`. Both invariants are checked by [`Self::new`]
70/// and enforced on deserialise in a future CBOR round-trip test.
71/// `vocab_id` pins the model family so two adapters with different
72/// vocabs never fuse posting lists; compare as a string (e.g.
73/// `"bert-base-uncased@30522"`).
74#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
75pub struct SparseEmbed {
76 /// Token IDs in the vocabulary, strictly ascending.
77 pub indices: Vec<u32>,
78 /// Non-zero weights, aligned with `indices`.
79 pub values: Vec<f32>,
80 /// Vocabulary identifier.
81 pub vocab_id: String,
82}
83
84impl SparseEmbed {
85 /// Construct a [`SparseEmbed`]. Panics (debug) / errors (release)
86 /// if the invariants are violated. `indices` is taken as-is; if
87 /// the caller is unsure whether it is sorted, use
88 /// [`Self::from_unsorted`] instead.
89 ///
90 /// # Errors
91 ///
92 /// - [`SparseError::Config`] if `indices.len() != values.len()`
93 /// or `indices` contains duplicates / non-ascending entries.
94 pub fn new(
95 indices: Vec<u32>,
96 values: Vec<f32>,
97 vocab_id: impl Into<String>,
98 ) -> Result<Self, SparseError> {
99 if indices.len() != values.len() {
100 return Err(SparseError::Config(format!(
101 "indices.len() {} != values.len() {}",
102 indices.len(),
103 values.len()
104 )));
105 }
106 for w in indices.windows(2) {
107 if w[0] >= w[1] {
108 return Err(SparseError::Config(format!(
109 "indices must be strictly ascending; saw {} then {}",
110 w[0], w[1]
111 )));
112 }
113 }
114 Ok(Self {
115 indices,
116 values,
117 vocab_id: vocab_id.into(),
118 })
119 }
120
121 /// Construct from an unsorted `(index, value)` pair list.
122 /// Duplicate indices are kept as the maximum value (SPLADE's
123 /// own pooling rule). Useful from ONNX-side decoders that
124 /// produce vectors in token-emission order.
125 pub fn from_unsorted(
126 pairs: impl IntoIterator<Item = (u32, f32)>,
127 vocab_id: impl Into<String>,
128 ) -> Self {
129 use std::collections::BTreeMap;
130 let mut bucket: BTreeMap<u32, f32> = BTreeMap::new();
131 for (i, v) in pairs {
132 let e = bucket.entry(i).or_insert(f32::NEG_INFINITY);
133 if v > *e {
134 *e = v;
135 }
136 }
137 let (indices, values): (Vec<_>, Vec<_>) =
138 bucket.into_iter().filter(|(_, v)| *v > 0.0).unzip();
139 Self {
140 indices,
141 values,
142 vocab_id: vocab_id.into(),
143 }
144 }
145
146 /// Validate the two invariants that [`SparseEmbed::new`] enforces at
147 /// construction time:
148 ///
149 /// 1. `indices.len() == values.len()` (aligned parallel arrays)
150 /// 2. `indices` is strictly ascending (no duplicates, no out-of-order)
151 ///
152 /// Call this after deserialising from untrusted bytes (e.g. IPLD round-
153 /// trips from a pre-G17 `extra["sparse_embed"]` node) to catch corrupt
154 /// data before it reaches the sparse sidecar.
155 ///
156 /// # Errors
157 ///
158 /// Returns a descriptive `String` (via `Err(String)`) if either
159 /// invariant is violated so callers can surface the message in a
160 /// warning without pulling in a new error type.
161 pub fn validate(&self) -> Result<(), String> {
162 if self.indices.len() != self.values.len() {
163 return Err(format!(
164 "SparseEmbed invariant violated: indices.len() {} != values.len() {}",
165 self.indices.len(),
166 self.values.len()
167 ));
168 }
169 for w in self.indices.windows(2) {
170 if w[0] >= w[1] {
171 return Err(format!(
172 "SparseEmbed invariant violated: indices must be strictly ascending; \
173 saw {} then {}",
174 w[0], w[1]
175 ));
176 }
177 }
178 Ok(())
179 }
180
181 /// Number of non-zero entries.
182 #[must_use]
183 pub const fn nnz(&self) -> usize {
184 self.indices.len()
185 }
186
187 /// Dot product with another sparse embedding (must share vocab_id).
188 /// Returns `None` if the vocab_ids differ.
189 #[must_use]
190 pub fn dot(&self, other: &Self) -> Option<f32> {
191 if self.vocab_id != other.vocab_id {
192 return None;
193 }
194 let mut i = 0;
195 let mut j = 0;
196 let mut sum = 0.0f32;
197 while i < self.indices.len() && j < other.indices.len() {
198 use std::cmp::Ordering;
199 match self.indices[i].cmp(&other.indices[j]) {
200 Ordering::Less => i += 1,
201 Ordering::Greater => j += 1,
202 Ordering::Equal => {
203 sum += self.values[i] * other.values[j];
204 i += 1;
205 j += 1;
206 }
207 }
208 }
209 Some(sum)
210 }
211}
212
213/// Learned-sparse encoder: given text, produce a [`SparseEmbed`]
214/// over a fixed vocabulary. Adapter crates implement this over
215/// SPLADE-ONNX, BGE-M3-sparse-ONNX, or a remote sidecar.
216pub trait SparseEncoder: Send + Sync + Debug {
217 /// Provider + model identifier. Lowercase, colon-separated by
218 /// convention (e.g. `"splade:opensearch-doc-v3-distill"`,
219 /// `"bgem3:sparse"`, `"mock:len-inverse"`).
220 fn model(&self) -> &str;
221
222 /// Vocabulary identifier. Passed through to
223 /// [`SparseEmbed::vocab_id`] on every emitted embedding.
224 fn vocab_id(&self) -> &str;
225
226 /// Encode a document-side text string into a sparse vector.
227 /// This is the path run at ingest time.
228 ///
229 /// # Errors
230 ///
231 /// Any [`SparseError`] the adapter surfaces. The caller fallback
232 /// policy matches the rerank / LLM pattern: on error, the sparse
233 /// lane is simply dropped from fusion and the hybrid still runs.
234 fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError>;
235
236 /// Encode a query-side text string into a sparse vector.
237 ///
238 /// Default implementation delegates to [`Self::encode`]. Adapters
239 /// with asymmetric inference (OpenSearch
240 /// `neural-sparse-encoding-doc-v3-distill` ships a distilled
241 /// `idf.json` table so the query side is tokenise + IDF-lookup
242 /// with zero neural compute) override this to skip the forward
243 /// pass. The overridden path keeps retrieval latency microsecond-
244 /// level even when documents use a 67M-parameter encoder.
245 fn encode_query(&self, text: &str) -> Result<SparseEmbed, SparseError> {
246 self.encode(text)
247 }
248}
249
250/// FNV-1a 32-bit offset basis. Standard value from the FNV
251/// specification (Fowler-Noll-Vo hash, Landon Curt Noll, 1991);
252/// see <http://www.isthe.com/chongo/tech/comp/fnv/>. Used as the
253/// seed state for the `MockSparseEncoder` token hash.
254const FNV_OFFSET_BASIS_32: u32 = 2_166_136_261;
255
256/// FNV-1a 32-bit prime. Standard value from the FNV specification
257/// (<http://www.isthe.com/chongo/tech/comp/fnv/>). Each byte of the
258/// input is XOR-then-multiplied by this prime to diffuse bits
259/// across the output word.
260const FNV_PRIME_32: u32 = 16_777_619;
261
262/// Mock vocabulary width. The token hash is reduced modulo this
263/// number of slots so the encoder emits indices in `0..1024`,
264/// matching the `"mock:1024"` `vocab_id` tag. Kept tiny so tests
265/// exercise collision handling cheaply.
266const MOCK_VOCAB_SIZE: u32 = 1024;
267
268/// Deterministic test-only encoder. Produces a `SparseEmbed` by
269/// hashing each whitespace-separated token into the first 1024
270/// vocabulary slots with a length-inverse weight
271/// (1.0 / (1.0 + token_len)).
272///
273/// Not a real SPLADE; do not use in benchmarks. Its purpose is to
274/// let `Retriever::with_sparse_ranker(...)` unit-test the fusion
275/// lane without pulling ONNX Runtime into `mnem-core`'s test deps.
276#[derive(Debug, Clone)]
277pub struct MockSparseEncoder {
278 vocab_id: String,
279}
280
281impl Default for MockSparseEncoder {
282 fn default() -> Self {
283 Self {
284 vocab_id: "mock:1024".into(),
285 }
286 }
287}
288
289impl SparseEncoder for MockSparseEncoder {
290 fn model(&self) -> &str {
291 "mock:len-inverse"
292 }
293
294 fn vocab_id(&self) -> &str {
295 &self.vocab_id
296 }
297
298 fn encode(&self, text: &str) -> Result<SparseEmbed, SparseError> {
299 if text.trim().is_empty() {
300 return Err(SparseError::EmptyInput);
301 }
302 let pairs = text.split_whitespace().map(|tok| {
303 let h: u32 = tok.bytes().fold(FNV_OFFSET_BASIS_32, |acc, b| {
304 acc.wrapping_mul(FNV_PRIME_32).wrapping_add(u32::from(b))
305 });
306 let idx = h % MOCK_VOCAB_SIZE;
307 let weight = 1.0f32 / (1.0 + tok.len() as f32);
308 (idx, weight)
309 });
310 Ok(SparseEmbed::from_unsorted(pairs, &self.vocab_id))
311 }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 #[test]
319 fn sparse_embed_rejects_unsorted_indices() {
320 let e = SparseEmbed::new(vec![5, 3], vec![0.5, 0.5], "v0").unwrap_err();
321 assert!(matches!(e, SparseError::Config(_)));
322 }
323
324 #[test]
325 fn sparse_embed_rejects_length_mismatch() {
326 let e = SparseEmbed::new(vec![1, 2], vec![0.5], "v0").unwrap_err();
327 assert!(matches!(e, SparseError::Config(_)));
328 }
329
330 #[test]
331 fn from_unsorted_sorts_and_max_pools() {
332 let s = SparseEmbed::from_unsorted([(5, 0.1), (3, 0.9), (5, 0.3), (1, 0.2)], "v0");
333 assert_eq!(s.indices, vec![1, 3, 5]);
334 assert!(
335 (s.values[2] - 0.3).abs() < 1e-6,
336 "max-pool should keep 0.3 for index 5"
337 );
338 }
339
340 #[test]
341 fn from_unsorted_drops_zero_weights() {
342 let s = SparseEmbed::from_unsorted([(1, 0.0), (2, 0.5), (3, -0.1)], "v0");
343 assert_eq!(s.indices, vec![2]);
344 }
345
346 #[test]
347 fn dot_product_on_disjoint_is_zero() {
348 let a = SparseEmbed::new(vec![1, 2], vec![1.0, 1.0], "v").unwrap();
349 let b = SparseEmbed::new(vec![3, 4], vec![1.0, 1.0], "v").unwrap();
350 assert_eq!(a.dot(&b), Some(0.0));
351 }
352
353 #[test]
354 fn dot_product_on_overlap() {
355 let a = SparseEmbed::new(vec![1, 2, 5], vec![0.5, 0.5, 0.2], "v").unwrap();
356 let b = SparseEmbed::new(vec![2, 5, 9], vec![0.4, 0.3, 0.1], "v").unwrap();
357 // Overlap at 2 (0.5*0.4=0.2) and 5 (0.2*0.3=0.06) -> 0.26.
358 let d = a.dot(&b).unwrap();
359 assert!((d - 0.26).abs() < 1e-6, "got {d}");
360 }
361
362 #[test]
363 fn dot_product_different_vocabs_is_none() {
364 let a = SparseEmbed::new(vec![1], vec![1.0], "v0").unwrap();
365 let b = SparseEmbed::new(vec![1], vec![1.0], "v1").unwrap();
366 assert_eq!(a.dot(&b), None);
367 }
368
369 #[test]
370 fn mock_encoder_is_deterministic() {
371 let e = MockSparseEncoder::default();
372 let a = e.encode("hello world").unwrap();
373 let b = e.encode("hello world").unwrap();
374 assert_eq!(a, b);
375 }
376
377 #[test]
378 fn mock_encoder_empty_input_errors() {
379 let e = MockSparseEncoder::default();
380 assert!(matches!(
381 e.encode(" ").unwrap_err(),
382 SparseError::EmptyInput
383 ));
384 }
385
386 #[test]
387 fn mock_encoder_vocab_id_carries_through() {
388 let e = MockSparseEncoder::default();
389 let emb = e.encode("hello").unwrap();
390 assert_eq!(emb.vocab_id, e.vocab_id());
391 }
392
393 // --- SparseEmbed::validate() unit tests ---
394
395 #[test]
396 fn validate_ok_on_valid_sparse_embed() {
397 let se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
398 assert!(se.validate().is_ok());
399 }
400
401 #[test]
402 fn validate_rejects_non_ascending_indices() {
403 // Construct via new() with valid data, then corrupt indices in place.
404 // SparseEmbed fields are pub so direct assignment compiles fine.
405 let mut se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
406 se.indices = vec![99, 10, 42]; // non-ascending -- invalid
407 assert!(
408 se.validate().is_err(),
409 "validate() must reject non-ascending indices"
410 );
411 }
412
413 #[test]
414 fn validate_rejects_duplicate_indices() {
415 let mut se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
416 se.indices = vec![10, 10, 99]; // duplicate 10 -- not strictly ascending
417 assert!(
418 se.validate().is_err(),
419 "validate() must reject duplicate indices"
420 );
421 }
422
423 #[test]
424 fn validate_rejects_mismatched_lengths() {
425 let mut se = SparseEmbed::new(vec![10, 42, 99], vec![0.8, 0.5, 0.2], "vocab").unwrap();
426 se.values.push(0.1); // one extra value -- length mismatch
427 assert!(
428 se.validate().is_err(),
429 "validate() must reject mismatched indices/values lengths"
430 );
431 }
432}