Skip to main content

cognee_cognify/
tasks.rs

1//! Cognify pipeline tasks — individual steps of the cognify process.
2//!
3//! Matches the Python SDK task breakdown:
4//! 1. [`classify_documents`] — Data items → Documents
5//! 2. [`extract_chunks_from_documents`] — Documents → DocumentChunks
6//! 3. [`extract_graph_from_data`] — Chunks → Chunks + entities/edges (stored in graph DB)
7//! 4. [`summarize_text`] — + summaries via LLM
8//! 5. [`add_data_points`] — embeddings + vector indexing → [`CognifyResult`]
9//!
10//! Temporal pipeline variant:
11//! 1. [`classify_documents`] — same
12//! 2. [`extract_chunks_from_documents`] — same
13//! 3. [`extract_temporal_events`] — Chunks → TemporalEvents (via two LLM passes)
14//! 4. [`add_temporal_data_points`] — persists events, timestamps, intervals, entities → graph+vector
15//!
16//! Public surface:
17//! - Intermediate types: [`CognifyInput`], [`ClassifiedDocuments`],
18//!   [`ExtractedChunks`], [`ExtractedGraphData`], [`SummarizedData`],
19//!   [`ExtractedTemporalEvents`]
20//! - Task implementations (free functions)
21//! - [`TypedTask`] factories: [`make_classify_documents_task`], etc.
22//! - Pipeline builders: [`build_cognify_pipeline`], [`build_temporal_cognify_pipeline`]
23
24use std::borrow::Cow;
25use std::collections::{HashMap, HashSet};
26use std::sync::Arc;
27
28use chrono::Utc;
29use cognee_chunking::{CutType, NAMESPACE_OID, TokenCounterKind, chunk_by_row, chunk_text};
30use cognee_core::pipeline_run_registry::DbPipelineWatcher;
31use cognee_core::{
32    CpuPool, Pipeline, PipelineBuilder, PipelineContext, TaskContextBuilder, TypedTask, Value,
33};
34use cognee_database::{DatabaseConnection, PipelineRunRepository};
35use cognee_embedding::engine::EmbeddingEngine;
36use cognee_graph::{EdgeData, GraphDBTrait, GraphDBTraitExt};
37#[cfg(feature = "audio-loader")]
38use cognee_ingestion::loaders::audio::AudioLoader;
39#[cfg(feature = "image-loader")]
40use cognee_ingestion::loaders::image::ImageLoader;
41use cognee_ingestion::loaders::{LoaderOutput, LoaderRegistry};
42use cognee_llm::Llm;
43use cognee_models::{
44    Data, Document, DocumentChunk, EdgeType, Embedding, TemporalEvent,
45    classify_documents as model_classify_documents,
46};
47use cognee_ontology::OntologyResolver;
48use cognee_storage::StorageTrait;
49use cognee_vector::{VectorDB, VectorPoint};
50use serde::Serialize;
51use serde_json::json;
52use tokio::sync::Semaphore;
53use tracing::{info, warn};
54use url::Url;
55use uuid::Uuid;
56
57use crate::config::CognifyConfig;
58use crate::error::CognifyError;
59use crate::fact_extraction::{FactExtractor, KnowledgeGraph};
60use crate::graph_integration::{
61    GraphEdgePair, GraphNodePair, deduplicate_nodes_and_edges, expand_with_nodes_and_edges,
62    retrieve_existing_edges,
63};
64use crate::pipeline::{CognifyResult, IndexedFieldsStats};
65use crate::qualification::{Qualification, check_pipeline_run_qualification};
66use crate::summarization::{SummaryExtractor, TextSummary};
67use crate::temporal_extraction::{TemporalEntityEnricher, TemporalEventExtractor};
68use cognee_models::DataPoint;
69
70// ---------------------------------------------------------------------------
71// Intermediate types
72// ---------------------------------------------------------------------------
73
74/// Input to the cognify pipeline.
75///
76/// Wraps all data items for a dataset along with the dataset identifier
77/// and optional user/tenant context.
78#[derive(Debug, Clone)]
79pub struct CognifyInput {
80    pub data_items: Vec<Data>,
81    pub dataset_id: Uuid,
82    /// Optional user ID (owner of the pipeline run).
83    pub user_id: Option<Uuid>,
84    /// Optional tenant ID for multi-tenant isolation.
85    pub tenant_id: Option<Uuid>,
86}
87
88/// Output of [`classify_documents`]: classified documents ready for chunking.
89#[derive(Debug, Clone)]
90pub struct ClassifiedDocuments {
91    pub documents: Vec<Document>,
92    pub dataset_id: Uuid,
93    pub user_id: Option<Uuid>,
94    pub tenant_id: Option<Uuid>,
95}
96
97/// Output of [`extract_chunks_from_documents`]: text chunks ready for graph extraction.
98#[derive(Debug, Clone)]
99pub struct ExtractedChunks {
100    pub chunks: Vec<DocumentChunk>,
101    /// Classified documents — carried forward so downstream tasks (e.g. DLT
102    /// filtering in [`extract_graph_from_data`]) can inspect document metadata.
103    pub documents: Vec<Document>,
104    pub dataset_id: Uuid,
105    pub user_id: Option<Uuid>,
106    pub tenant_id: Option<Uuid>,
107}
108
109/// Output of [`extract_graph_from_data`]: chunks plus extracted entities and edges
110/// (already stored in graph DB).
111#[derive(Debug, Clone)]
112pub struct ExtractedGraphData {
113    pub chunks: Vec<DocumentChunk>,
114    /// Classified documents — carried forward for DLT FK edge extraction.
115    pub documents: Vec<Document>,
116    pub entities: Vec<GraphNodePair>,
117    pub edges: Vec<GraphEdgePair>,
118    pub dataset_id: Uuid,
119    pub user_id: Option<Uuid>,
120    pub tenant_id: Option<Uuid>,
121}
122
123/// Output of [`summarize_text`]: graph data plus generated summaries.
124#[derive(Debug, Clone)]
125pub struct SummarizedData {
126    pub chunks: Vec<DocumentChunk>,
127    /// Classified documents — carried forward for DLT FK edge extraction.
128    pub documents: Vec<Document>,
129    pub entities: Vec<GraphNodePair>,
130    pub edges: Vec<GraphEdgePair>,
131    pub summaries: Vec<TextSummary>,
132    pub dataset_id: Uuid,
133    pub user_id: Option<Uuid>,
134    pub tenant_id: Option<Uuid>,
135}
136
137/// Output of [`extract_temporal_events`]: temporal events extracted from chunks
138/// via two LLM passes (event extraction + entity enrichment).
139///
140/// Used as the intermediate type between Task 3 and Task 4 in the temporal pipeline.
141#[derive(Debug, Clone)]
142pub struct ExtractedTemporalEvents {
143    pub events: Vec<TemporalEvent>,
144    pub dataset_id: Uuid,
145    pub user_id: Option<Uuid>,
146    pub tenant_id: Option<Uuid>,
147}
148
149// ---------------------------------------------------------------------------
150// Task 1: classify_documents
151// ---------------------------------------------------------------------------
152
153/// Classify Data items into typed Documents (Task 1).
154///
155/// Maps each Data item to a Document based on mime_type.
156/// Non-text items are filtered out.
157pub fn classify_documents(input: &CognifyInput) -> Result<ClassifiedDocuments, CognifyError> {
158    let documents: Vec<Document> = model_classify_documents(&input.data_items);
159    info!(doc_count = documents.len(), "documents classified");
160    Ok(ClassifiedDocuments {
161        documents,
162        dataset_id: input.dataset_id,
163        user_id: input.user_id,
164        tenant_id: input.tenant_id,
165    })
166}
167
168// ---------------------------------------------------------------------------
169// Task 2: extract_chunks_from_documents
170// ---------------------------------------------------------------------------
171
172/// Extract text chunks from classified documents (Task 2).
173///
174/// For each document, reads content from storage and applies the
175/// word → sentence → paragraph → text chunker hierarchy.
176///
177/// When `db` is `Some`, the accumulated token count for each document
178/// is written back to the corresponding `Data` record, mirroring
179/// Python's `update_document_token_count()`.
180pub async fn extract_chunks_from_documents(
181    input: &ClassifiedDocuments,
182    storage: &dyn StorageTrait,
183    max_chunk_size: usize,
184    token_counter_kind: TokenCounterKind,
185    db: Option<&DatabaseConnection>,
186    loader_registry: &LoaderRegistry,
187) -> Result<ExtractedChunks, CognifyError> {
188    let counter = token_counter_kind
189        .build()
190        .map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
191    let mut all_chunks = Vec::new();
192
193    for document in &input.documents {
194        let content_bytes = storage
195            .retrieve(&document.raw_data_location)
196            .await
197            .map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
198
199        // ---- DLT short-circuit ----
200        // DLT documents emit exactly one chunk with cut_type="dlt_row".
201        // No word/sentence/paragraph chunking. Mirrors Python DltRowDocument.read().
202        if document.document_type == "dlt_row" {
203            let text = String::from_utf8(content_bytes)
204                .map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
205            let trimmed = text.trim();
206            if !trimmed.is_empty() {
207                let chunk_id =
208                    Uuid::new_v5(&NAMESPACE_OID, format!("{}-0", document.base.id).as_bytes());
209                let word_count = counter.count_tokens(trimmed);
210                let mut chunk = DocumentChunk::new(
211                    chunk_id,
212                    trimmed.to_string(),
213                    word_count,
214                    0, // chunk_index
215                    CutType::DltRow.to_string(),
216                    document.base.id,
217                );
218                if document.base.belongs_to_set.is_some() {
219                    chunk.base.belongs_to_set = document.base.belongs_to_set.clone();
220                }
221                // Token count write-back
222                if let Some(db) = db
223                    && let Err(e) = cognee_database::ops::data::update_data_token_count(
224                        db,
225                        document.data_id,
226                        word_count as i64,
227                    )
228                    .await
229                {
230                    warn!(data_id = %document.data_id, "Failed to update token count: {e}");
231                }
232                all_chunks.push(chunk);
233            }
234            continue;
235        }
236
237        // ---- Loader dispatch ----
238        let loader = loader_registry
239            .get(&document.document_type)
240            .ok_or_else(|| CognifyError::UnsupportedDocumentType(document.document_type.clone()))?;
241
242        let output = loader
243            .extract(&content_bytes, document)
244            .await
245            .map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
246
247        let mut chunks = match output {
248            LoaderOutput::Text(text) => {
249                chunk_text(document.base.id, &text, max_chunk_size, &counter)
250            }
251            LoaderOutput::Rows(rows) => {
252                let joined = rows.join("\n\n");
253                chunk_by_row(document.base.id, &joined, max_chunk_size, &counter)
254            }
255            LoaderOutput::SingleChunk { text, cut_type } => {
256                let chunk_id =
257                    Uuid::new_v5(&NAMESPACE_OID, format!("{}-0", document.base.id).as_bytes());
258                let word_count = counter.count_tokens(&text);
259                vec![DocumentChunk::new(
260                    chunk_id,
261                    text,
262                    word_count,
263                    0,
264                    cut_type.to_string(),
265                    document.base.id,
266                )]
267            }
268        };
269
270        // Propagate belongs_to_set from Document to each DocumentChunk
271        // Mirrors Python: document_chunk.belongs_to_set = document.belongs_to_set
272        if document.base.belongs_to_set.is_some() {
273            for chunk in &mut chunks {
274                chunk.base.belongs_to_set = document.base.belongs_to_set.clone();
275            }
276        }
277
278        // Accumulate token count and write back to the Data record.
279        // Mirrors Python: update_document_token_count(document.id, document_token_count)
280        if let Some(db) = db {
281            let document_token_count: i64 = chunks.iter().map(|c| c.chunk_size as i64).sum();
282            if let Err(e) = cognee_database::ops::data::update_data_token_count(
283                db,
284                document.data_id,
285                document_token_count,
286            )
287            .await
288            {
289                warn!(
290                    data_id = %document.data_id,
291                    "Failed to update token count: {e}"
292                );
293            }
294        }
295
296        all_chunks.extend(chunks);
297    }
298
299    info!(total_chunks = all_chunks.len(), "chunking complete");
300    Ok(ExtractedChunks {
301        chunks: all_chunks,
302        documents: input.documents.clone(),
303        dataset_id: input.dataset_id,
304        user_id: input.user_id,
305        tenant_id: input.tenant_id,
306    })
307}
308
309// ---------------------------------------------------------------------------
310// Task 3: extract_graph_from_data
311// ---------------------------------------------------------------------------
312
313/// Extract knowledge graphs from chunks via LLM, then integrate (Task 3).
314///
315/// For each chunk batch, calls the LLM to extract entities and relationships.
316/// Then integrates: expands to storage-layer types, deduplicates against
317/// existing DB entries and in-memory, and stores nodes/edges in graph DB.
318pub async fn extract_graph_from_data(
319    input: &ExtractedChunks,
320    llm: Arc<dyn Llm>,
321    graph_db: Arc<dyn GraphDBTrait>,
322    ontology_resolver: Arc<dyn OntologyResolver>,
323    config: &CognifyConfig,
324    // Optional caller-supplied provenance user label. When `Some`, used
325    // verbatim for the entity / EntityType / EdgeType pre-stamps inside
326    // `expand_with_nodes_and_edges`. When `None`, falls back to the
327    // string-form `user_id` (the only label `ExtractedChunks` carries).
328    //
329    // The pipeline-driven path threads through
330    // `PipelineContext::user_label()` here so entities arrive at the
331    // task body already stamped with the email-form label that the
332    // provenance E2E test expects (locked decision 4 of
333    // `docs/telemetry/05-datapoint-provenance.md`).
334    user_label_override: Option<&str>,
335) -> Result<ExtractedGraphData, CognifyError> {
336    if input.chunks.is_empty() {
337        return Ok(ExtractedGraphData {
338            chunks: input.chunks.clone(),
339            documents: input.documents.clone(),
340            entities: vec![],
341            edges: vec![],
342            dataset_id: input.dataset_id,
343            user_id: input.user_id,
344            tenant_id: input.tenant_id,
345        });
346    }
347
348    // Filter out DLT chunks — their graph is built deterministically by
349    // extract_dlt_fk_edges from schema metadata, not by LLM extraction.
350    // Mirrors Python: cognee/tasks/graph/extract_graph_from_data.py:148-155
351    let dlt_doc_ids: HashSet<Uuid> = input
352        .documents
353        .iter()
354        .filter(|d| d.document_type == "dlt_row")
355        .map(|d| d.base.id)
356        .collect();
357
358    let (dlt_chunks, non_dlt_chunks): (Vec<&DocumentChunk>, Vec<&DocumentChunk>) = input
359        .chunks
360        .iter()
361        .partition(|c| dlt_doc_ids.contains(&c.document_id));
362
363    if !dlt_chunks.is_empty() {
364        info!(
365            "Skipping {} DLT chunks from LLM extraction ({} non-DLT chunks remain)",
366            dlt_chunks.len(),
367            non_dlt_chunks.len()
368        );
369    }
370
371    // If only DLT chunks remain, return early with all chunks but no entities/edges
372    if non_dlt_chunks.is_empty() {
373        return Ok(ExtractedGraphData {
374            chunks: input.chunks.clone(),
375            documents: input.documents.clone(),
376            entities: vec![],
377            edges: vec![],
378            dataset_id: input.dataset_id,
379            user_id: input.user_id,
380            tenant_id: input.tenant_id,
381        });
382    }
383
384    // Collect non-DLT chunks for LLM processing
385    let chunks_for_extraction: Vec<DocumentChunk> = non_dlt_chunks.into_iter().cloned().collect();
386
387    let batch_size = config.chunks_per_batch;
388    let mut all_graphs: Vec<(Uuid, KnowledgeGraph)> = Vec::new();
389    let semaphore = Arc::new(Semaphore::new(config.max_parallel_extractions));
390
391    for (batch_idx, batch) in chunks_for_extraction.chunks(batch_size).enumerate() {
392        let fact_extractor = FactExtractor::new(Arc::clone(&llm));
393        let mut extract_tasks = Vec::new();
394        let mut chunk_ids = Vec::new();
395
396        for chunk in batch {
397            let extractor = fact_extractor.clone();
398            let text = chunk.text.clone();
399            let sem = Arc::clone(&semaphore);
400            let prompt = config.custom_extraction_prompt.clone();
401
402            chunk_ids.push(chunk.base.id);
403            extract_tasks.push(tokio::spawn(async move {
404                #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
405                let _permit = sem
406                    .acquire()
407                    .await
408                    .expect("semaphore is never closed; created locally in this function");
409                extractor.extract_facts(&text, prompt.as_deref()).await
410            }));
411        }
412
413        let batch_results = futures::future::join_all(extract_tasks).await;
414        for (result, chunk_id) in batch_results.into_iter().zip(chunk_ids) {
415            let graph = result.map_err(|e| CognifyError::FactExtractionError(e.to_string()))??;
416            all_graphs.push((chunk_id, graph));
417        }
418
419        info!(
420            "Processed graph extraction batch {}/{} ({} chunks)",
421            batch_idx + 1,
422            chunks_for_extraction.len().div_ceil(batch_size),
423            batch.len()
424        );
425    }
426
427    // Database deduplication — query for existing edges
428    let graphs_only: Vec<KnowledgeGraph> = all_graphs.iter().map(|(_, g)| g.clone()).collect();
429    let existing_edges_set = retrieve_existing_edges(graph_db.as_ref(), &graphs_only).await?;
430
431    // Merge and deduplicate graphs (with DB awareness).
432    //
433    // The string-form `user_id` is the best label we have at this
434    // point in the pipeline-driven path — `ExtractedChunks` does not
435    // carry `user_email`. The executor's downstream walk
436    // (`PipelineContext::user_label()`, task 05-07) fills in the
437    // email-form label later if the run has it; the pre-stamp's
438    // `if dp.source_user.is_none()` guard then skips, so the more
439    // specific value wins.
440    let user_label_owned = user_label_override
441        .map(|s| s.to_string())
442        .or_else(|| input.user_id.as_ref().map(|id| id.to_string()));
443    let (nodes, edges) = expand_with_nodes_and_edges(
444        all_graphs,
445        input.dataset_id,
446        &existing_edges_set,
447        ontology_resolver.as_ref(),
448        user_label_owned.as_deref(),
449    )
450    .await;
451
452    // Final deduplication pass (in-memory only after DB filtering)
453    let dedup_result = deduplicate_nodes_and_edges(nodes, edges);
454
455    // Build chunk_id → entity IDs mapping from the deduplicated nodes.
456    // Each entity stores the chunk_id it was extracted from in its metadata.
457    let mut chunk_entity_map: HashMap<Uuid, Vec<serde_json::Value>> = HashMap::new();
458    for node_pair in &dedup_result.unique_nodes {
459        if let Some(chunk_id_val) = node_pair.entity.base.get_metadata("chunk_id")
460            && let Some(chunk_id_str) = chunk_id_val.as_str()
461            && let Ok(chunk_id) = Uuid::parse_str(chunk_id_str)
462        {
463            chunk_entity_map
464                .entry(chunk_id)
465                .or_default()
466                .push(json!(node_pair.entity.base.id.to_string()));
467        }
468    }
469
470    // Populate DocumentChunk.contains with extracted entity IDs
471    let mut updated_chunks = input.chunks.clone();
472    for chunk in &mut updated_chunks {
473        if let Some(entity_ids) = chunk_entity_map.get(&chunk.base.id) {
474            chunk.contains = entity_ids.clone();
475        }
476    }
477
478    // Store graph data (nodes and edges) in graph database
479    let entity_refs: Vec<&cognee_models::Entity> = dedup_result
480        .unique_nodes
481        .iter()
482        .map(|n| &n.entity)
483        .collect();
484    graph_db
485        .add_nodes(&entity_refs)
486        .await
487        .map_err(CognifyError::from)?;
488
489    let edge_data: Vec<_> = dedup_result
490        .unique_edges
491        .iter()
492        .map(|edge_pair| {
493            let properties: HashMap<std::borrow::Cow<'static, str>, serde_json::Value> = edge_pair
494                .properties
495                .iter()
496                .map(|(k, v)| {
497                    (
498                        std::borrow::Cow::Owned(k.clone()),
499                        serde_json::Value::String(v.clone()),
500                    )
501                })
502                .collect();
503            (
504                edge_pair.source_entity_id.to_string(),
505                edge_pair.target_entity_id.to_string(),
506                edge_pair.relationship_name.clone(),
507                properties,
508            )
509        })
510        .collect();
511
512    graph_db
513        .add_edges(&edge_data)
514        .await
515        .map_err(CognifyError::from)?;
516
517    Ok(ExtractedGraphData {
518        chunks: updated_chunks,
519        documents: input.documents.clone(),
520        entities: dedup_result.unique_nodes,
521        edges: dedup_result.unique_edges,
522        dataset_id: input.dataset_id,
523        user_id: input.user_id,
524        tenant_id: input.tenant_id,
525    })
526}
527
528#[derive(Debug, Clone, PartialEq, Eq)]
529struct WebPageMetadata {
530    url: String,
531    domain: String,
532    title: Option<String>,
533}
534
535fn parse_web_page_metadata(document: &Document) -> Option<WebPageMetadata> {
536    let metadata = document.external_metadata.as_ref()?;
537    let value: serde_json::Value = serde_json::from_str(metadata).ok()?;
538    let source = value.get("source").and_then(|v| v.as_str())?;
539    if source != "url" {
540        return None;
541    }
542
543    let url = value
544        .get("final_url")
545        .or_else(|| value.get("url"))
546        .and_then(|v| v.as_str())?;
547    let parsed = Url::parse(url).ok()?;
548    if !matches!(parsed.scheme(), "http" | "https") {
549        return None;
550    }
551    let domain = parsed.host_str()?.to_ascii_lowercase();
552    let title = value
553        .get("title")
554        .and_then(|v| v.as_str())
555        .filter(|s| !s.is_empty())
556        .map(str::to_string);
557
558    Some(WebPageMetadata {
559        url: parsed.to_string(),
560        domain,
561        title,
562    })
563}
564
565fn web_page_id(url: &str) -> Uuid {
566    Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("WebPage:{url}").as_bytes())
567}
568
569fn web_site_id(domain: &str) -> Uuid {
570    Uuid::new_v5(
571        &Uuid::NAMESPACE_OID,
572        format!("WebSite:{}", domain.to_ascii_lowercase()).as_bytes(),
573    )
574}
575
576fn first_chars(value: &str, limit: usize) -> String {
577    value.chars().take(limit).collect()
578}
579
580fn document_content_preview(document_id: Uuid, chunks: &[DocumentChunk]) -> String {
581    let mut preview = String::new();
582    for chunk in chunks
583        .iter()
584        .filter(|chunk| chunk.document_id == document_id)
585    {
586        if !preview.is_empty() {
587            preview.push('\n');
588        }
589        preview.push_str(&chunk.text);
590        if preview.chars().count() >= 500 {
591            break;
592        }
593    }
594    first_chars(&preview, 500)
595}
596
597fn empty_edge_props() -> HashMap<Cow<'static, str>, serde_json::Value> {
598    HashMap::new()
599}
600
601/// Create deterministic WebPage/WebSite graph provenance for URL-sourced documents.
602///
603/// Uses only URL metadata carried on [`Document::external_metadata`], produced
604/// by ingestion for URL inputs. Invalid JSON, non-URL metadata, unparsable URLs,
605/// and non-HTTP(S) URLs are skipped.
606pub async fn create_web_page_nodes(
607    documents: &[Document],
608    chunks: &[DocumentChunk],
609    graph_db: Arc<dyn GraphDBTrait>,
610) -> Result<(), CognifyError> {
611    if documents.is_empty() || chunks.is_empty() {
612        return Ok(());
613    }
614
615    let mut nodes_by_id: HashMap<String, serde_json::Value> = HashMap::new();
616    let mut candidate_edges: Vec<EdgeData> = Vec::new();
617    let mut seen_edges: HashSet<(String, String, String)> = HashSet::new();
618
619    for document in documents {
620        let Some(metadata) = parse_web_page_metadata(document) else {
621            continue;
622        };
623
624        let page_id = web_page_id(&metadata.url);
625        let site_id = web_site_id(&metadata.domain);
626        let page_id_str = page_id.to_string();
627        let site_id_str = site_id.to_string();
628
629        nodes_by_id.insert(
630            page_id_str.clone(),
631            json!({
632                "id": page_id_str,
633                "type": "WebPage",
634                "url": metadata.url,
635                "title": metadata.title,
636                "content": document_content_preview(document.base.id, chunks),
637            }),
638        );
639        nodes_by_id.insert(
640            site_id_str.clone(),
641            json!({
642                "id": site_id_str,
643                "type": "WebSite",
644                "domain": metadata.domain,
645            }),
646        );
647
648        push_unique_edge(
649            &mut candidate_edges,
650            &mut seen_edges,
651            page_id_str.clone(),
652            site_id_str,
653            "PART_OF",
654        );
655
656        for chunk in chunks
657            .iter()
658            .filter(|chunk| chunk.document_id == document.base.id)
659        {
660            push_unique_edge(
661                &mut candidate_edges,
662                &mut seen_edges,
663                chunk.base.id.to_string(),
664                page_id_str.clone(),
665                "SOURCED_FROM",
666            );
667        }
668    }
669
670    if !nodes_by_id.is_empty() {
671        graph_db
672            .add_nodes_raw(nodes_by_id.into_values().collect())
673            .await
674            .map_err(CognifyError::from)?;
675    }
676
677    if candidate_edges.is_empty() {
678        return Ok(());
679    }
680
681    let existing_edges = graph_db
682        .has_edges(&candidate_edges)
683        .await
684        .map_err(CognifyError::from)?;
685    let existing_keys: HashSet<(String, String, String)> = existing_edges
686        .into_iter()
687        .map(|(source, target, relationship, _)| (source, target, relationship))
688        .collect();
689    let missing_edges: Vec<EdgeData> = candidate_edges
690        .into_iter()
691        .filter(|(source, target, relationship, _)| {
692            !existing_keys.contains(&(source.clone(), target.clone(), relationship.clone()))
693        })
694        .collect();
695
696    if !missing_edges.is_empty() {
697        graph_db
698            .add_edges(&missing_edges)
699            .await
700            .map_err(CognifyError::from)?;
701    }
702
703    Ok(())
704}
705
706fn push_unique_edge(
707    edges: &mut Vec<EdgeData>,
708    seen: &mut HashSet<(String, String, String)>,
709    source: String,
710    target: String,
711    relationship: &str,
712) {
713    let key = (source.clone(), target.clone(), relationship.to_string());
714    if seen.insert(key) {
715        edges.push((source, target, relationship.to_string(), empty_edge_props()));
716    }
717}
718
719// ---------------------------------------------------------------------------
720// Task 3b: extract_custom_graph_from_data (custom graph model path)
721// ---------------------------------------------------------------------------
722
723/// Extract a custom graph model from chunks via LLM (Task 3 — custom model variant).
724///
725/// Mirrors the Python branching at `extract_graph_from_data.py:99-103`:
726/// when the graph model is **not** the built-in [`KnowledgeGraph`], the LLM
727/// output is serialized to JSON and stored directly in each
728/// [`DocumentChunk::contains`] without entity/edge expansion, deduplication,
729/// or graph DB storage.
730///
731/// This function is the generic counterpart of [`extract_graph_from_data`].
732/// It accepts any type implementing [`GraphModel`].
733///
734/// The returned [`ExtractedGraphData`] will have empty `entities` and `edges`
735/// fields (those only apply to the default KnowledgeGraph flow).
736///
737/// # Type Parameters
738/// * `M` — A type implementing [`GraphModel`]. Must be `Serialize +
739///   DeserializeOwned + JsonSchema + Clone + Send + Sync + 'static`.
740///
741/// # Errors
742/// - [`CognifyError::LlmError`] if any LLM call fails
743/// - [`CognifyError::SerializationError`] if the extracted model cannot be
744///   serialized to JSON
745pub async fn extract_custom_graph_from_data<M: crate::fact_extraction::GraphModel>(
746    input: &ExtractedChunks,
747    llm: Arc<dyn Llm>,
748    config: &CognifyConfig,
749) -> Result<ExtractedGraphData, CognifyError> {
750    if input.chunks.is_empty() {
751        return Ok(ExtractedGraphData {
752            chunks: input.chunks.clone(),
753            documents: input.documents.clone(),
754            entities: vec![],
755            edges: vec![],
756            dataset_id: input.dataset_id,
757            user_id: input.user_id,
758            tenant_id: input.tenant_id,
759        });
760    }
761
762    // Filter out DLT chunks — same as extract_graph_from_data
763    let dlt_doc_ids: HashSet<Uuid> = input
764        .documents
765        .iter()
766        .filter(|d| d.document_type == "dlt_row")
767        .map(|d| d.base.id)
768        .collect();
769
770    let batch_size = config.chunks_per_batch;
771    let semaphore = Arc::new(Semaphore::new(config.max_parallel_extractions));
772
773    let mut updated_chunks = input.chunks.clone();
774
775    // Only process non-DLT chunks through LLM
776    let non_dlt_indices: Vec<usize> = updated_chunks
777        .iter()
778        .enumerate()
779        .filter(|(_, c)| !dlt_doc_ids.contains(&c.document_id))
780        .map(|(i, _)| i)
781        .collect();
782
783    if non_dlt_indices.is_empty() {
784        return Ok(ExtractedGraphData {
785            chunks: updated_chunks,
786            documents: input.documents.clone(),
787            entities: vec![],
788            edges: vec![],
789            dataset_id: input.dataset_id,
790            user_id: input.user_id,
791            tenant_id: input.tenant_id,
792        });
793    }
794
795    let total_batches = non_dlt_indices.len().div_ceil(batch_size);
796
797    for (batch_idx, batch_indices) in non_dlt_indices.chunks(batch_size).enumerate() {
798        let mut extract_tasks = Vec::new();
799
800        for &idx in batch_indices {
801            let extractor = FactExtractor::new(Arc::clone(&llm));
802            let text = updated_chunks[idx].text.clone();
803            let sem = Arc::clone(&semaphore);
804            let prompt = config.custom_extraction_prompt.clone();
805
806            extract_tasks.push(tokio::spawn(async move {
807                #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
808                let _permit = sem
809                    .acquire()
810                    .await
811                    .expect("semaphore is never closed; created locally in this function");
812                extractor.extract::<M>(&text, prompt.as_deref()).await
813            }));
814        }
815
816        let batch_results = futures::future::join_all(extract_tasks).await;
817        let batch_len = batch_indices.len();
818
819        for (i, result) in batch_results.into_iter().enumerate() {
820            let model: M =
821                result.map_err(|e| CognifyError::FactExtractionError(e.to_string()))??;
822            let value = serde_json::to_value(&model)
823                .map_err(|e| CognifyError::SerializationError(e.to_string()))?;
824            updated_chunks[batch_indices[i]].contains = vec![value];
825        }
826
827        info!(
828            "Processed custom graph extraction batch {}/{} ({} chunks)",
829            batch_idx + 1,
830            total_batches,
831            batch_len
832        );
833    }
834
835    Ok(ExtractedGraphData {
836        chunks: updated_chunks,
837        documents: input.documents.clone(),
838        entities: vec![],
839        edges: vec![],
840        dataset_id: input.dataset_id,
841        user_id: input.user_id,
842        tenant_id: input.tenant_id,
843    })
844}
845
846// ---------------------------------------------------------------------------
847// Task 4: summarize_text
848// ---------------------------------------------------------------------------
849
850/// Summarize text chunks via LLM (Task 4).
851///
852/// If summarization is enabled in config, generates summaries for each chunk
853/// using batched parallel LLM calls.
854pub async fn summarize_text(
855    input: &ExtractedGraphData,
856    llm: Arc<dyn Llm>,
857    config: &CognifyConfig,
858) -> Result<SummarizedData, CognifyError> {
859    // Filter out DLT chunks — structured data rows should not be summarized.
860    // Mirrors Python: cognee/tasks/summarization/summarize_text.py:52-62
861    let dlt_doc_ids: HashSet<Uuid> = input
862        .documents
863        .iter()
864        .filter(|d| d.document_type == "dlt_row")
865        .map(|d| d.base.id)
866        .collect();
867
868    let non_dlt_chunks: Vec<DocumentChunk> = input
869        .chunks
870        .iter()
871        .filter(|c| !dlt_doc_ids.contains(&c.document_id))
872        .cloned()
873        .collect();
874
875    if non_dlt_chunks.len() < input.chunks.len() {
876        info!(
877            "Skipping {} DLT chunks from summarization ({} non-DLT chunks remain)",
878            input.chunks.len() - non_dlt_chunks.len(),
879            non_dlt_chunks.len()
880        );
881    }
882
883    let summaries = if config.enable_summarization && !non_dlt_chunks.is_empty() {
884        let summary_extractor =
885            SummaryExtractor::new_with_schema(llm, config.summary_schema.clone());
886        let mut all_summaries = Vec::new();
887
888        for batch in non_dlt_chunks.chunks(config.summarization_batch_size) {
889            let batch_summaries = summary_extractor.summarize_chunks(batch, None).await?;
890            all_summaries.extend(batch_summaries);
891        }
892
893        info!("Generated {} summaries", all_summaries.len());
894        all_summaries
895    } else {
896        if !config.enable_summarization {
897            info!("Summarization disabled in config");
898        } else {
899            info!("No non-DLT chunks to summarize");
900        }
901        Vec::new()
902    };
903
904    Ok(SummarizedData {
905        chunks: input.chunks.clone(),
906        documents: input.documents.clone(),
907        entities: input.entities.clone(),
908        edges: input.edges.clone(),
909        summaries,
910        dataset_id: input.dataset_id,
911        user_id: input.user_id,
912        tenant_id: input.tenant_id,
913    })
914}
915
916// ---------------------------------------------------------------------------
917// Task 5: add_data_points
918// ---------------------------------------------------------------------------
919
920/// Generate embeddings and index all data points in vector DB (Task 5).
921///
922/// Generates embeddings for chunks, entities (name + description), summaries,
923/// and optionally triplets. Creates vector collections and indexes points.
924///
925/// When `db` is `Some`, also writes provenance records (nodes/edges) to the
926/// relational database, matching Python's `upsert_nodes` / `upsert_edges`
927/// calls guarded by `if user and dataset and data:`.
928pub async fn add_data_points(
929    input: &SummarizedData,
930    graph_db: Arc<dyn GraphDBTrait>,
931    vector_db: Arc<dyn VectorDB>,
932    embedding_engine: Arc<dyn EmbeddingEngine>,
933    db: Option<Arc<DatabaseConnection>>,
934    config: &CognifyConfig,
935) -> Result<CognifyResult, CognifyError> {
936    // Store all DataPoint types as graph nodes (matches Python's add_data_points behavior).
937    // Python stores DocumentChunks, TextSummaries, and EntityTypes as graph nodes.
938
939    // Store DocumentChunks as graph nodes
940    if !input.chunks.is_empty() {
941        let chunk_refs: Vec<&DocumentChunk> = input.chunks.iter().collect();
942        graph_db
943            .add_nodes(&chunk_refs)
944            .await
945            .map_err(CognifyError::from)?;
946        info!("Stored {} document chunks as graph nodes", chunk_refs.len());
947    }
948
949    // Store TextSummaries as graph nodes
950    if !input.summaries.is_empty() {
951        let summary_refs: Vec<&TextSummary> = input.summaries.iter().collect();
952        graph_db
953            .add_nodes(&summary_refs)
954            .await
955            .map_err(CognifyError::from)?;
956        info!(
957            "Stored {} text summaries as graph nodes",
958            summary_refs.len()
959        );
960    }
961
962    // Store EntityTypes as graph nodes (extract from GraphNodePairs)
963    if !input.entities.is_empty() {
964        let entity_type_refs: Vec<&cognee_models::EntityType> = input
965            .entities
966            .iter()
967            .map(|pair| &pair.entity_type)
968            .collect();
969        graph_db
970            .add_nodes(&entity_type_refs)
971            .await
972            .map_err(CognifyError::from)?;
973        info!(
974            "Stored {} entity types as graph nodes",
975            entity_type_refs.len()
976        );
977    }
978
979    // Store Documents as graph nodes. Python reaches Documents by recursively
980    // walking each DocumentChunk's `is_part_of` field (a full Document
981    // DataPoint) in get_graph_from_model(). Rust's `is_part_of` is just a
982    // `Uuid`, so we store Documents explicitly here. The node `id` equals the
983    // source Data item's id (content-addressed, Python-identical) and the node
984    // `type` is the concrete subclass name (TextDocument, PdfDocument, …), so
985    // the `is_part_of` edge target now resolves to a stored Document node.
986    if !input.documents.is_empty() {
987        let doc_refs: Vec<&Document> = input.documents.iter().collect();
988        graph_db
989            .add_nodes(&doc_refs)
990            .await
991            .map_err(CognifyError::from)?;
992        info!("Stored {} documents as graph nodes", doc_refs.len());
993    }
994
995    // Build EdgeTypes keyed on each edge's retrieval text
996    // (port of Python's create_edge_type_datapoints + index_graph_edges).
997    //
998    // Parity note: Python's `index_graph_edges` only *vector-indexes* these
999    // EdgeType DataPoints (into `EdgeType_relationship_name`) — it never adds
1000    // them to the graph as nodes (see index_graph_edges.py:86-88 →
1001    // index_data_points, which touches the vector engine only). We therefore
1002    // build + vector-index them below but deliberately do NOT call
1003    // `graph_db.add_nodes` on them, so the Rust graph node-set matches Python's
1004    // and they don't surface as untyped/uncolored nodes in the visualization.
1005    //
1006    // Python keys EdgeType IDs and the embedded relationship_name on the
1007    // edge's retrieval text — `get_edge_retrieval_text(edge_text,
1008    // relationship_name)` (index_graph_edges.py:33-53), i.e. the nonblank
1009    // `edge_text` property, falling back to the nonblank relationship_name,
1010    // else dropped. `generate_edge_id(edge_id=text)` then derives the ID from
1011    // that text. We mirror that here so EdgeType UUIDs and the
1012    // EdgeType_relationship_name vector inputs match Python (B2.5).
1013    let mut edge_type_counts: HashMap<String, i32> = HashMap::new();
1014    for edge_pair in &input.edges {
1015        let edge_text = edge_retrieval_text(edge_pair);
1016        if edge_text.is_empty() {
1017            continue;
1018        }
1019        *edge_type_counts.entry(edge_text).or_insert(0) += 1;
1020    }
1021
1022    let mut edge_types: Vec<EdgeType> = edge_type_counts
1023        .into_iter()
1024        .map(|(text, count)| {
1025            let mut et = EdgeType::new_deterministic(&text, Some(input.dataset_id));
1026            et.set_count(count);
1027            et
1028        })
1029        .collect();
1030
1031    // Pre-stamp freshly-built EdgeType DataPoints at construction time so the
1032    // `source_*` provenance keys are populated before they are vector-indexed
1033    // (collection `EdgeType_relationship_name`) and before the Triplet payloads
1034    // copy those keys from the originating EdgeType (gap-05/08 §4.4, below).
1035    // The LLM-derived edge-type names trace back to the entity-extraction task,
1036    // so the `source_pipeline` / `source_task` literals match.
1037    //
1038    // These DataPoints are NOT stored as graph nodes (see parity note above),
1039    // so the stamp only affects vector payloads, not the graph/visualization.
1040    //
1041    // DLT-derived edges (`extract_dlt_fk_edges`) construct
1042    // `GraphEdgePair` instances rather than DataPoints; they carry no
1043    // DataPoint to stamp, so no pre-stamp call is needed there.
1044    {
1045        let user_label = input.user_id.as_ref().map(|id| id.to_string());
1046        let mut local_visited: HashSet<Uuid> = HashSet::new();
1047        for et in &mut edge_types {
1048            crate::graph_integration::expansion::pre_stamp_extraction(
1049                et,
1050                user_label.as_deref(),
1051                &mut local_visited,
1052            );
1053        }
1054    }
1055
1056    // Discover structural edges via GraphExtractable trait
1057    // (port of Python's get_graph_from_model() relationship discovery)
1058    let mut extractable_items: Vec<&dyn crate::graph_extraction::GraphExtractable> = Vec::new();
1059    for chunk in &input.chunks {
1060        extractable_items.push(chunk as &dyn crate::graph_extraction::GraphExtractable);
1061    }
1062    for summary in &input.summaries {
1063        extractable_items.push(summary as &dyn crate::graph_extraction::GraphExtractable);
1064    }
1065    for pair in &input.entities {
1066        extractable_items.push(&pair.entity as &dyn crate::graph_extraction::GraphExtractable);
1067        extractable_items.push(&pair.entity_type as &dyn crate::graph_extraction::GraphExtractable);
1068    }
1069
1070    let structural_edges = crate::graph_extraction::get_graph_from_model(&extractable_items);
1071
1072    if !structural_edges.is_empty() {
1073        graph_db
1074            .add_edges(&structural_edges)
1075            .await
1076            .map_err(CognifyError::from)?;
1077        info!("Created {} structural edges", structural_edges.len());
1078    }
1079
1080    let embeddings = generate_embeddings(
1081        &input.chunks,
1082        &input.entities,
1083        &input.summaries,
1084        embedding_engine.clone(),
1085    )
1086    .await?;
1087
1088    let indexed_fields = index_data_points(
1089        &input.chunks,
1090        &input.entities,
1091        &input.summaries,
1092        &input.documents,
1093        &input.edges,
1094        &edge_types,
1095        input.dataset_id,
1096        input.user_id,
1097        input.tenant_id,
1098        embedding_engine,
1099        vector_db,
1100        config,
1101    )
1102    .await?;
1103
1104    // ── Provenance upsert (mirrors Python's `if user and dataset and data:`) ──
1105    if let (Some(db), Some(user_id)) = (&db, input.user_id) {
1106        upsert_provenance(
1107            db,
1108            input.tenant_id,
1109            user_id,
1110            input.dataset_id,
1111            &input.chunks,
1112            &input.entities,
1113            &input.edges,
1114            &input.summaries,
1115            &input.documents,
1116            &structural_edges,
1117        )
1118        .await?;
1119    }
1120
1121    Ok(CognifyResult {
1122        chunks: input.chunks.clone(),
1123        entities: input.entities.clone(),
1124        edges: input.edges.clone(),
1125        summaries: input.summaries.clone(),
1126        edge_types,
1127        embeddings,
1128        indexed_fields,
1129        documents_for_dlt: input.documents.clone(),
1130        already_completed: false,
1131        prior_pipeline_run_id: None,
1132    })
1133}
1134
1135// ---------------------------------------------------------------------------
1136// Temporal Task 3: extract_temporal_events
1137// ---------------------------------------------------------------------------
1138
1139/// Extract temporal events from text chunks via two LLM passes (Temporal Task 3).
1140///
1141/// Mirrors the Python `get_temporal_tasks` pipeline stage 3:
1142/// `extract_events_and_timestamps` followed by `extract_knowledge_graph_from_events`.
1143///
1144/// Steps:
1145/// 1. Collects all non-DLT [`DocumentChunk`]s from `input`.
1146/// 2. Batches by `config.data_per_batch`.
1147/// 3. For each chunk in a batch, runs [`TemporalEventExtractor::extract_events`]
1148///    in parallel (bounded by `config.max_parallel_extractions`).
1149/// 4. Flattens per-chunk results and enriches each batch with entity attributes
1150///    via [`TemporalEntityEnricher::enrich`].
1151/// 5. Returns all events as [`ExtractedTemporalEvents`].
1152pub async fn extract_temporal_events(
1153    input: &ExtractedChunks,
1154    llm: Arc<dyn Llm>,
1155    config: &CognifyConfig,
1156) -> Result<ExtractedTemporalEvents, CognifyError> {
1157    if input.chunks.is_empty() {
1158        return Ok(ExtractedTemporalEvents {
1159            events: vec![],
1160            dataset_id: input.dataset_id,
1161            user_id: input.user_id,
1162            tenant_id: input.tenant_id,
1163        });
1164    }
1165
1166    // Filter out DLT chunks — same rationale as extract_graph_from_data.
1167    let dlt_doc_ids: HashSet<Uuid> = input
1168        .documents
1169        .iter()
1170        .filter(|d| d.document_type == "dlt_row")
1171        .map(|d| d.base.id)
1172        .collect();
1173
1174    let non_dlt_chunks: Vec<&DocumentChunk> = input
1175        .chunks
1176        .iter()
1177        .filter(|c| !dlt_doc_ids.contains(&c.document_id))
1178        .collect();
1179
1180    if non_dlt_chunks.is_empty() {
1181        return Ok(ExtractedTemporalEvents {
1182            events: vec![],
1183            dataset_id: input.dataset_id,
1184            user_id: input.user_id,
1185            tenant_id: input.tenant_id,
1186        });
1187    }
1188
1189    let batch_size = config.data_per_batch;
1190    let semaphore = Arc::new(Semaphore::new(config.max_parallel_extractions));
1191    let extractor = Arc::new(TemporalEventExtractor::new(Arc::clone(&llm)));
1192    let enricher = TemporalEntityEnricher::new(Arc::clone(&llm));
1193
1194    let mut all_events: Vec<TemporalEvent> = Vec::new();
1195
1196    for (batch_idx, batch) in non_dlt_chunks.chunks(batch_size).enumerate() {
1197        let mut extract_tasks = Vec::new();
1198
1199        for chunk in batch {
1200            let ext = Arc::clone(&extractor);
1201            let text = chunk.text.clone();
1202            let sem = Arc::clone(&semaphore);
1203            extract_tasks.push(tokio::spawn(async move {
1204                #[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
1205                let _permit = sem
1206                    .acquire()
1207                    .await
1208                    .expect("semaphore is never closed; created locally in this function");
1209                ext.extract_events(&text).await
1210            }));
1211        }
1212
1213        let batch_results = futures::future::join_all(extract_tasks).await;
1214        let mut batch_events: Vec<TemporalEvent> = Vec::new();
1215        for result in batch_results {
1216            let events = result.map_err(|e| CognifyError::FactExtractionError(e.to_string()))??;
1217            batch_events.extend(events);
1218        }
1219
1220        info!(
1221            "Temporal extraction batch {}/{}: {} events extracted",
1222            batch_idx + 1,
1223            non_dlt_chunks.len().div_ceil(batch_size),
1224            batch_events.len()
1225        );
1226
1227        // Entity enrichment pass for the whole batch.
1228        let enriched = enricher.enrich(batch_events).await?;
1229        all_events.extend(enriched);
1230    }
1231
1232    info!(
1233        "Temporal event extraction complete: {} total events",
1234        all_events.len()
1235    );
1236
1237    Ok(ExtractedTemporalEvents {
1238        events: all_events,
1239        dataset_id: input.dataset_id,
1240        user_id: input.user_id,
1241        tenant_id: input.tenant_id,
1242    })
1243}
1244
1245// ---------------------------------------------------------------------------
1246// Temporal Task 4: add_temporal_data_points
1247// ---------------------------------------------------------------------------
1248
1249/// Persist temporal events to graph and vector databases (Temporal Task 4).
1250///
1251/// Mirrors the Python `add_data_points` stage in the temporal pipeline.
1252///
1253/// For each [`TemporalEvent`]:
1254/// 1. Creates an `Event` graph node with a deterministic UUID5 ID.
1255/// 2. For `event.at` — creates a `Timestamp` graph node and an `at` edge.
1256/// 3. For `event.during` — creates `Timestamp` nodes for from/to, an `Interval`
1257///    node, and `during` / `time_from` / `time_to` edges (Python-compatible layout).
1258/// 4. For each [`EventAttribute`] — creates or looks up an entity graph node
1259///    and adds a typed edge from the `Event` to the entity.
1260/// 5. Embeds `event.name` and indexes to the `Event_name` vector collection.
1261pub async fn add_temporal_data_points(
1262    events: &ExtractedTemporalEvents,
1263    graph_db: Arc<dyn GraphDBTrait>,
1264    vector_db: Arc<dyn VectorDB>,
1265    embedding_engine: Arc<dyn EmbeddingEngine>,
1266) -> Result<CognifyResult, CognifyError> {
1267    if events.events.is_empty() {
1268        info!("No temporal events to persist.");
1269        return Ok(CognifyResult::empty());
1270    }
1271
1272    let mut graph_nodes: Vec<serde_json::Value> = Vec::new();
1273    let mut graph_edges: Vec<EdgeData> = Vec::new();
1274
1275    // Deduplicate entity nodes across events to avoid redundant graph inserts.
1276    let mut seen_entity_ids: HashSet<Uuid> = HashSet::new();
1277    // Deduplicate edges: (source_id, target_id, relationship_name)
1278    let mut seen_edge_keys: HashSet<(String, String, String)> = HashSet::new();
1279
1280    let mut event_ids: Vec<Uuid> = Vec::new();
1281    let mut event_names: Vec<String> = Vec::new();
1282
1283    for event in &events.events {
1284        // ── Event node ──────────────────────────────────────────────────────
1285        let event_id = Uuid::new_v5(
1286            &Uuid::NAMESPACE_OID,
1287            format!("event:{}", event.name).as_bytes(),
1288        );
1289        event_ids.push(event_id);
1290        event_names.push(event.name.clone());
1291
1292        let mut event_node = json!({
1293            "id": event_id.to_string(),
1294            "data_type": "Event",
1295            "name": event.name,
1296        });
1297        if let Some(desc) = &event.description {
1298            event_node["description"] = json!(desc);
1299        }
1300        if let Some(loc) = &event.location {
1301            event_node["location"] = json!(loc);
1302        }
1303        graph_nodes.push(event_node);
1304
1305        // ── Timestamp for event.at ──────────────────────────────────────────
1306        if let Some(ts) = &event.at {
1307            let ts_id = Uuid::new_v5(
1308                &Uuid::NAMESPACE_OID,
1309                format!("timestamp:{}", ts.time_at).as_bytes(),
1310            );
1311            graph_nodes.push(json!({
1312                "id": ts_id.to_string(),
1313                "data_type": "Timestamp",
1314                "time_at": ts.time_at,
1315                "timestamp_str": ts.timestamp_str,
1316                "year": ts.year,
1317                "month": ts.month,
1318                "day": ts.day,
1319                "hour": ts.hour,
1320                "minute": ts.minute,
1321                "second": ts.second,
1322            }));
1323
1324            let edge_key = (event_id.to_string(), ts_id.to_string(), "at".to_string());
1325            if seen_edge_keys.insert(edge_key) {
1326                graph_edges.push((
1327                    event_id.to_string(),
1328                    ts_id.to_string(),
1329                    "at".to_string(),
1330                    build_edge_props(&event_id.to_string(), &ts_id.to_string(), "at"),
1331                ));
1332            }
1333        }
1334
1335        // ── Interval for event.during ───────────────────────────────────────
1336        if let Some(interval) = &event.during {
1337            let ts_from = &interval.time_from;
1338            let ts_to = &interval.time_to;
1339
1340            let ts_from_id = Uuid::new_v5(
1341                &Uuid::NAMESPACE_OID,
1342                format!("timestamp:{}", ts_from.time_at).as_bytes(),
1343            );
1344            let ts_to_id = Uuid::new_v5(
1345                &Uuid::NAMESPACE_OID,
1346                format!("timestamp:{}", ts_to.time_at).as_bytes(),
1347            );
1348            let interval_id = Uuid::new_v5(
1349                &Uuid::NAMESPACE_OID,
1350                format!("interval:{}:{}", ts_from.time_at, ts_to.time_at).as_bytes(),
1351            );
1352
1353            graph_nodes.push(json!({
1354                "id": ts_from_id.to_string(),
1355                "data_type": "Timestamp",
1356                "time_at": ts_from.time_at,
1357                "timestamp_str": ts_from.timestamp_str,
1358                "year": ts_from.year,
1359                "month": ts_from.month,
1360                "day": ts_from.day,
1361                "hour": ts_from.hour,
1362                "minute": ts_from.minute,
1363                "second": ts_from.second,
1364            }));
1365            graph_nodes.push(json!({
1366                "id": ts_to_id.to_string(),
1367                "data_type": "Timestamp",
1368                "time_at": ts_to.time_at,
1369                "timestamp_str": ts_to.timestamp_str,
1370                "year": ts_to.year,
1371                "month": ts_to.month,
1372                "day": ts_to.day,
1373                "hour": ts_to.hour,
1374                "minute": ts_to.minute,
1375                "second": ts_to.second,
1376            }));
1377            graph_nodes.push(json!({
1378                "id": interval_id.to_string(),
1379                "data_type": "Interval",
1380            }));
1381
1382            // Event -[during]-> Interval
1383            let during_key = (
1384                event_id.to_string(),
1385                interval_id.to_string(),
1386                "during".to_string(),
1387            );
1388            if seen_edge_keys.insert(during_key) {
1389                graph_edges.push((
1390                    event_id.to_string(),
1391                    interval_id.to_string(),
1392                    "during".to_string(),
1393                    build_edge_props(&event_id.to_string(), &interval_id.to_string(), "during"),
1394                ));
1395            }
1396
1397            // Interval -[time_from]-> Timestamp(from)
1398            let from_key = (
1399                interval_id.to_string(),
1400                ts_from_id.to_string(),
1401                "time_from".to_string(),
1402            );
1403            if seen_edge_keys.insert(from_key) {
1404                graph_edges.push((
1405                    interval_id.to_string(),
1406                    ts_from_id.to_string(),
1407                    "time_from".to_string(),
1408                    build_edge_props(
1409                        &interval_id.to_string(),
1410                        &ts_from_id.to_string(),
1411                        "time_from",
1412                    ),
1413                ));
1414            }
1415
1416            // Interval -[time_to]-> Timestamp(to)
1417            let to_key = (
1418                interval_id.to_string(),
1419                ts_to_id.to_string(),
1420                "time_to".to_string(),
1421            );
1422            if seen_edge_keys.insert(to_key) {
1423                graph_edges.push((
1424                    interval_id.to_string(),
1425                    ts_to_id.to_string(),
1426                    "time_to".to_string(),
1427                    build_edge_props(&interval_id.to_string(), &ts_to_id.to_string(), "time_to"),
1428                ));
1429            }
1430        }
1431
1432        // ── Entity attribute nodes and edges ────────────────────────────────
1433        for attr in &event.attributes {
1434            let entity_id = Uuid::new_v5(
1435                &Uuid::NAMESPACE_OID,
1436                format!("entity:{}", attr.entity).as_bytes(),
1437            );
1438
1439            if seen_entity_ids.insert(entity_id) {
1440                graph_nodes.push(json!({
1441                    "id": entity_id.to_string(),
1442                    "data_type": attr.entity_type,
1443                    "name": attr.entity,
1444                }));
1445            }
1446
1447            let rel_key = (
1448                event_id.to_string(),
1449                entity_id.to_string(),
1450                attr.relationship.clone(),
1451            );
1452            if seen_edge_keys.insert(rel_key) {
1453                graph_edges.push((
1454                    event_id.to_string(),
1455                    entity_id.to_string(),
1456                    attr.relationship.clone(),
1457                    build_edge_props(
1458                        &event_id.to_string(),
1459                        &entity_id.to_string(),
1460                        &attr.relationship,
1461                    ),
1462                ));
1463            }
1464        }
1465    }
1466
1467    // Persist nodes and edges to graph DB.
1468    if !graph_nodes.is_empty() {
1469        let node_count = graph_nodes.len();
1470        graph_db
1471            .add_nodes_raw(graph_nodes)
1472            .await
1473            .map_err(CognifyError::from)?;
1474        info!("Stored {} temporal graph nodes", node_count);
1475    }
1476
1477    if !graph_edges.is_empty() {
1478        let edge_count = graph_edges.len();
1479        graph_db
1480            .add_edges(&graph_edges)
1481            .await
1482            .map_err(CognifyError::from)?;
1483        info!("Stored {} temporal graph edges", edge_count);
1484    }
1485
1486    // ── Vector indexing: Event.name ──────────────────────────────────────────
1487    let mut indexed_fields = IndexedFieldsStats::default();
1488
1489    if !event_ids.is_empty() {
1490        let dimension = embedding_engine.dimension();
1491
1492        if !vector_db
1493            .has_collection("Event", "name")
1494            .await
1495            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
1496        {
1497            vector_db
1498                .create_collection("Event", "name", dimension)
1499                .await
1500                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
1501        }
1502
1503        let name_strs: Vec<&str> = event_names.iter().map(String::as_str).collect();
1504        let vectors = embedding_engine
1505            .embed(&name_strs)
1506            .await
1507            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
1508
1509        let points: Vec<VectorPoint> = event_ids
1510            .iter()
1511            .zip(event_names.iter())
1512            .zip(vectors.iter())
1513            .map(|((id, name), vector)| {
1514                let mut point = VectorPoint::new(*id, vector.clone())
1515                    .with_metadata("type", json!("Event"))
1516                    .with_metadata("field", json!("name"))
1517                    .with_metadata("name", json!(name))
1518                    .with_metadata("dataset_id", json!(events.dataset_id.to_string()));
1519                if let Some(uid) = events.user_id {
1520                    point = point.with_metadata("user_id", json!(uid.to_string()));
1521                }
1522                if let Some(tid) = events.tenant_id {
1523                    point = point.with_metadata("tenant_id", json!(tid.to_string()));
1524                }
1525                point
1526            })
1527            .collect();
1528
1529        vector_db
1530            .index_points("Event", "name", &points)
1531            .await
1532            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
1533
1534        indexed_fields.record("Event", "name", event_ids.len());
1535        info!("Indexed {} event names in vector DB", event_ids.len());
1536    }
1537
1538    Ok(CognifyResult {
1539        chunks: vec![],
1540        entities: vec![],
1541        edges: vec![],
1542        summaries: vec![],
1543        edge_types: vec![],
1544        embeddings: vec![],
1545        indexed_fields,
1546        documents_for_dlt: vec![],
1547        already_completed: false,
1548        prior_pipeline_run_id: None,
1549    })
1550}
1551
1552/// Resolve the retrieval text for an edge, mirroring Python's
1553/// `get_edge_retrieval_text(edge_text, relationship_name)`
1554/// (prepare_edges_for_storage.py:26-28 via index_graph_edges.py:33-53):
1555/// prefer the nonblank `edge_text` property, fall back to the nonblank
1556/// `relationship_name`, else return an empty string (caller drops empties).
1557fn edge_retrieval_text(edge_pair: &GraphEdgePair) -> String {
1558    let from_edge_text = edge_pair
1559        .properties
1560        .get("edge_text")
1561        .map(|s| s.trim())
1562        .filter(|s| !s.is_empty());
1563
1564    if let Some(text) = from_edge_text {
1565        return text.to_string();
1566    }
1567
1568    let rel = edge_pair.relationship_name.trim();
1569    rel.to_string()
1570}
1571
1572/// Build minimal edge properties for graph storage.
1573fn build_edge_props(
1574    source_id: &str,
1575    target_id: &str,
1576    relationship_name: &str,
1577) -> HashMap<std::borrow::Cow<'static, str>, serde_json::Value> {
1578    let mut props = HashMap::new();
1579    props.insert(
1580        std::borrow::Cow::Borrowed("source_node_id"),
1581        json!(source_id),
1582    );
1583    props.insert(
1584        std::borrow::Cow::Borrowed("target_node_id"),
1585        json!(target_id),
1586    );
1587    props.insert(
1588        std::borrow::Cow::Borrowed("relationship_name"),
1589        json!(relationship_name),
1590    );
1591    props
1592}
1593
1594// ---------------------------------------------------------------------------
1595// Task 6: extract_dlt_fk_edges
1596// ---------------------------------------------------------------------------
1597
1598/// Create graph edges and schema nodes from DLT-sourced relational data.
1599///
1600/// Mirrors the Python `cognee/tasks/ingestion/extract_dlt_fk_edges.py`.
1601/// This task runs after `add_data_points` in the cognify pipeline. It:
1602/// 1. Identifies DLT documents from the classified documents list
1603/// 2. Parses `external_metadata` for table info and foreign key definitions
1604/// 3. Creates `is_row_of` edges from DLT document nodes to their source table
1605/// 4. Creates FK-based edges between documents of related rows
1606///
1607/// If no DLT documents are present, this is a no-op.
1608pub async fn extract_dlt_fk_edges(
1609    _chunks: &[DocumentChunk],
1610    documents: &[Document],
1611    graph_db: Arc<dyn GraphDBTrait>,
1612) -> Result<(), CognifyError> {
1613    // Collect DLT documents
1614    let dlt_docs: Vec<&Document> = documents
1615        .iter()
1616        .filter(|d| d.document_type == "dlt_row")
1617        .collect();
1618
1619    if dlt_docs.is_empty() {
1620        return Ok(());
1621    }
1622
1623    info!(
1624        "Processing {} DLT documents for FK edge extraction",
1625        dlt_docs.len()
1626    );
1627
1628    // Parse external_metadata for each DLT document
1629    // Collect table info and FK definitions
1630    let mut tables_seen: HashMap<String, DltTableMeta> = HashMap::new();
1631    let mut dlt_doc_meta: HashMap<Uuid, serde_json::Value> = HashMap::new();
1632    let mut fk_defs_seen: HashSet<(String, String, String, String)> = HashSet::new();
1633
1634    for doc in &dlt_docs {
1635        let ext_metadata = match &doc.external_metadata {
1636            Some(m) => match serde_json::from_str::<serde_json::Value>(m) {
1637                Ok(v) if v.get("source").and_then(|s| s.as_str()) == Some("dlt") => v,
1638                _ => continue,
1639            },
1640            None => continue,
1641        };
1642
1643        dlt_doc_meta.insert(doc.base.id, ext_metadata.clone());
1644
1645        let table_name = ext_metadata
1646            .get("table_name")
1647            .and_then(|v| v.as_str())
1648            .unwrap_or("")
1649            .to_string();
1650
1651        if !table_name.is_empty() && !tables_seen.contains_key(&table_name) {
1652            tables_seen.insert(
1653                table_name.clone(),
1654                DltTableMeta {
1655                    schema_info: ext_metadata.get("schema_info").cloned(),
1656                    foreign_keys: ext_metadata
1657                        .get("foreign_keys")
1658                        .and_then(|v| v.as_array())
1659                        .cloned()
1660                        .unwrap_or_default(),
1661                    dlt_db_name: ext_metadata
1662                        .get("dlt_db_name")
1663                        .and_then(|v| v.as_str())
1664                        .unwrap_or("")
1665                        .to_string(),
1666                },
1667            );
1668        }
1669    }
1670
1671    if dlt_doc_meta.is_empty() {
1672        return Ok(());
1673    }
1674
1675    let mut all_edges: Vec<cognee_graph::EdgeData> = Vec::new();
1676
1677    // Phase 1: Build table node IDs (deterministic via uuid5) and SchemaTable nodes
1678    let mut table_node_ids: HashMap<String, Uuid> = HashMap::new();
1679    let mut schema_nodes: Vec<serde_json::Value> = Vec::new();
1680
1681    for (table_name, table_meta) in &tables_seen {
1682        let id = Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("dlt:{table_name}").as_bytes());
1683        table_node_ids.insert(table_name.clone(), id);
1684
1685        let columns_str = table_meta
1686            .schema_info
1687            .as_ref()
1688            .map(|v| v.to_string())
1689            .unwrap_or_else(|| "[]".to_string());
1690        let fk_str =
1691            serde_json::to_string(&table_meta.foreign_keys).unwrap_or_else(|_| "[]".to_string());
1692
1693        let table_node = SchemaTableNode {
1694            id: id.to_string(),
1695            name: table_name.clone(),
1696            columns: columns_str,
1697            primary_key: None,
1698            foreign_keys: fk_str,
1699            sample_rows: "[]".to_string(),
1700            row_count_estimate: None,
1701            description: format!(
1702                "DLT-ingested relational table '{}' from database '{}'.",
1703                table_name, table_meta.dlt_db_name
1704            ),
1705            data_type: "SchemaTable".to_string(),
1706        };
1707        if let Ok(val) = serde_json::to_value(&table_node) {
1708            schema_nodes.push(val);
1709        }
1710    }
1711
1712    // Phase 2: Create FK relationship edges between table nodes
1713    for (table_name, table_meta) in &tables_seen {
1714        for fk in &table_meta.foreign_keys {
1715            let fk_col = fk
1716                .get("column")
1717                .and_then(|v| v.as_str())
1718                .unwrap_or("")
1719                .to_string();
1720            let ref_table = fk
1721                .get("ref_table")
1722                .and_then(|v| v.as_str())
1723                .unwrap_or("")
1724                .to_string();
1725            let ref_col = fk
1726                .get("ref_column")
1727                .and_then(|v| v.as_str())
1728                .unwrap_or("")
1729                .to_string();
1730
1731            if fk_col.is_empty() || ref_table.is_empty() {
1732                continue;
1733            }
1734
1735            let fk_key = (
1736                table_name.clone(),
1737                fk_col.clone(),
1738                ref_table.clone(),
1739                ref_col.clone(),
1740            );
1741            if fk_defs_seen.contains(&fk_key) {
1742                continue;
1743            }
1744            fk_defs_seen.insert(fk_key);
1745
1746            let rel_name = format!("{table_name}:{fk_col}->{ref_table}:{ref_col}");
1747            let rel_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("dlt:{rel_name}").as_bytes());
1748
1749            // Create SchemaRelationship node for this FK definition
1750            let rel_node = SchemaRelationshipNode {
1751                id: rel_id.to_string(),
1752                name: rel_name.clone(),
1753                source_table: table_name.clone(),
1754                target_table: ref_table.clone(),
1755                relationship_type: "foreign_key".to_string(),
1756                source_column: fk_col.clone(),
1757                target_column: ref_col.clone(),
1758                description: format!("Foreign key: {table_name}.{fk_col} -> {ref_table}.{ref_col}"),
1759                data_type: "SchemaRelationship".to_string(),
1760            };
1761            if let Ok(val) = serde_json::to_value(&rel_node) {
1762                schema_nodes.push(val);
1763            }
1764
1765            // source_table -> relationship (has_foreign_key)
1766            if let Some(&source_table_id) = table_node_ids.get(table_name.as_str()) {
1767                let mut props = HashMap::new();
1768                props.insert(
1769                    std::borrow::Cow::Borrowed("source_node_id"),
1770                    json!(source_table_id.to_string()),
1771                );
1772                props.insert(
1773                    std::borrow::Cow::Borrowed("target_node_id"),
1774                    json!(rel_id.to_string()),
1775                );
1776                props.insert(
1777                    std::borrow::Cow::Borrowed("relationship_name"),
1778                    json!("has_foreign_key"),
1779                );
1780                all_edges.push((
1781                    source_table_id.to_string(),
1782                    rel_id.to_string(),
1783                    "has_foreign_key".to_string(),
1784                    props,
1785                ));
1786            }
1787
1788            // relationship -> target_table (references_table)
1789            if let Some(&target_table_id) = table_node_ids.get(ref_table.as_str()) {
1790                let mut props = HashMap::new();
1791                props.insert(
1792                    std::borrow::Cow::Borrowed("source_node_id"),
1793                    json!(rel_id.to_string()),
1794                );
1795                props.insert(
1796                    std::borrow::Cow::Borrowed("target_node_id"),
1797                    json!(target_table_id.to_string()),
1798                );
1799                props.insert(
1800                    std::borrow::Cow::Borrowed("relationship_name"),
1801                    json!("references_table"),
1802                );
1803                all_edges.push((
1804                    rel_id.to_string(),
1805                    target_table_id.to_string(),
1806                    "references_table".to_string(),
1807                    props,
1808                ));
1809            }
1810        }
1811    }
1812
1813    // Phase 3: Create row-level edges (document -> table, document -> referenced document)
1814    let mut seen_row_edges: HashSet<(String, String, String)> = HashSet::new();
1815
1816    for (doc_id, ext_metadata) in &dlt_doc_meta {
1817        let table_name = ext_metadata
1818            .get("table_name")
1819            .and_then(|v| v.as_str())
1820            .unwrap_or("");
1821
1822        // Link document to its SchemaTable node
1823        if let Some(&table_node_id) = table_node_ids.get(table_name) {
1824            let mut props = HashMap::new();
1825            props.insert(
1826                std::borrow::Cow::Borrowed("source_node_id"),
1827                json!(doc_id.to_string()),
1828            );
1829            props.insert(
1830                std::borrow::Cow::Borrowed("target_node_id"),
1831                json!(table_node_id.to_string()),
1832            );
1833            props.insert(
1834                std::borrow::Cow::Borrowed("relationship_name"),
1835                json!("is_row_of"),
1836            );
1837            all_edges.push((
1838                doc_id.to_string(),
1839                table_node_id.to_string(),
1840                "is_row_of".to_string(),
1841                props,
1842            ));
1843        }
1844
1845        // Create FK row-level edges
1846        let fk_references = ext_metadata
1847            .get("fk_references")
1848            .and_then(|v| v.as_array())
1849            .cloned()
1850            .unwrap_or_default();
1851
1852        for fk_ref in &fk_references {
1853            let target_data_id = match fk_ref.get("target_data_id").and_then(|v| v.as_str()) {
1854                Some(id) => id.to_string(),
1855                None => continue,
1856            };
1857
1858            let relationship_name = fk_ref
1859                .get("relationship_name")
1860                .and_then(|v| v.as_str())
1861                .unwrap_or("references")
1862                .to_string();
1863
1864            let edge_key = (
1865                doc_id.to_string(),
1866                target_data_id.clone(),
1867                relationship_name.clone(),
1868            );
1869            if seen_row_edges.contains(&edge_key) {
1870                continue;
1871            }
1872            seen_row_edges.insert(edge_key);
1873
1874            let mut props = HashMap::new();
1875            props.insert(
1876                std::borrow::Cow::Borrowed("source_node_id"),
1877                json!(doc_id.to_string()),
1878            );
1879            props.insert(
1880                std::borrow::Cow::Borrowed("target_node_id"),
1881                json!(target_data_id.clone()),
1882            );
1883            props.insert(
1884                std::borrow::Cow::Borrowed("relationship_name"),
1885                json!(relationship_name.clone()),
1886            );
1887            props.insert(
1888                std::borrow::Cow::Borrowed("edge_text"),
1889                json!(relationship_name.replace('_', " ")),
1890            );
1891            props.insert(
1892                std::borrow::Cow::Borrowed("source_table"),
1893                json!(table_name),
1894            );
1895            props.insert(
1896                std::borrow::Cow::Borrowed("target_table"),
1897                json!(
1898                    fk_ref
1899                        .get("target_table")
1900                        .and_then(|v| v.as_str())
1901                        .unwrap_or("")
1902                ),
1903            );
1904            props.insert(
1905                std::borrow::Cow::Borrowed("fk_column"),
1906                json!(fk_ref.get("column").and_then(|v| v.as_str()).unwrap_or("")),
1907            );
1908
1909            all_edges.push((doc_id.to_string(), target_data_id, relationship_name, props));
1910        }
1911    }
1912
1913    // Persist schema nodes to graph DB (SchemaTable + SchemaRelationship)
1914    // NOTE: Python also calls `index_data_points(schema_nodes)` to embed these
1915    // into vector DB. That is out of scope for Phase 0; Rust's `add_data_points`
1916    // task handles vector indexing for the main pipeline data.
1917    if !schema_nodes.is_empty() {
1918        let node_count = schema_nodes.len();
1919        graph_db
1920            .add_nodes_raw(schema_nodes)
1921            .await
1922            .map_err(CognifyError::from)?;
1923        info!("Added {} DLT schema nodes to graph", node_count);
1924    }
1925
1926    // Persist edges to graph DB
1927    if !all_edges.is_empty() {
1928        graph_db
1929            .add_edges(&all_edges)
1930            .await
1931            .map_err(CognifyError::from)?;
1932        info!(
1933            "Added {} DLT FK edges to graph ({} tables, {} FK definitions)",
1934            all_edges.len(),
1935            table_node_ids.len(),
1936            fk_defs_seen.len()
1937        );
1938    }
1939
1940    Ok(())
1941}
1942
1943/// Graph node representing a DLT-ingested relational table.
1944///
1945/// Mirrors Python's `SchemaTable` DataPoint model from
1946/// `cognee/tasks/schema/models.py`.
1947#[derive(Debug, Serialize)]
1948struct SchemaTableNode {
1949    id: String,
1950    name: String,
1951    columns: String,
1952    primary_key: Option<String>,
1953    foreign_keys: String,
1954    sample_rows: String,
1955    row_count_estimate: Option<i64>,
1956    description: String,
1957    data_type: String,
1958}
1959
1960/// Graph node representing a foreign-key relationship between two tables.
1961///
1962/// Mirrors Python's `SchemaRelationship` DataPoint model from
1963/// `cognee/tasks/schema/models.py`.
1964#[derive(Debug, Serialize)]
1965struct SchemaRelationshipNode {
1966    id: String,
1967    name: String,
1968    source_table: String,
1969    target_table: String,
1970    relationship_type: String,
1971    source_column: String,
1972    target_column: String,
1973    description: String,
1974    data_type: String,
1975}
1976
1977/// Internal metadata for a DLT source table.
1978#[derive(Debug)]
1979struct DltTableMeta {
1980    schema_info: Option<serde_json::Value>,
1981    foreign_keys: Vec<serde_json::Value>,
1982    dlt_db_name: String,
1983}
1984
1985// ---------------------------------------------------------------------------
1986// Provenance stamping helper
1987// ---------------------------------------------------------------------------
1988
1989/// Stamp pipeline provenance fields on a [`DataPoint`].
1990///
1991/// Used by the **convenience [`cognify`] entry point** which bypasses
1992/// `cognee_core::execute()` and therefore does not benefit from the
1993/// executor-driven walk in
1994/// [`cognee_core::provenance::stamp_tree`]. Per locked decision 6 of
1995/// `docs/telemetry/05-datapoint-provenance.md`, both code paths land
1996/// stamping; the `if dp.source_X.is_none()` guards make double-stamping
1997/// a no-op.
1998///
1999/// Pipeline-driven cognify uses the executor walk via
2000/// [`cognee_core::provenance::stamp_tree_dyn`] — see
2001/// `crates/core/src/provenance.rs`.
2002///
2003/// Only sets each field if it is currently `None`, so earlier (more specific)
2004/// stamps are never overwritten.  Mirrors the Python
2005/// `run_tasks_base.py` post-task provenance stamping.
2006fn stamp_provenance(dp: &mut DataPoint, pipeline: &str, task: &str, user: Option<&str>) {
2007    if dp.source_pipeline.is_none() {
2008        dp.source_pipeline = Some(pipeline.to_string());
2009    }
2010    if dp.source_task.is_none() {
2011        dp.source_task = Some(task.to_string());
2012    }
2013    if dp.source_user.is_none() {
2014        dp.source_user = user.map(String::from);
2015    }
2016}
2017
2018// ---------------------------------------------------------------------------
2019// Convenience function: sequential execution of all tasks
2020// ---------------------------------------------------------------------------
2021
2022/// Run the complete cognify pipeline on a set of Data items.
2023///
2024/// Executes each task sequentially: classify → chunk → extract graph →
2025/// summarize → add data points (embed + index).
2026///
2027/// For composable pipeline-based execution (with concurrency, retry, progress
2028/// tracking), use [`build_cognify_pipeline`] + [`cognee_core::execute`].
2029#[allow(clippy::too_many_arguments)]
2030pub async fn cognify(
2031    data_items: Vec<Data>,
2032    dataset_id: Uuid,
2033    user_id: Option<Uuid>,
2034    user_email: Option<String>,
2035    tenant_id: Option<Uuid>,
2036    llm: Arc<dyn Llm>,
2037    storage: Arc<dyn StorageTrait>,
2038    graph_db: Arc<dyn GraphDBTrait>,
2039    vector_db: Arc<dyn VectorDB>,
2040    embedding_engine: Arc<dyn EmbeddingEngine>,
2041    database: Arc<DatabaseConnection>,
2042    pipeline_run_repo: Arc<dyn PipelineRunRepository>,
2043    thread_pool: Arc<dyn CpuPool>,
2044    ontology_resolver: Arc<dyn OntologyResolver>,
2045    config: &CognifyConfig,
2046) -> Result<CognifyResult, CognifyError> {
2047    config
2048        .validate()
2049        .map_err(|e| CognifyError::ConfigError(e.to_string()))?;
2050
2051    // Auto-calculate chunk size when the caller is using the default value.
2052    // Matches Python's `get_max_chunk_tokens()` from
2053    // `cognee/infrastructure/llm/utils.py`. Locked Decision 6: this mutation
2054    // happens **before** `pipeline::execute` so the executor sees a frozen
2055    // config in `build_cognify_pipeline`.
2056    let effective_config = if config.max_chunk_size == CognifyConfig::default().max_chunk_size {
2057        let cfg = config
2058            .clone()
2059            .with_auto_chunk_size(embedding_engine.as_ref(), llm.as_ref());
2060        info!("Auto-calculated max_chunk_size: {}", cfg.max_chunk_size);
2061        cfg
2062    } else {
2063        config.clone()
2064    };
2065
2066    info!(
2067        "Starting cognify pipeline with config: chunks_per_batch={}, max_chunk_size={}",
2068        effective_config.chunks_per_batch, effective_config.max_chunk_size
2069    );
2070
2071    // ── Qualification gate (gap 08-08, locked decision 3) ───────────────────
2072    // Python-parity `check_pipeline_run_qualification`: read the latest
2073    // `pipeline_runs` row for `(dataset_id, pipeline_name)` and decide
2074    // whether to proceed, short-circuit, or reject. The pipeline name used
2075    // here MUST match what `DbPipelineWatcher` will persist on the next
2076    // `pipeline::execute` call — that is the `Pipeline.name` set below
2077    // (`"cognify"` or `"temporal-cognify"`).
2078    let pipeline_name: &str = if effective_config.temporal_cognify {
2079        "temporal-cognify"
2080    } else {
2081        "cognify"
2082    };
2083    match check_pipeline_run_qualification(pipeline_run_repo.as_ref(), dataset_id, pipeline_name)
2084        .await
2085        .map_err(|e| CognifyError::DatabaseError(e.to_string()))?
2086    {
2087        Qualification::AlreadyCompleted(prior) => {
2088            info!(
2089                dataset_id = %dataset_id,
2090                pipeline_run_id = %prior.pipeline_run_id,
2091                "cognify: dataset already completed; short-circuiting (Python parity)"
2092            );
2093            return Ok(CognifyResult::already_completed(prior.pipeline_run_id));
2094        }
2095        Qualification::AlreadyRunning(_prior) => {
2096            return Err(CognifyError::PipelineAlreadyRunning {
2097                pipeline_name: pipeline_name.to_string(),
2098                dataset_id,
2099            });
2100        }
2101        Qualification::Proceed => {}
2102    }
2103
2104    // ── Empty-document short-circuit ────────────────────────────────────────
2105    // Preserved from the pre-executor path: a caller passing zero documents
2106    // gets back an empty result without paying for pipeline / context
2107    // construction or a no-op LLM round-trip.
2108    if data_items.is_empty() {
2109        return Ok(CognifyResult::empty());
2110    }
2111
2112    // ── Branch: temporal vs. standard pipeline ──────────────────────────────
2113    // LIB-06-04: both branches now route through `pipeline::execute`. The
2114    // selection happens *before* `execute()` per locked Decision 2 — temporal
2115    // is a distinct `Pipeline` with its own task DAG. Per locked option (a)
2116    // (user decision 2026-05-15), the shared tasks
2117    // (`make_classify_documents_task`, `make_extract_chunks_task`) stamp
2118    // `Document` / `DocumentChunk` DataPoints with
2119    // `source_pipeline = "cognify"` (the LIB-06-03 constant) on both
2120    // branches; the temporal pipeline keeps its distinct identity at the
2121    // `pipeline_runs` row level via `build_temporal_cognify_pipeline`'s
2122    // `with_name("temporal-cognify")`.
2123    let is_temporal = effective_config.temporal_cognify;
2124    let pipeline = if is_temporal {
2125        build_temporal_cognify_pipeline(
2126            Arc::clone(&storage),
2127            Arc::clone(&graph_db),
2128            Arc::clone(&vector_db),
2129            Arc::clone(&embedding_engine),
2130            Arc::clone(&llm),
2131            Some(Arc::clone(&database)),
2132            effective_config.clone(),
2133        )
2134    } else {
2135        build_cognify_pipeline(
2136            Arc::clone(&storage),
2137            Arc::clone(&graph_db),
2138            Arc::clone(&vector_db),
2139            Arc::clone(&embedding_engine),
2140            Arc::clone(&llm),
2141            Some(Arc::clone(&database)),
2142            Arc::clone(&ontology_resolver),
2143            effective_config.clone(),
2144        )
2145    };
2146
2147    // The executor re-derives `PipelineRunInfo.pipeline_id` from
2148    // `(pipeline.name, user_id, dataset_id)`; we still carry `pipeline.id`
2149    // through `PipelineContext` as the placeholder.
2150    let pipeline_ctx = PipelineContext {
2151        pipeline_id: pipeline.id,
2152        pipeline_name: pipeline.name.clone().unwrap_or_default(),
2153        user_id,
2154        tenant_id,
2155        dataset_id: Some(dataset_id),
2156        current_data: None,
2157        run_id: None,
2158        user_email: user_email.clone(),
2159        provenance_visited: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
2160    };
2161
2162    let (_cancel_handle, ctx) = TaskContextBuilder::new()
2163        .thread_pool(thread_pool)
2164        .database(Arc::clone(&database))
2165        .graph_db(Arc::clone(&graph_db))
2166        .vector_db(Arc::clone(&vector_db))
2167        .pipeline_context(pipeline_ctx)
2168        .build()
2169        .map_err(|e| CognifyError::ContextBuild(e.to_string()))?;
2170    let ctx = Arc::new(ctx);
2171
2172    let input = CognifyInput {
2173        data_items,
2174        dataset_id,
2175        user_id,
2176        tenant_id,
2177    };
2178    let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(input) as Arc<dyn Value>];
2179
2180    // Decision 11 (gap 08-07): `DbPipelineWatcher` persists the four-state
2181    // `pipeline_runs` trail through the caller-supplied repository.
2182    // Embedded callers pass `NoopPipelineRunRepository`; CLI / HTTP callers
2183    // pass a `SeaOrmPipelineRunRepository` to surface rows in the
2184    // `/api/v1/activity/pipeline-runs` endpoint.
2185    let watcher = DbPipelineWatcher::new(pipeline_run_repo);
2186    let outputs = cognee_core::pipeline::execute(&pipeline, inputs, ctx, &watcher)
2187        .await
2188        .map_err(|e| CognifyError::Execute(e.to_string()))?;
2189
2190    let result = extract_cognify_outputs(outputs)?;
2191
2192    // Decision 5: post-pipeline teardown — `extract_dlt_fk_edges` stays
2193    // outside the executor. The pipeline_runs row is already marked
2194    // COMPLETED by the watcher at this point; teardown failure surfaces as
2195    // `Err(...)` to the caller but does not roll back the run state.
2196    //
2197    // LIB-06-04: skip DLT FK extraction on the temporal branch — temporal
2198    // does not propagate `documents_for_dlt` (and Python's temporal cognify
2199    // does not run DLT teardown either).
2200    if !is_temporal {
2201        extract_dlt_fk_edges(
2202            &result.chunks,
2203            &result.documents_for_dlt,
2204            Arc::clone(&graph_db),
2205        )
2206        .await?;
2207    }
2208
2209    Ok(result)
2210}
2211
2212// ---------------------------------------------------------------------------
2213// Output extraction (Decision 9)
2214// ---------------------------------------------------------------------------
2215
2216/// Downcast the executor's [`Arc<dyn Value>`] outputs back to the concrete
2217/// [`CognifyResult`] the convenience function promises.
2218///
2219/// Returns [`CognifyError::OutputTypeMismatch`] when the downcast fails — a
2220/// programmer error indicating the pipeline's last task does not emit
2221/// `CognifyResult`. Mirrors `cognee_ingestion::pipeline::extract_data_outputs`
2222/// (LIB-06-01) and `cognee_cognify::memify::extract_memify_outputs` (LIB-06-02).
2223fn extract_cognify_outputs(outputs: Vec<Arc<dyn Value>>) -> Result<CognifyResult, CognifyError> {
2224    let first = outputs
2225        .into_iter()
2226        .next()
2227        .ok_or(CognifyError::OutputTypeMismatch {
2228            expected: "CognifyResult",
2229            actual: "empty",
2230        })?;
2231    // Explicit deref through `Arc` to reach the inner `dyn Value`, then call
2232    // `as_any` via vtable dispatch — without this, method resolution would
2233    // pick the blanket `<Arc<dyn Value> as Value>::as_any()` which downcasts
2234    // to `Arc<dyn Value>` and never to `CognifyResult`.
2235    (*first)
2236        .as_any()
2237        .downcast_ref::<CognifyResult>()
2238        .cloned()
2239        .ok_or(CognifyError::OutputTypeMismatch {
2240            expected: "CognifyResult",
2241            actual: "unknown",
2242        })
2243}
2244
2245// ---------------------------------------------------------------------------
2246// Internal helpers
2247// ---------------------------------------------------------------------------
2248
2249// ── Provenance helpers ──────────────────────────────────────────────────────
2250
2251/// Deterministic provenance node ID, matching Python's:
2252/// `uuid5(NAMESPACE_OID, str(tenant_id) + str(user_id) + str(dataset_id) + str(data_id) + str(node_id))`
2253///
2254/// When `tenant_id` is `None`, Python's `str(None)` produces `"None"`.
2255fn provenance_node_id(
2256    tenant_id: Option<Uuid>,
2257    user_id: Uuid,
2258    dataset_id: Uuid,
2259    data_id: Uuid,
2260    node_id: Uuid,
2261) -> Uuid {
2262    let tid = tenant_id.map_or("None".to_string(), |t| t.to_string());
2263    let raw = format!("{tid}{user_id}{dataset_id}{data_id}{node_id}");
2264    Uuid::new_v5(&Uuid::NAMESPACE_OID, raw.as_bytes())
2265}
2266
2267/// Deterministic provenance edge ID, matching Python's:
2268/// `uuid5(NAMESPACE_OID, str(tenant_id) + str(user_id) + str(dataset_id) + str(source_id) + str(edge_text) + str(target_id))`
2269fn provenance_edge_id(
2270    tenant_id: Option<Uuid>,
2271    user_id: Uuid,
2272    dataset_id: Uuid,
2273    source_id: Uuid,
2274    edge_text: &str,
2275    target_id: Uuid,
2276) -> Uuid {
2277    let tid = tenant_id.map_or("None".to_string(), |t| t.to_string());
2278    let raw = format!("{tid}{user_id}{dataset_id}{source_id}{edge_text}{target_id}");
2279    Uuid::new_v5(&Uuid::NAMESPACE_OID, raw.as_bytes())
2280}
2281
2282/// Deterministic edge slug, matching Python's `generate_edge_id`:
2283/// `uuid5(NAMESPACE_OID, edge_text.lower().replace(" ", "_").replace("'", ""))`
2284fn edge_slug(edge_text: &str) -> Uuid {
2285    let normalized = edge_text.to_lowercase().replace(' ', "_").replace('\'', "");
2286    Uuid::new_v5(&Uuid::NAMESPACE_OID, normalized.as_bytes())
2287}
2288
2289/// Deterministic triplet slug, matching `Triplet::new`.
2290fn triplet_slug(source_id: Uuid, relationship_name: &str, target_id: Uuid) -> Uuid {
2291    let raw = format!("{source_id}{relationship_name}{target_id}");
2292    let normalized = raw.to_lowercase().replace(' ', "_").replace('\'', "");
2293    Uuid::new_v5(&Uuid::NAMESPACE_OID, normalized.as_bytes())
2294}
2295
2296/// Write provenance node and edge records to the relational database.
2297///
2298/// Mirrors the Python `upsert_nodes()` / `upsert_edges()` calls in
2299/// `add_data_points` (guarded by `if user and dataset and data:`).
2300///
2301/// Provenance records link graph nodes/edges back to the user, tenant,
2302/// dataset, and data item they originated from.
2303#[allow(clippy::too_many_arguments)]
2304async fn upsert_provenance(
2305    db: &DatabaseConnection,
2306    tenant_id: Option<Uuid>,
2307    user_id: Uuid,
2308    dataset_id: Uuid,
2309    chunks: &[DocumentChunk],
2310    entities: &[GraphNodePair],
2311    edges: &[GraphEdgePair],
2312    summaries: &[TextSummary],
2313    documents: &[Document],
2314    structural_edges: &[EdgeData],
2315) -> Result<(), CognifyError> {
2316    use cognee_database::ops::graph_storage;
2317    use cognee_database::{GraphEdge, GraphNode};
2318
2319    // Build chunk_id → document_id map for tracing entity provenance back
2320    // to the originating Data item.
2321    let chunk_data_map: HashMap<Uuid, Uuid> =
2322        chunks.iter().map(|c| (c.base.id, c.document_id)).collect();
2323    let entity_data_map: HashMap<Uuid, Uuid> = entities
2324        .iter()
2325        .filter_map(|pair| {
2326            pair.entity
2327                .base
2328                .get_metadata("chunk_id")
2329                .and_then(|v| v.as_str())
2330                .and_then(|s| Uuid::parse_str(s).ok())
2331                .and_then(|chunk_id| chunk_data_map.get(&chunk_id).copied())
2332                .map(|data_id| (pair.entity.base.id, data_id))
2333        })
2334        .collect();
2335
2336    // ── Provenance nodes ────────────────────────────────────────────────
2337    let mut prov_nodes: Vec<GraphNode> = Vec::new();
2338
2339    // Entities
2340    for pair in entities {
2341        let entity = &pair.entity;
2342
2343        // Resolve data_id by tracing entity → chunk_id → document_id
2344        let data_id = entity
2345            .base
2346            .get_metadata("chunk_id")
2347            .and_then(|v| v.as_str())
2348            .and_then(|s| Uuid::parse_str(s).ok())
2349            .and_then(|chunk_id| chunk_data_map.get(&chunk_id).copied())
2350            .unwrap_or(Uuid::nil());
2351
2352        let indexed_fields = entity
2353            .base
2354            .get_metadata("index_fields")
2355            .cloned()
2356            .unwrap_or(json!(["name"]));
2357
2358        let label = if entity.name.is_empty() {
2359            entity.base.id.to_string()
2360        } else {
2361            entity.name.clone()
2362        };
2363
2364        prov_nodes.push(GraphNode {
2365            id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, entity.base.id),
2366            slug: entity.base.id,
2367            user_id,
2368            data_id,
2369            dataset_id,
2370            label: Some(label),
2371            node_type: entity.base.data_type.clone(),
2372            indexed_fields,
2373            attributes: serde_json::to_value(entity).ok(),
2374            created_at: Utc::now(),
2375        });
2376    }
2377
2378    // DocumentChunks
2379    for chunk in chunks {
2380        let data_id = chunk.document_id;
2381
2382        let indexed_fields = chunk
2383            .base
2384            .get_metadata("index_fields")
2385            .cloned()
2386            .unwrap_or(json!(["text"]));
2387
2388        prov_nodes.push(GraphNode {
2389            id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, chunk.base.id),
2390            slug: chunk.base.id,
2391            user_id,
2392            data_id,
2393            dataset_id,
2394            label: Some(format!("chunk_{}", chunk.chunk_index)),
2395            node_type: chunk.base.data_type.clone(),
2396            indexed_fields,
2397            attributes: serde_json::to_value(chunk).ok(),
2398            created_at: Utc::now(),
2399        });
2400    }
2401
2402    // TextSummaries
2403    for summary in summaries {
2404        let data_id = summary
2405            .made_from
2406            .and_then(|chunk_id| chunk_data_map.get(&chunk_id).copied())
2407            .unwrap_or(Uuid::nil());
2408
2409        let indexed_fields = summary
2410            .base
2411            .get_metadata("index_fields")
2412            .cloned()
2413            .unwrap_or(json!(["text"]));
2414
2415        prov_nodes.push(GraphNode {
2416            id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, summary.base.id),
2417            slug: summary.base.id,
2418            user_id,
2419            data_id,
2420            dataset_id,
2421            label: Some(format!("summary_{}", summary.base.id)),
2422            node_type: summary.base.data_type.clone(),
2423            indexed_fields,
2424            attributes: serde_json::to_value(summary).ok(),
2425            created_at: Utc::now(),
2426        });
2427    }
2428
2429    // EntityTypes
2430    for pair in entities {
2431        let et = &pair.entity_type;
2432        // EntityType is shared across entities; use nil data_id as in Python
2433        prov_nodes.push(GraphNode {
2434            id: provenance_node_id(tenant_id, user_id, dataset_id, Uuid::nil(), et.base.id),
2435            slug: et.base.id,
2436            user_id,
2437            data_id: Uuid::nil(),
2438            dataset_id,
2439            label: Some(et.name.clone()),
2440            node_type: et.base.data_type.clone(),
2441            indexed_fields: et
2442                .base
2443                .get_metadata("index_fields")
2444                .cloned()
2445                .unwrap_or(json!(["name"])),
2446            attributes: serde_json::to_value(et).ok(),
2447            created_at: Utc::now(),
2448        });
2449    }
2450
2451    // Documents. Python reaches the Document node by recursively walking each
2452    // DocumentChunk's `is_part_of` (a full Document DataPoint), so the Document
2453    // lands in `nodes` and `upsert_nodes(nodes, …)` writes its provenance row
2454    // keyed with the ctx `data_item.id`. Rust stores Documents explicitly (see
2455    // `add_data_points`), so we must register their provenance here too —
2456    // otherwise the Document graph node (slug == its id == the source Data
2457    // item's id) is never matched by the delete cleanup and leaks on hard
2458    // delete. The Document's id IS the Data item's id, so `data_id` = its id.
2459    for document in documents {
2460        let data_id = document.base.id;
2461
2462        let indexed_fields = document
2463            .base
2464            .get_metadata("index_fields")
2465            .cloned()
2466            .unwrap_or(json!(["name"]));
2467
2468        let label = if document.name.is_empty() {
2469            document.base.id.to_string()
2470        } else {
2471            document.name.clone()
2472        };
2473
2474        prov_nodes.push(GraphNode {
2475            id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, document.base.id),
2476            slug: document.base.id,
2477            user_id,
2478            data_id,
2479            dataset_id,
2480            label: Some(label),
2481            node_type: document.base.data_type.clone(),
2482            indexed_fields,
2483            attributes: serde_json::to_value(document).ok(),
2484            created_at: Utc::now(),
2485        });
2486    }
2487
2488    if !prov_nodes.is_empty() {
2489        graph_storage::upsert_nodes(db, &prov_nodes).await?;
2490        info!("Upserted {} provenance node records", prov_nodes.len());
2491    }
2492
2493    // ── Provenance edges ────────────────────────────────────────────────
2494    let mut prov_edges: Vec<GraphEdge> = Vec::new();
2495
2496    // Semantic edges from graph extraction
2497    for edge_pair in edges {
2498        let edge_text = if edge_pair.relationship_name == "contains" {
2499            edge_pair
2500                .properties
2501                .get("edge_text")
2502                .cloned()
2503                .unwrap_or_else(|| edge_pair.relationship_name.clone())
2504        } else {
2505            edge_pair.relationship_name.clone()
2506        };
2507
2508        let source_data_id = entity_data_map.get(&edge_pair.source_entity_id).copied();
2509        let target_data_id = entity_data_map.get(&edge_pair.target_entity_id).copied();
2510        let data_id = match (source_data_id, target_data_id) {
2511            (Some(source), Some(target)) if source == target => source,
2512            _ => Uuid::nil(),
2513        };
2514
2515        prov_edges.push(GraphEdge {
2516            id: provenance_edge_id(
2517                tenant_id,
2518                user_id,
2519                dataset_id,
2520                edge_pair.source_entity_id,
2521                &edge_text,
2522                edge_pair.target_entity_id,
2523            ),
2524            slug: triplet_slug(
2525                edge_pair.source_entity_id,
2526                &edge_text,
2527                edge_pair.target_entity_id,
2528            ),
2529            user_id,
2530            data_id,
2531            dataset_id,
2532            source_node_id: edge_pair.source_entity_id,
2533            destination_node_id: edge_pair.target_entity_id,
2534            relationship_name: edge_text,
2535            label: Some(edge_pair.relationship_name.clone()),
2536            attributes: serde_json::to_value(&edge_pair.properties).ok(),
2537            created_at: Utc::now(),
2538        });
2539    }
2540
2541    // Structural edges from get_graph_from_model (contains, is_a, made_from, etc.)
2542    // Python writes these to SQLite via upsert_edges() — Rust must match.
2543    for (source_id_str, target_id_str, rel_name, properties) in structural_edges {
2544        let source_id = Uuid::parse_str(source_id_str).unwrap_or(Uuid::nil());
2545        let target_id = Uuid::parse_str(target_id_str).unwrap_or(Uuid::nil());
2546
2547        let attrs = if properties.is_empty() {
2548            None
2549        } else {
2550            let map: serde_json::Map<String, serde_json::Value> = properties
2551                .iter()
2552                .map(|(k, v)| (k.to_string(), v.clone()))
2553                .collect();
2554            Some(serde_json::Value::Object(map))
2555        };
2556
2557        prov_edges.push(GraphEdge {
2558            id: provenance_edge_id(
2559                tenant_id, user_id, dataset_id, source_id, rel_name, target_id,
2560            ),
2561            slug: edge_slug(rel_name),
2562            user_id,
2563            data_id: Uuid::nil(), // structural edges span multiple DataPoints
2564            dataset_id,
2565            source_node_id: source_id,
2566            destination_node_id: target_id,
2567            relationship_name: rel_name.clone(),
2568            label: None,
2569            attributes: attrs,
2570            created_at: Utc::now(),
2571        });
2572    }
2573
2574    if !prov_edges.is_empty() {
2575        graph_storage::upsert_edges(db, &prov_edges).await?;
2576        info!("Upserted {} provenance edge records", prov_edges.len());
2577    }
2578
2579    Ok(())
2580}
2581
2582/// Generate embeddings for chunks, entities, and summaries.
2583async fn generate_embeddings(
2584    chunks: &[DocumentChunk],
2585    entities: &[GraphNodePair],
2586    summaries: &[TextSummary],
2587    engine: Arc<dyn EmbeddingEngine>,
2588) -> Result<Vec<Embedding>, CognifyError> {
2589    let mut embeddings = Vec::new();
2590
2591    if !chunks.is_empty() {
2592        let chunk_texts: Vec<_> = chunks.iter().map(|c| c.text.as_str()).collect();
2593        let chunk_vectors = engine
2594            .embed(&chunk_texts)
2595            .await
2596            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2597
2598        for (chunk, vector) in chunks.iter().zip(chunk_vectors) {
2599            embeddings.push(Embedding::new(
2600                chunk.base.id,
2601                "DocumentChunk",
2602                "text",
2603                vector,
2604            ));
2605        }
2606    }
2607
2608    if !entities.is_empty() {
2609        let entity_names: Vec<_> = entities.iter().map(|e| e.entity.name.as_str()).collect();
2610        let entity_vectors = engine
2611            .embed(&entity_names)
2612            .await
2613            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2614
2615        for (entity, vector) in entities.iter().zip(entity_vectors) {
2616            embeddings.push(Embedding::new(
2617                entity.entity.base.id,
2618                "Entity",
2619                "name",
2620                vector,
2621            ));
2622        }
2623    }
2624
2625    if !summaries.is_empty() {
2626        let summary_texts: Vec<_> = summaries.iter().map(|s| s.text.as_str()).collect();
2627        let summary_vectors = engine
2628            .embed(&summary_texts)
2629            .await
2630            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2631
2632        for (summary, vector) in summaries.iter().zip(summary_vectors) {
2633            embeddings.push(Embedding::new(
2634                summary.base.id,
2635                "TextSummary",
2636                "text",
2637                vector,
2638            ));
2639        }
2640    }
2641
2642    Ok(embeddings)
2643}
2644
2645/// Index data points in vector database.
2646#[allow(clippy::too_many_arguments)]
2647async fn index_data_points(
2648    chunks: &[DocumentChunk],
2649    entities: &[GraphNodePair],
2650    summaries: &[TextSummary],
2651    documents: &[Document],
2652    edges: &[GraphEdgePair],
2653    edge_types: &[EdgeType],
2654    dataset_id: Uuid,
2655    user_id: Option<Uuid>,
2656    tenant_id: Option<Uuid>,
2657    engine: Arc<dyn EmbeddingEngine>,
2658    vector_db: Arc<dyn VectorDB>,
2659    config: &CognifyConfig,
2660) -> Result<IndexedFieldsStats, CognifyError> {
2661    let mut stats = IndexedFieldsStats::default();
2662    let dimension = engine.dimension();
2663
2664    // 1. Index DocumentChunk.text field
2665    if !chunks.is_empty() {
2666        if !vector_db
2667            .has_collection("DocumentChunk", "text")
2668            .await
2669            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
2670        {
2671            vector_db
2672                .create_collection("DocumentChunk", "text", dimension)
2673                .await
2674                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2675        }
2676
2677        let texts: Vec<_> = chunks.iter().map(|c| c.text.as_str()).collect();
2678        let vectors = engine
2679            .embed(&texts)
2680            .await
2681            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2682
2683        let points: Vec<VectorPoint> = chunks
2684            .iter()
2685            .zip(vectors)
2686            .map(|(chunk, vector)| {
2687                let mut point = VectorPoint::new(chunk.base.id, vector);
2688
2689                // 1. Full DataPoint dump (Python parity — see gap-05/08).
2690                //    Provides `type`, `belongs_to_set`, all source_* keys, etc.
2691                for (k, v) in chunk.base.vector_metadata() {
2692                    point = point.with_metadata(k, v);
2693                }
2694
2695                // 2. Context-specific keys not present on the DataPoint.
2696                point = point
2697                    .with_metadata("field", json!("text"))
2698                    .with_metadata("text", json!(chunk.text.clone()))
2699                    .with_metadata("dataset_id", json!(dataset_id.to_string()))
2700                    .with_metadata("document_id", json!(chunk.document_id.to_string()))
2701                    .with_metadata("chunk_index", json!(chunk.chunk_index));
2702                if let Some(uid) = user_id {
2703                    point = point.with_metadata("user_id", json!(uid.to_string()));
2704                }
2705                if let Some(tid) = tenant_id {
2706                    point = point.with_metadata("tenant_id", json!(tid.to_string()));
2707                }
2708                point
2709            })
2710            .collect();
2711
2712        vector_db
2713            .index_points("DocumentChunk", "text", &points)
2714            .await
2715            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2716
2717        stats.record("DocumentChunk", "text", chunks.len());
2718        info!("Indexed {} document chunks", chunks.len());
2719    }
2720
2721    // 2a. Index Entity.name field
2722    if !entities.is_empty() {
2723        if !vector_db
2724            .has_collection("Entity", "name")
2725            .await
2726            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
2727        {
2728            vector_db
2729                .create_collection("Entity", "name", dimension)
2730                .await
2731                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2732        }
2733
2734        let names: Vec<_> = entities.iter().map(|e| e.entity.name.as_str()).collect();
2735        let vectors = engine
2736            .embed(&names)
2737            .await
2738            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2739
2740        let points: Vec<VectorPoint> = entities
2741            .iter()
2742            .zip(vectors)
2743            .map(|(entity, vector)| {
2744                let mut point = VectorPoint::new(entity.entity.base.id, vector);
2745
2746                // 1. Full DataPoint dump (Python parity — see gap-05/08).
2747                for (k, v) in entity.entity.base.vector_metadata() {
2748                    point = point.with_metadata(k, v);
2749                }
2750
2751                // 2. Context-specific keys not present on the DataPoint.
2752                point = point
2753                    .with_metadata("field", json!("name"))
2754                    .with_metadata("dataset_id", json!(dataset_id.to_string()))
2755                    .with_metadata("entity_type", json!(entity.entity_type.name.clone()));
2756                if let Some(uid) = user_id {
2757                    point = point.with_metadata("user_id", json!(uid.to_string()));
2758                }
2759                if let Some(tid) = tenant_id {
2760                    point = point.with_metadata("tenant_id", json!(tid.to_string()));
2761                }
2762                point
2763            })
2764            .collect();
2765
2766        vector_db
2767            .index_points("Entity", "name", &points)
2768            .await
2769            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2770
2771        stats.record("Entity", "name", entities.len());
2772        info!("Indexed {} entity names", entities.len());
2773    }
2774
2775    // 2b. Index EntityType.name field (deduplicated by EntityType ID)
2776    {
2777        let mut seen_ids = std::collections::HashSet::new();
2778        let unique_entity_types: Vec<&cognee_models::EntityType> = entities
2779            .iter()
2780            .map(|pair| &pair.entity_type)
2781            .filter(|et| seen_ids.insert(et.base.id))
2782            .collect();
2783
2784        if !unique_entity_types.is_empty() {
2785            if !vector_db
2786                .has_collection("EntityType", "name")
2787                .await
2788                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
2789            {
2790                vector_db
2791                    .create_collection("EntityType", "name", dimension)
2792                    .await
2793                    .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2794            }
2795
2796            let type_names: Vec<_> = unique_entity_types
2797                .iter()
2798                .map(|et| et.name.as_str())
2799                .collect();
2800            let vectors = engine
2801                .embed(&type_names)
2802                .await
2803                .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2804
2805            let points: Vec<VectorPoint> = unique_entity_types
2806                .iter()
2807                .zip(vectors)
2808                .map(|(et, vector)| {
2809                    let mut point = VectorPoint::new(et.base.id, vector);
2810
2811                    // 1. Full DataPoint dump (Python parity — see gap-05/08).
2812                    for (k, v) in et.base.vector_metadata() {
2813                        point = point.with_metadata(k, v);
2814                    }
2815
2816                    // 2. Context-specific keys not present on the DataPoint.
2817                    point = point
2818                        .with_metadata("field", json!("name"))
2819                        .with_metadata("dataset_id", json!(dataset_id.to_string()));
2820                    if let Some(uid) = user_id {
2821                        point = point.with_metadata("user_id", json!(uid.to_string()));
2822                    }
2823                    if let Some(tid) = tenant_id {
2824                        point = point.with_metadata("tenant_id", json!(tid.to_string()));
2825                    }
2826                    point
2827                })
2828                .collect();
2829
2830            vector_db
2831                .index_points("EntityType", "name", &points)
2832                .await
2833                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2834
2835            stats.record("EntityType", "name", unique_entity_types.len());
2836            info!("Indexed {} entity type names", unique_entity_types.len());
2837        }
2838    }
2839
2840    // 3. Index TextSummary.text field
2841    if !summaries.is_empty() {
2842        if !vector_db
2843            .has_collection("TextSummary", "text")
2844            .await
2845            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
2846        {
2847            vector_db
2848                .create_collection("TextSummary", "text", dimension)
2849                .await
2850                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2851        }
2852
2853        let texts: Vec<_> = summaries.iter().map(|s| s.text.as_str()).collect();
2854        let vectors = engine
2855            .embed(&texts)
2856            .await
2857            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2858
2859        let points: Vec<VectorPoint> = summaries
2860            .iter()
2861            .zip(vectors)
2862            .map(|(summary, vector)| {
2863                let mut point = VectorPoint::new(summary.base.id, vector);
2864
2865                // 1. Full DataPoint dump (Python parity — see gap-05/08).
2866                for (k, v) in summary.base.vector_metadata() {
2867                    point = point.with_metadata(k, v);
2868                }
2869
2870                // 2. Context-specific keys not present on the DataPoint.
2871                point = point
2872                    .with_metadata("field", json!("text"))
2873                    .with_metadata("text", json!(summary.text.clone()))
2874                    .with_metadata("dataset_id", json!(dataset_id.to_string()));
2875                if let Some(made_from) = summary.made_from {
2876                    point = point.with_metadata("chunk_id", json!(made_from.to_string()));
2877                }
2878                if let Some(uid) = user_id {
2879                    point = point.with_metadata("user_id", json!(uid.to_string()));
2880                }
2881                if let Some(tid) = tenant_id {
2882                    point = point.with_metadata("tenant_id", json!(tid.to_string()));
2883                }
2884                point
2885            })
2886            .collect();
2887
2888        vector_db
2889            .index_points("TextSummary", "text", &points)
2890            .await
2891            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2892
2893        stats.record("TextSummary", "text", summaries.len());
2894        info!("Indexed {} summaries", summaries.len());
2895    }
2896
2897    // 4. Index triplets (if enabled in config)
2898    if config.embed_triplets && !edges.is_empty() && !entities.is_empty() {
2899        use crate::triplet_creation::create_triplets_from_graph;
2900
2901        let triplets = create_triplets_from_graph(entities, edges);
2902
2903        if !triplets.is_empty() {
2904            if !vector_db
2905                .has_collection("Triplet", "text")
2906                .await
2907                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
2908            {
2909                vector_db
2910                    .create_collection("Triplet", "text", dimension)
2911                    .await
2912                    .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2913            }
2914
2915            let triplet_texts: Vec<_> = triplets.iter().map(|t| t.text.as_str()).collect();
2916            let triplet_vectors = engine
2917                .embed(&triplet_texts)
2918                .await
2919                .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
2920
2921            // Index the EdgeType DataPoints so each triplet payload can
2922            // inherit its originating edge's provenance (`source_*`) keys per
2923            // gap-05/08 §4.4. Triplet itself has no embedded `DataPoint`, so we
2924            // narrow the dump to just the five `source_*` keys to avoid
2925            // colliding with Triplet's own flat fields (id, type, etc.).
2926            //
2927            // EdgeTypes are now keyed on each edge's *retrieval text*
2928            // (`edge_retrieval_text`: nonblank `edge_text`, else
2929            // `relationship_name`) to match Python's `generate_edge_id`, but a
2930            // Triplet only carries the bare `relationship_name`. We therefore
2931            // map each triplet's (source, target, relationship) tuple to its
2932            // edge's retrieval text via the source edges, then look up the
2933            // EdgeType by that text — so the provenance copy survives the
2934            // Part-3 keying change even when edges carry a description.
2935            let edge_type_by_text: std::collections::HashMap<&str, &EdgeType> = edge_types
2936                .iter()
2937                .map(|et| (et.relationship_name.as_str(), et))
2938                .collect();
2939            let edge_text_by_triple: std::collections::HashMap<(Uuid, Uuid, &str), String> = edges
2940                .iter()
2941                .map(|e| {
2942                    (
2943                        (
2944                            e.source_entity_id,
2945                            e.target_entity_id,
2946                            e.relationship_name.as_str(),
2947                        ),
2948                        edge_retrieval_text(e),
2949                    )
2950                })
2951                .collect();
2952
2953            let triplet_points: Vec<VectorPoint> = triplets
2954                .iter()
2955                .zip(triplet_vectors)
2956                .map(|(triplet, vector)| {
2957                    let mut point = VectorPoint::new(triplet.id, vector)
2958                        .with_metadata("type", json!("Triplet"))
2959                        .with_metadata("field", json!("text"))
2960                        .with_metadata("source_id", json!(triplet.source_entity_id.to_string()))
2961                        .with_metadata("target_id", json!(triplet.target_entity_id.to_string()))
2962                        .with_metadata("relationship", json!(triplet.relationship_name.clone()));
2963
2964                    // Triplet special case (gap-05/08 §4.4): copy only the
2965                    // five `source_*` keys from the originating EdgeType's
2966                    // DataPoint, so Triplet's own flat fields are not
2967                    // overwritten.
2968                    let edge_type = edge_text_by_triple
2969                        .get(&(
2970                            triplet.source_entity_id,
2971                            triplet.target_entity_id,
2972                            triplet.relationship_name.as_str(),
2973                        ))
2974                        .and_then(|text| edge_type_by_text.get(text.as_str()));
2975                    if let Some(edge_type) = edge_type {
2976                        for (k, v) in edge_type.base.vector_metadata() {
2977                            if matches!(
2978                                k.as_str(),
2979                                "source_pipeline"
2980                                    | "source_task"
2981                                    | "source_user"
2982                                    | "source_node_set"
2983                                    | "source_content_hash"
2984                            ) {
2985                                point = point.with_metadata(k, v);
2986                            }
2987                        }
2988                    }
2989                    point
2990                })
2991                .collect();
2992
2993            vector_db
2994                .index_points("Triplet", "text", &triplet_points)
2995                .await
2996                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
2997
2998            stats.triplet_count = triplets.len();
2999            info!("Indexed {} triplets", triplets.len());
3000        }
3001    } else if config.embed_triplets {
3002        info!("Triplet embedding enabled but no edges/entities to index");
3003    }
3004
3005    // 5. Index EdgeType.relationship_name field
3006    if !edge_types.is_empty() {
3007        if !vector_db
3008            .has_collection("EdgeType", "relationship_name")
3009            .await
3010            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
3011        {
3012            vector_db
3013                .create_collection("EdgeType", "relationship_name", dimension)
3014                .await
3015                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
3016        }
3017
3018        let names: Vec<&str> = edge_types
3019            .iter()
3020            .map(|et| et.relationship_name.as_str())
3021            .collect();
3022        let vectors = engine
3023            .embed(&names)
3024            .await
3025            .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
3026
3027        let points: Vec<VectorPoint> = edge_types
3028            .iter()
3029            .zip(vectors)
3030            .map(|(et, vector)| {
3031                let mut point = VectorPoint::new(et.base.id, vector);
3032
3033                // 1. Full DataPoint dump (Python parity — see gap-05/08).
3034                for (k, v) in et.base.vector_metadata() {
3035                    point = point.with_metadata(k, v);
3036                }
3037
3038                // 2. Context-specific keys not present on the DataPoint.
3039                point = point
3040                    .with_metadata("field", json!("relationship_name"))
3041                    .with_metadata("relationship_name", json!(et.relationship_name.clone()))
3042                    .with_metadata("number_of_edges", json!(et.number_of_edges))
3043                    .with_metadata("dataset_id", json!(dataset_id.to_string()));
3044                if let Some(uid) = user_id {
3045                    point = point.with_metadata("user_id", json!(uid.to_string()));
3046                }
3047                if let Some(tid) = tenant_id {
3048                    point = point.with_metadata("tenant_id", json!(tid.to_string()));
3049                }
3050                point
3051            })
3052            .collect();
3053
3054        vector_db
3055            .index_points("EdgeType", "relationship_name", &points)
3056            .await
3057            .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
3058
3059        stats.record("EdgeType", "relationship_name", edge_types.len());
3060        info!("Indexed {} edge types", edge_types.len());
3061    }
3062
3063    // 6. Index Documents by name into `{ConcreteType}_name` collections
3064    //    (e.g. TextDocument_name, PdfDocument_name). Python indexes every
3065    //    Document subclass via its `index_fields=["name"]`
3066    //    (index_data_points.py:39-52). We group by the concrete subclass
3067    //    `data_type` so the collection names match Python's class names.
3068    if !documents.is_empty() {
3069        // Preserve a stable iteration order so the embed batches are
3070        // deterministic; group documents by their concrete type name.
3071        let mut by_type: std::collections::BTreeMap<&str, Vec<&Document>> =
3072            std::collections::BTreeMap::new();
3073        for d in documents {
3074            by_type
3075                .entry(d.base.data_type.as_str())
3076                .or_default()
3077                .push(d);
3078        }
3079
3080        for (type_name, docs) in by_type {
3081            if !vector_db
3082                .has_collection(type_name, "name")
3083                .await
3084                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?
3085            {
3086                vector_db
3087                    .create_collection(type_name, "name", dimension)
3088                    .await
3089                    .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
3090            }
3091
3092            let names: Vec<&str> = docs.iter().map(|d| d.name.as_str()).collect();
3093            let vectors = engine
3094                .embed(&names)
3095                .await
3096                .map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
3097
3098            let points: Vec<VectorPoint> = docs
3099                .iter()
3100                .zip(vectors)
3101                .map(|(doc, vector)| {
3102                    let mut point = VectorPoint::new(doc.base.id, vector);
3103
3104                    // 1. Full DataPoint dump (Python parity — see gap-05/08).
3105                    for (k, v) in doc.base.vector_metadata() {
3106                        point = point.with_metadata(k, v);
3107                    }
3108
3109                    // 2. Context-specific keys not present on the DataPoint.
3110                    point = point
3111                        .with_metadata("field", json!("name"))
3112                        .with_metadata("name", json!(doc.name.clone()))
3113                        .with_metadata("dataset_id", json!(dataset_id.to_string()));
3114                    if let Some(uid) = user_id {
3115                        point = point.with_metadata("user_id", json!(uid.to_string()));
3116                    }
3117                    if let Some(tid) = tenant_id {
3118                        point = point.with_metadata("tenant_id", json!(tid.to_string()));
3119                    }
3120                    point
3121                })
3122                .collect();
3123
3124            vector_db
3125                .index_points(type_name, "name", &points)
3126                .await
3127                .map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
3128
3129            stats.record(type_name, "name", docs.len());
3130            info!("Indexed {} {}", docs.len(), type_name);
3131        }
3132    }
3133
3134    Ok(stats)
3135}
3136
3137// ---------------------------------------------------------------------------
3138// TypedTask factories
3139// ---------------------------------------------------------------------------
3140
3141/// Name used by the executor's `stamp_tree_dyn` for the `classify_documents` task.
3142///
3143/// Kept as a `const` so the inline `stamp_provenance` literals removed in LIB-06-03
3144/// stay byte-stable with the executor's automatic stamp. Matches the historical
3145/// inline literal `"classify_documents"` at the convenience function call site.
3146pub const CLASSIFY_DOCUMENTS_TASK_NAME: &str = "classify_documents";
3147pub const EXTRACT_CHUNKS_TASK_NAME: &str = "extract_chunks_from_documents";
3148pub const EXTRACT_GRAPH_TASK_NAME: &str = "extract_graph_from_data";
3149pub const SUMMARIZE_TEXT_TASK_NAME: &str = "summarize_text";
3150pub const ADD_DATA_POINTS_TASK_NAME: &str = "add_data_points";
3151
3152/// Pipeline name carried by cognify task stamps (locked Decision 14 of
3153/// LIB-06). Used by the per-task in-body stamping below so the test in
3154/// `crates/cognify/tests/provenance_e2e.rs` sees `source_pipeline =
3155/// "cognify"` on every produced DataPoint.
3156const COGNIFY_PIPELINE_STAMP_NAME: &str = "cognify";
3157
3158/// Resolve the user label for in-body stamping from a [`TaskContext`].
3159///
3160/// Mirrors [`cognee_core::PipelineContext::user_label`]: prefer
3161/// `user_email`, fall back to `user_id.to_string()`, else `None`.
3162fn user_label_from_ctx(ctx: &Arc<cognee_core::TaskContext>) -> Option<String> {
3163    ctx.pipeline_ctx.as_ref().and_then(|p| p.user_label())
3164}
3165
3166/// Build a [`TypedTask`] that classifies Data items into Documents.
3167///
3168/// The returned task does **not** carry a name; the pipeline builder
3169/// [`build_cognify_pipeline`] wraps it with [`CLASSIFY_DOCUMENTS_TASK_NAME`].
3170///
3171/// In-body provenance stamping: stamps every emitted `Document` with
3172/// `source_pipeline = "cognify"` and `source_task = "classify_documents"`.
3173/// Necessary because `ClassifiedDocuments` is a non-`HasDataPoint` wrapper
3174/// not walked by the executor's `stamp_tree_dyn` (LIB-06-03 fixup).
3175pub fn make_classify_documents_task() -> TypedTask<CognifyInput, ClassifiedDocuments> {
3176    TypedTask::sync(|input: &CognifyInput, ctx| {
3177        let mut classified = classify_documents(input).map_err(|e| format!("{e}"))?;
3178        let user_label = user_label_from_ctx(&ctx);
3179        for doc in &mut classified.documents {
3180            stamp_provenance(
3181                &mut doc.base,
3182                COGNIFY_PIPELINE_STAMP_NAME,
3183                CLASSIFY_DOCUMENTS_TASK_NAME,
3184                user_label.as_deref(),
3185            );
3186        }
3187        Ok(Box::new(classified))
3188    })
3189}
3190
3191/// Build a [`TypedTask`] that extracts text chunks from classified documents.
3192///
3193/// In-body provenance stamping: stamps every emitted `DocumentChunk`
3194/// with `source_task = "extract_chunks_from_documents"`. Documents
3195/// inherited from the upstream wrapper keep their already-set stamp via
3196/// the `is_none()` guard inside [`stamp_provenance`].
3197pub fn make_extract_chunks_task(
3198    storage: Arc<dyn StorageTrait>,
3199    max_chunk_size: usize,
3200    token_counter_kind: TokenCounterKind,
3201    db: Option<Arc<DatabaseConnection>>,
3202    loader_registry: Arc<LoaderRegistry>,
3203) -> TypedTask<ClassifiedDocuments, ExtractedChunks> {
3204    TypedTask::async_fn(move |input: &ClassifiedDocuments, ctx| {
3205        let input = input.clone();
3206        let storage = Arc::clone(&storage);
3207        let db = db.clone();
3208        let token_counter_kind = token_counter_kind.clone();
3209        let loader_registry = Arc::clone(&loader_registry);
3210        let user_label = user_label_from_ctx(&ctx);
3211        Box::pin(async move {
3212            let mut extracted = extract_chunks_from_documents(
3213                &input,
3214                &*storage,
3215                max_chunk_size,
3216                token_counter_kind,
3217                db.as_deref(),
3218                &loader_registry,
3219            )
3220            .await
3221            .map_err(|e| format!("{e}"))?;
3222            for chunk in &mut extracted.chunks {
3223                stamp_provenance(
3224                    &mut chunk.base,
3225                    COGNIFY_PIPELINE_STAMP_NAME,
3226                    EXTRACT_CHUNKS_TASK_NAME,
3227                    user_label.as_deref(),
3228                );
3229            }
3230            // Documents carried forward keep their earlier stamp from
3231            // `classify_documents`; only stamp any that are still unstamped
3232            // (idempotent via the `is_none` guards).
3233            for doc in &mut extracted.documents {
3234                stamp_provenance(
3235                    &mut doc.base,
3236                    COGNIFY_PIPELINE_STAMP_NAME,
3237                    EXTRACT_CHUNKS_TASK_NAME,
3238                    user_label.as_deref(),
3239                );
3240            }
3241            Ok(Box::new(extracted))
3242        })
3243    })
3244}
3245
3246/// Build a [`TypedTask`] that extracts knowledge graphs from chunks via LLM.
3247///
3248/// In-body provenance stamping: stamps `entities[*].entity`,
3249/// `entities[*].entity_type` with `source_task = "extract_graph_from_data"`.
3250/// Carried-forward chunks/documents keep their earlier stamp via the
3251/// idempotent `is_none()` guards inside [`stamp_provenance`].
3252pub fn make_extract_graph_task(
3253    llm: Arc<dyn Llm>,
3254    graph_db: Arc<dyn GraphDBTrait>,
3255    ontology_resolver: Arc<dyn OntologyResolver>,
3256    config: CognifyConfig,
3257) -> TypedTask<ExtractedChunks, ExtractedGraphData> {
3258    TypedTask::async_fn(move |input: &ExtractedChunks, ctx| {
3259        let input = input.clone();
3260        let llm = Arc::clone(&llm);
3261        let graph_db = Arc::clone(&graph_db);
3262        let ontology_resolver = Arc::clone(&ontology_resolver);
3263        let config = config.clone();
3264        let user_label = user_label_from_ctx(&ctx);
3265        Box::pin(async move {
3266            let mut graph_data = extract_graph_from_data(
3267                &input,
3268                llm,
3269                Arc::clone(&graph_db),
3270                ontology_resolver,
3271                &config,
3272                user_label.as_deref(),
3273            )
3274            .await
3275            .map_err(|e| format!("{e}"))?;
3276            if config.create_web_page_nodes {
3277                create_web_page_nodes(&graph_data.documents, &graph_data.chunks, graph_db)
3278                    .await
3279                    .map_err(|e| format!("{e}"))?;
3280            }
3281            for pair in &mut graph_data.entities {
3282                stamp_provenance(
3283                    &mut pair.entity.base,
3284                    COGNIFY_PIPELINE_STAMP_NAME,
3285                    EXTRACT_GRAPH_TASK_NAME,
3286                    user_label.as_deref(),
3287                );
3288                stamp_provenance(
3289                    &mut pair.entity_type.base,
3290                    COGNIFY_PIPELINE_STAMP_NAME,
3291                    EXTRACT_GRAPH_TASK_NAME,
3292                    user_label.as_deref(),
3293                );
3294            }
3295            // Chunks/documents carried forward — idempotent re-stamp keeps
3296            // their upstream `source_task` intact via the `is_none` guard.
3297            for chunk in &mut graph_data.chunks {
3298                stamp_provenance(
3299                    &mut chunk.base,
3300                    COGNIFY_PIPELINE_STAMP_NAME,
3301                    EXTRACT_GRAPH_TASK_NAME,
3302                    user_label.as_deref(),
3303                );
3304            }
3305            for doc in &mut graph_data.documents {
3306                stamp_provenance(
3307                    &mut doc.base,
3308                    COGNIFY_PIPELINE_STAMP_NAME,
3309                    EXTRACT_GRAPH_TASK_NAME,
3310                    user_label.as_deref(),
3311                );
3312            }
3313            Ok(Box::new(graph_data))
3314        })
3315    })
3316}
3317
3318/// Build a [`TypedTask`] that summarizes text chunks via LLM.
3319///
3320/// In-body provenance stamping: stamps every emitted `TextSummary`
3321/// with `source_task = "summarize_text"`. Carried-forward
3322/// chunks/documents/entities keep their upstream stamps.
3323pub fn make_summarize_text_task(
3324    llm: Arc<dyn Llm>,
3325    config: CognifyConfig,
3326) -> TypedTask<ExtractedGraphData, SummarizedData> {
3327    TypedTask::async_fn(move |input: &ExtractedGraphData, ctx| {
3328        let input = input.clone();
3329        let llm = Arc::clone(&llm);
3330        let config = config.clone();
3331        let user_label = user_label_from_ctx(&ctx);
3332        Box::pin(async move {
3333            let mut summarized = summarize_text(&input, llm, &config)
3334                .await
3335                .map_err(|e| format!("{e}"))?;
3336            for summary in &mut summarized.summaries {
3337                stamp_provenance(
3338                    &mut summary.base,
3339                    COGNIFY_PIPELINE_STAMP_NAME,
3340                    SUMMARIZE_TEXT_TASK_NAME,
3341                    user_label.as_deref(),
3342                );
3343            }
3344            // Idempotent re-stamp of carried-forward DataPoints — only
3345            // ones that somehow escaped earlier stamping get filled in.
3346            for chunk in &mut summarized.chunks {
3347                stamp_provenance(
3348                    &mut chunk.base,
3349                    COGNIFY_PIPELINE_STAMP_NAME,
3350                    SUMMARIZE_TEXT_TASK_NAME,
3351                    user_label.as_deref(),
3352                );
3353            }
3354            for doc in &mut summarized.documents {
3355                stamp_provenance(
3356                    &mut doc.base,
3357                    COGNIFY_PIPELINE_STAMP_NAME,
3358                    SUMMARIZE_TEXT_TASK_NAME,
3359                    user_label.as_deref(),
3360                );
3361            }
3362            for pair in &mut summarized.entities {
3363                stamp_provenance(
3364                    &mut pair.entity.base,
3365                    COGNIFY_PIPELINE_STAMP_NAME,
3366                    SUMMARIZE_TEXT_TASK_NAME,
3367                    user_label.as_deref(),
3368                );
3369                stamp_provenance(
3370                    &mut pair.entity_type.base,
3371                    COGNIFY_PIPELINE_STAMP_NAME,
3372                    SUMMARIZE_TEXT_TASK_NAME,
3373                    user_label.as_deref(),
3374                );
3375            }
3376            Ok(Box::new(summarized))
3377        })
3378    })
3379}
3380
3381/// Build a [`TypedTask`] that generates embeddings and indexes data points.
3382///
3383/// In-body provenance stamping: idempotent re-stamp of every DataPoint
3384/// in the produced `CognifyResult`. Upstream tasks have already stamped
3385/// them with their specific `source_task`; this loop only fills in any
3386/// stragglers (e.g. fresh `EdgeType` entries or DataPoints constructed
3387/// inside `add_data_points` itself) — the `is_none` guards inside
3388/// [`stamp_provenance`] keep upstream stamps intact.
3389pub fn make_add_data_points_task(
3390    graph_db: Arc<dyn GraphDBTrait>,
3391    vector_db: Arc<dyn VectorDB>,
3392    embedding_engine: Arc<dyn EmbeddingEngine>,
3393    db: Option<Arc<DatabaseConnection>>,
3394    config: CognifyConfig,
3395) -> TypedTask<SummarizedData, CognifyResult> {
3396    TypedTask::async_fn(move |input: &SummarizedData, ctx| {
3397        let input = input.clone();
3398        let graph_db = Arc::clone(&graph_db);
3399        let vector_db = Arc::clone(&vector_db);
3400        let embedding_engine = Arc::clone(&embedding_engine);
3401        let db = db.clone();
3402        let config = config.clone();
3403        let user_label = user_label_from_ctx(&ctx);
3404        Box::pin(async move {
3405            let mut result =
3406                add_data_points(&input, graph_db, vector_db, embedding_engine, db, &config)
3407                    .await
3408                    .map_err(|e| format!("{e}"))?;
3409            for chunk in &mut result.chunks {
3410                stamp_provenance(
3411                    &mut chunk.base,
3412                    COGNIFY_PIPELINE_STAMP_NAME,
3413                    ADD_DATA_POINTS_TASK_NAME,
3414                    user_label.as_deref(),
3415                );
3416            }
3417            for pair in &mut result.entities {
3418                stamp_provenance(
3419                    &mut pair.entity.base,
3420                    COGNIFY_PIPELINE_STAMP_NAME,
3421                    ADD_DATA_POINTS_TASK_NAME,
3422                    user_label.as_deref(),
3423                );
3424                stamp_provenance(
3425                    &mut pair.entity_type.base,
3426                    COGNIFY_PIPELINE_STAMP_NAME,
3427                    ADD_DATA_POINTS_TASK_NAME,
3428                    user_label.as_deref(),
3429                );
3430            }
3431            for summary in &mut result.summaries {
3432                stamp_provenance(
3433                    &mut summary.base,
3434                    COGNIFY_PIPELINE_STAMP_NAME,
3435                    ADD_DATA_POINTS_TASK_NAME,
3436                    user_label.as_deref(),
3437                );
3438            }
3439            for edge_type in &mut result.edge_types {
3440                stamp_provenance(
3441                    &mut edge_type.base,
3442                    COGNIFY_PIPELINE_STAMP_NAME,
3443                    ADD_DATA_POINTS_TASK_NAME,
3444                    user_label.as_deref(),
3445                );
3446            }
3447            for doc in &mut result.documents_for_dlt {
3448                stamp_provenance(
3449                    &mut doc.base,
3450                    COGNIFY_PIPELINE_STAMP_NAME,
3451                    ADD_DATA_POINTS_TASK_NAME,
3452                    user_label.as_deref(),
3453                );
3454            }
3455            Ok(Box::new(result))
3456        })
3457    })
3458}
3459
3460// ---------------------------------------------------------------------------
3461// Pipeline builder
3462// ---------------------------------------------------------------------------
3463
3464/// Build a [`LoaderRegistry`] with the default text/pdf/csv loaders plus any
3465/// feature-gated media loaders that have the required handles available.
3466///
3467/// Centralized here so both [`build_cognify_pipeline`] and
3468/// [`build_temporal_cognify_pipeline`] stay in sync.
3469// `llm` is consumed only by the image loader and `config` only by the audio
3470// loader; when neither feature is enabled both are genuinely unused.
3471#[cfg_attr(
3472    not(any(feature = "image-loader", feature = "audio-loader")),
3473    allow(unused_variables)
3474)]
3475fn build_loader_registry(llm: &Arc<dyn Llm>, config: &CognifyConfig) -> LoaderRegistry {
3476    #[allow(unused_mut)]
3477    let mut registry = LoaderRegistry::default_registry();
3478    #[cfg(feature = "image-loader")]
3479    registry.register("image", Arc::new(ImageLoader::new(Arc::clone(llm))));
3480    #[cfg(feature = "audio-loader")]
3481    if let Some(ref transcriber_handle) = config.transcriber {
3482        registry.register(
3483            "audio",
3484            Arc::new(AudioLoader::new(Arc::clone(&transcriber_handle.0))),
3485        );
3486    }
3487    registry
3488}
3489
3490/// Build a complete cognify [`Pipeline`]:
3491/// [`CognifyInput`] → classify → chunk → extract_graph → summarize → add_data_points → [`CognifyResult`].
3492///
3493/// The `user_id` and `tenant_id` parameters are threaded through all pipeline
3494/// stages and included as metadata on vector points and graph nodes.
3495///
3496/// For composable pipeline-based execution (with concurrency, retry, progress
3497/// tracking, etc.), pass the result to [`cognee_core::execute`].
3498#[allow(clippy::too_many_arguments)]
3499pub fn build_cognify_pipeline(
3500    storage: Arc<dyn StorageTrait>,
3501    graph_db: Arc<dyn GraphDBTrait>,
3502    vector_db: Arc<dyn VectorDB>,
3503    embedding_engine: Arc<dyn EmbeddingEngine>,
3504    llm: Arc<dyn Llm>,
3505    db: Option<Arc<DatabaseConnection>>,
3506    ontology_resolver: Arc<dyn OntologyResolver>,
3507    config: CognifyConfig,
3508) -> Pipeline {
3509    let loader_registry = Arc::new(build_loader_registry(&llm, &config));
3510    PipelineBuilder::new_with_task("cognify", make_classify_documents_task())
3511        .with_first_task_name(CLASSIFY_DOCUMENTS_TASK_NAME)
3512        .add_task_named(
3513            make_extract_chunks_task(
3514                storage,
3515                config.max_chunk_size,
3516                config.token_counter_kind.clone(),
3517                db.clone(),
3518                loader_registry,
3519            ),
3520            EXTRACT_CHUNKS_TASK_NAME,
3521        )
3522        .add_task_named(
3523            make_extract_graph_task(
3524                Arc::clone(&llm),
3525                Arc::clone(&graph_db),
3526                ontology_resolver,
3527                config.clone(),
3528            ),
3529            EXTRACT_GRAPH_TASK_NAME,
3530        )
3531        .add_task_named(
3532            make_summarize_text_task(llm, config.clone()),
3533            SUMMARIZE_TEXT_TASK_NAME,
3534        )
3535        .add_task_named(
3536            make_add_data_points_task(graph_db, vector_db, embedding_engine, db, config),
3537            ADD_DATA_POINTS_TASK_NAME,
3538        )
3539        .with_name("cognify")
3540        .build()
3541}
3542
3543/// Build a [`TypedTask`] that extracts temporal events from chunks via LLM.
3544pub fn make_extract_temporal_events_task(
3545    llm: Arc<dyn Llm>,
3546    config: CognifyConfig,
3547) -> TypedTask<ExtractedChunks, ExtractedTemporalEvents> {
3548    TypedTask::async_fn(move |input: &ExtractedChunks, _ctx| {
3549        let input = input.clone();
3550        let llm = Arc::clone(&llm);
3551        let config = config.clone();
3552        Box::pin(async move {
3553            extract_temporal_events(&input, llm, &config)
3554                .await
3555                .map(Box::new)
3556                .map_err(|e| format!("{e}").into())
3557        })
3558    })
3559}
3560
3561/// Build a [`TypedTask`] that persists temporal events to graph and vector DBs.
3562pub fn make_add_temporal_data_points_task(
3563    graph_db: Arc<dyn GraphDBTrait>,
3564    vector_db: Arc<dyn VectorDB>,
3565    embedding_engine: Arc<dyn EmbeddingEngine>,
3566) -> TypedTask<ExtractedTemporalEvents, CognifyResult> {
3567    TypedTask::async_fn(move |input: &ExtractedTemporalEvents, _ctx| {
3568        let input = input.clone();
3569        let graph_db = Arc::clone(&graph_db);
3570        let vector_db = Arc::clone(&vector_db);
3571        let embedding_engine = Arc::clone(&embedding_engine);
3572        Box::pin(async move {
3573            add_temporal_data_points(&input, graph_db, vector_db, embedding_engine)
3574                .await
3575                .map(Box::new)
3576                .map_err(|e| format!("{e}").into())
3577        })
3578    })
3579}
3580
3581/// Build a complete temporal cognify [`Pipeline`]:
3582/// [`CognifyInput`] → classify → chunk → extract_temporal_events → add_temporal_data_points → [`CognifyResult`].
3583///
3584/// This pipeline runs instead of the standard cognify pipeline when
3585/// `CognifyConfig::temporal_cognify` is `true`. It mirrors the Python
3586/// `get_temporal_tasks()` pipeline that replaces the default stages with
3587/// event/timestamp extraction and temporal graph construction.
3588pub fn build_temporal_cognify_pipeline(
3589    storage: Arc<dyn StorageTrait>,
3590    graph_db: Arc<dyn GraphDBTrait>,
3591    vector_db: Arc<dyn VectorDB>,
3592    embedding_engine: Arc<dyn EmbeddingEngine>,
3593    llm: Arc<dyn Llm>,
3594    db: Option<Arc<DatabaseConnection>>,
3595    config: CognifyConfig,
3596) -> Pipeline {
3597    let loader_registry = Arc::new(build_loader_registry(&llm, &config));
3598    PipelineBuilder::new_with_task("temporal-cognify", make_classify_documents_task())
3599        .with_first_task_name(CLASSIFY_DOCUMENTS_TASK_NAME)
3600        .add_task_named(
3601            make_extract_chunks_task(
3602                storage,
3603                config.max_chunk_size,
3604                config.token_counter_kind.clone(),
3605                db,
3606                loader_registry,
3607            ),
3608            EXTRACT_CHUNKS_TASK_NAME,
3609        )
3610        .add_task_named(
3611            make_extract_temporal_events_task(llm, config),
3612            "extract_temporal_events",
3613        )
3614        .add_task_named(
3615            make_add_temporal_data_points_task(graph_db, vector_db, embedding_engine),
3616            "add_temporal_data_points",
3617        )
3618        .with_name("temporal-cognify")
3619        .build()
3620}
3621
3622#[cfg(test)]
3623#[allow(
3624    clippy::unwrap_used,
3625    clippy::expect_used,
3626    reason = "test code — panics are acceptable failures"
3627)]
3628mod tests {
3629    use super::*;
3630    use cognee_models::DataPoint;
3631    use cognee_storage::MockStorage;
3632
3633    #[test]
3634    fn test_classify_documents_empty() {
3635        let input = CognifyInput {
3636            data_items: vec![],
3637            dataset_id: Uuid::new_v4(),
3638            user_id: None,
3639            tenant_id: None,
3640        };
3641        let result = classify_documents(&input).unwrap();
3642        assert!(result.documents.is_empty());
3643    }
3644
3645    #[test]
3646    fn test_classify_documents_text_data() {
3647        let data = Data::builder(
3648            Uuid::new_v4(),
3649            "test.txt",
3650            "/storage/test.txt",
3651            "text://test",
3652            "txt",
3653            "text/plain",
3654            "hash123",
3655            Uuid::new_v4(),
3656        )
3657        .build();
3658
3659        let input = CognifyInput {
3660            data_items: vec![data],
3661            dataset_id: Uuid::new_v4(),
3662            user_id: None,
3663            tenant_id: None,
3664        };
3665        let result = classify_documents(&input).unwrap();
3666        assert_eq!(result.documents.len(), 1);
3667    }
3668
3669    #[test]
3670    fn test_classify_documents_skips_unknown_extension() {
3671        let data = Data::builder(
3672            Uuid::new_v4(),
3673            "data.xyz",
3674            "/storage/data.xyz",
3675            "file://data.xyz",
3676            "xyz",
3677            "application/octet-stream",
3678            "hash456",
3679            Uuid::new_v4(),
3680        )
3681        .build();
3682
3683        let input = CognifyInput {
3684            data_items: vec![data],
3685            dataset_id: Uuid::new_v4(),
3686            user_id: None,
3687            tenant_id: None,
3688        };
3689        let result = classify_documents(&input).unwrap();
3690        assert!(result.documents.is_empty());
3691    }
3692
3693    #[tokio::test]
3694    async fn test_extract_chunks_from_documents() {
3695        let storage = Arc::new(MockStorage::new());
3696        let location = storage
3697            .store(b"Hello world. This is a test.", "test.txt")
3698            .await
3699            .unwrap();
3700
3701        let doc_id = Uuid::new_v4();
3702        let mut base = DataPoint::new("TextDocument", None);
3703        base.id = doc_id;
3704        base.set_metadata("index_fields", serde_json::json!(["name"]));
3705        let doc = Document {
3706            base,
3707            document_type: "text".to_string(),
3708            name: "test.txt".to_string(),
3709            raw_data_location: location,
3710            mime_type: "text/plain".to_string(),
3711            extension: "txt".to_string(),
3712            data_id: doc_id,
3713            external_metadata: None,
3714        };
3715
3716        let input = ClassifiedDocuments {
3717            documents: vec![doc],
3718            dataset_id: Uuid::new_v4(),
3719            user_id: None,
3720            tenant_id: None,
3721        };
3722
3723        let registry = LoaderRegistry::default();
3724        let result = extract_chunks_from_documents(
3725            &input,
3726            &*storage,
3727            100,
3728            TokenCounterKind::Word,
3729            None,
3730            &registry,
3731        )
3732        .await
3733        .unwrap();
3734        assert!(!result.chunks.is_empty());
3735    }
3736
3737    #[tokio::test]
3738    async fn test_extract_chunks_empty_documents() {
3739        let storage = Arc::new(MockStorage::new());
3740        let input = ClassifiedDocuments {
3741            documents: vec![],
3742            dataset_id: Uuid::new_v4(),
3743            user_id: None,
3744            tenant_id: None,
3745        };
3746
3747        let registry = LoaderRegistry::default();
3748        let result = extract_chunks_from_documents(
3749            &input,
3750            &*storage,
3751            100,
3752            TokenCounterKind::Word,
3753            None,
3754            &registry,
3755        )
3756        .await
3757        .unwrap();
3758        assert!(result.chunks.is_empty());
3759    }
3760
3761    #[tokio::test]
3762    async fn test_dlt_short_circuit() {
3763        let storage = Arc::new(MockStorage::new());
3764        let location = storage
3765            .store(b"  some dlt row content  ", "dlt.txt")
3766            .await
3767            .unwrap();
3768
3769        let doc_id = Uuid::new_v4();
3770        let mut base = DataPoint::new("DltRowDocument", None);
3771        base.id = doc_id;
3772        base.set_metadata("index_fields", serde_json::json!(["text"]));
3773        let doc = Document {
3774            base,
3775            document_type: "dlt_row".to_string(),
3776            name: "dlt.txt".to_string(),
3777            raw_data_location: location,
3778            mime_type: "text/plain".to_string(),
3779            extension: "txt".to_string(),
3780            data_id: doc_id,
3781            external_metadata: None,
3782        };
3783
3784        let input = ClassifiedDocuments {
3785            documents: vec![doc],
3786            dataset_id: Uuid::new_v4(),
3787            user_id: None,
3788            tenant_id: None,
3789        };
3790
3791        let registry = LoaderRegistry::default();
3792        let result = extract_chunks_from_documents(
3793            &input,
3794            &*storage,
3795            100,
3796            TokenCounterKind::Word,
3797            None,
3798            &registry,
3799        )
3800        .await
3801        .unwrap();
3802
3803        assert_eq!(result.chunks.len(), 1);
3804        let chunk = &result.chunks[0];
3805        assert_eq!(chunk.text, "some dlt row content");
3806        assert_eq!(chunk.cut_type, "dlt_row");
3807        assert_eq!(chunk.chunk_index, 0);
3808        assert_eq!(chunk.document_id, doc_id);
3809    }
3810
3811    #[tokio::test]
3812    async fn test_unsupported_document_type() {
3813        // Use a document_type that is intentionally never registered in
3814        // LoaderRegistry::default(). The previous fixture used "pdf", but the
3815        // PDF loader added in phase2/task1 made that type supported, causing
3816        // this test to invoke the real PDFium loader on garbage bytes.
3817        const UNSUPPORTED: &str = "no_such_loader_type_for_test";
3818
3819        let storage = Arc::new(MockStorage::new());
3820        let location = storage.store(b"some content", "test.bin").await.unwrap();
3821
3822        let doc_id = Uuid::new_v4();
3823        let mut base = DataPoint::new("UnknownDocument", None);
3824        base.id = doc_id;
3825        base.set_metadata("index_fields", serde_json::json!(["text"]));
3826        let doc = Document {
3827            base,
3828            document_type: UNSUPPORTED.to_string(),
3829            name: "test.bin".to_string(),
3830            raw_data_location: location,
3831            mime_type: "application/octet-stream".to_string(),
3832            extension: "bin".to_string(),
3833            data_id: doc_id,
3834            external_metadata: None,
3835        };
3836
3837        let input = ClassifiedDocuments {
3838            documents: vec![doc],
3839            dataset_id: Uuid::new_v4(),
3840            user_id: None,
3841            tenant_id: None,
3842        };
3843
3844        let registry = LoaderRegistry::default();
3845        let result = extract_chunks_from_documents(
3846            &input,
3847            &*storage,
3848            100,
3849            TokenCounterKind::Word,
3850            None,
3851            &registry,
3852        )
3853        .await;
3854
3855        assert!(result.is_err());
3856        let err = result.unwrap_err();
3857        assert!(
3858            matches!(err, CognifyError::UnsupportedDocumentType(ref t) if t == UNSUPPORTED),
3859            "expected UnsupportedDocumentType({UNSUPPORTED:?}), got: {err:?}"
3860        );
3861    }
3862
3863    #[test]
3864    fn test_classify_documents_preserves_dataset_id() {
3865        let dataset_id = Uuid::new_v4();
3866        let input = CognifyInput {
3867            data_items: vec![],
3868            dataset_id,
3869            user_id: None,
3870            tenant_id: None,
3871        };
3872        let result = classify_documents(&input).unwrap();
3873        assert_eq!(result.dataset_id, dataset_id);
3874    }
3875
3876    // ── Provenance guard and ID tests ───────────────────────────────────
3877
3878    #[test]
3879    fn provenance_node_id_works_with_none_tenant() {
3880        let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
3881        let dataset_id = Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap();
3882        let data_id = Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap();
3883        let node_id = Uuid::parse_str("00000000-0000-0000-0000-000000000004").unwrap();
3884
3885        // Must not panic with None tenant
3886        let id = provenance_node_id(None, user_id, dataset_id, data_id, node_id);
3887
3888        // Matches Python's str(None) → "None" in the UUID5 input
3889        let expected_input = format!("None{user_id}{dataset_id}{data_id}{node_id}");
3890        let expected = Uuid::new_v5(&Uuid::NAMESPACE_OID, expected_input.as_bytes());
3891        assert_eq!(id, expected);
3892    }
3893
3894    #[test]
3895    fn provenance_node_id_with_real_tenant_differs_from_none() {
3896        let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
3897        let dataset_id = Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap();
3898        let data_id = Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap();
3899        let node_id = Uuid::parse_str("00000000-0000-0000-0000-000000000004").unwrap();
3900        let tenant_id = Uuid::parse_str("00000000-0000-0000-0000-000000000005").unwrap();
3901
3902        let id_none = provenance_node_id(None, user_id, dataset_id, data_id, node_id);
3903        let id_real = provenance_node_id(Some(tenant_id), user_id, dataset_id, data_id, node_id);
3904        assert_ne!(id_none, id_real);
3905    }
3906
3907    #[test]
3908    fn provenance_edge_id_works_with_none_tenant() {
3909        let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
3910        let dataset_id = Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap();
3911        let source_id = Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap();
3912        let target_id = Uuid::parse_str("00000000-0000-0000-0000-000000000004").unwrap();
3913
3914        let id = provenance_edge_id(
3915            None,
3916            user_id,
3917            dataset_id,
3918            source_id,
3919            "relates_to",
3920            target_id,
3921        );
3922
3923        let expected_input = format!("None{user_id}{dataset_id}{source_id}relates_to{target_id}");
3924        let expected = Uuid::new_v5(&Uuid::NAMESPACE_OID, expected_input.as_bytes());
3925        assert_eq!(id, expected);
3926    }
3927
3928    /// The provenance guard must fire when db + user_id are present,
3929    /// even if tenant_id is None.  This matches Python's
3930    /// `if user and dataset and data:` which doesn't check tenant.
3931    #[test]
3932    fn dlt_fk_rel_name_always_includes_ref_col_separator() {
3933        // Python: rel_name = f"{table_name}:{fk_col}->{ref_table}:{ref_col}"
3934        // This always includes the colon before ref_col, even when ref_col is empty.
3935
3936        // Case 1: non-empty ref_col
3937        let table_name = "orders";
3938        let fk_col = "customer_id";
3939        let ref_table = "customers";
3940        let ref_col = "id";
3941        let rel_name = format!("{table_name}:{fk_col}->{ref_table}:{ref_col}");
3942        assert_eq!(rel_name, "orders:customer_id->customers:id");
3943
3944        let rel_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("dlt:{rel_name}").as_bytes());
3945        let expected_id = Uuid::new_v5(
3946            &Uuid::NAMESPACE_OID,
3947            b"dlt:orders:customer_id->customers:id",
3948        );
3949        assert_eq!(rel_id, expected_id);
3950
3951        // Case 2: empty ref_col -- must still include trailing colon
3952        let ref_col_empty = "";
3953        let rel_name_empty = format!("{table_name}:{fk_col}->{ref_table}:{ref_col_empty}");
3954        assert_eq!(
3955            rel_name_empty, "orders:customer_id->customers:",
3956            "rel_name must include trailing colon even when ref_col is empty"
3957        );
3958
3959        let rel_id_empty = Uuid::new_v5(
3960            &Uuid::NAMESPACE_OID,
3961            format!("dlt:{rel_name_empty}").as_bytes(),
3962        );
3963        let expected_id_empty =
3964            Uuid::new_v5(&Uuid::NAMESPACE_OID, b"dlt:orders:customer_id->customers:");
3965        assert_eq!(rel_id_empty, expected_id_empty);
3966
3967        // Verify the two IDs differ (trailing colon changes the UUID5 seed)
3968        assert_ne!(
3969            rel_id, rel_id_empty,
3970            "non-empty and empty ref_col must produce different UUIDs"
3971        );
3972    }
3973
3974    #[test]
3975    fn provenance_guard_does_not_require_tenant_id() {
3976        // Simulate the guard condition from cognify():
3977        //   if let (Some(db), Some(user_id)) = (&db, input.user_id)
3978        let db: Option<u8> = Some(1); // stand-in for Some(db)
3979        let user_id: Option<Uuid> = Some(Uuid::new_v4());
3980        let tenant_id: Option<Uuid> = None;
3981
3982        let guard_fires = matches!((&db, user_id), (Some(_), Some(_)));
3983        assert!(
3984            guard_fires,
3985            "Provenance guard must fire when db + user_id are present, regardless of tenant_id"
3986        );
3987
3988        // Also verify the old (broken) guard would NOT fire
3989        let old_guard_fires = matches!((&db, user_id, tenant_id), (Some(_), Some(_), Some(_)));
3990        assert!(
3991            !old_guard_fires,
3992            "The old 3-way guard should NOT fire when tenant_id is None"
3993        );
3994    }
3995
3996    fn test_document_with_metadata(doc_id: Uuid, external_metadata: Option<String>) -> Document {
3997        let mut base = DataPoint::new("TextDocument", None);
3998        base.id = doc_id;
3999        Document {
4000            base,
4001            document_type: "text".to_string(),
4002            name: "test.txt".to_string(),
4003            raw_data_location: "file:///tmp/test.txt".to_string(),
4004            mime_type: "text/plain".to_string(),
4005            extension: "txt".to_string(),
4006            data_id: doc_id,
4007            external_metadata,
4008        }
4009    }
4010
4011    fn test_chunk(chunk_id: Uuid, doc_id: Uuid, text: &str) -> DocumentChunk {
4012        DocumentChunk::new(
4013            chunk_id,
4014            text.to_string(),
4015            text.split_whitespace().count(),
4016            0,
4017            "paragraph_end".to_string(),
4018            doc_id,
4019        )
4020    }
4021
4022    fn url_metadata(url: &str, final_url: &str, title: &str) -> String {
4023        json!({
4024            "source": "url",
4025            "url": url,
4026            "final_url": final_url,
4027            "content_type": "text/html",
4028            "title": title,
4029        })
4030        .to_string()
4031    }
4032
4033    #[tokio::test]
4034    async fn add_data_points_stores_document_node_and_indexes_document_name() {
4035        use cognee_embedding::MockEmbeddingEngine;
4036        use cognee_vector::MockVectorDB;
4037
4038        let graph: Arc<dyn GraphDBTrait> = Arc::new(cognee_graph::MockGraphDB::new());
4039        let vector: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
4040        let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
4041
4042        let doc_id = Uuid::parse_str("00000000-0000-0000-0000-0000000000a1").unwrap();
4043        let chunk_id = Uuid::parse_str("00000000-0000-0000-0000-0000000000b1").unwrap();
4044        let document = test_document_with_metadata(doc_id, None);
4045        let chunk = test_chunk(chunk_id, doc_id, "Hello world");
4046
4047        let input = SummarizedData {
4048            chunks: vec![chunk],
4049            documents: vec![document],
4050            entities: vec![],
4051            edges: vec![],
4052            summaries: vec![],
4053            dataset_id: Uuid::new_v4(),
4054            user_id: None,
4055            tenant_id: None,
4056        };
4057
4058        let config = CognifyConfig::default();
4059        add_data_points(
4060            &input,
4061            Arc::clone(&graph),
4062            Arc::clone(&vector),
4063            Arc::clone(&engine),
4064            None,
4065            &config,
4066        )
4067        .await
4068        .unwrap();
4069
4070        // (a) Document stored as a graph node with id == data id and the
4071        //     concrete subclass type.
4072        let node = graph
4073            .get_node(&doc_id.to_string())
4074            .await
4075            .unwrap()
4076            .expect("document node should exist");
4077        assert_eq!(
4078            node.get("type").and_then(|v| v.as_str()),
4079            Some("TextDocument")
4080        );
4081
4082        // (b) A TextDocument_name collection exists with exactly one point.
4083        assert!(vector.has_collection("TextDocument", "name").await.unwrap());
4084        assert_eq!(
4085            vector
4086                .collection_size("TextDocument", "name")
4087                .await
4088                .unwrap(),
4089            1
4090        );
4091    }
4092
4093    #[tokio::test]
4094    async fn extracted_edge_description_persists_as_edge_text_property() {
4095        use crate::fact_extraction::{Edge, KnowledgeGraph, Node};
4096        use cognee_ontology::NoOpOntologyResolver;
4097
4098        let graph = KnowledgeGraph {
4099            nodes: vec![
4100                Node {
4101                    id: "alice".to_string(),
4102                    name: "Alice".to_string(),
4103                    node_type: "PERSON".to_string(),
4104                    description: "A person".to_string(),
4105                },
4106                Node {
4107                    id: "acme".to_string(),
4108                    name: "Acme".to_string(),
4109                    node_type: "ORGANIZATION".to_string(),
4110                    description: "A company".to_string(),
4111                },
4112            ],
4113            edges: vec![Edge {
4114                source_node_id: "alice".to_string(),
4115                target_node_id: "acme".to_string(),
4116                relationship_name: "founded".to_string(),
4117                // Leading/trailing whitespace exercises the trim semantics.
4118                description: Some("  Alice founded Acme  ".to_string()),
4119            }],
4120        };
4121
4122        let chunk_id = Uuid::new_v4();
4123        let dataset_id = Uuid::new_v4();
4124        let resolver = NoOpOntologyResolver::new();
4125
4126        let (_nodes, edges) = expand_with_nodes_and_edges(
4127            vec![(chunk_id, graph)],
4128            dataset_id,
4129            &HashSet::new(),
4130            &resolver,
4131            None,
4132        )
4133        .await;
4134
4135        assert_eq!(edges.len(), 1);
4136        let edge_text = edges[0]
4137            .properties
4138            .get("edge_text")
4139            .expect("edge_text property should be set");
4140        // Trimmed, matching Python _strip_nonblank_text.
4141        assert_eq!(edge_text, "Alice founded Acme");
4142    }
4143
4144    #[test]
4145    fn cognify_config_creates_web_page_nodes_by_default() {
4146        assert!(CognifyConfig::default().create_web_page_nodes);
4147        assert!(
4148            !CognifyConfig::default()
4149                .with_web_page_nodes(false)
4150                .create_web_page_nodes
4151        );
4152    }
4153
4154    #[tokio::test]
4155    async fn create_web_page_nodes_creates_deterministic_page_site_and_edges() {
4156        let graph = Arc::new(cognee_graph::MockGraphDB::new());
4157        let doc_id = Uuid::parse_str("00000000-0000-0000-0000-000000000101").unwrap();
4158        let chunk_id = Uuid::parse_str("00000000-0000-0000-0000-000000000201").unwrap();
4159        let final_url = "https://Example.com/path?q=1";
4160        let documents = vec![test_document_with_metadata(
4161            doc_id,
4162            Some(url_metadata(
4163                "https://example.com/start",
4164                final_url,
4165                "Example title",
4166            )),
4167        )];
4168        let chunks = vec![test_chunk(chunk_id, doc_id, "Visible page content")];
4169
4170        create_web_page_nodes(&documents, &chunks, graph.clone())
4171            .await
4172            .unwrap();
4173
4174        let page_id = web_page_id("https://example.com/path?q=1").to_string();
4175        let site_id = web_site_id("example.com").to_string();
4176        let (nodes, edges) = graph.get_graph_data().await.unwrap();
4177        assert_eq!(nodes.len(), 2);
4178
4179        let page = graph.get_node(&page_id).await.unwrap().unwrap();
4180        assert_eq!(page.get("type").and_then(|v| v.as_str()), Some("WebPage"));
4181        assert_eq!(
4182            page.get("url").and_then(|v| v.as_str()),
4183            Some("https://example.com/path?q=1")
4184        );
4185        assert_eq!(
4186            page.get("title").and_then(|v| v.as_str()),
4187            Some("Example title")
4188        );
4189        assert_eq!(
4190            page.get("content").and_then(|v| v.as_str()),
4191            Some("Visible page content")
4192        );
4193        assert!(
4194            !page.contains_key("created_at"),
4195            "WebPage node payload should be deterministic"
4196        );
4197
4198        let site = graph.get_node(&site_id).await.unwrap().unwrap();
4199        assert_eq!(site.get("type").and_then(|v| v.as_str()), Some("WebSite"));
4200        assert_eq!(
4201            site.get("domain").and_then(|v| v.as_str()),
4202            Some("example.com")
4203        );
4204
4205        assert_eq!(edges.len(), 2);
4206        assert!(edges.iter().any(|(source, target, rel, _)| {
4207            source == &page_id && target == &site_id && rel == "PART_OF"
4208        }));
4209        assert!(edges.iter().any(|(source, target, rel, _)| {
4210            source == &chunk_id.to_string() && target == &page_id && rel == "SOURCED_FROM"
4211        }));
4212    }
4213
4214    #[tokio::test]
4215    async fn create_web_page_nodes_truncates_content_to_500_chars() {
4216        let graph = Arc::new(cognee_graph::MockGraphDB::new());
4217        let doc_id = Uuid::new_v4();
4218        let long_text = "a".repeat(650);
4219        let documents = vec![test_document_with_metadata(
4220            doc_id,
4221            Some(url_metadata(
4222                "https://example.com/long",
4223                "https://example.com/long",
4224                "Long",
4225            )),
4226        )];
4227        let chunks = vec![test_chunk(Uuid::new_v4(), doc_id, &long_text)];
4228
4229        create_web_page_nodes(&documents, &chunks, graph.clone())
4230            .await
4231            .unwrap();
4232
4233        let page_id = web_page_id("https://example.com/long").to_string();
4234        let page = graph.get_node(&page_id).await.unwrap().unwrap();
4235        assert_eq!(
4236            page.get("content")
4237                .and_then(|v| v.as_str())
4238                .unwrap()
4239                .chars()
4240                .count(),
4241            500
4242        );
4243    }
4244
4245    #[tokio::test]
4246    async fn create_web_page_nodes_skips_invalid_and_non_url_metadata() {
4247        let graph = Arc::new(cognee_graph::MockGraphDB::new());
4248        let doc_with_invalid_json =
4249            test_document_with_metadata(Uuid::new_v4(), Some("{not valid json".to_string()));
4250        let non_url_doc = test_document_with_metadata(
4251            Uuid::new_v4(),
4252            Some(json!({"source": "dlt", "url": "https://example.com"}).to_string()),
4253        );
4254        let bad_url_doc = test_document_with_metadata(
4255            Uuid::new_v4(),
4256            Some(json!({"source": "url", "final_url": "not a url"}).to_string()),
4257        );
4258        let chunks = vec![
4259            test_chunk(Uuid::new_v4(), doc_with_invalid_json.base.id, "a"),
4260            test_chunk(Uuid::new_v4(), non_url_doc.base.id, "b"),
4261            test_chunk(Uuid::new_v4(), bad_url_doc.base.id, "c"),
4262        ];
4263
4264        create_web_page_nodes(
4265            &[doc_with_invalid_json, non_url_doc, bad_url_doc],
4266            &chunks,
4267            graph.clone(),
4268        )
4269        .await
4270        .unwrap();
4271
4272        assert_eq!(graph.node_count(), 0);
4273        assert_eq!(graph.edge_count(), 0);
4274    }
4275
4276    #[tokio::test]
4277    async fn create_web_page_nodes_is_idempotent_for_edges() {
4278        let graph = Arc::new(cognee_graph::MockGraphDB::new());
4279        let doc_id = Uuid::new_v4();
4280        let documents = vec![test_document_with_metadata(
4281            doc_id,
4282            Some(url_metadata(
4283                "https://example.com/idempotent",
4284                "https://example.com/idempotent",
4285                "Idempotent",
4286            )),
4287        )];
4288        let chunks = vec![test_chunk(Uuid::new_v4(), doc_id, "content")];
4289
4290        create_web_page_nodes(&documents, &chunks, graph.clone())
4291            .await
4292            .unwrap();
4293        create_web_page_nodes(&documents, &chunks, graph.clone())
4294            .await
4295            .unwrap();
4296
4297        assert_eq!(graph.node_count(), 2);
4298        assert_eq!(graph.edge_count(), 2);
4299    }
4300
4301    #[tokio::test]
4302    async fn make_extract_graph_task_wires_web_page_nodes_and_respects_opt_out() {
4303        use cognee_ontology::NoOpOntologyResolver;
4304        use cognee_test_utils::{MockLlm, test_task_context};
4305
4306        let doc_id = Uuid::new_v4();
4307        let input = ExtractedChunks {
4308            chunks: vec![test_chunk(Uuid::new_v4(), doc_id, "content")],
4309            documents: vec![test_document_with_metadata(
4310                doc_id,
4311                Some(url_metadata(
4312                    "https://example.com/wired",
4313                    "https://example.com/wired",
4314                    "Wired",
4315                )),
4316            )],
4317            dataset_id: Uuid::new_v4(),
4318            user_id: None,
4319            tenant_id: None,
4320        };
4321
4322        let graph = Arc::new(cognee_graph::MockGraphDB::new());
4323        let (_, ctx, _) = test_task_context().await;
4324        let task = make_extract_graph_task(
4325            Arc::new(MockLlm::empty()),
4326            graph.clone(),
4327            Arc::new(NoOpOntologyResolver::new()),
4328            CognifyConfig::default(),
4329        );
4330        let TypedTask::Async(run) = task else {
4331            panic!("extract graph task should be async");
4332        };
4333        run(&input, ctx.clone()).await.unwrap();
4334        assert_eq!(graph.node_count(), 2);
4335        assert_eq!(graph.edge_count(), 2);
4336
4337        let graph = Arc::new(cognee_graph::MockGraphDB::new());
4338        let task = make_extract_graph_task(
4339            Arc::new(MockLlm::empty()),
4340            graph.clone(),
4341            Arc::new(NoOpOntologyResolver::new()),
4342            CognifyConfig::default().with_web_page_nodes(false),
4343        );
4344        let TypedTask::Async(run) = task else {
4345            panic!("extract graph task should be async");
4346        };
4347        run(&input, ctx).await.unwrap();
4348        assert_eq!(graph.node_count(), 0);
4349        assert_eq!(graph.edge_count(), 0);
4350    }
4351
4352    #[tokio::test]
4353    async fn test_summarize_text_skips_dlt_chunks() {
4354        use cognee_test_utils::MockLlm;
4355
4356        let doc_id_text = Uuid::new_v4();
4357        let doc_id_dlt = Uuid::new_v4();
4358
4359        let mut base_text = DataPoint::new("TextDocument", None);
4360        base_text.id = doc_id_text;
4361        let text_doc = Document {
4362            base: base_text,
4363            document_type: "text".to_string(),
4364            name: "test.txt".to_string(),
4365            raw_data_location: "file:///tmp/test.txt".to_string(),
4366            mime_type: "text/plain".to_string(),
4367            extension: "txt".to_string(),
4368            data_id: doc_id_text,
4369            external_metadata: None,
4370        };
4371
4372        let mut base_dlt = DataPoint::new("DltRowDocument", None);
4373        base_dlt.id = doc_id_dlt;
4374        let dlt_doc = Document {
4375            base: base_dlt,
4376            document_type: "dlt_row".to_string(),
4377            name: "dlt_row.json".to_string(),
4378            raw_data_location: "file:///tmp/dlt_row.json".to_string(),
4379            mime_type: "application/json".to_string(),
4380            extension: "json".to_string(),
4381            data_id: doc_id_dlt,
4382            external_metadata: None,
4383        };
4384
4385        let text_chunk = DocumentChunk::new(
4386            Uuid::new_v4(),
4387            "Some meaningful text to summarize.".to_string(),
4388            5,
4389            0,
4390            "paragraph_end".to_string(),
4391            doc_id_text,
4392        );
4393
4394        let dlt_chunk = DocumentChunk::new(
4395            Uuid::new_v4(),
4396            r#"{"id": 1, "name": "row"}"#.to_string(),
4397            3,
4398            0,
4399            "paragraph_end".to_string(),
4400            doc_id_dlt,
4401        );
4402
4403        let input = ExtractedGraphData {
4404            chunks: vec![text_chunk, dlt_chunk],
4405            documents: vec![text_doc, dlt_doc],
4406            entities: vec![],
4407            edges: vec![],
4408            dataset_id: Uuid::new_v4(),
4409            user_id: None,
4410            tenant_id: None,
4411        };
4412
4413        // With summarization disabled, verify we get zero summaries and no panic.
4414        let config = CognifyConfig::default().with_summarization(false);
4415        let llm: Arc<dyn Llm> = Arc::new(MockLlm::empty());
4416        let result = summarize_text(&input, llm, &config).await.unwrap();
4417        assert!(result.summaries.is_empty());
4418        // All chunks (both DLT and non-DLT) are still passed through.
4419        assert_eq!(result.chunks.len(), 2);
4420    }
4421
4422    /// Regression guard: an image document must produce ≥1 chunk and must NOT
4423    /// return `CognifyError::UnsupportedDocumentType`.
4424    #[cfg(feature = "image-loader")]
4425    #[tokio::test]
4426    async fn test_image_document_produces_chunks() {
4427        use cognee_ingestion::loaders::image::ImageLoader;
4428        use cognee_test_utils::MockLlm;
4429
4430        let storage = Arc::new(MockStorage::new());
4431        // Store fake image bytes so the loader can retrieve them.
4432        let location = storage
4433            .store(b"fake-image-bytes", "test.jpg")
4434            .await
4435            .expect("MockStorage store should succeed");
4436
4437        let doc_id = Uuid::new_v4();
4438        let mut base = DataPoint::new("ImageDocument", None);
4439        base.id = doc_id;
4440        base.set_metadata("index_fields", serde_json::json!(["name"]));
4441        let doc = Document {
4442            base,
4443            document_type: "image".to_string(),
4444            name: "test.jpg".to_string(),
4445            raw_data_location: location,
4446            mime_type: "image/jpeg".to_string(),
4447            extension: "jpg".to_string(),
4448            data_id: doc_id,
4449            external_metadata: None,
4450        };
4451
4452        let input = ClassifiedDocuments {
4453            documents: vec![doc],
4454            dataset_id: Uuid::new_v4(),
4455            user_id: None,
4456            tenant_id: None,
4457        };
4458
4459        // Build a registry that contains an ImageLoader backed by a MockLlm
4460        // that returns a vision description.
4461        let mock_llm = Arc::new(
4462            MockLlm::new(vec![])
4463                .with_vision_responses(vec!["An image description for testing.".to_string()]),
4464        );
4465        let mut registry = LoaderRegistry::default();
4466        registry.register("image", Arc::new(ImageLoader::new(mock_llm)));
4467
4468        let result = extract_chunks_from_documents(
4469            &input,
4470            &*storage,
4471            100,
4472            TokenCounterKind::Word,
4473            None,
4474            &registry,
4475        )
4476        .await;
4477
4478        // Must not be UnsupportedDocumentType — that is the regression we guard.
4479        assert!(
4480            !matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
4481            "image document must not produce UnsupportedDocumentType"
4482        );
4483        let chunks = result.expect("extract_chunks_from_documents should succeed for image docs");
4484        assert!(
4485            !chunks.chunks.is_empty(),
4486            "image document should produce at least one chunk"
4487        );
4488    }
4489
4490    /// Regression guard: an audio document must produce ≥1 chunk and must NOT
4491    /// return `CognifyError::UnsupportedDocumentType`.
4492    #[cfg(feature = "audio-loader")]
4493    #[tokio::test]
4494    async fn test_audio_document_produces_chunks() {
4495        use cognee_ingestion::loaders::audio::AudioLoader;
4496        use cognee_llm::TranscriptionOutput;
4497        use cognee_test_utils::MockTranscriber;
4498
4499        let storage = Arc::new(MockStorage::new());
4500        // Store fake audio bytes so the loader can retrieve them.
4501        let location = storage
4502            .store(b"fake-audio-bytes", "test.mp3")
4503            .await
4504            .expect("MockStorage store should succeed");
4505
4506        let doc_id = Uuid::new_v4();
4507        let mut base = DataPoint::new("AudioDocument", None);
4508        base.id = doc_id;
4509        base.set_metadata("index_fields", serde_json::json!(["name"]));
4510        let doc = Document {
4511            base,
4512            document_type: "audio".to_string(),
4513            name: "test.mp3".to_string(),
4514            raw_data_location: location,
4515            mime_type: "audio/mpeg".to_string(),
4516            extension: "mp3".to_string(),
4517            data_id: doc_id,
4518            external_metadata: None,
4519        };
4520
4521        let input = ClassifiedDocuments {
4522            documents: vec![doc],
4523            dataset_id: Uuid::new_v4(),
4524            user_id: None,
4525            tenant_id: None,
4526        };
4527
4528        // Build a registry that contains an AudioLoader backed by a MockTranscriber.
4529        let mock_transcriber = Arc::new(MockTranscriber::new(
4530            "mock-whisper",
4531            vec![TranscriptionOutput {
4532                text: "Test transcript.".to_string(),
4533                language: None,
4534                duration: None,
4535            }],
4536        ));
4537        let mut registry = LoaderRegistry::default();
4538        registry.register("audio", Arc::new(AudioLoader::new(mock_transcriber)));
4539
4540        let result = extract_chunks_from_documents(
4541            &input,
4542            &*storage,
4543            100,
4544            TokenCounterKind::Word,
4545            None,
4546            &registry,
4547        )
4548        .await;
4549
4550        // Must not be UnsupportedDocumentType — that is the regression we guard.
4551        assert!(
4552            !matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
4553            "audio document must not produce UnsupportedDocumentType"
4554        );
4555        let chunks = result.expect("extract_chunks_from_documents should succeed for audio docs");
4556        assert!(
4557            !chunks.chunks.is_empty(),
4558            "audio document should produce at least one chunk"
4559        );
4560    }
4561
4562    /// Regression guard: `.html`/`.htm` files must be classified (not silently
4563    /// dropped).  Before the `html-loader` feature was added,
4564    /// `extension_to_doc_type("html")` returned `None` so `classify_documents`
4565    /// produced an empty Vec — this test would have failed then.
4566    #[test]
4567    fn classify_html_extension_not_dropped() {
4568        for ext in ["html", "htm"] {
4569            let data = Data::builder(
4570                Uuid::new_v4(),
4571                format!("page.{ext}"),
4572                format!("/storage/page.{ext}"),
4573                format!("file:///page.{ext}"),
4574                ext,
4575                "text/html",
4576                "hash_html",
4577                Uuid::new_v4(),
4578            )
4579            .build();
4580
4581            let input = CognifyInput {
4582                data_items: vec![data],
4583                dataset_id: Uuid::new_v4(),
4584                user_id: None,
4585                tenant_id: None,
4586            };
4587            let result = classify_documents(&input).expect("classify should not error");
4588            assert_eq!(
4589                result.documents.len(),
4590                1,
4591                ".{ext} file must not be dropped by classify_documents"
4592            );
4593            assert_eq!(
4594                result.documents[0].document_type, "html",
4595                ".{ext} must classify as document_type=\"html\""
4596            );
4597            // Cross-SDK parity: Python's BeautifulSoupLoader stores TextDocument nodes.
4598            assert_eq!(
4599                result.documents[0].base.data_type, "TextDocument",
4600                ".{ext} must carry data_type=\"TextDocument\" for Python DB parity"
4601            );
4602        }
4603    }
4604
4605    /// Regression guard: the classify → load → chunk pipeline for an HTML file
4606    /// must produce text chunks (not an `UnsupportedDocumentType` error).
4607    ///
4608    /// Before this feature:
4609    ///  1. `classify_documents` would return an empty Vec for `.html` files
4610    ///     (extension was not mapped).
4611    ///  2. Even if the document type was forced to "html", `extract_chunks_from_documents`
4612    ///     would return `CognifyError::UnsupportedDocumentType("html")` because no
4613    ///     loader was registered.
4614    /// Both regressions are guarded here end-to-end.
4615    #[cfg(feature = "html-loader")]
4616    #[tokio::test]
4617    async fn classify_then_chunk_html_end_to_end() {
4618        let storage = Arc::new(MockStorage::new());
4619        let html = b"<html><head><title>Guide</title></head><body><p>The quick brown fox.</p></body></html>";
4620        let location = storage
4621            .store(html, "guide.html")
4622            .await
4623            .expect("MockStorage store should succeed");
4624
4625        let data = Data::builder(
4626            Uuid::new_v4(),
4627            "guide.html",
4628            &location, // raw_data_location == storage path so retrieve() can find it
4629            "file:///guide.html",
4630            "html",
4631            "text/html",
4632            "hash_guide_html",
4633            Uuid::new_v4(),
4634        )
4635        .build();
4636
4637        let input = CognifyInput {
4638            data_items: vec![data],
4639            dataset_id: Uuid::new_v4(),
4640            user_id: None,
4641            tenant_id: None,
4642        };
4643
4644        // Regression 1: classify must not drop the HTML file.
4645        let classified =
4646            classify_documents(&input).expect("classify_documents must succeed for html");
4647        assert_eq!(
4648            classified.documents.len(),
4649            1,
4650            "classify_documents must not drop the .html file"
4651        );
4652        assert_eq!(classified.documents[0].document_type, "html");
4653
4654        // Regression 2: the HtmlLoader must be dispatched and produce chunks.
4655        let registry = LoaderRegistry::default();
4656        let result = extract_chunks_from_documents(
4657            &classified,
4658            &*storage,
4659            100,
4660            TokenCounterKind::Word,
4661            None,
4662            &registry,
4663        )
4664        .await;
4665
4666        assert!(
4667            !matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
4668            "html loader must be registered (UnsupportedDocumentType must not occur)"
4669        );
4670        let chunks = result.expect("extract_chunks_from_documents must succeed for html");
4671        assert!(
4672            !chunks.chunks.is_empty(),
4673            "html file must produce at least one chunk"
4674        );
4675        assert!(
4676            chunks
4677                .chunks
4678                .iter()
4679                .any(|c| c.text.contains("quick brown fox")),
4680            "extracted text must appear in chunks (HTML tags must be stripped)"
4681        );
4682    }
4683
4684    /// Regression guard: an HTML document must produce ≥1 chunk via the
4685    /// always-registered `HtmlLoader` and must NOT return
4686    /// `CognifyError::UnsupportedDocumentType`.
4687    #[cfg(feature = "html-loader")]
4688    #[tokio::test]
4689    async fn test_html_document_produces_chunks() {
4690        let storage = Arc::new(MockStorage::new());
4691        let html =
4692            b"<html><head><title>T</title></head><body><h1>Heading</h1><p>Body text here.</p></body></html>";
4693        let location = storage
4694            .store(html, "test.html")
4695            .await
4696            .expect("MockStorage store should succeed");
4697
4698        let doc_id = Uuid::new_v4();
4699        // Cross-SDK parity: HTML docs carry the TextDocument data_type.
4700        let mut base = DataPoint::new("TextDocument", None);
4701        base.id = doc_id;
4702        base.set_metadata("index_fields", serde_json::json!(["name"]));
4703        let doc = Document {
4704            base,
4705            document_type: "html".to_string(),
4706            name: "test.html".to_string(),
4707            raw_data_location: location,
4708            mime_type: "text/html".to_string(),
4709            extension: "html".to_string(),
4710            data_id: doc_id,
4711            external_metadata: None,
4712        };
4713
4714        let input = ClassifiedDocuments {
4715            documents: vec![doc],
4716            dataset_id: Uuid::new_v4(),
4717            user_id: None,
4718            tenant_id: None,
4719        };
4720
4721        // The HtmlLoader is part of the default registry when the feature is on.
4722        let registry = LoaderRegistry::default();
4723
4724        let result = extract_chunks_from_documents(
4725            &input,
4726            &*storage,
4727            100,
4728            TokenCounterKind::Word,
4729            None,
4730            &registry,
4731        )
4732        .await;
4733
4734        assert!(
4735            !matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
4736            "html document must not produce UnsupportedDocumentType"
4737        );
4738        let chunks = result.expect("extract_chunks_from_documents should succeed for html docs");
4739        assert!(
4740            !chunks.chunks.is_empty(),
4741            "html document should produce at least one chunk"
4742        );
4743        // The extracted text (not raw HTML tags) should reach the chunk.
4744        assert!(
4745            chunks.chunks.iter().any(|c| c.text.contains("Body text")),
4746            "extracted HTML text should appear in chunks"
4747        );
4748    }
4749
4750    // ── build_loader_registry wiring tests ────────────────────────────────────
4751
4752    /// `build_loader_registry` must always register an image loader when the
4753    /// `image-loader` feature is enabled.
4754    #[cfg(feature = "image-loader")]
4755    #[test]
4756    fn test_build_loader_registry_includes_image() {
4757        use cognee_test_utils::MockLlm;
4758
4759        let llm: Arc<dyn Llm> = Arc::new(MockLlm::empty());
4760        let config = CognifyConfig::default();
4761        let registry = build_loader_registry(&llm, &config);
4762        assert!(
4763            registry.get("image").is_some(),
4764            "build_loader_registry must include \"image\" loader when image-loader feature is on"
4765        );
4766    }
4767
4768    /// `build_loader_registry` must register an audio loader when a transcriber
4769    /// is set on the config AND the `audio-loader` feature is enabled.
4770    #[cfg(feature = "audio-loader")]
4771    #[test]
4772    fn test_build_loader_registry_includes_audio_when_transcriber_set() {
4773        use cognee_llm::TranscriptionOutput;
4774        use cognee_test_utils::MockTranscriber;
4775
4776        let llm: Arc<dyn Llm> = Arc::new(cognee_test_utils::MockLlm::empty());
4777        let transcriber: Arc<dyn cognee_llm::Transcriber> = Arc::new(MockTranscriber::new(
4778            "mock",
4779            vec![TranscriptionOutput {
4780                text: "hi".to_string(),
4781                language: None,
4782                duration: None,
4783            }],
4784        ));
4785        let config = CognifyConfig::default().with_transcriber(transcriber);
4786        let registry = build_loader_registry(&llm, &config);
4787        assert!(
4788            registry.get("audio").is_some(),
4789            "build_loader_registry must include \"audio\" loader when transcriber is set"
4790        );
4791    }
4792
4793    /// Without a transcriber on the config, no audio loader should be
4794    /// registered — audio stays gracefully unsupported (D5).
4795    #[cfg(feature = "audio-loader")]
4796    #[test]
4797    fn test_build_loader_registry_no_audio_without_transcriber() {
4798        let llm: Arc<dyn Llm> = Arc::new(cognee_test_utils::MockLlm::empty());
4799        let config = CognifyConfig::default(); // no transcriber
4800        let registry = build_loader_registry(&llm, &config);
4801        assert!(
4802            registry.get("audio").is_none(),
4803            "build_loader_registry must NOT include \"audio\" loader when transcriber is None"
4804        );
4805    }
4806}