Skip to main content

cognee_models/
document.rs

1use serde::{Deserialize, Serialize};
2use serde_json::json;
3use uuid::Uuid;
4
5use crate::Data;
6use crate::DataPoint;
7use crate::has_datapoint::HasDataPoint;
8
9/// A classified document derived from a Data item.
10///
11/// Mirrors the Python `Document` class hierarchy. In Python, each document type
12/// is a separate class (TextDocument, PdfDocument, etc.). In Rust we use a single
13/// struct with a `document_type` field and the `base.data_type` discriminator
14/// set to the class name (e.g. "TextDocument", "PdfDocument").
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct Document {
17    /// DataPoint base — carries id, timestamps, metadata, data_type discriminator.
18    #[serde(flatten)]
19    pub base: DataPoint,
20    /// Document type category: "text", "pdf", "csv", "html", "image", "audio", "unstructured", "dlt_row".
21    pub document_type: String,
22    pub name: String,
23    pub raw_data_location: String,
24    pub mime_type: String,
25    pub extension: String,
26    /// Reference back to the source Data record.
27    pub data_id: Uuid,
28    /// Pretty-printed external metadata JSON, if any.
29    pub external_metadata: Option<String>,
30}
31
32/// Map a file extension to a document type string, returning `None` for
33/// unrecognised extensions.
34///
35/// Public wrapper over the internal [`extension_to_doc_type`] mapping so the
36/// ingestion pipeline can pick the right loader at ADD time using the same
37/// extension → document-type table that [`classify_documents`] uses at cognify
38/// time.
39pub fn doc_type_for_extension(ext: &str) -> Option<&'static str> {
40    extension_to_doc_type(ext)
41}
42
43/// Map a file extension to a document type string.
44///
45/// Matches the 39-entry `EXTENSION_TO_DOCUMENT_CLASS` mapping in the Python SDK
46/// (`cognee/tasks/documents/classify_documents.py`).
47fn extension_to_doc_type(ext: &str) -> Option<&'static str> {
48    match ext.to_lowercase().as_str() {
49        "pdf" => Some("pdf"),
50        "txt" => Some("text"),
51        "csv" => Some("csv"),
52        "docx" | "doc" | "odt" | "xls" | "xlsx" | "ppt" | "pptx" | "odp" | "ods" => {
53            Some("unstructured")
54        }
55        "png" | "dwg" | "xcf" | "jpg" | "jpx" | "apng" | "gif" | "webp" | "cr2" | "tif" | "bmp"
56        | "jxr" | "psd" | "ico" | "heic" | "avif" => Some("image"),
57        "aac" | "mid" | "mp3" | "m4a" | "ogg" | "flac" | "wav" | "amr" | "aiff" => Some("audio"),
58        // HTML — handled by the BeautifulSoup-equivalent loader. Note: Python's
59        // `EXTENSION_TO_DOCUMENT_CLASS` has no `html` entry because its
60        // BeautifulSoupLoader runs at add-time and stores extracted text as a
61        // TextDocument. Rust runs loaders at cognify-time keyed on
62        // `document_type`, so we classify html/htm to a dedicated "html"
63        // document type while keeping the `TextDocument` class discriminator
64        // (see `doc_type_to_class_name`) for cross-SDK DB parity.
65        "html" | "htm" => Some("html"),
66        _ => None,
67    }
68}
69
70/// Return the `data_type` discriminator (Python class name) for a document type.
71fn doc_type_to_class_name(doc_type: &str) -> &'static str {
72    match doc_type {
73        "text" => "TextDocument",
74        "pdf" => "PdfDocument",
75        "csv" => "CsvDocument",
76        "image" => "ImageDocument",
77        "audio" => "AudioDocument",
78        // HTML content becomes text; Python stores it as a TextDocument, so we
79        // match that node `data_type` for cross-SDK graph parity.
80        "html" => "TextDocument",
81        "unstructured" => "UnstructuredDocument",
82        "dlt_row" => "DltRowDocument",
83        _ => "Document",
84    }
85}
86
87/// Check whether the `external_metadata` JSON indicates a DLT source.
88///
89/// Mirrors Python `cognee/tasks/ingestion/dlt_utils.py:is_dlt_sourced`.
90fn is_dlt_sourced(external_metadata: &Option<String>) -> bool {
91    external_metadata
92        .as_ref()
93        .and_then(|m| serde_json::from_str::<serde_json::Value>(m).ok())
94        .map(|v| metadata_value_is_dlt_sourced(&v))
95        .unwrap_or(false)
96}
97
98fn metadata_value_is_dlt_sourced(value: &serde_json::Value) -> bool {
99    value
100        .get("source")
101        .and_then(|source| source.as_str())
102        .map(|source| source == "dlt")
103        .unwrap_or(false)
104        || value
105            .get("data_item_external_metadata")
106            .map(metadata_value_is_dlt_sourced)
107            .unwrap_or(false)
108}
109
110/// Classify Data items into Documents based on file extension.
111///
112/// Mirrors the Python `classify_documents` function. DLT-sourced items are
113/// classified as `DltRowDocument`; all others use the extension-to-document-type
114/// mapping. Items with unrecognised extensions are silently skipped.
115pub fn classify_documents(data_items: &[Data]) -> Vec<Document> {
116    data_items
117        .iter()
118        .filter_map(|data| {
119            // DLT detection takes priority
120            let doc_type = if is_dlt_sourced(&data.external_metadata) {
121                "dlt_row"
122            } else {
123                extension_to_doc_type(&data.extension)?
124            };
125
126            let class_name = doc_type_to_class_name(doc_type);
127            let mut base = DataPoint::new(class_name, None);
128            base.id = data.id; // use Data's deterministic ID
129            base.set_metadata("index_fields", json!(["name"]));
130
131            // Format external_metadata as indented JSON (Python does json.dumps(..., indent=4))
132            let formatted_metadata = data.external_metadata.as_ref().and_then(|m| {
133                let v: serde_json::Value = serde_json::from_str(m).ok()?;
134                serde_json::to_string_pretty(&v).ok()
135            });
136
137            let mut doc = Document {
138                base,
139                document_type: doc_type.to_string(),
140                name: data.name.clone(),
141                raw_data_location: data.raw_data_location.clone(),
142                mime_type: data.mime_type.clone(),
143                extension: data.extension.clone(),
144                data_id: data.id,
145                external_metadata: formatted_metadata.or(data.external_metadata.clone()),
146            };
147
148            // update_node_set: parse external_metadata for node_set array
149            // Mirrors Python cognee/tasks/documents/classify_documents.py:update_node_set()
150            if let Some(ref meta_str) = doc.external_metadata
151                && let Ok(meta_val) = serde_json::from_str::<serde_json::Value>(meta_str)
152                && let Some(node_set_array) = meta_val.get("node_set").and_then(|v| v.as_array())
153            {
154                // Build NodeSet-like JSON values with deterministic IDs
155                // Python: NodeSet(id=generate_node_id(f"NodeSet:{name}"), name=name)
156                let node_set_values: Vec<serde_json::Value> = node_set_array
157                    .iter()
158                    .filter_map(|v| {
159                        let name = v.as_str()?;
160                        let key = format!("NodeSet:{name}")
161                            .to_lowercase()
162                            .replace(' ', "_")
163                            .replace('\'', "");
164                        let id = uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, key.as_bytes());
165                        Some(json!({
166                            "id": id.to_string(),
167                            "name": name,
168                            "type": "NodeSet"
169                        }))
170                    })
171                    .collect();
172
173                if !node_set_values.is_empty() {
174                    // source_node_set = comma-separated names (Python: ", ".join(node_set))
175                    let names: Vec<&str> =
176                        node_set_array.iter().filter_map(|v| v.as_str()).collect();
177                    doc.base.source_node_set = Some(names.join(", "));
178                    doc.base.belongs_to_set = Some(node_set_values);
179                }
180            }
181
182            Some(doc)
183        })
184        .collect()
185}
186
187impl HasDataPoint for Document {
188    fn data_point(&self) -> &DataPoint {
189        &self.base
190    }
191    fn data_point_mut(&mut self) -> &mut DataPoint {
192        &mut self.base
193    }
194    // for_each_child_mut: default no-op — Document has no nested
195    // DataPoint-bearing fields (links to its source `Data` by `data_id: Uuid`).
196}
197
198#[cfg(test)]
199#[allow(
200    clippy::unwrap_used,
201    clippy::expect_used,
202    reason = "test code — panics are acceptable failures"
203)]
204mod tests {
205    use super::*;
206
207    fn make_data(mime_type: &str, extension: &str) -> Data {
208        Data::builder(
209            Uuid::new_v4(),
210            format!("test.{extension}"),
211            "/storage/test",
212            "text://test",
213            extension,
214            mime_type,
215            "hash123",
216            Uuid::new_v4(),
217        )
218        .build()
219    }
220
221    fn make_data_with_metadata(mime_type: &str, extension: &str, metadata: &str) -> Data {
222        Data::builder(
223            Uuid::new_v4(),
224            format!("test.{extension}"),
225            "/storage/test",
226            "text://test",
227            extension,
228            mime_type,
229            "hash123",
230            Uuid::new_v4(),
231        )
232        .external_metadata(metadata)
233        .build()
234    }
235
236    // ----- Extension-based classification tests -----
237
238    #[test]
239    fn classifies_text_plain() {
240        let data = vec![make_data("text/plain", "txt")];
241        let docs = classify_documents(&data);
242        assert_eq!(docs.len(), 1);
243        assert_eq!(docs[0].document_type, "text");
244        assert_eq!(docs[0].base.data_type, "TextDocument");
245        assert_eq!(docs[0].mime_type, "text/plain");
246        assert_eq!(docs[0].data_id, data[0].id);
247        assert_eq!(docs[0].base.id, data[0].id);
248        assert_eq!(docs[0].base.data_type, "TextDocument");
249        assert_eq!(
250            docs[0].base.get_metadata("index_fields"),
251            Some(&serde_json::json!(["name"]))
252        );
253    }
254
255    #[test]
256    fn classifies_extracted_html_url_text_as_text_document() {
257        let data = vec![
258            Data::builder(
259                Uuid::new_v4(),
260                "text_hash",
261                "file:///storage/text_hash.txt",
262                "file:///storage/source.html",
263                "txt",
264                "text/plain",
265                "hash123",
266                Uuid::new_v4(),
267            )
268            .original_extension("html")
269            .original_mime_type("text/html")
270            .loader_engine("beautiful_soup_loader")
271            .external_metadata(
272                r#"{"source":"url","url":"https://example.test","content_type":"text/html"}"#,
273            )
274            .build(),
275        ];
276
277        let docs = classify_documents(&data);
278
279        assert_eq!(docs.len(), 1);
280        assert_eq!(docs[0].document_type, "text");
281        assert_eq!(docs[0].base.data_type, "TextDocument");
282        assert_eq!(docs[0].extension, "txt");
283        assert_eq!(docs[0].mime_type, "text/plain");
284    }
285
286    #[test]
287    fn classifies_pdf() {
288        let data = vec![make_data("application/pdf", "pdf")];
289        let docs = classify_documents(&data);
290        assert_eq!(docs.len(), 1);
291        assert_eq!(docs[0].document_type, "pdf");
292        assert_eq!(docs[0].base.data_type, "PdfDocument");
293    }
294
295    #[test]
296    fn classifies_csv() {
297        let data = vec![make_data("text/csv", "csv")];
298        let docs = classify_documents(&data);
299        assert_eq!(docs.len(), 1);
300        assert_eq!(docs[0].document_type, "csv");
301        assert_eq!(docs[0].base.data_type, "CsvDocument");
302    }
303
304    #[test]
305    fn classifies_image_extensions() {
306        for ext in &[
307            "png", "dwg", "xcf", "jpg", "jpx", "apng", "gif", "webp", "cr2", "tif", "bmp", "jxr",
308            "psd", "ico", "heic", "avif",
309        ] {
310            let data = vec![make_data(&format!("image/{ext}"), ext)];
311            let docs = classify_documents(&data);
312            assert_eq!(docs.len(), 1, "failed for extension: {ext}");
313            assert_eq!(
314                docs[0].document_type, "image",
315                "failed for extension: {ext}"
316            );
317            assert_eq!(
318                docs[0].base.data_type, "ImageDocument",
319                "failed for extension: {ext}"
320            );
321        }
322    }
323
324    #[test]
325    fn classifies_audio_extensions() {
326        for ext in &[
327            "aac", "mid", "mp3", "m4a", "ogg", "flac", "wav", "amr", "aiff",
328        ] {
329            let data = vec![make_data(&format!("audio/{ext}"), ext)];
330            let docs = classify_documents(&data);
331            assert_eq!(docs.len(), 1, "failed for extension: {ext}");
332            assert_eq!(
333                docs[0].document_type, "audio",
334                "failed for extension: {ext}"
335            );
336            assert_eq!(
337                docs[0].base.data_type, "AudioDocument",
338                "failed for extension: {ext}"
339            );
340        }
341    }
342
343    #[test]
344    fn classifies_unstructured_extensions() {
345        for ext in &[
346            "docx", "doc", "odt", "xls", "xlsx", "ppt", "pptx", "odp", "ods",
347        ] {
348            let data = vec![make_data("application/octet-stream", ext)];
349            let docs = classify_documents(&data);
350            assert_eq!(docs.len(), 1, "failed for extension: {ext}");
351            assert_eq!(
352                docs[0].document_type, "unstructured",
353                "failed for extension: {ext}"
354            );
355            assert_eq!(
356                docs[0].base.data_type, "UnstructuredDocument",
357                "failed for extension: {ext}"
358            );
359        }
360    }
361
362    #[test]
363    fn classifies_html_extensions() {
364        for ext in &["html", "htm"] {
365            let data = vec![make_data("text/html", ext)];
366            let docs = classify_documents(&data);
367            assert_eq!(docs.len(), 1, "failed for extension: {ext}");
368            assert_eq!(docs[0].document_type, "html", "failed for extension: {ext}");
369            // Cross-SDK parity: Python's BeautifulSoupLoader produces a
370            // TextDocument, so HTML documents carry the TextDocument data_type.
371            assert_eq!(
372                docs[0].base.data_type, "TextDocument",
373                "failed for extension: {ext}"
374            );
375        }
376    }
377
378    // ----- Unknown extensions are skipped -----
379
380    #[test]
381    fn skips_unknown_extensions() {
382        let data = vec![make_data("application/octet-stream", "xyz")];
383        let docs = classify_documents(&data);
384        assert!(docs.is_empty());
385    }
386
387    #[test]
388    fn source_code_extensions_are_not_classified() {
389        for ext in &["py", "rs", "js", "ts", "c", "cpp", "go", "java", "rb", "sh"] {
390            let data = vec![make_data("text/plain", ext)];
391            let docs = classify_documents(&data);
392            assert!(docs.is_empty(), "extension .{ext} should not be classified");
393        }
394    }
395
396    // ----- Mixed input: only known extensions pass through -----
397
398    #[test]
399    fn mixed_input_filters_correctly() {
400        let data = vec![
401            make_data("text/plain", "txt"),
402            make_data("application/octet-stream", "xyz"),
403            make_data("application/pdf", "pdf"),
404            make_data("image/png", "png"),
405            make_data("audio/mp3", "mp3"),
406        ];
407        let docs = classify_documents(&data);
408        assert_eq!(docs.len(), 4);
409        assert_eq!(docs[0].document_type, "text");
410        assert_eq!(docs[1].document_type, "pdf");
411        assert_eq!(docs[2].document_type, "image");
412        assert_eq!(docs[3].document_type, "audio");
413    }
414
415    // ----- DLT detection -----
416
417    #[test]
418    fn classifies_dlt_sourced_data() {
419        let data = vec![make_data_with_metadata(
420            "text/plain",
421            "txt",
422            r#"{"source": "dlt"}"#,
423        )];
424        let docs = classify_documents(&data);
425        assert_eq!(docs.len(), 1);
426        assert_eq!(docs[0].document_type, "dlt_row");
427        assert_eq!(docs[0].base.data_type, "DltRowDocument");
428    }
429
430    #[test]
431    fn dlt_detection_with_unknown_extension() {
432        // DLT sourced items should be classified even with unknown extensions
433        let data = vec![make_data_with_metadata(
434            "application/octet-stream",
435            "xyz",
436            r#"{"source": "dlt"}"#,
437        )];
438        let docs = classify_documents(&data);
439        assert_eq!(docs.len(), 1);
440        assert_eq!(docs[0].document_type, "dlt_row");
441    }
442
443    #[test]
444    fn dlt_detection_survives_url_metadata_merge_conflict() {
445        let data = vec![make_data_with_metadata(
446            "text/plain",
447            "txt",
448            r#"{"source":"url","url":"https://example.test","data_item_external_metadata":{"source":"dlt","table":"events"}}"#,
449        )];
450        let docs = classify_documents(&data);
451        assert_eq!(docs.len(), 1);
452        assert_eq!(docs[0].document_type, "dlt_row");
453        assert_eq!(docs[0].base.data_type, "DltRowDocument");
454    }
455
456    #[test]
457    fn non_dlt_metadata_does_not_affect_classification() {
458        let data = vec![make_data_with_metadata(
459            "text/plain",
460            "txt",
461            r#"{"source": "other"}"#,
462        )];
463        let docs = classify_documents(&data);
464        assert_eq!(docs.len(), 1);
465        assert_eq!(docs[0].document_type, "text");
466    }
467
468    // ----- External metadata formatting -----
469
470    #[test]
471    fn formats_external_metadata_as_pretty_json() {
472        let data = vec![make_data_with_metadata(
473            "text/plain",
474            "txt",
475            r#"{"key":"value","nested":{"a":1}}"#,
476        )];
477        let docs = classify_documents(&data);
478        assert_eq!(docs.len(), 1);
479        let meta = docs[0].external_metadata.as_ref().unwrap();
480        // Pretty-printed JSON should contain newlines and indentation
481        assert!(meta.contains('\n'));
482        assert!(meta.contains("  "));
483    }
484
485    #[test]
486    fn invalid_json_metadata_passed_through_as_is() {
487        let data = vec![make_data_with_metadata("text/plain", "txt", "not-json")];
488        let docs = classify_documents(&data);
489        assert_eq!(docs.len(), 1);
490        // Invalid JSON can't be pretty-printed, so original is kept
491        assert_eq!(docs[0].external_metadata.as_ref().unwrap(), "not-json");
492    }
493
494    // ----- DataPoint base -----
495
496    #[test]
497    fn document_has_index_fields_metadata() {
498        let data = vec![make_data("text/plain", "txt")];
499        let docs = classify_documents(&data);
500        assert_eq!(
501            docs[0].base.get_metadata("index_fields"),
502            Some(&json!(["name"]))
503        );
504    }
505
506    #[test]
507    fn document_id_matches_data_id() {
508        let data = vec![make_data("text/plain", "txt")];
509        let docs = classify_documents(&data);
510        assert_eq!(docs[0].base.id, data[0].id);
511        assert_eq!(docs[0].data_id, data[0].id);
512    }
513
514    // ----- Empty input -----
515
516    #[test]
517    fn empty_input() {
518        let docs = classify_documents(&[]);
519        assert!(docs.is_empty());
520    }
521
522    // ----- Case insensitivity -----
523
524    #[test]
525    fn extension_matching_is_case_insensitive() {
526        assert_eq!(extension_to_doc_type("PDF"), Some("pdf"));
527        assert_eq!(extension_to_doc_type("Txt"), Some("text"));
528        assert_eq!(extension_to_doc_type("PNG"), Some("image"));
529        assert_eq!(extension_to_doc_type("MP3"), Some("audio"));
530    }
531
532    // ----- NodeSet handling (update_node_set) -----
533
534    #[test]
535    fn node_set_populates_belongs_to_set_and_source_node_set() {
536        let data = vec![make_data_with_metadata(
537            "text/plain",
538            "txt",
539            r#"{"node_set": ["setA", "setB"]}"#,
540        )];
541        let docs = classify_documents(&data);
542        assert_eq!(docs.len(), 1);
543
544        // belongs_to_set should have two NodeSet entries
545        let bts = docs[0].base.belongs_to_set.as_ref().unwrap();
546        assert_eq!(bts.len(), 2);
547
548        // Each entry should have id, name, type
549        assert_eq!(bts[0]["name"], "setA");
550        assert_eq!(bts[0]["type"], "NodeSet");
551        assert_eq!(bts[1]["name"], "setB");
552        assert_eq!(bts[1]["type"], "NodeSet");
553
554        // IDs should be deterministic UUID5
555        let key_a = "nodeset:seta"; // lowercased, spaces→underscores
556        let expected_id_a =
557            uuid::Uuid::new_v5(&uuid::Uuid::NAMESPACE_OID, key_a.as_bytes()).to_string();
558        assert_eq!(bts[0]["id"], expected_id_a);
559
560        // source_node_set should be comma-separated names
561        assert_eq!(docs[0].base.source_node_set.as_ref().unwrap(), "setA, setB");
562    }
563
564    #[test]
565    fn node_set_single_entry() {
566        let data = vec![make_data_with_metadata(
567            "text/plain",
568            "txt",
569            r#"{"node_set": ["only_one"]}"#,
570        )];
571        let docs = classify_documents(&data);
572        assert_eq!(docs.len(), 1);
573
574        let bts = docs[0].base.belongs_to_set.as_ref().unwrap();
575        assert_eq!(bts.len(), 1);
576        assert_eq!(bts[0]["name"], "only_one");
577
578        assert_eq!(docs[0].base.source_node_set.as_ref().unwrap(), "only_one");
579    }
580
581    #[test]
582    fn no_node_set_key_leaves_belongs_to_set_unset() {
583        let data = vec![make_data_with_metadata(
584            "text/plain",
585            "txt",
586            r#"{"other_key": "value"}"#,
587        )];
588        let docs = classify_documents(&data);
589        assert_eq!(docs.len(), 1);
590        assert!(docs[0].base.belongs_to_set.is_none());
591        assert!(docs[0].base.source_node_set.is_none());
592    }
593
594    #[test]
595    fn node_set_not_array_leaves_belongs_to_set_unset() {
596        let data = vec![make_data_with_metadata(
597            "text/plain",
598            "txt",
599            r#"{"node_set": "not_an_array"}"#,
600        )];
601        let docs = classify_documents(&data);
602        assert_eq!(docs.len(), 1);
603        assert!(docs[0].base.belongs_to_set.is_none());
604        assert!(docs[0].base.source_node_set.is_none());
605    }
606
607    #[test]
608    fn node_set_empty_array_leaves_belongs_to_set_unset() {
609        let data = vec![make_data_with_metadata(
610            "text/plain",
611            "txt",
612            r#"{"node_set": []}"#,
613        )];
614        let docs = classify_documents(&data);
615        assert_eq!(docs.len(), 1);
616        // Empty array produces no NodeSet values, so stays unset
617        assert!(docs[0].base.belongs_to_set.is_none());
618        assert!(docs[0].base.source_node_set.is_none());
619    }
620
621    #[test]
622    fn node_set_with_no_metadata_leaves_belongs_to_set_unset() {
623        let data = vec![make_data("text/plain", "txt")];
624        let docs = classify_documents(&data);
625        assert_eq!(docs.len(), 1);
626        assert!(docs[0].base.belongs_to_set.is_none());
627        assert!(docs[0].base.source_node_set.is_none());
628    }
629
630    #[test]
631    fn document_implements_has_datapoint() {
632        let data = vec![make_data("text/plain", "txt")];
633        let docs = classify_documents(&data);
634        assert_eq!(docs.len(), 1);
635        let dp_id = docs[0].base.id;
636        assert_eq!(docs[0].data_point().id, dp_id);
637        let mut doc = docs[0].clone();
638        assert_eq!(doc.data_point_mut().id, dp_id);
639    }
640}