use std::collections::BTreeMap;
use bytes::Bytes;
use ipld_core::ipld::Ipld;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::error::ObjectError;
use crate::id::NodeId;
use crate::sparse::SparseEmbed;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
#[derive(Default)]
pub enum Dtype {
F16,
#[default]
F32,
F64,
I8,
}
impl Dtype {
#[must_use]
pub const fn byte_width(self) -> usize {
match self {
Self::F16 => 2,
Self::F32 => 4,
Self::F64 => 8,
Self::I8 => 1,
}
}
}
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct Embedding {
pub model: String,
#[serde(default)]
pub dtype: Dtype,
pub dim: u32,
pub vector: Bytes,
}
impl Embedding {
pub const fn validate(&self) -> Result<(), ObjectError> {
let expected = (self.dim as usize) * self.dtype.byte_width();
if self.vector.len() == expected {
Ok(())
} else {
Err(ObjectError::EmbeddingSizeMismatch {
expected,
got: self.vector.len(),
})
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub struct Node {
pub id: NodeId,
pub ntype: String,
pub summary: Option<String>,
pub props: BTreeMap<String, Ipld>,
pub content: Option<Bytes>,
pub sparse_embed: Option<SparseEmbed>,
pub context_sentence: Option<String>,
pub extra: BTreeMap<String, Ipld>,
}
impl Node {
pub const KIND: &'static str = "node";
pub const DEFAULT_NTYPE: &'static str = "Node";
#[must_use]
pub fn new(id: NodeId, ntype: impl Into<String>) -> Self {
Self {
id,
ntype: ntype.into(),
summary: None,
props: BTreeMap::new(),
content: None,
sparse_embed: None,
context_sentence: None,
extra: BTreeMap::new(),
}
}
#[must_use]
pub fn new_default(id: NodeId) -> Self {
Self::new(id, Self::DEFAULT_NTYPE)
}
#[must_use]
pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
self.summary = Some(summary.into());
self
}
#[must_use]
pub fn with_prop(mut self, key: impl Into<String>, value: impl Into<Ipld>) -> Self {
self.props.insert(key.into(), value.into());
self
}
#[must_use]
pub fn with_content(mut self, content: Bytes) -> Self {
self.content = Some(content);
self
}
#[must_use]
pub fn with_sparse_embed(mut self, sparse_embed: SparseEmbed) -> Self {
self.sparse_embed = Some(sparse_embed);
self
}
#[must_use]
pub fn with_context_sentence(mut self, context: impl Into<String>) -> Self {
self.context_sentence = Some(context.into());
self
}
#[must_use]
pub fn get_str(&self, key: &str) -> Option<&str> {
match self.props.get(key)? {
Ipld::String(s) => Some(s.as_str()),
_ => None,
}
}
#[must_use]
pub fn get_int(&self, key: &str) -> Option<i128> {
match self.props.get(key)? {
Ipld::Integer(n) => Some(*n),
_ => None,
}
}
#[must_use]
pub fn get_bool(&self, key: &str) -> Option<bool> {
match self.props.get(key)? {
Ipld::Bool(b) => Some(*b),
_ => None,
}
}
#[must_use]
pub fn get_float(&self, key: &str) -> Option<f64> {
match self.props.get(key)? {
Ipld::Float(f) => Some(*f),
_ => None,
}
}
#[must_use]
pub fn get_bytes(&self, key: &str) -> Option<&[u8]> {
match self.props.get(key)? {
Ipld::Bytes(b) => Some(b.as_slice()),
_ => None,
}
}
}
#[derive(Serialize, Deserialize)]
struct NodeWire {
#[serde(rename = "_kind")]
kind: String,
id: NodeId,
ntype: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
summary: Option<String>,
props: BTreeMap<String, Ipld>,
#[serde(default, skip_serializing_if = "Option::is_none")]
content: Option<Bytes>,
#[serde(default, skip_serializing_if = "Option::is_none")]
sparse_embed: Option<SparseEmbed>,
#[serde(default, skip_serializing_if = "Option::is_none")]
context_sentence: Option<String>,
#[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
extra: BTreeMap<String, Ipld>,
}
impl Serialize for Node {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
NodeWire {
kind: Self::KIND.into(),
id: self.id,
ntype: self.ntype.clone(),
summary: self.summary.clone(),
props: self.props.clone(),
content: self.content.clone(),
sparse_embed: self.sparse_embed.clone(),
context_sentence: self.context_sentence.clone(),
extra: self.extra.clone(),
}
.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Node {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
let wire = NodeWire::deserialize(deserializer)?;
if wire.kind != Self::KIND {
return Err(serde::de::Error::custom(format!(
"expected _kind='{}', got '{}'",
Self::KIND,
wire.kind
)));
}
Ok(Self {
id: wire.id,
ntype: wire.ntype,
summary: wire.summary,
props: wire.props,
content: wire.content,
sparse_embed: wire.sparse_embed,
context_sentence: wire.context_sentence,
extra: wire.extra,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::{from_canonical_bytes, hash_to_cid, to_canonical_bytes};
fn alice() -> Node {
Node::new(NodeId::from_bytes_raw([1u8; 16]), "Person")
.with_prop("name", Ipld::String("Alice".into()))
.with_prop("age", Ipld::Integer(30))
}
#[test]
fn node_round_trip_byte_identity() {
let original = alice();
let bytes = to_canonical_bytes(&original).expect("encode");
let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
assert_eq!(original, decoded);
let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
assert_eq!(bytes, bytes2);
}
#[test]
fn node_cid_is_deterministic() {
let a1 = alice();
let a2 = alice();
let (_, c1) = hash_to_cid(&a1).expect("hash");
let (_, c2) = hash_to_cid(&a2).expect("hash");
assert_eq!(c1, c2);
}
#[test]
fn new_default_uses_default_ntype() {
let n = Node::new_default(NodeId::from_bytes_raw([7u8; 16]));
assert_eq!(n.ntype, Node::DEFAULT_NTYPE);
assert_eq!(n.ntype, "Node");
}
#[test]
fn new_default_and_explicit_new_match_when_ntype_equal() {
let id = NodeId::from_bytes_raw([9u8; 16]);
let default_node = Node::new_default(id);
let explicit_node = Node::new(id, Node::DEFAULT_NTYPE);
let (_, c_default) = hash_to_cid(&default_node).expect("hash default");
let (_, c_explicit) = hash_to_cid(&explicit_node).expect("hash explicit");
assert_eq!(c_default, c_explicit);
}
#[test]
fn node_kind_rejection() {
let wire = NodeWire {
kind: "edge".into(),
id: NodeId::from_bytes_raw([1u8; 16]),
ntype: "x".into(),
summary: None,
props: BTreeMap::new(),
content: None,
sparse_embed: None,
context_sentence: None,
extra: BTreeMap::new(),
};
let bytes = serde_ipld_dagcbor::to_vec(&wire).expect("encode wire");
let err = serde_ipld_dagcbor::from_slice::<Node>(&bytes).unwrap_err();
assert!(
err.to_string().contains("_kind"),
"expected _kind rejection, got: {err}"
);
}
#[test]
fn node_extra_fields_round_trip() {
let mut wire = NodeWire {
kind: "node".into(),
id: NodeId::from_bytes_raw([2u8; 16]),
ntype: "Future".into(),
summary: None,
props: BTreeMap::new(),
content: None,
sparse_embed: None,
context_sentence: None,
extra: BTreeMap::new(),
};
wire.extra.insert(
"x-future-field".into(),
Ipld::String("value-from-v99".into()),
);
let bytes_in = serde_ipld_dagcbor::to_vec(&wire).expect("encode");
let decoded: Node = serde_ipld_dagcbor::from_slice(&bytes_in).expect("decode");
assert_eq!(
decoded.extra.get("x-future-field"),
Some(&Ipld::String("value-from-v99".into())),
);
let bytes_out = to_canonical_bytes(&decoded).expect("re-encode");
assert_eq!(bytes_in, bytes_out);
}
#[test]
fn legacy_embed_field_round_trips_through_extra() {
#[derive(Serialize)]
struct LegacyNodeWire {
#[serde(rename = "_kind")]
kind: String,
id: NodeId,
ntype: String,
#[serde(skip_serializing_if = "Option::is_none")]
summary: Option<String>,
props: BTreeMap<String, Ipld>,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<Bytes>,
embed: Embedding,
}
let legacy = LegacyNodeWire {
kind: "node".into(),
id: NodeId::from_bytes_raw([42u8; 16]),
ntype: "Doc".into(),
summary: None,
props: BTreeMap::new(),
content: None,
embed: Embedding {
model: "openai:text-embedding-3-small".into(),
dtype: Dtype::F32,
dim: 2,
vector: Bytes::from(vec![
0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x00, ]),
},
};
let bytes_in = serde_ipld_dagcbor::to_vec(&legacy).expect("encode legacy");
let decoded: Node = serde_ipld_dagcbor::from_slice(&bytes_in).expect("decode legacy");
assert!(
decoded.extra.contains_key("embed"),
"legacy embed must land in extra"
);
let bytes_out = to_canonical_bytes(&decoded).expect("re-encode");
assert_eq!(bytes_in, bytes_out, "legacy bytes must round-trip exactly");
let (bytes_from_node, cid_from_node) = hash_to_cid(&decoded).expect("hash node");
assert_eq!(
bytes_in.as_slice(),
bytes_from_node.as_ref(),
"a future version re-encode must match legacy bytes byte-for-byte"
);
let cid_via_legacy_bytes = {
let mh = crate::id::Multihash::sha2_256(&bytes_in);
crate::id::Cid::new(crate::id::CODEC_DAG_CBOR, mh)
};
assert_eq!(
cid_from_node, cid_via_legacy_bytes,
"NodeCid via a future version reader must equal NodeCid via legacy bytes"
);
}
#[test]
fn node_round_trip_with_summary() {
let n = Node::new(NodeId::from_bytes_raw([3u8; 16]), "Person")
.with_summary("Alice, 30, based in Berlin.")
.with_prop("name", Ipld::String("Alice".into()));
let bytes = to_canonical_bytes(&n).expect("encode");
let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
assert_eq!(
decoded.summary.as_deref(),
Some("Alice, 30, based in Berlin.")
);
assert_eq!(n, decoded);
let bare = Node::new(NodeId::from_bytes_raw([3u8; 16]), "Person")
.with_prop("name", Ipld::String("Alice".into()));
let (_, c_with) = hash_to_cid(&n).expect("hash");
let (_, c_without) = hash_to_cid(&bare).expect("hash");
assert_ne!(c_with, c_without);
}
#[test]
fn node_sparse_embed_round_trips() {
let s = crate::sparse::SparseEmbed::new(vec![1, 5, 9], vec![0.5, 0.2, 0.1], "test-vocab")
.unwrap();
let n = Node::new(NodeId::from_bytes_raw([6u8; 16]), "Doc").with_sparse_embed(s.clone());
let bytes = to_canonical_bytes(&n).expect("encode");
let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
assert_eq!(decoded.sparse_embed.as_ref(), Some(&s));
let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
assert_eq!(bytes, bytes2);
}
#[test]
fn node_context_sentence_round_trips() {
let ctx = "This paragraph is from Section 3 of the 2024 lease.";
let n = Node::new(NodeId::from_bytes_raw([9u8; 16]), "Paragraph")
.with_summary("The tenant shall maintain the premises...")
.with_context_sentence(ctx);
let bytes = to_canonical_bytes(&n).expect("encode");
let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
assert_eq!(decoded.context_sentence.as_deref(), Some(ctx));
let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
assert_eq!(bytes, bytes2);
}
#[test]
fn node_context_sentence_absent_not_emitted() {
let n = Node::new(NodeId::from_bytes_raw([10u8; 16]), "Plain");
let bytes = to_canonical_bytes(&n).expect("encode");
assert!(
!bytes.windows(16).any(|w| w == b"context_sentence"),
"absent context_sentence should not appear on the wire"
);
}
#[test]
fn node_context_sentence_participates_in_cid() {
let base = Node::new(NodeId::from_bytes_raw([11u8; 16]), "P").with_summary("x");
let with_ctx = base.clone().with_context_sentence("cue");
let (_, c1) = hash_to_cid(&base).unwrap();
let (_, c2) = hash_to_cid(&with_ctx).unwrap();
assert_ne!(c1, c2, "context_sentence must participate in the CID");
}
#[test]
fn node_sparse_embed_absent_not_emitted() {
let n = Node::new(NodeId::from_bytes_raw([7u8; 16]), "Thing");
let bytes = to_canonical_bytes(&n).expect("encode");
assert!(
!bytes.windows(12).any(|w| w == b"sparse_embed"),
"absent sparse_embed should not appear on the wire"
);
}
#[test]
fn node_sparse_embed_participates_in_cid() {
let s = crate::sparse::SparseEmbed::new(vec![1], vec![1.0], "v").unwrap();
let n_with = Node::new(NodeId::from_bytes_raw([8u8; 16]), "Doc").with_sparse_embed(s);
let n_without = Node::new(NodeId::from_bytes_raw([8u8; 16]), "Doc");
let (_, c_with) = hash_to_cid(&n_with).unwrap();
let (_, c_without) = hash_to_cid(&n_without).unwrap();
assert_ne!(c_with, c_without);
}
#[test]
fn node_summary_absent_not_emitted() {
let n = Node::new(NodeId::from_bytes_raw([4u8; 16]), "Thing");
let bytes = to_canonical_bytes(&n).expect("encode");
assert!(
!bytes.windows(7).any(|w| w == b"summary"),
"absent summary should not appear on the wire"
);
}
#[test]
fn embedding_validate_ok_and_err() {
let ok = Embedding {
model: "m".into(),
dtype: Dtype::F32,
dim: 4,
vector: Bytes::from(vec![0u8; 16]),
};
ok.validate().unwrap();
let bad = Embedding {
model: "m".into(),
dtype: Dtype::F32,
dim: 4,
vector: Bytes::from(vec![0u8; 10]),
};
let err = bad.validate().unwrap_err();
match err {
ObjectError::EmbeddingSizeMismatch { expected, got } => {
assert_eq!(expected, 16);
assert_eq!(got, 10);
}
e => panic!("wrong variant: {e:?}"),
}
}
}