mnem_core/objects/embedding_set.rs
1//! `EmbeddingBucket` - per-node leaf object inside the Prolly sidecar
2//! that lifts the embedding vector out of the
3//! [`Node`](super::Node) canonical bytes.
4//!
5//! # Why this exists
6//!
7//! When the embedding vector lives inline on `Node`:
8//!
9//! ```text
10//! NodeCid = blake3(canonical_bytes(Node)) // includes embed.vector
11//! ```
12//!
13//! ORT reorders f32 sums across thread counts (TBB-style work-stealing
14//! reductions are not associative on `f32`), so two machines re-deriving
15//! the same source text on different core counts produce vectors that
16//! differ in the last bit. Different vector → different Node bytes →
17//! different `NodeCid` for embed-bearing chunks. That breaks mnem's
18//! "two machines indexing the same logical event produce identical
19//! Node CIDs" federated-dedup promise as soon as the runtime uses
20//! `available_parallelism()` instead of a single thread.
21//!
22//! Fix: vectors live in a separate Prolly tree referenced by
23//! `Commit.embeddings: Option<Cid>` (the sibling slot to
24//! `Commit.indexes`). The tree is keyed by 32-byte `NodeCid` digest;
25//! values are `EmbeddingBucket`s carrying one `(model, Embedding)`
26//! pair per simultaneously-indexed embedder. Identity bytes (Node)
27//! and derived bytes (Embedding) are content-addressed independently.
28//! Multi-thread ORT no longer leaks into Node CIDs.
29//!
30//! # Pattern source
31//!
32//! Mirrors the [`AdjacencyBucket`](super::AdjacencyBucket) shape from
33//! the existing [`IndexSet`](super::IndexSet) sidecar: sorted entry
34//! list inside each leaf, hand-rolled `Serialize`/`Deserialize`
35//! carrying a `_kind` discriminator and a `#[serde(flatten)] extra`
36//! forward-compat carrier so unrelated schema bumps stay
37//! round-trippable.
38
39use std::collections::BTreeMap;
40
41use ipld_core::ipld::Ipld;
42use serde::{Deserialize, Deserializer, Serialize, Serializer};
43
44use super::node::Embedding;
45
46/// Per-node bucket of embeddings inside the
47/// [`Commit.embeddings`](super::Commit::embeddings) Prolly tree.
48///
49/// One bucket per node. Each bucket holds a sorted
50/// `(model, Embedding)` list so a node may carry multiple
51/// embeddings simultaneously - e.g. one local MiniLM vector plus
52/// one OpenAI vector for the same chunk text. Lookups index into
53/// the bucket by `model` string after the outer Prolly walk has
54/// returned the bucket itself.
55#[derive(Clone, Debug, Default, PartialEq, Eq)]
56pub struct EmbeddingBucket {
57 /// Entries sorted lexicographically by `model` for byte-stable
58 /// canonical form. The sort is enforced on every serialize, so
59 /// callers may push entries in any order without breaking CID
60 /// determinism on the bucket itself.
61 pub entries: Vec<EmbeddingEntry>,
62 /// Forward-compat extension carrier. Unknown CBOR fields land
63 /// here on decode and are emitted verbatim on re-encode, so a
64 /// future schema bump that adds a per-bucket field stays
65 /// round-trippable on today's reader.
66 pub extra: BTreeMap<String, Ipld>,
67}
68
69impl EmbeddingBucket {
70 /// On-wire `_kind` discriminator. Every content-addressed object
71 /// in mnem/0.x carries a `_kind` field as the first canonical key
72 /// so a corrupt bucket / wrong-type decode fails fast with an
73 /// actionable error instead of silently mis-decoding.
74 pub const KIND: &'static str = "embedding_bucket";
75
76 /// Look up an embedding by model string. Returns `None` when this
77 /// bucket has no entry for the requested embedder; the caller
78 /// typically falls back to lazy compute via the configured
79 /// embed provider.
80 #[must_use]
81 pub fn get(&self, model: &str) -> Option<&Embedding> {
82 self.entries
83 .iter()
84 .find(|e| e.model == model)
85 .map(|e| &e.embedding)
86 }
87
88 /// Insert or replace an entry by `model`. Returns the previous
89 /// embedding for that model when one existed (so callers can
90 /// detect a refresh vs first write).
91 pub fn upsert(&mut self, model: String, embedding: Embedding) -> Option<Embedding> {
92 if let Some(slot) = self.entries.iter_mut().find(|e| e.model == model) {
93 return Some(std::mem::replace(&mut slot.embedding, embedding));
94 }
95 self.entries.push(EmbeddingEntry { model, embedding });
96 None
97 }
98
99 /// Remove an entry by `model`. Returns the removed embedding when
100 /// one existed.
101 pub fn remove(&mut self, model: &str) -> Option<Embedding> {
102 let i = self.entries.iter().position(|e| e.model == model)?;
103 Some(self.entries.remove(i).embedding)
104 }
105}
106
107/// One `(model, Embedding)` pair inside an [`EmbeddingBucket`].
108///
109/// Kept as a separate type rather than a tuple so future schema bumps
110/// can add per-entry fields (provenance, deprecation, signature) under
111/// the same canonical-form contract every other mnem object uses.
112#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
113pub struct EmbeddingEntry {
114 /// Embedder identifier. Conventionally `"<provider>:<model>"`
115 /// (matches the `model` string inside `Embedding`); the bucket
116 /// indexes on this exact string for `get` / `upsert` / `remove`.
117 pub model: String,
118 /// The embedding vector and dim/dtype metadata. Its own
119 /// `validate()` invariant (`vector.len() == dim * dtype.byte_width()`)
120 /// is enforced where embeddings are constructed; the bucket does
121 /// not re-validate on decode (cheap-decode contract). Untrusted
122 /// callers (HTTP / MCP / replication) are expected to call
123 /// `Embedding::validate()` themselves before storing.
124 pub embedding: Embedding,
125}
126
127// ---------------- Serde wire shape ----------------
128//
129// Same hand-rolled pattern as `Node`/`Commit`/`AdjacencyBucket`:
130// internal `*Wire` mirror with explicit field defaults +
131// `_kind` discriminator + `extra` flatten. Encode sorts entries by
132// `model` so bucket bytes (and therefore the bucket CID) are
133// independent of insertion order. Decode rejects wrong `_kind`
134// values up front.
135
136#[derive(Serialize, Deserialize)]
137struct EmbeddingBucketWire {
138 #[serde(rename = "_kind")]
139 kind: String,
140 #[serde(default)]
141 entries: Vec<EmbeddingEntry>,
142 #[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
143 extra: BTreeMap<String, Ipld>,
144}
145
146impl Serialize for EmbeddingBucket {
147 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
148 // Canonical order: sorted by `model`. We clone-then-sort
149 // rather than mutate the borrowed field - the public API
150 // does not promise any particular insertion order, but the
151 // wire form is contract-bound to be deterministic.
152 let mut sorted = self.entries.clone();
153 sorted.sort_by(|a, b| a.model.cmp(&b.model));
154 EmbeddingBucketWire {
155 kind: Self::KIND.into(),
156 entries: sorted,
157 extra: self.extra.clone(),
158 }
159 .serialize(serializer)
160 }
161}
162
163impl<'de> Deserialize<'de> for EmbeddingBucket {
164 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
165 let w = EmbeddingBucketWire::deserialize(deserializer)?;
166 if w.kind != Self::KIND {
167 return Err(serde::de::Error::custom(format!(
168 "expected _kind='{}', got '{}'",
169 Self::KIND,
170 w.kind
171 )));
172 }
173 Ok(Self {
174 entries: w.entries,
175 extra: w.extra,
176 })
177 }
178}
179
180#[cfg(test)]
181mod tests {
182 use super::*;
183 use crate::codec::{from_canonical_bytes, to_canonical_bytes};
184 use crate::objects::node::Dtype;
185
186 fn sample_embedding(model: &str, dim: u32) -> Embedding {
187 // Cheap deterministic dummy vector: one f32 per dim, all
188 // zeroes. Exercises the validate invariant while staying
189 // independent of any embedder.
190 let bytes_len = (dim as usize) * Dtype::F32.byte_width();
191 Embedding {
192 model: model.into(),
193 dtype: Dtype::F32,
194 dim,
195 vector: bytes::Bytes::from(vec![0u8; bytes_len]),
196 }
197 }
198
199 #[test]
200 fn empty_bucket_round_trips() {
201 let original = EmbeddingBucket::default();
202 let bytes = to_canonical_bytes(&original).unwrap();
203 let decoded: EmbeddingBucket = from_canonical_bytes(&bytes).unwrap();
204 assert_eq!(original, decoded);
205 let bytes2 = to_canonical_bytes(&decoded).unwrap();
206 assert_eq!(bytes, bytes2, "round-trip must be byte-identical");
207 }
208
209 #[test]
210 fn populated_bucket_round_trips() {
211 let mut bucket = EmbeddingBucket::default();
212 bucket.upsert(
213 "openai:text-embedding-3-small".into(),
214 sample_embedding("openai:text-embedding-3-small", 1536),
215 );
216 bucket.upsert(
217 "onnx:all-MiniLM-L6-v2".into(),
218 sample_embedding("onnx:all-MiniLM-L6-v2", 384),
219 );
220 let bytes = to_canonical_bytes(&bucket).unwrap();
221 let decoded: EmbeddingBucket = from_canonical_bytes(&bytes).unwrap();
222 assert_eq!(bucket.entries.len(), decoded.entries.len());
223 // Decoded copy is canonical (sorted by model). Sort the
224 // original before equating because the public API does not
225 // promise input order is preserved.
226 let mut sorted_orig = bucket.entries.clone();
227 sorted_orig.sort_by(|a, b| a.model.cmp(&b.model));
228 assert_eq!(sorted_orig, decoded.entries);
229 }
230
231 #[test]
232 fn wire_form_sorts_by_model_regardless_of_insert_order() {
233 // Insert in z-then-a order; canonical bytes must equal the
234 // alphabetical order's canonical bytes.
235 let mut a = EmbeddingBucket::default();
236 a.upsert("zzz".into(), sample_embedding("zzz", 4));
237 a.upsert("aaa".into(), sample_embedding("aaa", 4));
238 let mut b = EmbeddingBucket::default();
239 b.upsert("aaa".into(), sample_embedding("aaa", 4));
240 b.upsert("zzz".into(), sample_embedding("zzz", 4));
241 assert_eq!(
242 to_canonical_bytes(&a).unwrap(),
243 to_canonical_bytes(&b).unwrap(),
244 "encode must sort entries by model so bucket CIDs are insertion-order-invariant"
245 );
246 }
247
248 #[test]
249 fn wrong_kind_fails_decode() {
250 // Manually craft a CBOR map with `_kind = "node"` and verify
251 // the EmbeddingBucket decoder rejects it. Uses
252 // serde_ipld_dagcbor::to_vec on a small inline struct rather
253 // than hand-encoding bytes.
254 #[derive(Serialize)]
255 struct Wrong {
256 #[serde(rename = "_kind")]
257 kind: String,
258 entries: Vec<EmbeddingEntry>,
259 }
260 let bytes = serde_ipld_dagcbor::to_vec(&Wrong {
261 kind: "node".into(),
262 entries: vec![],
263 })
264 .unwrap();
265 let res: Result<EmbeddingBucket, _> = from_canonical_bytes(&bytes);
266 assert!(res.is_err(), "decode must reject wrong _kind discriminator");
267 let msg = format!("{}", res.unwrap_err());
268 assert!(
269 msg.contains("embedding_bucket"),
270 "error must reference the expected kind; got: {msg}"
271 );
272 }
273
274 #[test]
275 fn upsert_returns_previous_value_on_replace() {
276 let mut bucket = EmbeddingBucket::default();
277 let first = sample_embedding("m", 4);
278 let second = sample_embedding("m", 4);
279 assert_eq!(bucket.upsert("m".into(), first.clone()), None);
280 assert_eq!(bucket.upsert("m".into(), second), Some(first));
281 }
282
283 #[test]
284 fn get_finds_inserted_entry() {
285 let mut bucket = EmbeddingBucket::default();
286 let emb = sample_embedding("m", 4);
287 bucket.upsert("m".into(), emb.clone());
288 assert_eq!(bucket.get("m"), Some(&emb));
289 assert_eq!(bucket.get("missing"), None);
290 }
291
292 #[test]
293 fn remove_removes_existing_entry() {
294 let mut bucket = EmbeddingBucket::default();
295 let emb = sample_embedding("m", 4);
296 bucket.upsert("m".into(), emb.clone());
297 assert_eq!(bucket.remove("m"), Some(emb));
298 assert_eq!(bucket.get("m"), None);
299 assert_eq!(bucket.remove("m"), None);
300 }
301
302 #[test]
303 fn extra_fields_round_trip() {
304 // Forward-compat: a future schema bump adding a sidecar field
305 // (e.g. `provenance`) on a bucket should round-trip through
306 // today's reader. Simulate by manually injecting an `extra`
307 // entry, encoding, and asserting the decoded bucket carries it.
308 let mut bucket = EmbeddingBucket::default();
309 bucket
310 .extra
311 .insert("future_field".into(), Ipld::String("forward-compat".into()));
312 let bytes = to_canonical_bytes(&bucket).unwrap();
313 let decoded: EmbeddingBucket = from_canonical_bytes(&bytes).unwrap();
314 assert_eq!(bucket, decoded, "extra fields must survive round-trip");
315 assert!(decoded.extra.contains_key("future_field"));
316 }
317}