mnem_core/objects/sparse_set.rs
1//! `SparseBucket` - per-node leaf object inside the Prolly sidecar
2//! that lifts the sparse embedding out of the
3//! [`Node`](super::Node) canonical bytes.
4//!
5//! # Why this exists
6//!
7//! When the sparse embedding lives inline on `Node`:
8//!
9//! ```text
10//! NodeCid = blake3(canonical_bytes(Node)) // includes sparse_embed
11//! ```
12//!
13//! Different sparse encoders and vocabulary differences produce
14//! different byte representations, so two machines indexing the same
15//! logical source text with different encoder versions produce different
16//! `NodeCid` values. That breaks mnem's federated-dedup promise.
17//!
18//! Fix: sparse embeddings live in a separate Prolly tree referenced by
19//! `Commit.sparse: Option<Cid>` (the sibling slot to
20//! `Commit.embeddings`). The tree is keyed by 16-byte truncated blake3
21//! of the `NodeCid` wire form; values are `SparseBucket`s carrying one
22//! `(vocab_id, SparseEmbed)` pair per indexed vocabulary. Identity bytes
23//! (Node) and derived bytes (SparseEmbed) are content-addressed
24//! independently. Vocab differences no longer leak into Node CIDs.
25//!
26//! # Pattern source
27//!
28//! Mirrors the [`EmbeddingBucket`](super::EmbeddingBucket) shape from
29//! G16 and the [`AdjacencyBucket`](super::AdjacencyBucket) shape from
30//! the existing [`IndexSet`](super::IndexSet) sidecar: sorted entry
31//! list inside each leaf, hand-rolled `Serialize`/`Deserialize`
32//! carrying a `_kind` discriminator and a `#[serde(flatten)] extra`
33//! forward-compat carrier so unrelated schema bumps stay
34//! round-trippable.
35
36use std::collections::BTreeMap;
37
38use ipld_core::ipld::Ipld;
39use serde::{Deserialize, Deserializer, Serialize, Serializer};
40
41use crate::sparse::SparseEmbed;
42
43/// Per-node bucket of sparse embeddings inside the
44/// [`Commit.sparse`](super::Commit::sparse) Prolly tree.
45///
46/// One bucket per node. Each bucket holds a sorted
47/// `(vocab_id, SparseEmbed)` list so a node may carry multiple
48/// sparse embeddings simultaneously - e.g. one BGE-M3 vector plus
49/// one OpenSearch-distill vector for the same chunk text. Lookups
50/// index into the bucket by `vocab_id` string after the outer Prolly
51/// walk has returned the bucket itself.
52#[derive(Clone, Debug, Default)]
53pub struct SparseBucket {
54 /// Entries sorted lexicographically by `vocab_id` for byte-stable
55 /// canonical form. The sort is enforced on every serialize, so
56 /// callers may push entries in any order without breaking CID
57 /// determinism on the bucket itself.
58 pub entries: Vec<SparseEntry>,
59 /// Forward-compat extension carrier. Unknown CBOR fields land
60 /// here on decode and are emitted verbatim on re-encode, so a
61 /// future schema bump that adds a per-bucket field stays
62 /// round-trippable on today's reader.
63 pub extra: BTreeMap<String, Ipld>,
64}
65
66impl SparseBucket {
67 /// On-wire `_kind` discriminator. Every content-addressed object
68 /// in mnem/0.x carries a `_kind` field as the first canonical key
69 /// so a corrupt bucket / wrong-type decode fails fast with an
70 /// actionable error instead of silently mis-decoding.
71 pub const KIND: &'static str = "sparse_bucket";
72
73 /// Look up a sparse embedding by vocab_id string. Returns `None` when
74 /// this bucket has no entry for the requested vocabulary; the caller
75 /// typically falls back to lazy compute via the configured sparse
76 /// encoder adapter.
77 #[must_use]
78 pub fn get(&self, vocab_id: &str) -> Option<&SparseEmbed> {
79 self.entries
80 .iter()
81 .find(|e| e.vocab_id == vocab_id)
82 .map(|e| &e.sparse)
83 }
84
85 /// Insert or replace an entry by `vocab_id`. The bucket does not
86 /// return the previous value (unlike `EmbeddingBucket::upsert`)
87 /// because `SparseEmbed` contains `Vec<f32>` which is not `PartialEq`
88 /// in a meaningful sense for callers here.
89 pub fn upsert(&mut self, vocab_id: String, sparse: SparseEmbed) {
90 if let Some(slot) = self.entries.iter_mut().find(|e| e.vocab_id == vocab_id) {
91 slot.sparse = sparse;
92 return;
93 }
94 self.entries.push(SparseEntry { vocab_id, sparse });
95 }
96
97 /// Remove an entry by `vocab_id`.
98 pub fn remove(&mut self, vocab_id: &str) {
99 if let Some(i) = self.entries.iter().position(|e| e.vocab_id == vocab_id) {
100 self.entries.remove(i);
101 }
102 }
103}
104
105/// One `(vocab_id, SparseEmbed)` pair inside a [`SparseBucket`].
106///
107/// Kept as a separate type rather than a tuple so future schema bumps
108/// can add per-entry fields (provenance, deprecation, signature) under
109/// the same canonical-form contract every other mnem object uses.
110#[derive(Clone, Debug, Serialize, Deserialize)]
111pub struct SparseEntry {
112 /// Vocabulary identifier. Conventionally a short string identifying
113 /// the encoder and vocab (e.g. `"bge-m3"`, `"opensearch-distill-v3"`).
114 /// The bucket indexes on this exact string for `get` / `upsert` / `remove`.
115 pub vocab_id: String,
116 /// The sparse embedding produced by the encoder. Contains `indices`
117 /// (token ids) and `values` (non-zero weights) alongside `vocab_id`.
118 pub sparse: SparseEmbed,
119}
120
121// ---------------- Serde wire shape ----------------
122//
123// Same hand-rolled pattern as `Node`/`Commit`/`EmbeddingBucket`:
124// internal `*Wire` mirror with explicit field defaults +
125// `_kind` discriminator + `extra` flatten. Encode sorts entries by
126// `vocab_id` so bucket bytes (and therefore the bucket CID) are
127// independent of insertion order. Decode rejects wrong `_kind`
128// values up front.
129
130#[derive(Serialize, Deserialize)]
131struct SparseBucketWire {
132 #[serde(rename = "_kind")]
133 kind: String,
134 #[serde(default)]
135 entries: Vec<SparseEntry>,
136 #[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
137 extra: BTreeMap<String, Ipld>,
138}
139
140impl Serialize for SparseBucket {
141 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
142 // Canonical order: sorted by `vocab_id`. We clone-then-sort
143 // rather than mutate the borrowed field - the public API
144 // does not promise any particular insertion order, but the
145 // wire form is contract-bound to be deterministic.
146 let mut sorted = self.entries.clone();
147 sorted.sort_by(|a, b| a.vocab_id.cmp(&b.vocab_id));
148 SparseBucketWire {
149 kind: Self::KIND.into(),
150 entries: sorted,
151 extra: self.extra.clone(),
152 }
153 .serialize(serializer)
154 }
155}
156
157impl<'de> Deserialize<'de> for SparseBucket {
158 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
159 let w = SparseBucketWire::deserialize(deserializer)?;
160 if w.kind != Self::KIND {
161 return Err(serde::de::Error::custom(format!(
162 "expected _kind='{}', got '{}'",
163 Self::KIND,
164 w.kind
165 )));
166 }
167 Ok(Self {
168 entries: w.entries,
169 extra: w.extra,
170 })
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 use crate::codec::{from_canonical_bytes, to_canonical_bytes};
178 use crate::sparse::SparseEmbed;
179
180 fn sample_sparse(vocab_id: &str) -> SparseEmbed {
181 SparseEmbed::new(vec![1, 5, 9], vec![0.5, 0.2, 0.1], vocab_id).unwrap()
182 }
183
184 #[test]
185 fn empty_bucket_round_trips() {
186 let original = SparseBucket::default();
187 let bytes = to_canonical_bytes(&original).unwrap();
188 let decoded: SparseBucket = from_canonical_bytes(&bytes).unwrap();
189 assert_eq!(original.entries.len(), decoded.entries.len());
190 let bytes2 = to_canonical_bytes(&decoded).unwrap();
191 assert_eq!(bytes, bytes2, "round-trip must be byte-identical");
192 }
193
194 #[test]
195 fn populated_bucket_round_trips() {
196 let mut bucket = SparseBucket::default();
197 bucket.upsert("bge-m3".into(), sample_sparse("bge-m3"));
198 bucket.upsert(
199 "opensearch-distill".into(),
200 sample_sparse("opensearch-distill"),
201 );
202 let bytes = to_canonical_bytes(&bucket).unwrap();
203 let decoded: SparseBucket = from_canonical_bytes(&bytes).unwrap();
204 assert_eq!(bucket.entries.len(), decoded.entries.len());
205 }
206
207 #[test]
208 fn wire_form_sorts_by_vocab_id_regardless_of_insert_order() {
209 // Insert in z-then-a order; canonical bytes must equal the
210 // alphabetical order's canonical bytes.
211 let mut a = SparseBucket::default();
212 a.upsert("zzz".into(), sample_sparse("zzz"));
213 a.upsert("aaa".into(), sample_sparse("aaa"));
214 let mut b = SparseBucket::default();
215 b.upsert("aaa".into(), sample_sparse("aaa"));
216 b.upsert("zzz".into(), sample_sparse("zzz"));
217 assert_eq!(
218 to_canonical_bytes(&a).unwrap(),
219 to_canonical_bytes(&b).unwrap(),
220 "encode must sort entries by vocab_id so bucket CIDs are insertion-order-invariant"
221 );
222 }
223
224 #[test]
225 fn wrong_kind_fails_decode() {
226 #[derive(Serialize)]
227 struct Wrong {
228 #[serde(rename = "_kind")]
229 kind: String,
230 entries: Vec<SparseEntry>,
231 }
232 let bytes = serde_ipld_dagcbor::to_vec(&Wrong {
233 kind: "node".into(),
234 entries: vec![],
235 })
236 .unwrap();
237 let res: Result<SparseBucket, _> = from_canonical_bytes(&bytes);
238 assert!(res.is_err(), "decode must reject wrong _kind discriminator");
239 let msg = format!("{}", res.unwrap_err());
240 assert!(
241 msg.contains("sparse_bucket"),
242 "error must reference the expected kind; got: {msg}"
243 );
244 }
245
246 #[test]
247 fn upsert_overwrites_existing_entry() {
248 let mut bucket = SparseBucket::default();
249 bucket.upsert("v0".into(), sample_sparse("v0"));
250 // Upsert again with different data.
251 let new_sparse = SparseEmbed::new(vec![100], vec![0.9], "v0").unwrap();
252 bucket.upsert("v0".into(), new_sparse);
253 // Only one entry should exist.
254 assert_eq!(bucket.entries.len(), 1);
255 assert_eq!(bucket.get("v0").unwrap().indices, vec![100]);
256 }
257
258 #[test]
259 fn get_finds_inserted_entry() {
260 let mut bucket = SparseBucket::default();
261 let sp = sample_sparse("v0");
262 bucket.upsert("v0".into(), sp.clone());
263 assert_eq!(bucket.get("v0").unwrap().vocab_id, "v0");
264 assert!(bucket.get("missing").is_none());
265 }
266
267 #[test]
268 fn extra_fields_round_trip() {
269 let mut bucket = SparseBucket::default();
270 bucket
271 .extra
272 .insert("future_field".into(), Ipld::String("forward-compat".into()));
273 let bytes = to_canonical_bytes(&bucket).unwrap();
274 let decoded: SparseBucket = from_canonical_bytes(&bytes).unwrap();
275 assert_eq!(bucket.extra.len(), decoded.extra.len());
276 assert!(decoded.extra.contains_key("future_field"));
277 }
278}