1use std::collections::BTreeMap;
28
29use bytes::Bytes;
30use ipld_core::ipld::Ipld;
31use serde::{Deserialize, Deserializer, Serialize, Serializer};
32
33use crate::error::ObjectError;
34use crate::id::NodeId;
35use crate::sparse::SparseEmbed;
36
37#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
41#[serde(rename_all = "lowercase")]
42#[derive(Default)]
43pub enum Dtype {
44 F16,
46 #[default]
48 F32,
49 F64,
51 I8,
53}
54
55impl Dtype {
56 #[must_use]
58 pub const fn byte_width(self) -> usize {
59 match self {
60 Self::F16 => 2,
61 Self::F32 => 4,
62 Self::F64 => 8,
63 Self::I8 => 1,
64 }
65 }
66}
67
68#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
75pub struct Embedding {
76 pub model: String,
79 #[serde(default)]
81 pub dtype: Dtype,
82 pub dim: u32,
84 pub vector: Bytes,
87}
88
89impl Embedding {
90 pub const fn validate(&self) -> Result<(), ObjectError> {
97 let expected = (self.dim as usize) * self.dtype.byte_width();
98 if self.vector.len() == expected {
99 Ok(())
100 } else {
101 Err(ObjectError::EmbeddingSizeMismatch {
102 expected,
103 got: self.vector.len(),
104 })
105 }
106 }
107}
108
109#[derive(Clone, Debug, PartialEq)]
121pub struct Node {
122 pub id: NodeId,
124 pub ntype: String,
126 pub summary: Option<String>,
132 pub props: BTreeMap<String, Ipld>,
134 pub content: Option<Bytes>,
136 pub sparse_embed: Option<SparseEmbed>,
145 pub context_sentence: Option<String>,
162 pub extra: BTreeMap<String, Ipld>,
166}
167
168impl Node {
169 pub const KIND: &'static str = "node";
171
172 pub const DEFAULT_NTYPE: &'static str = "Node";
178
179 #[must_use]
181 pub fn new(id: NodeId, ntype: impl Into<String>) -> Self {
182 Self {
183 id,
184 ntype: ntype.into(),
185 summary: None,
186 props: BTreeMap::new(),
187 content: None,
188 sparse_embed: None,
189 context_sentence: None,
190 extra: BTreeMap::new(),
191 }
192 }
193
194 #[must_use]
198 pub fn new_default(id: NodeId) -> Self {
199 Self::new(id, Self::DEFAULT_NTYPE)
200 }
201
202 #[must_use]
204 pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
205 self.summary = Some(summary.into());
206 self
207 }
208
209 #[must_use]
211 pub fn with_prop(mut self, key: impl Into<String>, value: impl Into<Ipld>) -> Self {
212 self.props.insert(key.into(), value.into());
213 self
214 }
215
216 #[must_use]
218 pub fn with_content(mut self, content: Bytes) -> Self {
219 self.content = Some(content);
220 self
221 }
222
223 #[must_use]
226 pub fn with_sparse_embed(mut self, sparse_embed: SparseEmbed) -> Self {
227 self.sparse_embed = Some(sparse_embed);
228 self
229 }
230
231 #[must_use]
242 pub fn with_context_sentence(mut self, context: impl Into<String>) -> Self {
243 self.context_sentence = Some(context.into());
244 self
245 }
246
247 #[must_use]
256 pub fn get_str(&self, key: &str) -> Option<&str> {
257 match self.props.get(key)? {
258 Ipld::String(s) => Some(s.as_str()),
259 _ => None,
260 }
261 }
262
263 #[must_use]
265 pub fn get_int(&self, key: &str) -> Option<i128> {
266 match self.props.get(key)? {
267 Ipld::Integer(n) => Some(*n),
268 _ => None,
269 }
270 }
271
272 #[must_use]
274 pub fn get_bool(&self, key: &str) -> Option<bool> {
275 match self.props.get(key)? {
276 Ipld::Bool(b) => Some(*b),
277 _ => None,
278 }
279 }
280
281 #[must_use]
283 pub fn get_float(&self, key: &str) -> Option<f64> {
284 match self.props.get(key)? {
285 Ipld::Float(f) => Some(*f),
286 _ => None,
287 }
288 }
289
290 #[must_use]
292 pub fn get_bytes(&self, key: &str) -> Option<&[u8]> {
293 match self.props.get(key)? {
294 Ipld::Bytes(b) => Some(b.as_slice()),
295 _ => None,
296 }
297 }
298}
299
300#[derive(Serialize, Deserialize)]
309struct NodeWire {
310 #[serde(rename = "_kind")]
311 kind: String,
312 id: NodeId,
313 ntype: String,
314 #[serde(default, skip_serializing_if = "Option::is_none")]
315 summary: Option<String>,
316 props: BTreeMap<String, Ipld>,
317 #[serde(default, skip_serializing_if = "Option::is_none")]
318 content: Option<Bytes>,
319 #[serde(default, skip_serializing_if = "Option::is_none")]
320 sparse_embed: Option<SparseEmbed>,
321 #[serde(default, skip_serializing_if = "Option::is_none")]
322 context_sentence: Option<String>,
323 #[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
327 extra: BTreeMap<String, Ipld>,
328}
329
330impl Serialize for Node {
331 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
332 NodeWire {
333 kind: Self::KIND.into(),
334 id: self.id,
335 ntype: self.ntype.clone(),
336 summary: self.summary.clone(),
337 props: self.props.clone(),
338 content: self.content.clone(),
339 sparse_embed: self.sparse_embed.clone(),
340 context_sentence: self.context_sentence.clone(),
341 extra: self.extra.clone(),
342 }
343 .serialize(serializer)
344 }
345}
346
347impl<'de> Deserialize<'de> for Node {
348 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
349 let wire = NodeWire::deserialize(deserializer)?;
350 if wire.kind != Self::KIND {
351 return Err(serde::de::Error::custom(format!(
352 "expected _kind='{}', got '{}'",
353 Self::KIND,
354 wire.kind
355 )));
356 }
357 Ok(Self {
358 id: wire.id,
359 ntype: wire.ntype,
360 summary: wire.summary,
361 props: wire.props,
362 content: wire.content,
363 sparse_embed: wire.sparse_embed,
364 context_sentence: wire.context_sentence,
365 extra: wire.extra,
366 })
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use crate::codec::{from_canonical_bytes, hash_to_cid, to_canonical_bytes};
374
375 fn alice() -> Node {
376 Node::new(NodeId::from_bytes_raw([1u8; 16]), "Person")
377 .with_prop("name", Ipld::String("Alice".into()))
378 .with_prop("age", Ipld::Integer(30))
379 }
380
381 #[test]
382 fn node_round_trip_byte_identity() {
383 let original = alice();
384 let bytes = to_canonical_bytes(&original).expect("encode");
385 let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
386 assert_eq!(original, decoded);
387 let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
388 assert_eq!(bytes, bytes2);
389 }
390
391 #[test]
392 fn node_cid_is_deterministic() {
393 let a1 = alice();
394 let a2 = alice();
395 let (_, c1) = hash_to_cid(&a1).expect("hash");
396 let (_, c2) = hash_to_cid(&a2).expect("hash");
397 assert_eq!(c1, c2);
398 }
399
400 #[test]
401 fn new_default_uses_default_ntype() {
402 let n = Node::new_default(NodeId::from_bytes_raw([7u8; 16]));
403 assert_eq!(n.ntype, Node::DEFAULT_NTYPE);
404 assert_eq!(n.ntype, "Node");
405 }
406
407 #[test]
408 fn new_default_and_explicit_new_match_when_ntype_equal() {
409 let id = NodeId::from_bytes_raw([9u8; 16]);
412 let default_node = Node::new_default(id);
413 let explicit_node = Node::new(id, Node::DEFAULT_NTYPE);
414 let (_, c_default) = hash_to_cid(&default_node).expect("hash default");
415 let (_, c_explicit) = hash_to_cid(&explicit_node).expect("hash explicit");
416 assert_eq!(c_default, c_explicit);
417 }
418
419 #[test]
420 fn node_kind_rejection() {
421 let wire = NodeWire {
423 kind: "edge".into(),
424 id: NodeId::from_bytes_raw([1u8; 16]),
425 ntype: "x".into(),
426 summary: None,
427 props: BTreeMap::new(),
428 content: None,
429 sparse_embed: None,
430 context_sentence: None,
431 extra: BTreeMap::new(),
432 };
433 let bytes = serde_ipld_dagcbor::to_vec(&wire).expect("encode wire");
434 let err = serde_ipld_dagcbor::from_slice::<Node>(&bytes).unwrap_err();
435 assert!(
436 err.to_string().contains("_kind"),
437 "expected _kind rejection, got: {err}"
438 );
439 }
440
441 #[test]
442 fn node_extra_fields_round_trip() {
443 let mut wire = NodeWire {
445 kind: "node".into(),
446 id: NodeId::from_bytes_raw([2u8; 16]),
447 ntype: "Future".into(),
448 summary: None,
449 props: BTreeMap::new(),
450 content: None,
451 sparse_embed: None,
452 context_sentence: None,
453 extra: BTreeMap::new(),
454 };
455 wire.extra.insert(
456 "x-future-field".into(),
457 Ipld::String("value-from-v99".into()),
458 );
459 let bytes_in = serde_ipld_dagcbor::to_vec(&wire).expect("encode");
460
461 let decoded: Node = serde_ipld_dagcbor::from_slice(&bytes_in).expect("decode");
463 assert_eq!(
464 decoded.extra.get("x-future-field"),
465 Some(&Ipld::String("value-from-v99".into())),
466 );
467
468 let bytes_out = to_canonical_bytes(&decoded).expect("re-encode");
470 assert_eq!(bytes_in, bytes_out);
471 }
472
473 #[test]
474 fn legacy_embed_field_round_trips_through_extra() {
475 #[derive(Serialize)]
486 struct LegacyNodeWire {
487 #[serde(rename = "_kind")]
488 kind: String,
489 id: NodeId,
490 ntype: String,
491 #[serde(skip_serializing_if = "Option::is_none")]
492 summary: Option<String>,
493 props: BTreeMap<String, Ipld>,
494 #[serde(skip_serializing_if = "Option::is_none")]
495 content: Option<Bytes>,
496 embed: Embedding,
497 }
498
499 let legacy = LegacyNodeWire {
500 kind: "node".into(),
501 id: NodeId::from_bytes_raw([42u8; 16]),
502 ntype: "Doc".into(),
503 summary: None,
504 props: BTreeMap::new(),
505 content: None,
506 embed: Embedding {
507 model: "openai:text-embedding-3-small".into(),
508 dtype: Dtype::F32,
509 dim: 2,
510 vector: Bytes::from(vec![
511 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x00, ]),
514 },
515 };
516 let bytes_in = serde_ipld_dagcbor::to_vec(&legacy).expect("encode legacy");
517
518 let decoded: Node = serde_ipld_dagcbor::from_slice(&bytes_in).expect("decode legacy");
521 assert!(
522 decoded.extra.contains_key("embed"),
523 "legacy embed must land in extra"
524 );
525
526 let bytes_out = to_canonical_bytes(&decoded).expect("re-encode");
528 assert_eq!(bytes_in, bytes_out, "legacy bytes must round-trip exactly");
529
530 let (bytes_from_node, cid_from_node) = hash_to_cid(&decoded).expect("hash node");
535 assert_eq!(
536 bytes_in.as_slice(),
537 bytes_from_node.as_ref(),
538 "a future version re-encode must match legacy bytes byte-for-byte"
539 );
540 let cid_via_legacy_bytes = {
541 let mh = crate::id::Multihash::sha2_256(&bytes_in);
542 crate::id::Cid::new(crate::id::CODEC_DAG_CBOR, mh)
543 };
544 assert_eq!(
545 cid_from_node, cid_via_legacy_bytes,
546 "NodeCid via a future version reader must equal NodeCid via legacy bytes"
547 );
548 }
549
550 #[test]
551 fn node_round_trip_with_summary() {
552 let n = Node::new(NodeId::from_bytes_raw([3u8; 16]), "Person")
553 .with_summary("Alice, 30, based in Berlin.")
554 .with_prop("name", Ipld::String("Alice".into()));
555 let bytes = to_canonical_bytes(&n).expect("encode");
556 let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
557 assert_eq!(
558 decoded.summary.as_deref(),
559 Some("Alice, 30, based in Berlin.")
560 );
561 assert_eq!(n, decoded);
562
563 let bare = Node::new(NodeId::from_bytes_raw([3u8; 16]), "Person")
566 .with_prop("name", Ipld::String("Alice".into()));
567 let (_, c_with) = hash_to_cid(&n).expect("hash");
568 let (_, c_without) = hash_to_cid(&bare).expect("hash");
569 assert_ne!(c_with, c_without);
570 }
571
572 #[test]
573 fn node_sparse_embed_round_trips() {
574 let s = crate::sparse::SparseEmbed::new(vec![1, 5, 9], vec![0.5, 0.2, 0.1], "test-vocab")
575 .unwrap();
576 let n = Node::new(NodeId::from_bytes_raw([6u8; 16]), "Doc").with_sparse_embed(s.clone());
577 let bytes = to_canonical_bytes(&n).expect("encode");
578 let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
579 assert_eq!(decoded.sparse_embed.as_ref(), Some(&s));
580 let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
582 assert_eq!(bytes, bytes2);
583 }
584
585 #[test]
586 fn node_context_sentence_round_trips() {
587 let ctx = "This paragraph is from Section 3 of the 2024 lease.";
588 let n = Node::new(NodeId::from_bytes_raw([9u8; 16]), "Paragraph")
589 .with_summary("The tenant shall maintain the premises...")
590 .with_context_sentence(ctx);
591 let bytes = to_canonical_bytes(&n).expect("encode");
592 let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
593 assert_eq!(decoded.context_sentence.as_deref(), Some(ctx));
594 let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
595 assert_eq!(bytes, bytes2);
596 }
597
598 #[test]
599 fn node_context_sentence_absent_not_emitted() {
600 let n = Node::new(NodeId::from_bytes_raw([10u8; 16]), "Plain");
603 let bytes = to_canonical_bytes(&n).expect("encode");
604 assert!(
605 !bytes.windows(16).any(|w| w == b"context_sentence"),
606 "absent context_sentence should not appear on the wire"
607 );
608 }
609
610 #[test]
611 fn node_context_sentence_participates_in_cid() {
612 let base = Node::new(NodeId::from_bytes_raw([11u8; 16]), "P").with_summary("x");
613 let with_ctx = base.clone().with_context_sentence("cue");
614 let (_, c1) = hash_to_cid(&base).unwrap();
615 let (_, c2) = hash_to_cid(&with_ctx).unwrap();
616 assert_ne!(c1, c2, "context_sentence must participate in the CID");
617 }
618
619 #[test]
620 fn node_sparse_embed_absent_not_emitted() {
621 let n = Node::new(NodeId::from_bytes_raw([7u8; 16]), "Thing");
625 let bytes = to_canonical_bytes(&n).expect("encode");
626 assert!(
627 !bytes.windows(12).any(|w| w == b"sparse_embed"),
628 "absent sparse_embed should not appear on the wire"
629 );
630 }
631
632 #[test]
633 fn node_sparse_embed_participates_in_cid() {
634 let s = crate::sparse::SparseEmbed::new(vec![1], vec![1.0], "v").unwrap();
637 let n_with = Node::new(NodeId::from_bytes_raw([8u8; 16]), "Doc").with_sparse_embed(s);
638 let n_without = Node::new(NodeId::from_bytes_raw([8u8; 16]), "Doc");
639 let (_, c_with) = hash_to_cid(&n_with).unwrap();
640 let (_, c_without) = hash_to_cid(&n_without).unwrap();
641 assert_ne!(c_with, c_without);
642 }
643
644 #[test]
645 fn node_summary_absent_not_emitted() {
646 let n = Node::new(NodeId::from_bytes_raw([4u8; 16]), "Thing");
650 let bytes = to_canonical_bytes(&n).expect("encode");
651 assert!(
652 !bytes.windows(7).any(|w| w == b"summary"),
653 "absent summary should not appear on the wire"
654 );
655 }
656
657 #[test]
658 fn embedding_validate_ok_and_err() {
659 let ok = Embedding {
660 model: "m".into(),
661 dtype: Dtype::F32,
662 dim: 4,
663 vector: Bytes::from(vec![0u8; 16]),
664 };
665 ok.validate().unwrap();
666
667 let bad = Embedding {
668 model: "m".into(),
669 dtype: Dtype::F32,
670 dim: 4,
671 vector: Bytes::from(vec![0u8; 10]),
672 };
673 let err = bad.validate().unwrap_err();
674 match err {
675 ObjectError::EmbeddingSizeMismatch { expected, got } => {
676 assert_eq!(expected, 16);
677 assert_eq!(got, 10);
678 }
679 e => panic!("wrong variant: {e:?}"),
680 }
681 }
682}