use std::collections::BTreeMap;
use ipld_core::ipld::Ipld;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::sparse::SparseEmbed;
#[derive(Clone, Debug, Default)]
pub struct SparseBucket {
pub entries: Vec<SparseEntry>,
pub extra: BTreeMap<String, Ipld>,
}
impl SparseBucket {
pub const KIND: &'static str = "sparse_bucket";
#[must_use]
pub fn get(&self, vocab_id: &str) -> Option<&SparseEmbed> {
self.entries
.iter()
.find(|e| e.vocab_id == vocab_id)
.map(|e| &e.sparse)
}
pub fn upsert(&mut self, vocab_id: String, sparse: SparseEmbed) {
if let Some(slot) = self.entries.iter_mut().find(|e| e.vocab_id == vocab_id) {
slot.sparse = sparse;
return;
}
self.entries.push(SparseEntry { vocab_id, sparse });
}
pub fn remove(&mut self, vocab_id: &str) {
if let Some(i) = self.entries.iter().position(|e| e.vocab_id == vocab_id) {
self.entries.remove(i);
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SparseEntry {
pub vocab_id: String,
pub sparse: SparseEmbed,
}
#[derive(Serialize, Deserialize)]
struct SparseBucketWire {
#[serde(rename = "_kind")]
kind: String,
#[serde(default)]
entries: Vec<SparseEntry>,
#[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
extra: BTreeMap<String, Ipld>,
}
impl Serialize for SparseBucket {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
let mut sorted = self.entries.clone();
sorted.sort_by(|a, b| a.vocab_id.cmp(&b.vocab_id));
SparseBucketWire {
kind: Self::KIND.into(),
entries: sorted,
extra: self.extra.clone(),
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for SparseBucket {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let w = SparseBucketWire::deserialize(deserializer)?;
if w.kind != Self::KIND {
return Err(serde::de::Error::custom(format!(
"expected _kind='{}', got '{}'",
Self::KIND,
w.kind
)));
}
Ok(Self {
entries: w.entries,
extra: w.extra,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{from_canonical_bytes, to_canonical_bytes};
use crate::sparse::SparseEmbed;
fn sample_sparse(vocab_id: &str) -> SparseEmbed {
SparseEmbed::new(vec![1, 5, 9], vec![0.5, 0.2, 0.1], vocab_id).unwrap()
}
#[test]
fn empty_bucket_round_trips() {
let original = SparseBucket::default();
let bytes = to_canonical_bytes(&original).unwrap();
let decoded: SparseBucket = from_canonical_bytes(&bytes).unwrap();
assert_eq!(original.entries.len(), decoded.entries.len());
let bytes2 = to_canonical_bytes(&decoded).unwrap();
assert_eq!(bytes, bytes2, "round-trip must be byte-identical");
}
#[test]
fn populated_bucket_round_trips() {
let mut bucket = SparseBucket::default();
bucket.upsert("bge-m3".into(), sample_sparse("bge-m3"));
bucket.upsert(
"opensearch-distill".into(),
sample_sparse("opensearch-distill"),
);
let bytes = to_canonical_bytes(&bucket).unwrap();
let decoded: SparseBucket = from_canonical_bytes(&bytes).unwrap();
assert_eq!(bucket.entries.len(), decoded.entries.len());
}
#[test]
fn wire_form_sorts_by_vocab_id_regardless_of_insert_order() {
let mut a = SparseBucket::default();
a.upsert("zzz".into(), sample_sparse("zzz"));
a.upsert("aaa".into(), sample_sparse("aaa"));
let mut b = SparseBucket::default();
b.upsert("aaa".into(), sample_sparse("aaa"));
b.upsert("zzz".into(), sample_sparse("zzz"));
assert_eq!(
to_canonical_bytes(&a).unwrap(),
to_canonical_bytes(&b).unwrap(),
"encode must sort entries by vocab_id so bucket CIDs are insertion-order-invariant"
);
}
#[test]
fn wrong_kind_fails_decode() {
#[derive(Serialize)]
struct Wrong {
#[serde(rename = "_kind")]
kind: String,
entries: Vec<SparseEntry>,
}
let bytes = serde_ipld_dagcbor::to_vec(&Wrong {
kind: "node".into(),
entries: vec![],
})
.unwrap();
let res: Result<SparseBucket, _> = from_canonical_bytes(&bytes);
assert!(res.is_err(), "decode must reject wrong _kind discriminator");
let msg = format!("{}", res.unwrap_err());
assert!(
msg.contains("sparse_bucket"),
"error must reference the expected kind; got: {msg}"
);
}
#[test]
fn upsert_overwrites_existing_entry() {
let mut bucket = SparseBucket::default();
bucket.upsert("v0".into(), sample_sparse("v0"));
let new_sparse = SparseEmbed::new(vec![100], vec![0.9], "v0").unwrap();
bucket.upsert("v0".into(), new_sparse);
assert_eq!(bucket.entries.len(), 1);
assert_eq!(bucket.get("v0").unwrap().indices, vec![100]);
}
#[test]
fn get_finds_inserted_entry() {
let mut bucket = SparseBucket::default();
let sp = sample_sparse("v0");
bucket.upsert("v0".into(), sp.clone());
assert_eq!(bucket.get("v0").unwrap().vocab_id, "v0");
assert!(bucket.get("missing").is_none());
}
#[test]
fn extra_fields_round_trip() {
let mut bucket = SparseBucket::default();
bucket
.extra
.insert("future_field".into(), Ipld::String("forward-compat".into()));
let bytes = to_canonical_bytes(&bucket).unwrap();
let decoded: SparseBucket = from_canonical_bytes(&bytes).unwrap();
assert_eq!(bucket.extra.len(), decoded.extra.len());
assert!(decoded.extra.contains_key("future_field"));
}
}