1use petgraph::graph::{DiGraph, NodeIndex};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10
11#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
35pub enum Relationship {
36 #[serde(alias = "prerequisite")]
38 Prerequisite,
39 #[serde(alias = "leads_to")]
41 LeadsTo,
42 #[default]
44 #[serde(alias = "relates_to")]
45 RelatesTo,
46 #[serde(alias = "extends")]
48 Extends,
49 #[serde(alias = "introduces")]
51 Introduces,
52 #[serde(alias = "covers")]
54 Covers,
55 #[serde(alias = "variant_of")]
57 VariantOf,
58 #[serde(alias = "contrasts_with")]
60 ContrastsWith,
61 #[serde(alias = "answers_question")]
63 AnswersQuestion,
64 Custom(String),
66}
67
68impl Relationship {
69 pub fn default_weight(&self) -> f32 {
74 match self {
75 Self::Prerequisite => 1.0,
76 Self::LeadsTo => 1.0,
77 Self::Extends => 0.9,
78 Self::Introduces => 0.8,
79 Self::Covers => 0.8,
80 Self::VariantOf => 0.9,
81 Self::ContrastsWith => 0.7,
82 Self::AnswersQuestion => 0.6,
83 Self::RelatesTo => 0.7,
84 Self::Custom(_) => 0.5,
85 }
86 }
87
88 pub fn name(&self) -> &str {
90 match self {
91 Self::Prerequisite => "prerequisite",
92 Self::LeadsTo => "leads_to",
93 Self::RelatesTo => "relates_to",
94 Self::Extends => "extends",
95 Self::Introduces => "introduces",
96 Self::Covers => "covers",
97 Self::VariantOf => "variant_of",
98 Self::ContrastsWith => "contrasts_with",
99 Self::AnswersQuestion => "answers_question",
100 Self::Custom(name) => name,
101 }
102 }
103}
104
105#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
113pub enum EdgeOrigin {
114 #[default]
116 #[serde(alias = "extracted", alias = "frontmatter")]
117 Frontmatter,
118 #[serde(alias = "content_body")]
120 ContentBody,
121 #[serde(alias = "manual")]
123 Manual,
124 #[serde(alias = "inferred")]
126 Inferred,
127}
128
129#[derive(Clone, Debug, Default, PartialEq, Eq, Hash, Serialize, Deserialize)]
140#[serde(rename_all = "snake_case")]
141pub enum NodeType {
142 #[default]
144 Domain,
145 UserQuery,
147 Custom(String),
149}
150
151#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
160pub struct Node {
161 pub id: String,
163 pub title: String,
165 pub category: Option<String>,
167 pub source_id: Option<String>,
169 #[serde(default = "default_is_canonical")]
171 pub is_canonical: bool,
172 pub canonical_id: Option<String>,
174 #[serde(default)]
176 pub node_type: NodeType,
177 #[serde(default)]
179 pub metadata: HashMap<String, serde_json::Value>,
180}
181
182fn default_is_canonical() -> bool {
183 true
184}
185
186impl Node {
187 pub fn new(id: impl Into<String>, title: impl Into<String>) -> Self {
189 Self {
190 id: id.into(),
191 title: title.into(),
192 category: None,
193 source_id: None,
194 is_canonical: true,
195 canonical_id: None,
196 node_type: NodeType::default(),
197 metadata: HashMap::new(),
198 }
199 }
200
201 pub fn with_category(mut self, category: impl Into<String>) -> Self {
203 self.category = Some(category.into());
204 self
205 }
206
207 pub fn with_source(mut self, source_id: impl Into<String>) -> Self {
209 self.source_id = Some(source_id.into());
210 self
211 }
212
213 pub fn as_variant_of(mut self, canonical_id: impl Into<String>) -> Self {
215 self.is_canonical = false;
216 self.canonical_id = Some(canonical_id.into());
217 self
218 }
219
220 pub fn with_node_type(mut self, node_type: NodeType) -> Self {
222 self.node_type = node_type;
223 self
224 }
225
226 pub fn with_metadata(
228 mut self,
229 key: impl Into<String>,
230 value: impl Into<serde_json::Value>,
231 ) -> Self {
232 self.metadata.insert(key.into(), value.into());
233 self
234 }
235}
236
237#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
243pub struct Edge {
244 pub from: String,
246 pub to: String,
248 pub relationship: Relationship,
250 pub weight: f32,
252 pub origin: EdgeOrigin,
254}
255
256impl Edge {
257 pub fn new(from: impl Into<String>, to: impl Into<String>, relationship: Relationship) -> Self {
259 let weight = relationship.default_weight();
260 Self {
261 from: from.into(),
262 to: to.into(),
263 relationship,
264 weight,
265 origin: EdgeOrigin::default(),
266 }
267 }
268
269 pub fn with_weight(mut self, weight: f32) -> Self {
271 self.weight = weight;
272 self
273 }
274
275 pub fn with_origin(mut self, origin: EdgeOrigin) -> Self {
277 self.origin = origin;
278 self
279 }
280}
281
282#[derive(Clone, Debug)]
291pub struct GraphData {
292 pub graph: DiGraph<Node, Edge>,
294 pub node_indices: HashMap<String, NodeIndex>,
296 pub nodes: HashMap<String, Node>,
298 pub edges: Vec<Edge>,
300}
301
302impl GraphData {
303 pub fn new() -> Self {
305 Self {
306 graph: DiGraph::new(),
307 node_indices: HashMap::new(),
308 nodes: HashMap::new(),
309 edges: Vec::new(),
310 }
311 }
312
313 pub fn node_count(&self) -> usize {
315 self.graph.node_count()
316 }
317
318 pub fn edge_count(&self) -> usize {
320 self.graph.edge_count()
321 }
322
323 pub fn get_node(&self, id: &str) -> Option<&Node> {
325 self.nodes.get(id)
326 }
327
328 pub fn get_index(&self, id: &str) -> Option<NodeIndex> {
330 self.node_indices.get(id).copied()
331 }
332
333 pub fn contains_node(&self, id: &str) -> bool {
335 self.nodes.contains_key(id)
336 }
337
338 pub fn node_ids(&self) -> impl Iterator<Item = &str> {
340 self.nodes.keys().map(String::as_str)
341 }
342
343 pub fn iter_nodes(&self) -> impl Iterator<Item = &Node> {
345 self.nodes.values()
346 }
347
348 pub fn iter_edges(&self) -> impl Iterator<Item = &Edge> {
350 self.edges.iter()
351 }
352
353 pub fn add_node(&mut self, node: Node) -> NodeIndex {
362 if let Some(&existing_idx) = self.node_indices.get(&node.id) {
363 return existing_idx;
364 }
365 let id = node.id.clone();
366 let idx = self.graph.add_node(node.clone());
367 self.node_indices.insert(id.clone(), idx);
368 self.nodes.insert(id, node);
369 idx
370 }
371
372 pub fn add_edge(&mut self, edge: Edge) -> fabryk_core::Result<()> {
377 let from_idx = self
378 .node_indices
379 .get(&edge.from)
380 .copied()
381 .ok_or_else(|| fabryk_core::Error::not_found("node", &edge.from))?;
382 let to_idx = self
383 .node_indices
384 .get(&edge.to)
385 .copied()
386 .ok_or_else(|| fabryk_core::Error::not_found("node", &edge.to))?;
387
388 self.graph.add_edge(from_idx, to_idx, edge.clone());
389 self.edges.push(edge);
390 Ok(())
391 }
392
393 pub fn remove_node(&mut self, id: &str) -> Option<Node> {
397 let idx = self.node_indices.remove(id)?;
398 let node = self.nodes.remove(id)?;
399
400 self.graph.remove_node(idx);
402
403 self.edges.retain(|e| e.from != id && e.to != id);
405
406 self.node_indices.clear();
409 for ni in self.graph.node_indices() {
410 let n = &self.graph[ni];
411 self.node_indices.insert(n.id.clone(), ni);
412 }
413
414 Some(node)
415 }
416}
417
418impl Default for GraphData {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424#[cfg(test)]
429mod tests {
430 use super::*;
431
432 #[test]
437 fn test_relationship_default_weights() {
438 assert_eq!(Relationship::Prerequisite.default_weight(), 1.0);
439 assert_eq!(Relationship::LeadsTo.default_weight(), 1.0);
440 assert_eq!(Relationship::Extends.default_weight(), 0.9);
441 assert_eq!(Relationship::Introduces.default_weight(), 0.8);
442 assert_eq!(Relationship::Covers.default_weight(), 0.8);
443 assert_eq!(Relationship::VariantOf.default_weight(), 0.9);
444 assert_eq!(Relationship::RelatesTo.default_weight(), 0.7);
445 assert_eq!(
446 Relationship::Custom("custom".to_string()).default_weight(),
447 0.5
448 );
449 }
450
451 #[test]
452 fn test_relationship_names() {
453 assert_eq!(Relationship::Prerequisite.name(), "prerequisite");
454 assert_eq!(Relationship::LeadsTo.name(), "leads_to");
455 assert_eq!(Relationship::RelatesTo.name(), "relates_to");
456 assert_eq!(Relationship::Extends.name(), "extends");
457 assert_eq!(Relationship::Introduces.name(), "introduces");
458 assert_eq!(Relationship::Covers.name(), "covers");
459 assert_eq!(Relationship::VariantOf.name(), "variant_of");
460 assert_eq!(
461 Relationship::Custom("implies".to_string()).name(),
462 "implies"
463 );
464 }
465
466 #[test]
467 fn test_relationship_default() {
468 assert_eq!(Relationship::default(), Relationship::RelatesTo);
469 }
470
471 #[test]
472 fn test_relationship_serialization() {
473 let rel = Relationship::Custom("implies".to_string());
474 let json = serde_json::to_string(&rel).unwrap();
475 assert!(json.contains("implies"));
476
477 let parsed: Relationship = serde_json::from_str(&json).unwrap();
478 assert_eq!(parsed, rel);
479 }
480
481 #[test]
482 fn test_relationship_all_variants_serialize() {
483 let variants = vec![
484 Relationship::Prerequisite,
485 Relationship::LeadsTo,
486 Relationship::RelatesTo,
487 Relationship::Extends,
488 Relationship::Introduces,
489 Relationship::Covers,
490 Relationship::VariantOf,
491 Relationship::Custom("test".to_string()),
492 ];
493
494 for rel in variants {
495 let json = serde_json::to_string(&rel).unwrap();
496 let parsed: Relationship = serde_json::from_str(&json).unwrap();
497 assert_eq!(parsed, rel);
498 }
499 }
500
501 #[test]
506 fn test_edge_origin_default() {
507 assert_eq!(EdgeOrigin::default(), EdgeOrigin::Frontmatter);
508 }
509
510 #[test]
511 fn test_edge_origin_serialization() {
512 let origins = vec![
513 EdgeOrigin::Frontmatter,
514 EdgeOrigin::ContentBody,
515 EdgeOrigin::Manual,
516 EdgeOrigin::Inferred,
517 ];
518
519 for origin in origins {
520 let json = serde_json::to_string(&origin).unwrap();
521 let parsed: EdgeOrigin = serde_json::from_str(&json).unwrap();
522 assert_eq!(parsed, origin);
523 }
524 }
525
526 #[test]
531 fn test_node_type_default() {
532 assert_eq!(NodeType::default(), NodeType::Domain);
533 }
534
535 #[test]
536 fn test_node_type_serialization() {
537 let types = vec![
538 NodeType::Domain,
539 NodeType::UserQuery,
540 NodeType::Custom("special".to_string()),
541 ];
542
543 for nt in types {
544 let json = serde_json::to_string(&nt).unwrap();
545 let parsed: NodeType = serde_json::from_str(&json).unwrap();
546 assert_eq!(parsed, nt);
547 }
548 }
549
550 #[test]
551 fn test_node_type_rename_all() {
552 let json = serde_json::to_string(&NodeType::UserQuery).unwrap();
553 assert_eq!(json, "\"user_query\"");
554
555 let json = serde_json::to_string(&NodeType::Domain).unwrap();
556 assert_eq!(json, "\"domain\"");
557 }
558
559 #[test]
564 fn test_node_builder() {
565 let node = Node::new("test-id", "Test Title")
566 .with_category("test-category")
567 .with_source("test-source")
568 .with_metadata("key", "value");
569
570 assert_eq!(node.id, "test-id");
571 assert_eq!(node.title, "Test Title");
572 assert_eq!(node.category, Some("test-category".to_string()));
573 assert_eq!(node.source_id, Some("test-source".to_string()));
574 assert!(node.is_canonical);
575 assert!(node.canonical_id.is_none());
576 assert_eq!(node.node_type, NodeType::Domain);
577 assert!(node.metadata.contains_key("key"));
578 }
579
580 #[test]
581 fn test_node_variant() {
582 let variant =
583 Node::new("source-concept", "Source Concept").as_variant_of("canonical-concept");
584
585 assert!(!variant.is_canonical);
586 assert_eq!(variant.canonical_id, Some("canonical-concept".to_string()));
587 }
588
589 #[test]
590 fn test_node_with_node_type() {
591 let node = Node::new("query-1", "User Query").with_node_type(NodeType::UserQuery);
592
593 assert_eq!(node.node_type, NodeType::UserQuery);
594 }
595
596 #[test]
597 fn test_node_serialization() {
598 let node = Node::new("test", "Test")
599 .with_category("cat")
600 .with_node_type(NodeType::UserQuery)
601 .with_metadata("foo", "bar");
602
603 let json = serde_json::to_string(&node).unwrap();
604 let parsed: Node = serde_json::from_str(&json).unwrap();
605
606 assert_eq!(parsed.id, node.id);
607 assert_eq!(parsed.title, node.title);
608 assert_eq!(parsed.category, node.category);
609 assert_eq!(parsed.node_type, node.node_type);
610 }
611
612 #[test]
617 fn test_edge_builder() {
618 let edge = Edge::new("a", "b", Relationship::Prerequisite)
619 .with_weight(0.8)
620 .with_origin(EdgeOrigin::Manual);
621
622 assert_eq!(edge.from, "a");
623 assert_eq!(edge.to, "b");
624 assert_eq!(edge.weight, 0.8);
625 assert_eq!(edge.origin, EdgeOrigin::Manual);
626 }
627
628 #[test]
629 fn test_edge_default_weight() {
630 let edge = Edge::new("a", "b", Relationship::Prerequisite);
631 assert_eq!(edge.weight, 1.0);
632
633 let edge2 = Edge::new("a", "b", Relationship::RelatesTo);
634 assert_eq!(edge2.weight, 0.7);
635 }
636
637 #[test]
638 fn test_edge_default_origin() {
639 let edge = Edge::new("a", "b", Relationship::Prerequisite);
640 assert_eq!(edge.origin, EdgeOrigin::Frontmatter);
641 }
642
643 #[test]
644 fn test_edge_serialization() {
645 let edge = Edge::new("a", "b", Relationship::LeadsTo)
646 .with_weight(0.5)
647 .with_origin(EdgeOrigin::Manual);
648
649 let json = serde_json::to_string(&edge).unwrap();
650 let parsed: Edge = serde_json::from_str(&json).unwrap();
651
652 assert_eq!(parsed.from, edge.from);
653 assert_eq!(parsed.to, edge.to);
654 assert_eq!(parsed.relationship, edge.relationship);
655 assert_eq!(parsed.weight, edge.weight);
656 assert_eq!(parsed.origin, edge.origin);
657 }
658
659 #[test]
664 fn test_graph_data_new() {
665 let graph = GraphData::new();
666 assert_eq!(graph.node_count(), 0);
667 assert_eq!(graph.edge_count(), 0);
668 assert!(!graph.contains_node("test"));
669 }
670
671 #[test]
672 fn test_graph_data_default() {
673 let graph = GraphData::default();
674 assert_eq!(graph.node_count(), 0);
675 }
676
677 #[test]
678 fn test_graph_data_iterators_empty() {
679 let graph = GraphData::new();
680 assert_eq!(graph.node_ids().count(), 0);
681 assert_eq!(graph.iter_nodes().count(), 0);
682 assert_eq!(graph.iter_edges().count(), 0);
683 }
684
685 #[test]
690 fn test_graph_data_add_node() {
691 let mut graph = GraphData::new();
692 let node = Node::new("a", "Node A");
693 let idx = graph.add_node(node);
694
695 assert_eq!(graph.node_count(), 1);
696 assert!(graph.contains_node("a"));
697 assert_eq!(graph.get_index("a"), Some(idx));
698 assert_eq!(graph.get_node("a").unwrap().title, "Node A");
699 }
700
701 #[test]
702 fn test_graph_data_add_node_duplicate() {
703 let mut graph = GraphData::new();
704 let idx1 = graph.add_node(Node::new("a", "Node A"));
705 let idx2 = graph.add_node(Node::new("a", "Node A Again"));
706
707 assert_eq!(idx1, idx2);
709 assert_eq!(graph.node_count(), 1);
710 }
711
712 #[test]
713 fn test_graph_data_add_edge() {
714 let mut graph = GraphData::new();
715 graph.add_node(Node::new("a", "Node A"));
716 graph.add_node(Node::new("b", "Node B"));
717
718 let edge = Edge::new("a", "b", Relationship::Prerequisite);
719 graph.add_edge(edge).unwrap();
720
721 assert_eq!(graph.edge_count(), 1);
722 assert_eq!(graph.edges.len(), 1);
723 }
724
725 #[test]
726 fn test_graph_data_add_edge_missing_from() {
727 let mut graph = GraphData::new();
728 graph.add_node(Node::new("b", "Node B"));
729
730 let edge = Edge::new("missing", "b", Relationship::Prerequisite);
731 let result = graph.add_edge(edge);
732
733 assert!(result.is_err());
734 }
735
736 #[test]
737 fn test_graph_data_add_edge_missing_to() {
738 let mut graph = GraphData::new();
739 graph.add_node(Node::new("a", "Node A"));
740
741 let edge = Edge::new("a", "missing", Relationship::Prerequisite);
742 let result = graph.add_edge(edge);
743
744 assert!(result.is_err());
745 }
746
747 #[test]
748 fn test_graph_data_remove_node() {
749 let mut graph = GraphData::new();
750 graph.add_node(Node::new("a", "Node A"));
751 graph.add_node(Node::new("b", "Node B"));
752 graph.add_node(Node::new("c", "Node C"));
753 graph
754 .add_edge(Edge::new("a", "b", Relationship::Prerequisite))
755 .unwrap();
756 graph
757 .add_edge(Edge::new("b", "c", Relationship::LeadsTo))
758 .unwrap();
759
760 let removed = graph.remove_node("b");
761 assert!(removed.is_some());
762 assert_eq!(removed.unwrap().id, "b");
763
764 assert_eq!(graph.node_count(), 2);
765 assert!(!graph.contains_node("b"));
766 assert!(graph.contains_node("a"));
767 assert!(graph.contains_node("c"));
768 assert_eq!(graph.edge_count(), 0); assert!(graph.edges.is_empty());
770 }
771
772 #[test]
773 fn test_graph_data_remove_nonexistent_node() {
774 let mut graph = GraphData::new();
775 graph.add_node(Node::new("a", "Node A"));
776
777 let removed = graph.remove_node("nonexistent");
778 assert!(removed.is_none());
779 assert_eq!(graph.node_count(), 1);
780 }
781
782 #[test]
783 fn test_graph_data_remove_node_preserves_indices() {
784 let mut graph = GraphData::new();
785 graph.add_node(Node::new("a", "Node A"));
786 graph.add_node(Node::new("b", "Node B"));
787 graph.add_node(Node::new("c", "Node C"));
788
789 graph.remove_node("a");
790
791 assert!(graph.contains_node("b"));
793 assert!(graph.contains_node("c"));
794 assert!(graph.get_index("b").is_some());
795 assert!(graph.get_index("c").is_some());
796 }
797
798 #[test]
799 fn test_graph_data_full_workflow() {
800 let mut graph = GraphData::new();
801
802 graph.add_node(Node::new("intervals", "Intervals").with_category("basics"));
804 graph.add_node(Node::new("scales", "Scales").with_category("basics"));
805 graph.add_node(Node::new("chords", "Chords").with_category("harmony"));
806
807 graph
809 .add_edge(Edge::new("intervals", "scales", Relationship::Prerequisite))
810 .unwrap();
811 graph
812 .add_edge(Edge::new("scales", "chords", Relationship::LeadsTo))
813 .unwrap();
814
815 assert_eq!(graph.node_count(), 3);
816 assert_eq!(graph.edge_count(), 2);
817
818 let intervals = graph.get_node("intervals").unwrap();
820 assert_eq!(intervals.category, Some("basics".to_string()));
821
822 graph.add_node(Node::new("query-1", "User Query").with_node_type(NodeType::UserQuery));
824 graph
825 .add_edge(Edge::new(
826 "query-1",
827 "chords",
828 Relationship::Custom("queries_about".to_string()),
829 ))
830 .unwrap();
831
832 assert_eq!(graph.node_count(), 4);
833 assert_eq!(graph.edge_count(), 3);
834
835 graph.remove_node("query-1");
837 assert_eq!(graph.node_count(), 3);
838 assert_eq!(graph.edge_count(), 2);
839 }
840}