Skip to main content

cognee_cognify/fact_extraction/
models.rs

1//! Knowledge graph data models.
2//!
3//! Port of Python's cognee/shared/data_models.py
4//! These models represent the extracted knowledge graph structure:
5//! - Node: Entities and concepts in the graph
6//! - Edge: Relationships between nodes
7//! - KnowledgeGraph: Collection of nodes and edges
8
9use schemars::JsonSchema;
10use serde::{Deserialize, Serialize, de::DeserializeOwned};
11
12/// Marker trait for types that can be used as graph extraction models.
13///
14/// Types implementing this trait can be extracted from text via LLM
15/// structured output. The LLM generates JSON conforming to the type's
16/// [`JsonSchema`], which is then deserialized into the concrete type.
17///
18/// The built-in [`KnowledgeGraph`] model implements this trait with
19/// `is_default_knowledge_graph() == true`, which triggers additional
20/// post-processing (entity/edge expansion, deduplication, graph DB storage).
21/// Custom models return `false`, causing the extracted value to be stored
22/// directly in [`DocumentChunk::contains`] as serialized JSON — mirroring
23/// the Python branching at `extract_graph_from_data.py:99-103`.
24///
25/// # Required bounds
26/// `Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static`
27pub trait GraphModel:
28    Serialize + DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static
29{
30    /// Returns `true` if this is the built-in [`KnowledgeGraph`] model.
31    ///
32    /// Custom models should leave the default (`false`), which changes
33    /// the processing flow: extracted data is stored as-is in chunk metadata
34    /// instead of being expanded into graph nodes and edges.
35    fn is_default_knowledge_graph() -> bool {
36        false
37    }
38}
39
40/// Node in a knowledge graph.
41///
42/// Represents an entity or concept extracted from text.
43/// Nodes are akin to Wikipedia nodes - they represent distinct entities.
44///
45/// # Fields
46/// * `id` - Unique identifier (human-readable, not an integer)
47/// * `name` - Display name of the entity
48/// * `node_type` - Type classification (e.g., "PERSON", "ORGANIZATION", "CONCEPT")
49/// * `description` - Brief description of the entity
50#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
51pub struct Node {
52    /// Unique identifier for the node (human-readable, e.g., "Albert Einstein")
53    pub id: String,
54
55    /// Display name of the entity
56    pub name: String,
57
58    /// Entity type (e.g., "PERSON", "ORGANIZATION", "CONCEPT")
59    /// Use uppercase for consistency with Python
60    #[serde(rename = "type")]
61    pub node_type: String,
62
63    /// Brief description of the entity (1-2 sentences)
64    pub description: String,
65}
66
67/// Edge in a knowledge graph.
68///
69/// Represents a relationship between two nodes.
70/// Edges are akin to Wikipedia links - they connect related concepts.
71///
72/// # Fields
73/// * `source_node_id` - ID of the source node
74/// * `target_node_id` - ID of the target node
75/// * `relationship_name` - Type of relationship (use snake_case, e.g., "works_at")
76/// * `description` - Concrete one-sentence fact expressed by this edge
77#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
78pub struct Edge {
79    /// ID of the source node
80    pub source_node_id: String,
81
82    /// ID of the target node
83    pub target_node_id: String,
84
85    /// Type of relationship (snake_case, e.g., "works_at", "founded", "located_in")
86    pub relationship_name: String,
87
88    /// Concrete one-sentence fact expressed by this edge, using endpoint names.
89    /// Mirrors Python `KnowledgeGraph.Edge.description` (data_models.py:62-71).
90    /// Becomes the `edge_text` graph-edge property, feeding EdgeType + Triplet
91    /// embeddings. Optional because older/custom outputs may omit it.
92    #[serde(default)]
93    pub description: Option<String>,
94}
95
96/// Knowledge graph extracted from text.
97///
98/// Contains nodes (entities/concepts) and edges (relationships).
99/// This is the primary output of fact extraction.
100///
101/// # Fields
102/// * `nodes` - List of extracted entities and concepts
103/// * `edges` - List of relationships between nodes
104#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
105pub struct KnowledgeGraph {
106    /// List of nodes (entities and concepts)
107    #[serde(default)]
108    pub nodes: Vec<Node>,
109
110    /// List of edges (relationships between nodes)
111    #[serde(default)]
112    pub edges: Vec<Edge>,
113}
114
115impl KnowledgeGraph {
116    /// Create a new empty knowledge graph.
117    pub fn new() -> Self {
118        Self {
119            nodes: Vec::new(),
120            edges: Vec::new(),
121        }
122    }
123
124    /// Check if the graph is empty (no nodes or edges).
125    pub fn is_empty(&self) -> bool {
126        self.nodes.is_empty() && self.edges.is_empty()
127    }
128
129    /// Get the number of nodes in the graph.
130    pub fn node_count(&self) -> usize {
131        self.nodes.len()
132    }
133
134    /// Get the number of edges in the graph.
135    pub fn edge_count(&self) -> usize {
136        self.edges.len()
137    }
138}
139
140impl Default for KnowledgeGraph {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146impl GraphModel for KnowledgeGraph {
147    fn is_default_knowledge_graph() -> bool {
148        true
149    }
150}
151
152#[cfg(test)]
153#[allow(
154    clippy::unwrap_used,
155    clippy::expect_used,
156    reason = "test code — panics are acceptable failures"
157)]
158mod tests {
159    use super::*;
160
161    #[test]
162    fn test_node_serialization() {
163        let node = Node {
164            id: "alice_johnson".to_string(),
165            name: "Alice Johnson".to_string(),
166            node_type: "PERSON".to_string(),
167            description: "Software engineer at TechCorp".to_string(),
168        };
169
170        let json = serde_json::to_string(&node).unwrap();
171        assert!(json.contains("\"type\":\"PERSON\""));
172
173        let deserialized: Node = serde_json::from_str(&json).unwrap();
174        assert_eq!(deserialized.node_type, "PERSON");
175    }
176
177    #[test]
178    fn test_edge_creation() {
179        let edge = Edge {
180            source_node_id: "alice_johnson".to_string(),
181            target_node_id: "techcorp".to_string(),
182            relationship_name: "works_at".to_string(),
183            description: None,
184        };
185
186        assert_eq!(edge.relationship_name, "works_at");
187    }
188
189    #[test]
190    fn test_edge_serializes_description() {
191        let edge = Edge {
192            source_node_id: "alice".to_string(),
193            target_node_id: "acme".to_string(),
194            relationship_name: "founded".to_string(),
195            description: Some("Alice founded Acme".to_string()),
196        };
197
198        let json = serde_json::to_string(&edge).unwrap();
199        assert!(json.contains("\"description\":\"Alice founded Acme\""));
200    }
201
202    #[test]
203    fn test_edge_deserializes_without_description() {
204        // Back-compat: JSON omitting `description` defaults to None.
205        let json = r#"{
206            "source_node_id": "alice",
207            "target_node_id": "acme",
208            "relationship_name": "founded"
209        }"#;
210        let edge: Edge = serde_json::from_str(json).unwrap();
211        assert_eq!(edge.relationship_name, "founded");
212        assert_eq!(edge.description, None);
213    }
214
215    #[test]
216    fn test_edge_deserializes_with_description() {
217        let json = r#"{
218            "source_node_id": "alice",
219            "target_node_id": "acme",
220            "relationship_name": "founded",
221            "description": "Alice founded Acme"
222        }"#;
223        let edge: Edge = serde_json::from_str(json).unwrap();
224        assert_eq!(edge.description.as_deref(), Some("Alice founded Acme"));
225    }
226
227    #[test]
228    fn test_knowledge_graph() {
229        let mut graph = KnowledgeGraph::new();
230        assert!(graph.is_empty());
231        assert_eq!(graph.node_count(), 0);
232        assert_eq!(graph.edge_count(), 0);
233
234        graph.nodes.push(Node {
235            id: "alice".to_string(),
236            name: "Alice".to_string(),
237            node_type: "PERSON".to_string(),
238            description: "A person".to_string(),
239        });
240
241        graph.edges.push(Edge {
242            source_node_id: "alice".to_string(),
243            target_node_id: "techcorp".to_string(),
244            relationship_name: "works_at".to_string(),
245            description: None,
246        });
247
248        assert!(!graph.is_empty());
249        assert_eq!(graph.node_count(), 1);
250        assert_eq!(graph.edge_count(), 1);
251    }
252
253    #[test]
254    fn test_knowledge_graph_is_default() {
255        assert!(KnowledgeGraph::is_default_knowledge_graph());
256    }
257
258    /// A custom graph model for testing the `GraphModel` trait.
259    #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
260    struct CustomModel {
261        items: Vec<String>,
262    }
263
264    impl GraphModel for CustomModel {}
265
266    #[test]
267    fn test_custom_model_is_not_default() {
268        assert!(!CustomModel::is_default_knowledge_graph());
269    }
270
271    #[test]
272    fn test_custom_model_roundtrip() {
273        let model = CustomModel {
274            items: vec!["a".to_string(), "b".to_string()],
275        };
276        let json = serde_json::to_string(&model).unwrap();
277        let deserialized: CustomModel = serde_json::from_str(&json).unwrap();
278        assert_eq!(deserialized.items, vec!["a", "b"]);
279    }
280}