cognee_cognify/graph_extraction/
extractable.rs1use std::borrow::Cow;
8use std::collections::{HashMap, HashSet};
9
10use chrono::Utc;
11use cognee_graph::EdgeData;
12use cognee_models::{Document, DocumentChunk, Entity, EntityType};
13use serde_json::json;
14use uuid::Uuid;
15
16use crate::summarization::TextSummary;
17
18#[derive(Debug, Clone, PartialEq, Eq, Hash)]
24pub struct Relationship {
25 pub field_name: String,
27 pub target_id: Uuid,
29}
30
31pub trait GraphExtractable: Send + Sync {
39 fn data_point_id(&self) -> Uuid;
41
42 fn data_point_type(&self) -> &str;
44
45 fn relationships(&self) -> Vec<Relationship>;
47}
48
49impl GraphExtractable for DocumentChunk {
54 fn data_point_id(&self) -> Uuid {
55 self.base.id
56 }
57
58 fn data_point_type(&self) -> &str {
59 &self.base.data_type
60 }
61
62 fn relationships(&self) -> Vec<Relationship> {
63 let mut rels = Vec::new();
64
65 if let Some(doc_id) = self.is_part_of {
67 rels.push(Relationship {
68 field_name: "is_part_of".to_string(),
69 target_id: doc_id,
70 });
71 }
72
73 for entity_ref in &self.contains {
75 if let Some(id_str) = entity_ref.as_str()
76 && let Ok(id) = Uuid::parse_str(id_str)
77 {
78 rels.push(Relationship {
79 field_name: "contains".to_string(),
80 target_id: id,
81 });
82 }
83 }
84
85 rels
86 }
87}
88
89impl GraphExtractable for Document {
90 fn data_point_id(&self) -> Uuid {
91 self.base.id
94 }
95
96 fn data_point_type(&self) -> &str {
97 &self.base.data_type
99 }
100
101 fn relationships(&self) -> Vec<Relationship> {
102 Vec::new()
106 }
107}
108
109impl GraphExtractable for Entity {
110 fn data_point_id(&self) -> Uuid {
111 self.base.id
112 }
113
114 fn data_point_type(&self) -> &str {
115 &self.base.data_type
116 }
117
118 fn relationships(&self) -> Vec<Relationship> {
119 let mut rels = Vec::new();
120
121 if let Some(type_id) = self.is_a {
123 rels.push(Relationship {
124 field_name: "is_a".to_string(),
125 target_id: type_id,
126 });
127 }
128
129 rels
130 }
131}
132
133impl GraphExtractable for EntityType {
134 fn data_point_id(&self) -> Uuid {
135 self.base.id
136 }
137
138 fn data_point_type(&self) -> &str {
139 &self.base.data_type
140 }
141
142 fn relationships(&self) -> Vec<Relationship> {
143 Vec::new()
145 }
146}
147
148impl GraphExtractable for TextSummary {
149 fn data_point_id(&self) -> Uuid {
150 self.base.id
151 }
152
153 fn data_point_type(&self) -> &str {
154 &self.base.data_type
155 }
156
157 fn relationships(&self) -> Vec<Relationship> {
158 let mut rels = Vec::new();
159
160 if let Some(chunk_id) = self.made_from {
162 rels.push(Relationship {
163 field_name: "made_from".to_string(),
164 target_id: chunk_id,
165 });
166 }
167
168 rels
169 }
170}
171
172pub fn get_graph_from_model(items: &[&dyn GraphExtractable]) -> Vec<EdgeData> {
185 let mut edges: Vec<EdgeData> = Vec::new();
186 let mut seen: HashSet<(String, String, String)> = HashSet::new();
187 let now = Utc::now().to_rfc3339();
188
189 for item in items {
190 for rel in item.relationships() {
191 let source = item.data_point_id().to_string();
192 let target = rel.target_id.to_string();
193 let key = (source.clone(), target.clone(), rel.field_name.clone());
194
195 if seen.insert(key) {
196 edges.push((
197 source,
198 target,
199 rel.field_name,
200 HashMap::from([(Cow::from("updated_at"), json!(now.clone()))]),
201 ));
202 }
203 }
204 }
205
206 edges
207}
208
209#[cfg(test)]
214mod tests {
215 use super::*;
216
217 #[test]
218 fn test_document_chunk_relationships() {
219 let doc_id = Uuid::new_v4();
220 let chunk = DocumentChunk::new(
221 Uuid::new_v4(),
222 "test text".to_string(),
223 2,
224 0,
225 "paragraph_end".to_string(),
226 doc_id,
227 );
228
229 let rels = chunk.relationships();
230 assert_eq!(rels.len(), 1);
231 assert_eq!(rels[0].field_name, "is_part_of");
232 assert_eq!(rels[0].target_id, doc_id);
233 }
234
235 #[test]
236 fn test_document_chunk_with_contains() {
237 let doc_id = Uuid::new_v4();
238 let entity_id = Uuid::new_v4();
239 let mut chunk = DocumentChunk::new(
240 Uuid::new_v4(),
241 "test text".to_string(),
242 2,
243 0,
244 "paragraph_end".to_string(),
245 doc_id,
246 );
247 chunk.contains = vec![json!(entity_id.to_string())];
248
249 let rels = chunk.relationships();
250 assert_eq!(rels.len(), 2);
251 assert_eq!(rels[0].field_name, "is_part_of");
252 assert_eq!(rels[1].field_name, "contains");
253 assert_eq!(rels[1].target_id, entity_id);
254 }
255
256 #[test]
257 fn test_document_has_no_relationships_and_id_matches_data() {
258 use cognee_models::{Data, classify_documents};
259
260 let data = Data::builder(
261 Uuid::new_v4(),
262 "test.txt",
263 "/storage/test",
264 "file:///storage/test.txt",
265 "txt",
266 "text/plain",
267 "hash123",
268 Uuid::new_v4(),
269 )
270 .build();
271 let data_id = data.id;
272 let docs = classify_documents(std::slice::from_ref(&data));
273 assert_eq!(docs.len(), 1);
274 let doc = &docs[0];
275
276 assert_eq!(doc.data_point_id(), data_id);
278 assert_eq!(doc.data_point_type(), "TextDocument");
280 assert!(doc.relationships().is_empty());
282 }
283
284 #[test]
285 fn test_entity_relationships() {
286 let type_id = Uuid::new_v4();
287 let entity = Entity::new("TechCorp", Some(type_id), "A company", None);
288
289 let rels = entity.relationships();
290 assert_eq!(rels.len(), 1);
291 assert_eq!(rels[0].field_name, "is_a");
292 assert_eq!(rels[0].target_id, type_id);
293 }
294
295 #[test]
296 fn test_entity_no_type_no_relationships() {
297 let entity = Entity::new("TechCorp", None, "A company", None);
298
299 let rels = entity.relationships();
300 assert!(rels.is_empty());
301 }
302
303 #[test]
304 fn test_entity_type_no_relationships() {
305 let et = EntityType::new("Organization", "A company type", None);
306
307 let rels = et.relationships();
308 assert!(rels.is_empty());
309 }
310
311 #[test]
312 fn test_text_summary_relationships() {
313 let chunk_id = Uuid::new_v4();
314 let summary = TextSummary::new(
315 chunk_id,
316 "Summary text".to_string(),
317 None,
318 "gpt-4".to_string(),
319 );
320
321 let rels = summary.relationships();
322 assert_eq!(rels.len(), 1);
323 assert_eq!(rels[0].field_name, "made_from");
324 assert_eq!(rels[0].target_id, chunk_id);
325 }
326
327 #[test]
328 fn test_get_graph_from_model_basic() {
329 let doc_id = Uuid::new_v4();
330 let chunk = DocumentChunk::new(
331 Uuid::new_v4(),
332 "test".to_string(),
333 1,
334 0,
335 "paragraph_end".to_string(),
336 doc_id,
337 );
338
339 let items: Vec<&dyn GraphExtractable> = vec![&chunk];
340 let edges = get_graph_from_model(&items);
341
342 assert_eq!(edges.len(), 1);
343 assert_eq!(edges[0].0, chunk.base.id.to_string());
344 assert_eq!(edges[0].1, doc_id.to_string());
345 assert_eq!(edges[0].2, "is_part_of");
346 assert!(edges[0].3.contains_key(&Cow::from("updated_at")));
347 }
348
349 #[test]
350 fn test_get_graph_from_model_deduplication() {
351 let doc_id = Uuid::new_v4();
352 let chunk_id = Uuid::new_v4();
353 let chunk = DocumentChunk::new(
354 chunk_id,
355 "test".to_string(),
356 1,
357 0,
358 "paragraph_end".to_string(),
359 doc_id,
360 );
361
362 let items: Vec<&dyn GraphExtractable> = vec![&chunk, &chunk];
364 let edges = get_graph_from_model(&items);
365
366 assert_eq!(edges.len(), 1);
367 }
368
369 #[test]
370 fn test_get_graph_from_model_multiple_types() {
371 let doc_id = Uuid::new_v4();
372 let type_id = Uuid::new_v4();
373 let chunk_id = Uuid::new_v4();
374
375 let chunk = DocumentChunk::new(
376 chunk_id,
377 "test".to_string(),
378 1,
379 0,
380 "paragraph_end".to_string(),
381 doc_id,
382 );
383
384 let entity = Entity::new("TechCorp", Some(type_id), "A company", None);
385 let entity_type = EntityType::new("Organization", "A type", None);
386
387 let summary = TextSummary::new(chunk_id, "Summary".to_string(), None, "gpt-4".to_string());
388
389 let items: Vec<&dyn GraphExtractable> = vec![&chunk, &entity, &entity_type, &summary];
390 let edges = get_graph_from_model(&items);
391
392 assert_eq!(edges.len(), 3);
397
398 let edge_names: Vec<&str> = edges.iter().map(|e| e.2.as_str()).collect();
399 assert!(edge_names.contains(&"is_part_of"));
400 assert!(edge_names.contains(&"is_a"));
401 assert!(edge_names.contains(&"made_from"));
402 }
403
404 #[test]
405 fn test_get_graph_from_model_empty() {
406 let items: Vec<&dyn GraphExtractable> = vec![];
407 let edges = get_graph_from_model(&items);
408 assert!(edges.is_empty());
409 }
410
411 #[test]
412 fn test_get_graph_from_model_contains_edges() {
413 let doc_id = Uuid::new_v4();
414 let entity_id_1 = Uuid::new_v4();
415 let entity_id_2 = Uuid::new_v4();
416
417 let mut chunk = DocumentChunk::new(
418 Uuid::new_v4(),
419 "test".to_string(),
420 1,
421 0,
422 "paragraph_end".to_string(),
423 doc_id,
424 );
425 chunk.contains = vec![
426 json!(entity_id_1.to_string()),
427 json!(entity_id_2.to_string()),
428 ];
429
430 let items: Vec<&dyn GraphExtractable> = vec![&chunk];
431 let edges = get_graph_from_model(&items);
432
433 assert_eq!(edges.len(), 3);
435
436 let contains_edges: Vec<_> = edges.iter().filter(|e| e.2 == "contains").collect();
437 assert_eq!(contains_edges.len(), 2);
438 }
439
440 #[test]
441 fn test_relationship_equality() {
442 let id = Uuid::new_v4();
443 let r1 = Relationship {
444 field_name: "is_a".to_string(),
445 target_id: id,
446 };
447 let r2 = Relationship {
448 field_name: "is_a".to_string(),
449 target_id: id,
450 };
451 assert_eq!(r1, r2);
452 }
453
454 #[test]
455 fn test_data_point_type_names() {
456 let chunk = DocumentChunk::new(
457 Uuid::new_v4(),
458 "t".to_string(),
459 1,
460 0,
461 "word".to_string(),
462 Uuid::new_v4(),
463 );
464 assert_eq!(chunk.data_point_type(), "DocumentChunk");
465
466 let entity = Entity::new("Test", None, "desc", None);
467 assert_eq!(entity.data_point_type(), "Entity");
468
469 let et = EntityType::new("Type", "desc", None);
470 assert_eq!(et.data_point_type(), "EntityType");
471
472 let summary = TextSummary::new(Uuid::new_v4(), "s".to_string(), None, "model".to_string());
473 assert_eq!(summary.data_point_type(), "TextSummary");
474 }
475
476 #[test]
477 fn test_invalid_uuid_in_contains_is_skipped() {
478 let doc_id = Uuid::new_v4();
479 let mut chunk = DocumentChunk::new(
480 Uuid::new_v4(),
481 "test".to_string(),
482 1,
483 0,
484 "paragraph_end".to_string(),
485 doc_id,
486 );
487 chunk.contains = vec![json!("not-a-valid-uuid")];
489
490 let rels = chunk.relationships();
491 assert_eq!(rels.len(), 1);
493 assert_eq!(rels[0].field_name, "is_part_of");
494 }
495
496 #[test]
497 fn test_non_string_in_contains_is_skipped() {
498 let doc_id = Uuid::new_v4();
499 let mut chunk = DocumentChunk::new(
500 Uuid::new_v4(),
501 "test".to_string(),
502 1,
503 0,
504 "paragraph_end".to_string(),
505 doc_id,
506 );
507 chunk.contains = vec![json!(42)];
509
510 let rels = chunk.relationships();
511 assert_eq!(rels.len(), 1);
512 assert_eq!(rels[0].field_name, "is_part_of");
513 }
514}