1use std::collections::BTreeMap;
28
29use bytes::Bytes;
30use ipld_core::ipld::Ipld;
31use serde::{Deserialize, Deserializer, Serialize, Serializer};
32
33use crate::error::ObjectError;
34use crate::id::NodeId;
35
36#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
40#[serde(rename_all = "lowercase")]
41#[derive(Default)]
42pub enum Dtype {
43 F16,
45 #[default]
47 F32,
48 F64,
50 I8,
52}
53
54impl Dtype {
55 #[must_use]
57 pub const fn byte_width(self) -> usize {
58 match self {
59 Self::F16 => 2,
60 Self::F32 => 4,
61 Self::F64 => 8,
62 Self::I8 => 1,
63 }
64 }
65}
66
67#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
74pub struct Embedding {
75 pub model: String,
78 #[serde(default)]
80 pub dtype: Dtype,
81 pub dim: u32,
83 pub vector: Bytes,
86}
87
88impl Embedding {
89 pub const fn validate(&self) -> Result<(), ObjectError> {
96 let expected = (self.dim as usize) * self.dtype.byte_width();
97 if self.vector.len() == expected {
98 Ok(())
99 } else {
100 Err(ObjectError::EmbeddingSizeMismatch {
101 expected,
102 got: self.vector.len(),
103 })
104 }
105 }
106}
107
108#[derive(Clone, Debug, PartialEq, Eq)]
122pub struct Node {
123 pub id: NodeId,
125 pub ntype: String,
127 pub summary: Option<String>,
133 pub props: BTreeMap<String, Ipld>,
135 pub content: Option<Bytes>,
137 pub context_sentence: Option<String>,
154 pub extra: BTreeMap<String, Ipld>,
159}
160
161impl Node {
162 pub const KIND: &'static str = "node";
164
165 pub const DEFAULT_NTYPE: &'static str = "Node";
171
172 #[must_use]
174 pub fn new(id: NodeId, ntype: impl Into<String>) -> Self {
175 Self {
176 id,
177 ntype: ntype.into(),
178 summary: None,
179 props: BTreeMap::new(),
180 content: None,
181 context_sentence: None,
182 extra: BTreeMap::new(),
183 }
184 }
185
186 #[must_use]
190 pub fn new_default(id: NodeId) -> Self {
191 Self::new(id, Self::DEFAULT_NTYPE)
192 }
193
194 #[must_use]
196 pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
197 self.summary = Some(summary.into());
198 self
199 }
200
201 #[must_use]
203 pub fn with_prop(mut self, key: impl Into<String>, value: impl Into<Ipld>) -> Self {
204 self.props.insert(key.into(), value.into());
205 self
206 }
207
208 #[must_use]
210 pub fn with_content(mut self, content: Bytes) -> Self {
211 self.content = Some(content);
212 self
213 }
214
215 #[must_use]
226 pub fn with_context_sentence(mut self, context: impl Into<String>) -> Self {
227 self.context_sentence = Some(context.into());
228 self
229 }
230
231 #[must_use]
240 pub fn get_str(&self, key: &str) -> Option<&str> {
241 match self.props.get(key)? {
242 Ipld::String(s) => Some(s.as_str()),
243 _ => None,
244 }
245 }
246
247 #[must_use]
249 pub fn get_int(&self, key: &str) -> Option<i128> {
250 match self.props.get(key)? {
251 Ipld::Integer(n) => Some(*n),
252 _ => None,
253 }
254 }
255
256 #[must_use]
258 pub fn get_bool(&self, key: &str) -> Option<bool> {
259 match self.props.get(key)? {
260 Ipld::Bool(b) => Some(*b),
261 _ => None,
262 }
263 }
264
265 #[must_use]
267 pub fn get_float(&self, key: &str) -> Option<f64> {
268 match self.props.get(key)? {
269 Ipld::Float(f) => Some(*f),
270 _ => None,
271 }
272 }
273
274 #[must_use]
276 pub fn get_bytes(&self, key: &str) -> Option<&[u8]> {
277 match self.props.get(key)? {
278 Ipld::Bytes(b) => Some(b.as_slice()),
279 _ => None,
280 }
281 }
282}
283
284#[derive(Serialize, Deserialize)]
293struct NodeWire {
294 #[serde(rename = "_kind")]
295 kind: String,
296 id: NodeId,
297 ntype: String,
298 #[serde(default, skip_serializing_if = "Option::is_none")]
299 summary: Option<String>,
300 props: BTreeMap<String, Ipld>,
301 #[serde(default, skip_serializing_if = "Option::is_none")]
302 content: Option<Bytes>,
303 #[serde(default, skip_serializing_if = "Option::is_none")]
304 context_sentence: Option<String>,
305 #[serde(flatten, default, skip_serializing_if = "BTreeMap::is_empty")]
309 extra: BTreeMap<String, Ipld>,
310}
311
312impl Serialize for Node {
313 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
314 NodeWire {
315 kind: Self::KIND.into(),
316 id: self.id,
317 ntype: self.ntype.clone(),
318 summary: self.summary.clone(),
319 props: self.props.clone(),
320 content: self.content.clone(),
321 context_sentence: self.context_sentence.clone(),
322 extra: self.extra.clone(),
323 }
324 .serialize(serializer)
325 }
326}
327
328impl<'de> Deserialize<'de> for Node {
329 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
330 let wire = NodeWire::deserialize(deserializer)?;
331 if wire.kind != Self::KIND {
332 return Err(serde::de::Error::custom(format!(
333 "expected _kind='{}', got '{}'",
334 Self::KIND,
335 wire.kind
336 )));
337 }
338 Ok(Self {
339 id: wire.id,
340 ntype: wire.ntype,
341 summary: wire.summary,
342 props: wire.props,
343 content: wire.content,
344 context_sentence: wire.context_sentence,
345 extra: wire.extra,
346 })
347 }
348}
349
350#[cfg(test)]
351mod tests {
352 use super::*;
353 use crate::codec::{from_canonical_bytes, hash_to_cid, to_canonical_bytes};
354
355 fn alice() -> Node {
356 Node::new(NodeId::from_bytes_raw([1u8; 16]), "Person")
357 .with_prop("name", Ipld::String("Alice".into()))
358 .with_prop("age", Ipld::Integer(30))
359 }
360
361 #[test]
362 fn node_round_trip_byte_identity() {
363 let original = alice();
364 let bytes = to_canonical_bytes(&original).expect("encode");
365 let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
366 assert_eq!(original, decoded);
367 let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
368 assert_eq!(bytes, bytes2);
369 }
370
371 #[test]
372 fn node_cid_is_deterministic() {
373 let a1 = alice();
374 let a2 = alice();
375 let (_, c1) = hash_to_cid(&a1).expect("hash");
376 let (_, c2) = hash_to_cid(&a2).expect("hash");
377 assert_eq!(c1, c2);
378 }
379
380 #[test]
381 fn new_default_uses_default_ntype() {
382 let n = Node::new_default(NodeId::from_bytes_raw([7u8; 16]));
383 assert_eq!(n.ntype, Node::DEFAULT_NTYPE);
384 assert_eq!(n.ntype, "Node");
385 }
386
387 #[test]
388 fn new_default_and_explicit_new_match_when_ntype_equal() {
389 let id = NodeId::from_bytes_raw([9u8; 16]);
392 let default_node = Node::new_default(id);
393 let explicit_node = Node::new(id, Node::DEFAULT_NTYPE);
394 let (_, c_default) = hash_to_cid(&default_node).expect("hash default");
395 let (_, c_explicit) = hash_to_cid(&explicit_node).expect("hash explicit");
396 assert_eq!(c_default, c_explicit);
397 }
398
399 #[test]
400 fn node_kind_rejection() {
401 let wire = NodeWire {
403 kind: "edge".into(),
404 id: NodeId::from_bytes_raw([1u8; 16]),
405 ntype: "x".into(),
406 summary: None,
407 props: BTreeMap::new(),
408 content: None,
409 context_sentence: None,
410 extra: BTreeMap::new(),
411 };
412 let bytes = serde_ipld_dagcbor::to_vec(&wire).expect("encode wire");
413 let err = serde_ipld_dagcbor::from_slice::<Node>(&bytes).unwrap_err();
414 assert!(
415 err.to_string().contains("_kind"),
416 "expected _kind rejection, got: {err}"
417 );
418 }
419
420 #[test]
421 fn node_extra_fields_round_trip() {
422 let mut wire = NodeWire {
424 kind: "node".into(),
425 id: NodeId::from_bytes_raw([2u8; 16]),
426 ntype: "Future".into(),
427 summary: None,
428 props: BTreeMap::new(),
429 content: None,
430 context_sentence: None,
431 extra: BTreeMap::new(),
432 };
433 wire.extra.insert(
434 "x-future-field".into(),
435 Ipld::String("value-from-v99".into()),
436 );
437 let bytes_in = serde_ipld_dagcbor::to_vec(&wire).expect("encode");
438
439 let decoded: Node = serde_ipld_dagcbor::from_slice(&bytes_in).expect("decode");
441 assert_eq!(
442 decoded.extra.get("x-future-field"),
443 Some(&Ipld::String("value-from-v99".into())),
444 );
445
446 let bytes_out = to_canonical_bytes(&decoded).expect("re-encode");
448 assert_eq!(bytes_in, bytes_out);
449 }
450
451 #[test]
452 fn legacy_embed_field_round_trips_through_extra() {
453 #[derive(Serialize)]
464 struct LegacyNodeWire {
465 #[serde(rename = "_kind")]
466 kind: String,
467 id: NodeId,
468 ntype: String,
469 #[serde(skip_serializing_if = "Option::is_none")]
470 summary: Option<String>,
471 props: BTreeMap<String, Ipld>,
472 #[serde(skip_serializing_if = "Option::is_none")]
473 content: Option<Bytes>,
474 embed: Embedding,
475 }
476
477 let legacy = LegacyNodeWire {
478 kind: "node".into(),
479 id: NodeId::from_bytes_raw([42u8; 16]),
480 ntype: "Doc".into(),
481 summary: None,
482 props: BTreeMap::new(),
483 content: None,
484 embed: Embedding {
485 model: "openai:text-embedding-3-small".into(),
486 dtype: Dtype::F32,
487 dim: 2,
488 vector: Bytes::from(vec![
489 0x00, 0x00, 0x80, 0x3f, 0x00, 0x00, 0x00, 0x00, ]),
492 },
493 };
494 let bytes_in = serde_ipld_dagcbor::to_vec(&legacy).expect("encode legacy");
495
496 let decoded: Node = serde_ipld_dagcbor::from_slice(&bytes_in).expect("decode legacy");
499 assert!(
500 decoded.extra.contains_key("embed"),
501 "legacy embed must land in extra"
502 );
503
504 let bytes_out = to_canonical_bytes(&decoded).expect("re-encode");
506 assert_eq!(bytes_in, bytes_out, "legacy bytes must round-trip exactly");
507
508 let (bytes_from_node, cid_from_node) = hash_to_cid(&decoded).expect("hash node");
513 assert_eq!(
514 bytes_in.as_slice(),
515 bytes_from_node.as_ref(),
516 "a future version re-encode must match legacy bytes byte-for-byte"
517 );
518 let cid_via_legacy_bytes = {
519 let mh = crate::id::Multihash::sha2_256(&bytes_in);
520 crate::id::Cid::new(crate::id::CODEC_DAG_CBOR, mh)
521 };
522 assert_eq!(
523 cid_from_node, cid_via_legacy_bytes,
524 "NodeCid via a future version reader must equal NodeCid via legacy bytes"
525 );
526 }
527
528 #[test]
529 fn node_round_trip_with_summary() {
530 let n = Node::new(NodeId::from_bytes_raw([3u8; 16]), "Person")
531 .with_summary("Alice, 30, based in Berlin.")
532 .with_prop("name", Ipld::String("Alice".into()));
533 let bytes = to_canonical_bytes(&n).expect("encode");
534 let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
535 assert_eq!(
536 decoded.summary.as_deref(),
537 Some("Alice, 30, based in Berlin.")
538 );
539 assert_eq!(n, decoded);
540
541 let bare = Node::new(NodeId::from_bytes_raw([3u8; 16]), "Person")
544 .with_prop("name", Ipld::String("Alice".into()));
545 let (_, c_with) = hash_to_cid(&n).expect("hash");
546 let (_, c_without) = hash_to_cid(&bare).expect("hash");
547 assert_ne!(c_with, c_without);
548 }
549
550 #[test]
551 fn node_context_sentence_round_trips() {
552 let ctx = "This paragraph is from Section 3 of the 2024 lease.";
553 let n = Node::new(NodeId::from_bytes_raw([9u8; 16]), "Paragraph")
554 .with_summary("The tenant shall maintain the premises...")
555 .with_context_sentence(ctx);
556 let bytes = to_canonical_bytes(&n).expect("encode");
557 let decoded: Node = from_canonical_bytes(&bytes).expect("decode");
558 assert_eq!(decoded.context_sentence.as_deref(), Some(ctx));
559 let bytes2 = to_canonical_bytes(&decoded).expect("re-encode");
560 assert_eq!(bytes, bytes2);
561 }
562
563 #[test]
564 fn node_context_sentence_absent_not_emitted() {
565 let n = Node::new(NodeId::from_bytes_raw([10u8; 16]), "Plain");
568 let bytes = to_canonical_bytes(&n).expect("encode");
569 assert!(
570 !bytes.windows(16).any(|w| w == b"context_sentence"),
571 "absent context_sentence should not appear on the wire"
572 );
573 }
574
575 #[test]
576 fn node_context_sentence_participates_in_cid() {
577 let base = Node::new(NodeId::from_bytes_raw([11u8; 16]), "P").with_summary("x");
578 let with_ctx = base.clone().with_context_sentence("cue");
579 let (_, c1) = hash_to_cid(&base).unwrap();
580 let (_, c2) = hash_to_cid(&with_ctx).unwrap();
581 assert_ne!(c1, c2, "context_sentence must participate in the CID");
582 }
583
584 #[test]
585 fn node_summary_absent_not_emitted() {
586 let n = Node::new(NodeId::from_bytes_raw([4u8; 16]), "Thing");
590 let bytes = to_canonical_bytes(&n).expect("encode");
591 assert!(
592 !bytes.windows(7).any(|w| w == b"summary"),
593 "absent summary should not appear on the wire"
594 );
595 }
596
597 #[test]
598 fn embedding_validate_ok_and_err() {
599 let ok = Embedding {
600 model: "m".into(),
601 dtype: Dtype::F32,
602 dim: 4,
603 vector: Bytes::from(vec![0u8; 16]),
604 };
605 ok.validate().unwrap();
606
607 let bad = Embedding {
608 model: "m".into(),
609 dtype: Dtype::F32,
610 dim: 4,
611 vector: Bytes::from(vec![0u8; 10]),
612 };
613 let err = bad.validate().unwrap_err();
614 match err {
615 ObjectError::EmbeddingSizeMismatch { expected, got } => {
616 assert_eq!(expected, 16);
617 assert_eq!(got, 10);
618 }
619 e => panic!("wrong variant: {e:?}"),
620 }
621 }
622}