use std::collections::BTreeMap;
use ipld_core::ipld::Ipld;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use super::node::Embedding;
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct EmbeddingBucket {
pub entries: Vec<EmbeddingEntry>,
pub extra: BTreeMap<String, Ipld>,
}
impl EmbeddingBucket {
pub const KIND: &'static str = "embedding_bucket";
#[must_use]
pub fn get(&self, model: &str) -> Option<&Embedding> {
self.entries
.iter()
.find(|e| e.model == model)
.map(|e| &e.embedding)
}
pub fn upsert(&mut self, model: String, embedding: Embedding) -> Option<Embedding> {
if let Some(slot) = self.entries.iter_mut().find(|e| e.model == model) {
return Some(std::mem::replace(&mut slot.embedding, embedding));
}
self.entries.push(EmbeddingEntry { model, embedding });
None
}
pub fn remove(&mut self, model: &str) -> Option<Embedding> {
let i = self.entries.iter().position(|e| e.model == model)?;
Some(self.entries.remove(i).embedding)
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct EmbeddingEntry {
pub model: String,
pub embedding: Embedding,
}
#[derive(Serialize, Deserialize)]
struct EmbeddingBucketWire {
#[serde(rename = "_kind")]
kind: String,
#[serde(default)]
entries: Vec<EmbeddingEntry>,
#[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
extra: BTreeMap<String, Ipld>,
}
impl Serialize for EmbeddingBucket {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut sorted = self.entries.clone();
sorted.sort_by(|a, b| a.model.cmp(&b.model));
EmbeddingBucketWire {
kind: Self::KIND.into(),
entries: sorted,
extra: self.extra.clone(),
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for EmbeddingBucket {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let w = EmbeddingBucketWire::deserialize(deserializer)?;
if w.kind != Self::KIND {
return Err(serde::de::Error::custom(format!(
"expected _kind='{}', got '{}'",
Self::KIND,
w.kind
)));
}
Ok(Self {
entries: w.entries,
extra: w.extra,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{from_canonical_bytes, to_canonical_bytes};
use crate::objects::node::Dtype;
fn sample_embedding(model: &str, dim: u32) -> Embedding {
let bytes_len = (dim as usize) * Dtype::F32.byte_width();
Embedding {
model: model.into(),
dtype: Dtype::F32,
dim,
vector: bytes::Bytes::from(vec![0u8; bytes_len]),
}
}
#[test]
fn empty_bucket_round_trips() {
let original = EmbeddingBucket::default();
let bytes = to_canonical_bytes(&original).unwrap();
let decoded: EmbeddingBucket = from_canonical_bytes(&bytes).unwrap();
assert_eq!(original, decoded);
let bytes2 = to_canonical_bytes(&decoded).unwrap();
assert_eq!(bytes, bytes2, "round-trip must be byte-identical");
}
#[test]
fn populated_bucket_round_trips() {
let mut bucket = EmbeddingBucket::default();
bucket.upsert(
"openai:text-embedding-3-small".into(),
sample_embedding("openai:text-embedding-3-small", 1536),
);
bucket.upsert(
"onnx:all-MiniLM-L6-v2".into(),
sample_embedding("onnx:all-MiniLM-L6-v2", 384),
);
let bytes = to_canonical_bytes(&bucket).unwrap();
let decoded: EmbeddingBucket = from_canonical_bytes(&bytes).unwrap();
assert_eq!(bucket.entries.len(), decoded.entries.len());
let mut sorted_orig = bucket.entries.clone();
sorted_orig.sort_by(|a, b| a.model.cmp(&b.model));
assert_eq!(sorted_orig, decoded.entries);
}
#[test]
fn wire_form_sorts_by_model_regardless_of_insert_order() {
let mut a = EmbeddingBucket::default();
a.upsert("zzz".into(), sample_embedding("zzz", 4));
a.upsert("aaa".into(), sample_embedding("aaa", 4));
let mut b = EmbeddingBucket::default();
b.upsert("aaa".into(), sample_embedding("aaa", 4));
b.upsert("zzz".into(), sample_embedding("zzz", 4));
assert_eq!(
to_canonical_bytes(&a).unwrap(),
to_canonical_bytes(&b).unwrap(),
"encode must sort entries by model so bucket CIDs are insertion-order-invariant"
);
}
#[test]
fn wrong_kind_fails_decode() {
#[derive(Serialize)]
struct Wrong {
#[serde(rename = "_kind")]
kind: String,
entries: Vec<EmbeddingEntry>,
}
let bytes = serde_ipld_dagcbor::to_vec(&Wrong {
kind: "node".into(),
entries: vec![],
})
.unwrap();
let res: Result<EmbeddingBucket, _> = from_canonical_bytes(&bytes);
assert!(res.is_err(), "decode must reject wrong _kind discriminator");
let msg = format!("{}", res.unwrap_err());
assert!(
msg.contains("embedding_bucket"),
"error must reference the expected kind; got: {msg}"
);
}
#[test]
fn upsert_returns_previous_value_on_replace() {
let mut bucket = EmbeddingBucket::default();
let first = sample_embedding("m", 4);
let second = sample_embedding("m", 4);
assert_eq!(bucket.upsert("m".into(), first.clone()), None);
assert_eq!(bucket.upsert("m".into(), second), Some(first));
}
#[test]
fn get_finds_inserted_entry() {
let mut bucket = EmbeddingBucket::default();
let emb = sample_embedding("m", 4);
bucket.upsert("m".into(), emb.clone());
assert_eq!(bucket.get("m"), Some(&emb));
assert_eq!(bucket.get("missing"), None);
}
#[test]
fn remove_removes_existing_entry() {
let mut bucket = EmbeddingBucket::default();
let emb = sample_embedding("m", 4);
bucket.upsert("m".into(), emb.clone());
assert_eq!(bucket.remove("m"), Some(emb));
assert_eq!(bucket.get("m"), None);
assert_eq!(bucket.remove("m"), None);
}
#[test]
fn extra_fields_round_trip() {
let mut bucket = EmbeddingBucket::default();
bucket
.extra
.insert("future_field".into(), Ipld::String("forward-compat".into()));
let bytes = to_canonical_bytes(&bucket).unwrap();
let decoded: EmbeddingBucket = from_canonical_bytes(&bytes).unwrap();
assert_eq!(bucket, decoded, "extra fields must survive round-trip");
assert!(decoded.extra.contains_key("future_field"));
}
}