Skip to main content

cognee_cognify/graph_extraction/
extractable.rs

1//! GraphExtractable trait and get_graph_from_model function.
2//!
3//! Mirrors Python's `get_graph_from_model()` which uses runtime reflection
4//! to discover DataPoint fields that reference other DataPoints. In Rust we
5//! use a trait that each type explicitly implements.
6
7use std::borrow::Cow;
8use std::collections::{HashMap, HashSet};
9
10use chrono::Utc;
11use cognee_graph::EdgeData;
12use cognee_models::{Document, DocumentChunk, Entity, EntityType};
13use serde_json::json;
14use uuid::Uuid;
15
16use crate::summarization::TextSummary;
17
18// ---------------------------------------------------------------------------
19// Trait and Relationship
20// ---------------------------------------------------------------------------
21
22/// A directed relationship from a DataPoint to another DataPoint.
23#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct Relationship {
25    /// The field name that declares this relationship (e.g. "is_part_of", "contains").
26    pub field_name: String,
27    /// The UUID of the target DataPoint.
28    pub target_id: Uuid,
29}
30
31/// Declares how a DataPoint type participates in the knowledge graph.
32///
33/// Each concrete DataPoint struct (DocumentChunk, Entity, TextSummary, etc.)
34/// implements this trait to declare its outgoing structural relationships.
35///
36/// `belongs_to_set` is intentionally excluded — it is a metadata property,
37/// not a graph edge.
38pub trait GraphExtractable: Send + Sync {
39    /// The DataPoint ID of this instance.
40    fn data_point_id(&self) -> Uuid;
41
42    /// The DataPoint type name (e.g., "DocumentChunk", "Entity").
43    fn data_point_type(&self) -> &str;
44
45    /// Outgoing structural relationships from this instance.
46    fn relationships(&self) -> Vec<Relationship>;
47}
48
49// ---------------------------------------------------------------------------
50// Implementations for built-in types
51// ---------------------------------------------------------------------------
52
53impl GraphExtractable for DocumentChunk {
54    fn data_point_id(&self) -> Uuid {
55        self.base.id
56    }
57
58    fn data_point_type(&self) -> &str {
59        &self.base.data_type
60    }
61
62    fn relationships(&self) -> Vec<Relationship> {
63        let mut rels = Vec::new();
64
65        // is_part_of: DocumentChunk → Document
66        if let Some(doc_id) = self.is_part_of {
67            rels.push(Relationship {
68                field_name: "is_part_of".to_string(),
69                target_id: doc_id,
70            });
71        }
72
73        // contains: DocumentChunk → Entity (from chunk.contains populated in graph extraction)
74        for entity_ref in &self.contains {
75            if let Some(id_str) = entity_ref.as_str()
76                && let Ok(id) = Uuid::parse_str(id_str)
77            {
78                rels.push(Relationship {
79                    field_name: "contains".to_string(),
80                    target_id: id,
81                });
82            }
83        }
84
85        rels
86    }
87}
88
89impl GraphExtractable for Document {
90    fn data_point_id(&self) -> Uuid {
91        // Document node ID = source Data item's id (content-addressed,
92        // Python-identical via classify_documents). See document.rs:117.
93        self.base.id
94    }
95
96    fn data_point_type(&self) -> &str {
97        // Concrete Python subclass name (e.g. "TextDocument", "PdfDocument").
98        &self.base.data_type
99    }
100
101    fn relationships(&self) -> Vec<Relationship> {
102        // Document has no outgoing DataPoint relationships of its own — the
103        // `is_part_of` edge originates from the DocumentChunk (see
104        // DocumentChunk::relationships) and points back at this Document.
105        Vec::new()
106    }
107}
108
109impl GraphExtractable for Entity {
110    fn data_point_id(&self) -> Uuid {
111        self.base.id
112    }
113
114    fn data_point_type(&self) -> &str {
115        &self.base.data_type
116    }
117
118    fn relationships(&self) -> Vec<Relationship> {
119        let mut rels = Vec::new();
120
121        // is_a: Entity → EntityType
122        if let Some(type_id) = self.is_a {
123            rels.push(Relationship {
124                field_name: "is_a".to_string(),
125                target_id: type_id,
126            });
127        }
128
129        rels
130    }
131}
132
133impl GraphExtractable for EntityType {
134    fn data_point_id(&self) -> Uuid {
135        self.base.id
136    }
137
138    fn data_point_type(&self) -> &str {
139        &self.base.data_type
140    }
141
142    fn relationships(&self) -> Vec<Relationship> {
143        // EntityType has no outgoing DataPoint relationships
144        Vec::new()
145    }
146}
147
148impl GraphExtractable for TextSummary {
149    fn data_point_id(&self) -> Uuid {
150        self.base.id
151    }
152
153    fn data_point_type(&self) -> &str {
154        &self.base.data_type
155    }
156
157    fn relationships(&self) -> Vec<Relationship> {
158        let mut rels = Vec::new();
159
160        // made_from: TextSummary → DocumentChunk
161        if let Some(chunk_id) = self.made_from {
162            rels.push(Relationship {
163                field_name: "made_from".to_string(),
164                target_id: chunk_id,
165            });
166        }
167
168        rels
169    }
170}
171
172// ---------------------------------------------------------------------------
173// get_graph_from_model
174// ---------------------------------------------------------------------------
175
176/// Discover all structural edges from a set of graph-extractable items.
177///
178/// Returns a deduplicated list of [`EdgeData`] tuples, each with an
179/// `updated_at` property matching Python's format.
180///
181/// Port of Python's `get_graph_from_model()` — simplified because our
182/// current types don't have nested DataPoint fields that require recursive
183/// DFS traversal.
184pub fn get_graph_from_model(items: &[&dyn GraphExtractable]) -> Vec<EdgeData> {
185    let mut edges: Vec<EdgeData> = Vec::new();
186    let mut seen: HashSet<(String, String, String)> = HashSet::new();
187    let now = Utc::now().to_rfc3339();
188
189    for item in items {
190        for rel in item.relationships() {
191            let source = item.data_point_id().to_string();
192            let target = rel.target_id.to_string();
193            let key = (source.clone(), target.clone(), rel.field_name.clone());
194
195            if seen.insert(key) {
196                edges.push((
197                    source,
198                    target,
199                    rel.field_name,
200                    HashMap::from([(Cow::from("updated_at"), json!(now.clone()))]),
201                ));
202            }
203        }
204    }
205
206    edges
207}
208
209// ---------------------------------------------------------------------------
210// Tests
211// ---------------------------------------------------------------------------
212
213#[cfg(test)]
214mod tests {
215    use super::*;
216
217    #[test]
218    fn test_document_chunk_relationships() {
219        let doc_id = Uuid::new_v4();
220        let chunk = DocumentChunk::new(
221            Uuid::new_v4(),
222            "test text".to_string(),
223            2,
224            0,
225            "paragraph_end".to_string(),
226            doc_id,
227        );
228
229        let rels = chunk.relationships();
230        assert_eq!(rels.len(), 1);
231        assert_eq!(rels[0].field_name, "is_part_of");
232        assert_eq!(rels[0].target_id, doc_id);
233    }
234
235    #[test]
236    fn test_document_chunk_with_contains() {
237        let doc_id = Uuid::new_v4();
238        let entity_id = Uuid::new_v4();
239        let mut chunk = DocumentChunk::new(
240            Uuid::new_v4(),
241            "test text".to_string(),
242            2,
243            0,
244            "paragraph_end".to_string(),
245            doc_id,
246        );
247        chunk.contains = vec![json!(entity_id.to_string())];
248
249        let rels = chunk.relationships();
250        assert_eq!(rels.len(), 2);
251        assert_eq!(rels[0].field_name, "is_part_of");
252        assert_eq!(rels[1].field_name, "contains");
253        assert_eq!(rels[1].target_id, entity_id);
254    }
255
256    #[test]
257    fn test_document_has_no_relationships_and_id_matches_data() {
258        use cognee_models::{Data, classify_documents};
259
260        let data = Data::builder(
261            Uuid::new_v4(),
262            "test.txt",
263            "/storage/test",
264            "file:///storage/test.txt",
265            "txt",
266            "text/plain",
267            "hash123",
268            Uuid::new_v4(),
269        )
270        .build();
271        let data_id = data.id;
272        let docs = classify_documents(std::slice::from_ref(&data));
273        assert_eq!(docs.len(), 1);
274        let doc = &docs[0];
275
276        // Document node ID equals the source Data item's id (content-addressed).
277        assert_eq!(doc.data_point_id(), data_id);
278        // Concrete Python subclass name carried through.
279        assert_eq!(doc.data_point_type(), "TextDocument");
280        // No outgoing DataPoint relationships.
281        assert!(doc.relationships().is_empty());
282    }
283
284    #[test]
285    fn test_entity_relationships() {
286        let type_id = Uuid::new_v4();
287        let entity = Entity::new("TechCorp", Some(type_id), "A company", None);
288
289        let rels = entity.relationships();
290        assert_eq!(rels.len(), 1);
291        assert_eq!(rels[0].field_name, "is_a");
292        assert_eq!(rels[0].target_id, type_id);
293    }
294
295    #[test]
296    fn test_entity_no_type_no_relationships() {
297        let entity = Entity::new("TechCorp", None, "A company", None);
298
299        let rels = entity.relationships();
300        assert!(rels.is_empty());
301    }
302
303    #[test]
304    fn test_entity_type_no_relationships() {
305        let et = EntityType::new("Organization", "A company type", None);
306
307        let rels = et.relationships();
308        assert!(rels.is_empty());
309    }
310
311    #[test]
312    fn test_text_summary_relationships() {
313        let chunk_id = Uuid::new_v4();
314        let summary = TextSummary::new(
315            chunk_id,
316            "Summary text".to_string(),
317            None,
318            "gpt-4".to_string(),
319        );
320
321        let rels = summary.relationships();
322        assert_eq!(rels.len(), 1);
323        assert_eq!(rels[0].field_name, "made_from");
324        assert_eq!(rels[0].target_id, chunk_id);
325    }
326
327    #[test]
328    fn test_get_graph_from_model_basic() {
329        let doc_id = Uuid::new_v4();
330        let chunk = DocumentChunk::new(
331            Uuid::new_v4(),
332            "test".to_string(),
333            1,
334            0,
335            "paragraph_end".to_string(),
336            doc_id,
337        );
338
339        let items: Vec<&dyn GraphExtractable> = vec![&chunk];
340        let edges = get_graph_from_model(&items);
341
342        assert_eq!(edges.len(), 1);
343        assert_eq!(edges[0].0, chunk.base.id.to_string());
344        assert_eq!(edges[0].1, doc_id.to_string());
345        assert_eq!(edges[0].2, "is_part_of");
346        assert!(edges[0].3.contains_key(&Cow::from("updated_at")));
347    }
348
349    #[test]
350    fn test_get_graph_from_model_deduplication() {
351        let doc_id = Uuid::new_v4();
352        let chunk_id = Uuid::new_v4();
353        let chunk = DocumentChunk::new(
354            chunk_id,
355            "test".to_string(),
356            1,
357            0,
358            "paragraph_end".to_string(),
359            doc_id,
360        );
361
362        // Pass the same item twice — edges should be deduplicated
363        let items: Vec<&dyn GraphExtractable> = vec![&chunk, &chunk];
364        let edges = get_graph_from_model(&items);
365
366        assert_eq!(edges.len(), 1);
367    }
368
369    #[test]
370    fn test_get_graph_from_model_multiple_types() {
371        let doc_id = Uuid::new_v4();
372        let type_id = Uuid::new_v4();
373        let chunk_id = Uuid::new_v4();
374
375        let chunk = DocumentChunk::new(
376            chunk_id,
377            "test".to_string(),
378            1,
379            0,
380            "paragraph_end".to_string(),
381            doc_id,
382        );
383
384        let entity = Entity::new("TechCorp", Some(type_id), "A company", None);
385        let entity_type = EntityType::new("Organization", "A type", None);
386
387        let summary = TextSummary::new(chunk_id, "Summary".to_string(), None, "gpt-4".to_string());
388
389        let items: Vec<&dyn GraphExtractable> = vec![&chunk, &entity, &entity_type, &summary];
390        let edges = get_graph_from_model(&items);
391
392        // chunk: is_part_of → doc_id (1)
393        // entity: is_a → type_id (1)
394        // entity_type: (0)
395        // summary: made_from → chunk_id (1)
396        assert_eq!(edges.len(), 3);
397
398        let edge_names: Vec<&str> = edges.iter().map(|e| e.2.as_str()).collect();
399        assert!(edge_names.contains(&"is_part_of"));
400        assert!(edge_names.contains(&"is_a"));
401        assert!(edge_names.contains(&"made_from"));
402    }
403
404    #[test]
405    fn test_get_graph_from_model_empty() {
406        let items: Vec<&dyn GraphExtractable> = vec![];
407        let edges = get_graph_from_model(&items);
408        assert!(edges.is_empty());
409    }
410
411    #[test]
412    fn test_get_graph_from_model_contains_edges() {
413        let doc_id = Uuid::new_v4();
414        let entity_id_1 = Uuid::new_v4();
415        let entity_id_2 = Uuid::new_v4();
416
417        let mut chunk = DocumentChunk::new(
418            Uuid::new_v4(),
419            "test".to_string(),
420            1,
421            0,
422            "paragraph_end".to_string(),
423            doc_id,
424        );
425        chunk.contains = vec![
426            json!(entity_id_1.to_string()),
427            json!(entity_id_2.to_string()),
428        ];
429
430        let items: Vec<&dyn GraphExtractable> = vec![&chunk];
431        let edges = get_graph_from_model(&items);
432
433        // is_part_of + 2 contains
434        assert_eq!(edges.len(), 3);
435
436        let contains_edges: Vec<_> = edges.iter().filter(|e| e.2 == "contains").collect();
437        assert_eq!(contains_edges.len(), 2);
438    }
439
440    #[test]
441    fn test_relationship_equality() {
442        let id = Uuid::new_v4();
443        let r1 = Relationship {
444            field_name: "is_a".to_string(),
445            target_id: id,
446        };
447        let r2 = Relationship {
448            field_name: "is_a".to_string(),
449            target_id: id,
450        };
451        assert_eq!(r1, r2);
452    }
453
454    #[test]
455    fn test_data_point_type_names() {
456        let chunk = DocumentChunk::new(
457            Uuid::new_v4(),
458            "t".to_string(),
459            1,
460            0,
461            "word".to_string(),
462            Uuid::new_v4(),
463        );
464        assert_eq!(chunk.data_point_type(), "DocumentChunk");
465
466        let entity = Entity::new("Test", None, "desc", None);
467        assert_eq!(entity.data_point_type(), "Entity");
468
469        let et = EntityType::new("Type", "desc", None);
470        assert_eq!(et.data_point_type(), "EntityType");
471
472        let summary = TextSummary::new(Uuid::new_v4(), "s".to_string(), None, "model".to_string());
473        assert_eq!(summary.data_point_type(), "TextSummary");
474    }
475
476    #[test]
477    fn test_invalid_uuid_in_contains_is_skipped() {
478        let doc_id = Uuid::new_v4();
479        let mut chunk = DocumentChunk::new(
480            Uuid::new_v4(),
481            "test".to_string(),
482            1,
483            0,
484            "paragraph_end".to_string(),
485            doc_id,
486        );
487        // Add an invalid UUID string — should be silently skipped
488        chunk.contains = vec![json!("not-a-valid-uuid")];
489
490        let rels = chunk.relationships();
491        // Only is_part_of, the invalid contains entry is skipped
492        assert_eq!(rels.len(), 1);
493        assert_eq!(rels[0].field_name, "is_part_of");
494    }
495
496    #[test]
497    fn test_non_string_in_contains_is_skipped() {
498        let doc_id = Uuid::new_v4();
499        let mut chunk = DocumentChunk::new(
500            Uuid::new_v4(),
501            "test".to_string(),
502            1,
503            0,
504            "paragraph_end".to_string(),
505            doc_id,
506        );
507        // Add a non-string JSON value — should be silently skipped
508        chunk.contains = vec![json!(42)];
509
510        let rels = chunk.relationships();
511        assert_eq!(rels.len(), 1);
512        assert_eq!(rels[0].field_name, "is_part_of");
513    }
514}