1use super::entity::{Entity, Relation};
84use serde::{Deserialize, Serialize};
85use std::collections::HashMap;
86
87#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct GraphNode {
94 pub id: String,
96 pub node_type: String,
98 pub name: String,
100 #[serde(default)]
102 pub properties: HashMap<String, serde_json::Value>,
103}
104
105impl GraphNode {
106 #[must_use]
108 pub fn new(
109 id: impl Into<String>,
110 node_type: impl Into<String>,
111 name: impl Into<String>,
112 ) -> Self {
113 Self {
114 id: id.into(),
115 node_type: node_type.into(),
116 name: name.into(),
117 properties: HashMap::new(),
118 }
119 }
120
121 #[must_use]
123 pub fn with_property(
124 mut self,
125 key: impl Into<String>,
126 value: impl Into<serde_json::Value>,
127 ) -> Self {
128 self.properties.insert(key.into(), value.into());
129 self
130 }
131
132 #[must_use]
134 pub fn with_mentions_count(self, count: usize) -> Self {
135 self.with_property("mentions_count", count)
136 }
137
138 #[must_use]
140 pub fn with_first_seen(self, offset: usize) -> Self {
141 self.with_property("first_seen", offset)
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
147pub struct GraphEdge {
148 pub source: String,
150 pub target: String,
152 pub relation: String,
154 #[serde(default)]
156 pub confidence: f64,
157 #[serde(default)]
159 pub properties: HashMap<String, serde_json::Value>,
160}
161
162impl GraphEdge {
163 #[must_use]
165 pub fn new(
166 source: impl Into<String>,
167 target: impl Into<String>,
168 relation: impl Into<String>,
169 ) -> Self {
170 Self {
171 source: source.into(),
172 target: target.into(),
173 relation: relation.into(),
174 confidence: 1.0,
175 properties: HashMap::new(),
176 }
177 }
178
179 #[must_use]
181 pub fn with_confidence(mut self, confidence: f64) -> Self {
182 self.confidence = confidence;
183 self
184 }
185
186 #[must_use]
188 pub fn with_property(
189 mut self,
190 key: impl Into<String>,
191 value: impl Into<serde_json::Value>,
192 ) -> Self {
193 self.properties.insert(key.into(), value.into());
194 self
195 }
196
197 #[must_use]
199 pub fn with_trigger(self, trigger: impl Into<String>) -> Self {
200 self.with_property("trigger", trigger.into())
201 }
202}
203
204#[derive(Debug, Clone, Default, Serialize, Deserialize)]
206pub struct GraphDocument {
207 pub nodes: Vec<GraphNode>,
209 pub edges: Vec<GraphEdge>,
211 #[serde(default)]
213 pub metadata: HashMap<String, serde_json::Value>,
214}
215
216impl GraphDocument {
217 #[must_use]
219 pub fn new() -> Self {
220 Self::default()
221 }
222
223 #[must_use]
233 pub fn from_extraction(
234 entities: &[Entity],
235 relations: &[Relation],
236 _coref_chains: Option<()>,
239 ) -> Self {
240 let mut doc = Self::new();
241
242 let canonical_mentions: HashMap<super::types::CanonicalId, (&str, usize)> = HashMap::new();
245
246 let mut seen_nodes: HashMap<String, usize> = HashMap::new();
248 let mut entity_to_node: HashMap<usize, String> = HashMap::new();
249
250 for (idx, entity) in entities.iter().enumerate() {
252 let node_id = get_node_id(entity);
253
254 if let Some(&existing_idx) = seen_nodes.get(&node_id) {
256 if let Some(count) = doc.nodes[existing_idx].properties.get_mut("mentions_count") {
258 if let Some(n) = count.as_u64() {
259 *count = serde_json::Value::from(n + 1);
260 }
261 }
262 entity_to_node.insert(idx, node_id);
263 continue;
264 }
265
266 let (name, mentions_count) = if let Some(canonical_id) = entity.canonical_id {
268 canonical_mentions
269 .get(&canonical_id)
270 .map(|(text, count)| (text.to_string(), *count))
271 .unwrap_or_else(|| (entity.text.clone(), 1))
272 } else {
273 (entity.text.clone(), 1)
274 };
275
276 let mut node = GraphNode::new(&node_id, entity.entity_type.as_label(), name)
277 .with_mentions_count(mentions_count)
278 .with_first_seen(entity.start);
279
280 if let Some(valid_from) = &entity.valid_from {
282 node = node.with_property("valid_from", valid_from.to_rfc3339());
283 }
284 if let Some(valid_until) = &entity.valid_until {
285 node = node.with_property("valid_until", valid_until.to_rfc3339());
286 }
287
288 if let Some(viewport) = &entity.viewport {
290 node = node.with_property("viewport", viewport.as_str());
291 }
292
293 seen_nodes.insert(node_id.clone(), doc.nodes.len());
294 entity_to_node.insert(idx, node_id);
295 doc.nodes.push(node);
296 }
297
298 let mut seen_edges: HashMap<(String, String, String), usize> = HashMap::new();
300 for relation in relations {
301 let source_node_id = get_node_id(&relation.head);
303 let target_node_id = get_node_id(&relation.tail);
304
305 let source_exists = seen_nodes.contains_key(&source_node_id);
307 let target_exists = seen_nodes.contains_key(&target_node_id);
308
309 if source_exists && target_exists {
310 let key = (
311 source_node_id.clone(),
312 target_node_id.clone(),
313 relation.relation_type.clone(),
314 );
315 if let Some(&idx) = seen_edges.get(&key) {
316 if let Some(existing) = doc.edges.get_mut(idx) {
317 existing.confidence = existing.confidence.max(relation.confidence);
318 }
319 } else {
320 let edge =
321 GraphEdge::new(&source_node_id, &target_node_id, &relation.relation_type)
322 .with_confidence(relation.confidence);
323 doc.edges.push(edge);
324 seen_edges.insert(key, doc.edges.len().saturating_sub(1));
325 }
326 }
327 }
328
329 doc
330 }
331
332 #[must_use]
338 pub fn from_entities_cooccurrence(entities: &[Entity], max_distance: usize) -> Self {
339 let mut doc = Self::new();
340 let mut entity_to_node: HashMap<usize, String> = HashMap::new();
341 let mut seen_nodes: HashMap<String, usize> = HashMap::new();
342
343 for (idx, entity) in entities.iter().enumerate() {
345 let node_id = get_node_id(entity);
346
347 if seen_nodes.contains_key(&node_id) {
348 entity_to_node.insert(idx, node_id);
349 continue;
350 }
351
352 let mut node = GraphNode::new(&node_id, entity.entity_type.as_label(), &entity.text)
353 .with_first_seen(entity.start);
354
355 if let Some(valid_from) = &entity.valid_from {
357 node = node.with_property("valid_from", valid_from.to_rfc3339());
358 }
359 if let Some(valid_until) = &entity.valid_until {
360 node = node.with_property("valid_until", valid_until.to_rfc3339());
361 }
362
363 if let Some(viewport) = &entity.viewport {
365 node = node.with_property("viewport", viewport.as_str());
366 }
367
368 seen_nodes.insert(node_id.clone(), doc.nodes.len());
369 entity_to_node.insert(idx, node_id);
370 doc.nodes.push(node);
371 }
372
373 for (i, entity_a) in entities.iter().enumerate() {
375 for (j, entity_b) in entities.iter().enumerate().skip(i + 1) {
376 let distance = if entity_a.end <= entity_b.start {
377 entity_b.start.saturating_sub(entity_a.end)
378 } else if entity_b.end <= entity_a.start {
379 entity_a.start.saturating_sub(entity_b.end)
380 } else {
381 0 };
383
384 if distance <= max_distance {
385 if let (Some(source), Some(target)) =
386 (entity_to_node.get(&i), entity_to_node.get(&j))
387 {
388 if source != target {
390 let edge = GraphEdge::new(source, target, "RELATED_TO")
391 .with_property("distance", distance);
392 doc.edges.push(edge);
393 }
394 }
395 }
396 }
397 }
398
399 doc
400 }
401
402 #[must_use]
404 pub fn to_cypher(&self) -> String {
405 let mut cypher = String::new();
406
407 for node in &self.nodes {
409 let props = format_cypher_props(&node.properties, &node.name);
410 cypher.push_str(&format!(
411 "CREATE (n{}:{} {{id: '{}'{}}});\n",
412 sanitize_cypher_name(&node.id),
413 sanitize_cypher_name(&node.node_type),
414 escape_cypher_string(&node.id),
415 props
416 ));
417 }
418
419 cypher.push('\n');
420
421 for edge in &self.edges {
423 let props = if edge.confidence < 1.0 {
424 format!(" {{confidence: {:.3}}}", edge.confidence)
425 } else {
426 String::new()
427 };
428
429 cypher.push_str(&format!(
430 "MATCH (a {{id: '{}'}}), (b {{id: '{}'}}) CREATE (a)-[:{}{}]->(b);\n",
431 escape_cypher_string(&edge.source),
432 escape_cypher_string(&edge.target),
433 sanitize_cypher_name(&edge.relation),
434 props
435 ));
436 }
437
438 cypher
439 }
440
441 #[must_use]
452 pub fn to_networkx_json(&self) -> String {
453 #[derive(Serialize)]
454 struct NetworkXGraph<'a> {
455 directed: bool,
456 multigraph: bool,
457 graph: HashMap<String, serde_json::Value>,
458 nodes: Vec<NetworkXNode<'a>>,
459 links: Vec<NetworkXLink<'a>>,
460 }
461
462 #[derive(Serialize)]
463 struct NetworkXNode<'a> {
464 id: &'a str,
465 #[serde(rename = "type")]
466 node_type: &'a str,
467 name: &'a str,
468 #[serde(flatten)]
469 properties: &'a HashMap<String, serde_json::Value>,
470 }
471
472 #[derive(Serialize)]
473 struct NetworkXLink<'a> {
474 source: &'a str,
475 target: &'a str,
476 relation: &'a str,
477 #[serde(skip_serializing_if = "is_default_confidence")]
478 confidence: f64,
479 #[serde(flatten)]
480 properties: &'a HashMap<String, serde_json::Value>,
481 }
482
483 fn is_default_confidence(c: &f64) -> bool {
484 (*c - 1.0).abs() < f64::EPSILON
485 }
486
487 let graph = NetworkXGraph {
488 directed: true,
489 multigraph: false,
490 graph: self.metadata.clone(),
491 nodes: self
492 .nodes
493 .iter()
494 .map(|n| NetworkXNode {
495 id: &n.id,
496 node_type: &n.node_type,
497 name: &n.name,
498 properties: &n.properties,
499 })
500 .collect(),
501 links: self
502 .edges
503 .iter()
504 .map(|e| NetworkXLink {
505 source: &e.source,
506 target: &e.target,
507 relation: &e.relation,
508 confidence: e.confidence,
509 properties: &e.properties,
510 })
511 .collect(),
512 };
513
514 serde_json::to_string_pretty(&graph).unwrap_or_else(|_| "{}".to_string())
515 }
516
517 #[must_use]
519 pub fn to_json_ld(&self) -> String {
520 #[derive(Serialize)]
521 struct JsonLd<'a> {
522 #[serde(rename = "@context")]
523 context: JsonLdContext,
524 #[serde(rename = "@graph")]
525 graph: Vec<JsonLdNode<'a>>,
526 }
527
528 #[derive(Serialize)]
529 struct JsonLdContext {
530 #[serde(rename = "@vocab")]
531 vocab: &'static str,
532 name: &'static str,
533 #[serde(rename = "type")]
534 type_: &'static str,
535 }
536
537 #[derive(Serialize)]
538 struct JsonLdNode<'a> {
539 #[serde(rename = "@id")]
540 id: &'a str,
541 #[serde(rename = "@type")]
542 node_type: &'a str,
543 name: &'a str,
544 #[serde(skip_serializing_if = "Vec::is_empty")]
545 relations: Vec<JsonLdRelation<'a>>,
546 }
547
548 #[derive(Serialize)]
549 struct JsonLdRelation<'a> {
550 #[serde(rename = "@type")]
551 relation_type: &'a str,
552 target: &'a str,
553 }
554
555 let mut node_edges: HashMap<&str, Vec<&GraphEdge>> = HashMap::new();
557 for edge in &self.edges {
558 node_edges.entry(&edge.source).or_default().push(edge);
559 }
560
561 let doc = JsonLd {
562 context: JsonLdContext {
563 vocab: "http://schema.org/",
564 name: "http://schema.org/name",
565 type_: "http://www.w3.org/1999/02/22-rdf-syntax-ns#type",
566 },
567 graph: self
568 .nodes
569 .iter()
570 .map(|n| JsonLdNode {
571 id: &n.id,
572 node_type: &n.node_type,
573 name: &n.name,
574 relations: node_edges
575 .get(n.id.as_str())
576 .map(|edges| {
577 edges
578 .iter()
579 .map(|e| JsonLdRelation {
580 relation_type: &e.relation,
581 target: &e.target,
582 })
583 .collect()
584 })
585 .unwrap_or_default(),
586 })
587 .collect(),
588 };
589
590 serde_json::to_string_pretty(&doc).unwrap_or_else(|_| "{}".to_string())
591 }
592
593 pub fn with_metadata(
595 mut self,
596 key: impl Into<String>,
597 value: impl Into<serde_json::Value>,
598 ) -> Self {
599 self.metadata.insert(key.into(), value.into());
600 self
601 }
602
603 #[must_use]
605 pub fn node_count(&self) -> usize {
606 self.nodes.len()
607 }
608
609 #[must_use]
611 pub fn edge_count(&self) -> usize {
612 self.edges.len()
613 }
614
615 #[must_use]
617 pub fn is_empty(&self) -> bool {
618 self.nodes.is_empty()
619 }
620
621 #[must_use]
646 pub fn from_grounded_document(doc: &super::grounded::GroundedDocument) -> Self {
647 let entities: Vec<super::Entity> = doc.to_entities();
651
652 let relations: Vec<super::entity::Relation> = Vec::new();
658
659 Self::from_extraction(&entities, &relations, None)
660 }
661}
662
663#[derive(Debug, Clone, Copy, PartialEq, Eq)]
669pub enum GraphExportFormat {
670 Cypher,
672 NetworkXJson,
674 JsonLd,
676}
677
678impl GraphDocument {
679 #[must_use]
681 pub fn export(&self, format: GraphExportFormat) -> String {
682 match format {
683 GraphExportFormat::Cypher => self.to_cypher(),
684 GraphExportFormat::NetworkXJson => self.to_networkx_json(),
685 GraphExportFormat::JsonLd => self.to_json_ld(),
686 }
687 }
688}
689
690fn get_node_id(entity: &Entity) -> String {
696 if let Some(ref kb_id) = entity.kb_id {
698 return kb_id.clone();
699 }
700 if let Some(canonical_id) = entity.canonical_id {
701 return format!("coref_{}", canonical_id);
702 }
703 fn normalize_for_id(text: &str) -> String {
706 let mut s = text.trim().to_lowercase();
707 if s.is_empty() {
708 return s;
709 }
710
711 fn is_edge_punct(c: char) -> bool {
712 matches!(
713 c,
714 '.' | ','
715 | ';'
716 | ':'
717 | '!'
718 | '?'
719 | '"'
720 | '\''
721 | '’'
722 | '“'
723 | '”'
724 | '('
725 | ')'
726 | '['
727 | ']'
728 | '{'
729 | '}'
730 )
731 }
732
733 while s.chars().next().is_some_and(is_edge_punct) {
734 s.remove(0);
735 }
736 while s.chars().last().is_some_and(is_edge_punct) {
737 s.pop();
738 }
739
740 if s.ends_with("'s") || s.ends_with("’s") {
741 s.pop();
742 s.pop();
743 } else if s.ends_with("s'") || s.ends_with("s’") {
744 s.pop();
745 }
746
747 while s.chars().last().is_some_and(is_edge_punct) {
748 s.pop();
749 }
750
751 s.split_whitespace()
752 .collect::<Vec<_>>()
753 .join(" ")
754 .replace(' ', "_")
755 }
756
757 format!(
758 "{}:{}",
759 entity.entity_type.as_label().to_lowercase(),
760 normalize_for_id(&entity.text)
761 )
762}
763
764fn format_cypher_props(props: &HashMap<String, serde_json::Value>, name: &str) -> String {
766 let mut parts = vec![format!("name: '{}'", escape_cypher_string(name))];
767
768 for (key, value) in props {
769 let formatted = match value {
770 serde_json::Value::String(s) => format!("{}: '{}'", key, escape_cypher_string(s)),
771 serde_json::Value::Number(n) => format!("{}: {}", key, n),
772 serde_json::Value::Bool(b) => format!("{}: {}", key, b),
773 _ => continue,
774 };
775 parts.push(formatted);
776 }
777
778 if parts.len() > 1 {
779 format!(", {}", parts[1..].join(", "))
780 } else {
781 String::new()
782 }
783}
784
785fn escape_cypher_string(s: &str) -> String {
787 s.replace('\\', "\\\\").replace('\'', "\\'")
788}
789
790fn sanitize_cypher_name(s: &str) -> String {
792 s.chars()
793 .map(|c| {
794 if c.is_alphanumeric() || c == '_' {
795 c
796 } else {
797 '_'
798 }
799 })
800 .collect()
801}
802
803#[cfg(test)]
808mod tests {
809 #![allow(clippy::unwrap_used)] use super::*;
811 use crate::EntityType;
812
813 fn make_entity(text: &str, entity_type: EntityType, start: usize) -> Entity {
814 Entity::new(text, entity_type, start, start + text.len(), 0.9)
815 }
816
817 #[test]
818 fn test_graph_from_entities() {
819 let elon = make_entity("Elon Musk", EntityType::Person, 0).with_canonical_id(1);
820 let tesla = make_entity("Tesla", EntityType::Organization, 19).with_canonical_id(2);
821
822 let relations = vec![Relation::with_trigger(
823 elon.clone(),
824 tesla.clone(),
825 "FOUNDED",
826 10,
827 17,
828 0.85,
829 )];
830 let entities = vec![elon, tesla];
831
832 let graph = GraphDocument::from_extraction(&entities, &relations, None);
833
834 assert_eq!(graph.node_count(), 2);
835 assert_eq!(graph.edge_count(), 1);
836 assert_eq!(graph.edges[0].relation, "FOUNDED");
837 }
838
839 #[test]
840 fn test_graph_edge_deduplication_keeps_max_confidence() {
841 let a = make_entity("A", EntityType::Person, 0);
842 let b = make_entity("B", EntityType::Organization, 10);
843 let entities = vec![a.clone(), b.clone()];
844
845 let relations = vec![
846 Relation::new(a.clone(), b.clone(), "WORKS_AT", 0.2),
847 Relation::new(a, b, "WORKS_AT", 0.9),
848 ];
849
850 let graph = GraphDocument::from_extraction(&entities, &relations, None);
851 assert_eq!(graph.edge_count(), 1);
852 assert!((graph.edges[0].confidence - 0.9).abs() < 1e-9);
853 }
854
855 #[test]
856 fn test_graph_node_id_normalization_possessive_and_punct() {
857 let entities = vec![
858 make_entity("OpenAI", EntityType::Organization, 0),
859 make_entity("OpenAI's", EntityType::Organization, 10),
860 make_entity("Cupertino,", EntityType::Location, 20),
861 make_entity("Cupertino", EntityType::Location, 30),
862 ];
863
864 let graph = GraphDocument::from_extraction(&entities, &[], None);
865
866 assert_eq!(graph.node_count(), 2);
868 }
869
870 #[test]
871 fn test_graph_deduplication() {
872 let entities = vec![
873 make_entity("Elon Musk", EntityType::Person, 0).with_canonical_id(1),
874 make_entity("Musk", EntityType::Person, 50).with_canonical_id(1), make_entity("Tesla", EntityType::Organization, 100).with_canonical_id(2),
876 ];
877
878 let graph = GraphDocument::from_extraction(&entities, &[], None);
879
880 assert_eq!(graph.node_count(), 2);
882 }
883
884 #[test]
885 fn test_cypher_export() {
886 let entities = vec![make_entity("Apple", EntityType::Organization, 0)];
887 let graph = GraphDocument::from_extraction(&entities, &[], None);
888
889 let cypher = graph.to_cypher();
890 assert!(cypher.contains("CREATE"));
891 assert!(cypher.contains(":ORG"));
892 }
893
894 #[test]
895 fn test_networkx_json() {
896 let entity_a = make_entity("A", EntityType::Person, 0);
897 let entity_b = make_entity("B", EntityType::Organization, 10);
898 let entities = vec![entity_a.clone(), entity_b.clone()];
899 let relations = vec![Relation::new(entity_a, entity_b, "WORKS_AT", 0.9)];
900
901 let graph = GraphDocument::from_extraction(&entities, &relations, None);
902 let json = graph.to_networkx_json();
903
904 let parsed: serde_json::Value = serde_json::from_str(&json).unwrap();
906 assert!(parsed.get("nodes").is_some());
907 assert!(parsed.get("links").is_some());
908 assert_eq!(parsed["directed"], true);
909 }
910
911 #[test]
912 fn test_cooccurrence_graph() {
913 let entities = vec![
914 make_entity("A", EntityType::Person, 0),
915 make_entity("B", EntityType::Organization, 20),
916 make_entity("C", EntityType::Location, 100), ];
918
919 let graph = GraphDocument::from_entities_cooccurrence(&entities, 50);
920
921 assert_eq!(graph.node_count(), 3);
923 assert_eq!(graph.edge_count(), 1); }
925
926 #[test]
927 fn test_json_ld_export() {
928 let entities = vec![make_entity("Test", EntityType::Person, 0)];
929 let graph = GraphDocument::from_extraction(&entities, &[], None);
930
931 let json_ld = graph.to_json_ld();
932 let parsed: serde_json::Value = serde_json::from_str(&json_ld).unwrap();
933
934 assert!(parsed.get("@context").is_some());
935 assert!(parsed.get("@graph").is_some());
936 }
937
938 #[test]
939 fn test_temporal_validity_export() {
940 use crate::EntityViewport;
941 use chrono::{TimeZone, Utc};
942
943 let mut nadella = make_entity("Satya Nadella", EntityType::Person, 0);
945 nadella.valid_from = Some(Utc.with_ymd_and_hms(2014, 2, 4, 0, 0, 0).unwrap());
946 nadella.viewport = Some(EntityViewport::Business);
947
948 let mut ballmer = make_entity("Steve Ballmer", EntityType::Person, 50);
950 ballmer.valid_from = Some(Utc.with_ymd_and_hms(2000, 1, 13, 0, 0, 0).unwrap());
951 ballmer.valid_until = Some(Utc.with_ymd_and_hms(2014, 2, 4, 0, 0, 0).unwrap());
952 ballmer.viewport = Some(EntityViewport::Historical);
953
954 let entities = vec![nadella, ballmer];
955 let graph = GraphDocument::from_extraction(&entities, &[], None);
956
957 assert_eq!(graph.node_count(), 2);
959
960 let nadella_node = graph
962 .nodes
963 .iter()
964 .find(|n| n.name == "Satya Nadella")
965 .unwrap();
966 assert!(nadella_node.properties.contains_key("valid_from"));
967 assert!(!nadella_node.properties.contains_key("valid_until"));
968 assert_eq!(nadella_node.properties.get("viewport").unwrap(), "business");
969
970 let ballmer_node = graph
972 .nodes
973 .iter()
974 .find(|n| n.name == "Steve Ballmer")
975 .unwrap();
976 assert!(ballmer_node.properties.contains_key("valid_from"));
977 assert!(ballmer_node.properties.contains_key("valid_until"));
978 assert_eq!(
979 ballmer_node.properties.get("viewport").unwrap(),
980 "historical"
981 );
982
983 let json = graph.to_networkx_json();
985 assert!(json.contains("valid_from"));
986 assert!(json.contains("valid_until"));
987 assert!(json.contains("2014-02-04")); assert!(json.contains("2000-01-13")); }
990}