use std::borrow::Cow;
use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use chrono::Utc;
use cognee_chunking::{CutType, NAMESPACE_OID, TokenCounterKind, chunk_by_row, chunk_text};
use cognee_core::pipeline_run_registry::DbPipelineWatcher;
use cognee_core::{
CpuPool, Pipeline, PipelineBuilder, PipelineContext, TaskContextBuilder, TypedTask, Value,
};
use cognee_database::{DatabaseConnection, PipelineRunRepository};
use cognee_embedding::engine::EmbeddingEngine;
use cognee_graph::{EdgeData, GraphDBTrait, GraphDBTraitExt};
#[cfg(feature = "audio-loader")]
use cognee_ingestion::loaders::audio::AudioLoader;
#[cfg(feature = "image-loader")]
use cognee_ingestion::loaders::image::ImageLoader;
use cognee_ingestion::loaders::{LoaderOutput, LoaderRegistry};
use cognee_llm::Llm;
use cognee_models::{
Data, Document, DocumentChunk, EdgeType, Embedding, TemporalEvent,
classify_documents as model_classify_documents,
};
use cognee_ontology::OntologyResolver;
use cognee_storage::StorageTrait;
use cognee_vector::{VectorDB, VectorPoint};
use serde::Serialize;
use serde_json::json;
use tokio::sync::Semaphore;
use tracing::{info, warn};
use url::Url;
use uuid::Uuid;
use crate::config::CognifyConfig;
use crate::error::CognifyError;
use crate::fact_extraction::{FactExtractor, KnowledgeGraph};
use crate::graph_integration::{
GraphEdgePair, GraphNodePair, deduplicate_nodes_and_edges, expand_with_nodes_and_edges,
retrieve_existing_edges,
};
use crate::pipeline::{CognifyResult, IndexedFieldsStats};
use crate::qualification::{Qualification, check_pipeline_run_qualification};
use crate::summarization::{SummaryExtractor, TextSummary};
use crate::temporal_extraction::{TemporalEntityEnricher, TemporalEventExtractor};
use cognee_models::DataPoint;
#[derive(Debug, Clone)]
pub struct CognifyInput {
pub data_items: Vec<Data>,
pub dataset_id: Uuid,
pub user_id: Option<Uuid>,
pub tenant_id: Option<Uuid>,
}
#[derive(Debug, Clone)]
pub struct ClassifiedDocuments {
pub documents: Vec<Document>,
pub dataset_id: Uuid,
pub user_id: Option<Uuid>,
pub tenant_id: Option<Uuid>,
}
#[derive(Debug, Clone)]
pub struct ExtractedChunks {
pub chunks: Vec<DocumentChunk>,
pub documents: Vec<Document>,
pub dataset_id: Uuid,
pub user_id: Option<Uuid>,
pub tenant_id: Option<Uuid>,
}
#[derive(Debug, Clone)]
pub struct ExtractedGraphData {
pub chunks: Vec<DocumentChunk>,
pub documents: Vec<Document>,
pub entities: Vec<GraphNodePair>,
pub edges: Vec<GraphEdgePair>,
pub dataset_id: Uuid,
pub user_id: Option<Uuid>,
pub tenant_id: Option<Uuid>,
}
#[derive(Debug, Clone)]
pub struct SummarizedData {
pub chunks: Vec<DocumentChunk>,
pub documents: Vec<Document>,
pub entities: Vec<GraphNodePair>,
pub edges: Vec<GraphEdgePair>,
pub summaries: Vec<TextSummary>,
pub dataset_id: Uuid,
pub user_id: Option<Uuid>,
pub tenant_id: Option<Uuid>,
}
#[derive(Debug, Clone)]
pub struct ExtractedTemporalEvents {
pub events: Vec<TemporalEvent>,
pub dataset_id: Uuid,
pub user_id: Option<Uuid>,
pub tenant_id: Option<Uuid>,
}
pub fn classify_documents(input: &CognifyInput) -> Result<ClassifiedDocuments, CognifyError> {
let documents: Vec<Document> = model_classify_documents(&input.data_items);
info!(doc_count = documents.len(), "documents classified");
Ok(ClassifiedDocuments {
documents,
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
})
}
pub async fn extract_chunks_from_documents(
input: &ClassifiedDocuments,
storage: &dyn StorageTrait,
max_chunk_size: usize,
token_counter_kind: TokenCounterKind,
db: Option<&DatabaseConnection>,
loader_registry: &LoaderRegistry,
) -> Result<ExtractedChunks, CognifyError> {
let counter = token_counter_kind
.build()
.map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
let mut all_chunks = Vec::new();
for document in &input.documents {
let content_bytes = storage
.retrieve(&document.raw_data_location)
.await
.map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
if document.document_type == "dlt_row" {
let text = String::from_utf8(content_bytes)
.map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
let trimmed = text.trim();
if !trimmed.is_empty() {
let chunk_id =
Uuid::new_v5(&NAMESPACE_OID, format!("{}-0", document.base.id).as_bytes());
let word_count = counter.count_tokens(trimmed);
let mut chunk = DocumentChunk::new(
chunk_id,
trimmed.to_string(),
word_count,
0, CutType::DltRow.to_string(),
document.base.id,
);
if document.base.belongs_to_set.is_some() {
chunk.base.belongs_to_set = document.base.belongs_to_set.clone();
}
if let Some(db) = db
&& let Err(e) = cognee_database::ops::data::update_data_token_count(
db,
document.data_id,
word_count as i64,
)
.await
{
warn!(data_id = %document.data_id, "Failed to update token count: {e}");
}
all_chunks.push(chunk);
}
continue;
}
let loader = loader_registry
.get(&document.document_type)
.ok_or_else(|| CognifyError::UnsupportedDocumentType(document.document_type.clone()))?;
let output = loader
.extract(&content_bytes, document)
.await
.map_err(|e| CognifyError::ChunkingError(e.to_string()))?;
let mut chunks = match output {
LoaderOutput::Text(text) => {
chunk_text(document.base.id, &text, max_chunk_size, &counter)
}
LoaderOutput::Rows(rows) => {
let joined = rows.join("\n\n");
chunk_by_row(document.base.id, &joined, max_chunk_size, &counter)
}
LoaderOutput::SingleChunk { text, cut_type } => {
let chunk_id =
Uuid::new_v5(&NAMESPACE_OID, format!("{}-0", document.base.id).as_bytes());
let word_count = counter.count_tokens(&text);
vec![DocumentChunk::new(
chunk_id,
text,
word_count,
0,
cut_type.to_string(),
document.base.id,
)]
}
};
if document.base.belongs_to_set.is_some() {
for chunk in &mut chunks {
chunk.base.belongs_to_set = document.base.belongs_to_set.clone();
}
}
if let Some(db) = db {
let document_token_count: i64 = chunks.iter().map(|c| c.chunk_size as i64).sum();
if let Err(e) = cognee_database::ops::data::update_data_token_count(
db,
document.data_id,
document_token_count,
)
.await
{
warn!(
data_id = %document.data_id,
"Failed to update token count: {e}"
);
}
}
all_chunks.extend(chunks);
}
info!(total_chunks = all_chunks.len(), "chunking complete");
Ok(ExtractedChunks {
chunks: all_chunks,
documents: input.documents.clone(),
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
})
}
pub async fn extract_graph_from_data(
input: &ExtractedChunks,
llm: Arc<dyn Llm>,
graph_db: Arc<dyn GraphDBTrait>,
ontology_resolver: Arc<dyn OntologyResolver>,
config: &CognifyConfig,
user_label_override: Option<&str>,
) -> Result<ExtractedGraphData, CognifyError> {
if input.chunks.is_empty() {
return Ok(ExtractedGraphData {
chunks: input.chunks.clone(),
documents: input.documents.clone(),
entities: vec![],
edges: vec![],
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
});
}
let dlt_doc_ids: HashSet<Uuid> = input
.documents
.iter()
.filter(|d| d.document_type == "dlt_row")
.map(|d| d.base.id)
.collect();
let (dlt_chunks, non_dlt_chunks): (Vec<&DocumentChunk>, Vec<&DocumentChunk>) = input
.chunks
.iter()
.partition(|c| dlt_doc_ids.contains(&c.document_id));
if !dlt_chunks.is_empty() {
info!(
"Skipping {} DLT chunks from LLM extraction ({} non-DLT chunks remain)",
dlt_chunks.len(),
non_dlt_chunks.len()
);
}
if non_dlt_chunks.is_empty() {
return Ok(ExtractedGraphData {
chunks: input.chunks.clone(),
documents: input.documents.clone(),
entities: vec![],
edges: vec![],
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
});
}
let chunks_for_extraction: Vec<DocumentChunk> = non_dlt_chunks.into_iter().cloned().collect();
let batch_size = config.chunks_per_batch;
let mut all_graphs: Vec<(Uuid, KnowledgeGraph)> = Vec::new();
let semaphore = Arc::new(Semaphore::new(config.max_parallel_extractions));
for (batch_idx, batch) in chunks_for_extraction.chunks(batch_size).enumerate() {
let fact_extractor = FactExtractor::new(Arc::clone(&llm));
let mut extract_tasks = Vec::new();
let mut chunk_ids = Vec::new();
for chunk in batch {
let extractor = fact_extractor.clone();
let text = chunk.text.clone();
let sem = Arc::clone(&semaphore);
let prompt = config.custom_extraction_prompt.clone();
chunk_ids.push(chunk.base.id);
extract_tasks.push(tokio::spawn(async move {
#[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
let _permit = sem
.acquire()
.await
.expect("semaphore is never closed; created locally in this function");
extractor.extract_facts(&text, prompt.as_deref()).await
}));
}
let batch_results = futures::future::join_all(extract_tasks).await;
for (result, chunk_id) in batch_results.into_iter().zip(chunk_ids) {
let graph = result.map_err(|e| CognifyError::FactExtractionError(e.to_string()))??;
all_graphs.push((chunk_id, graph));
}
info!(
"Processed graph extraction batch {}/{} ({} chunks)",
batch_idx + 1,
chunks_for_extraction.len().div_ceil(batch_size),
batch.len()
);
}
let graphs_only: Vec<KnowledgeGraph> = all_graphs.iter().map(|(_, g)| g.clone()).collect();
let existing_edges_set = retrieve_existing_edges(graph_db.as_ref(), &graphs_only).await?;
let user_label_owned = user_label_override
.map(|s| s.to_string())
.or_else(|| input.user_id.as_ref().map(|id| id.to_string()));
let (nodes, edges) = expand_with_nodes_and_edges(
all_graphs,
input.dataset_id,
&existing_edges_set,
ontology_resolver.as_ref(),
user_label_owned.as_deref(),
)
.await;
let dedup_result = deduplicate_nodes_and_edges(nodes, edges);
let mut chunk_entity_map: HashMap<Uuid, Vec<serde_json::Value>> = HashMap::new();
for node_pair in &dedup_result.unique_nodes {
if let Some(chunk_id_val) = node_pair.entity.base.get_metadata("chunk_id")
&& let Some(chunk_id_str) = chunk_id_val.as_str()
&& let Ok(chunk_id) = Uuid::parse_str(chunk_id_str)
{
chunk_entity_map
.entry(chunk_id)
.or_default()
.push(json!(node_pair.entity.base.id.to_string()));
}
}
let mut updated_chunks = input.chunks.clone();
for chunk in &mut updated_chunks {
if let Some(entity_ids) = chunk_entity_map.get(&chunk.base.id) {
chunk.contains = entity_ids.clone();
}
}
let entity_refs: Vec<&cognee_models::Entity> = dedup_result
.unique_nodes
.iter()
.map(|n| &n.entity)
.collect();
graph_db
.add_nodes(&entity_refs)
.await
.map_err(CognifyError::from)?;
let edge_data: Vec<_> = dedup_result
.unique_edges
.iter()
.map(|edge_pair| {
let properties: HashMap<std::borrow::Cow<'static, str>, serde_json::Value> = edge_pair
.properties
.iter()
.map(|(k, v)| {
(
std::borrow::Cow::Owned(k.clone()),
serde_json::Value::String(v.clone()),
)
})
.collect();
(
edge_pair.source_entity_id.to_string(),
edge_pair.target_entity_id.to_string(),
edge_pair.relationship_name.clone(),
properties,
)
})
.collect();
graph_db
.add_edges(&edge_data)
.await
.map_err(CognifyError::from)?;
Ok(ExtractedGraphData {
chunks: updated_chunks,
documents: input.documents.clone(),
entities: dedup_result.unique_nodes,
edges: dedup_result.unique_edges,
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
})
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct WebPageMetadata {
url: String,
domain: String,
title: Option<String>,
}
fn parse_web_page_metadata(document: &Document) -> Option<WebPageMetadata> {
let metadata = document.external_metadata.as_ref()?;
let value: serde_json::Value = serde_json::from_str(metadata).ok()?;
let source = value.get("source").and_then(|v| v.as_str())?;
if source != "url" {
return None;
}
let url = value
.get("final_url")
.or_else(|| value.get("url"))
.and_then(|v| v.as_str())?;
let parsed = Url::parse(url).ok()?;
if !matches!(parsed.scheme(), "http" | "https") {
return None;
}
let domain = parsed.host_str()?.to_ascii_lowercase();
let title = value
.get("title")
.and_then(|v| v.as_str())
.filter(|s| !s.is_empty())
.map(str::to_string);
Some(WebPageMetadata {
url: parsed.to_string(),
domain,
title,
})
}
fn web_page_id(url: &str) -> Uuid {
Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("WebPage:{url}").as_bytes())
}
fn web_site_id(domain: &str) -> Uuid {
Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("WebSite:{}", domain.to_ascii_lowercase()).as_bytes(),
)
}
fn first_chars(value: &str, limit: usize) -> String {
value.chars().take(limit).collect()
}
fn document_content_preview(document_id: Uuid, chunks: &[DocumentChunk]) -> String {
let mut preview = String::new();
for chunk in chunks
.iter()
.filter(|chunk| chunk.document_id == document_id)
{
if !preview.is_empty() {
preview.push('\n');
}
preview.push_str(&chunk.text);
if preview.chars().count() >= 500 {
break;
}
}
first_chars(&preview, 500)
}
fn empty_edge_props() -> HashMap<Cow<'static, str>, serde_json::Value> {
HashMap::new()
}
pub async fn create_web_page_nodes(
documents: &[Document],
chunks: &[DocumentChunk],
graph_db: Arc<dyn GraphDBTrait>,
) -> Result<(), CognifyError> {
if documents.is_empty() || chunks.is_empty() {
return Ok(());
}
let mut nodes_by_id: HashMap<String, serde_json::Value> = HashMap::new();
let mut candidate_edges: Vec<EdgeData> = Vec::new();
let mut seen_edges: HashSet<(String, String, String)> = HashSet::new();
for document in documents {
let Some(metadata) = parse_web_page_metadata(document) else {
continue;
};
let page_id = web_page_id(&metadata.url);
let site_id = web_site_id(&metadata.domain);
let page_id_str = page_id.to_string();
let site_id_str = site_id.to_string();
nodes_by_id.insert(
page_id_str.clone(),
json!({
"id": page_id_str,
"type": "WebPage",
"url": metadata.url,
"title": metadata.title,
"content": document_content_preview(document.base.id, chunks),
}),
);
nodes_by_id.insert(
site_id_str.clone(),
json!({
"id": site_id_str,
"type": "WebSite",
"domain": metadata.domain,
}),
);
push_unique_edge(
&mut candidate_edges,
&mut seen_edges,
page_id_str.clone(),
site_id_str,
"PART_OF",
);
for chunk in chunks
.iter()
.filter(|chunk| chunk.document_id == document.base.id)
{
push_unique_edge(
&mut candidate_edges,
&mut seen_edges,
chunk.base.id.to_string(),
page_id_str.clone(),
"SOURCED_FROM",
);
}
}
if !nodes_by_id.is_empty() {
graph_db
.add_nodes_raw(nodes_by_id.into_values().collect())
.await
.map_err(CognifyError::from)?;
}
if candidate_edges.is_empty() {
return Ok(());
}
let existing_edges = graph_db
.has_edges(&candidate_edges)
.await
.map_err(CognifyError::from)?;
let existing_keys: HashSet<(String, String, String)> = existing_edges
.into_iter()
.map(|(source, target, relationship, _)| (source, target, relationship))
.collect();
let missing_edges: Vec<EdgeData> = candidate_edges
.into_iter()
.filter(|(source, target, relationship, _)| {
!existing_keys.contains(&(source.clone(), target.clone(), relationship.clone()))
})
.collect();
if !missing_edges.is_empty() {
graph_db
.add_edges(&missing_edges)
.await
.map_err(CognifyError::from)?;
}
Ok(())
}
fn push_unique_edge(
edges: &mut Vec<EdgeData>,
seen: &mut HashSet<(String, String, String)>,
source: String,
target: String,
relationship: &str,
) {
let key = (source.clone(), target.clone(), relationship.to_string());
if seen.insert(key) {
edges.push((source, target, relationship.to_string(), empty_edge_props()));
}
}
pub async fn extract_custom_graph_from_data<M: crate::fact_extraction::GraphModel>(
input: &ExtractedChunks,
llm: Arc<dyn Llm>,
config: &CognifyConfig,
) -> Result<ExtractedGraphData, CognifyError> {
if input.chunks.is_empty() {
return Ok(ExtractedGraphData {
chunks: input.chunks.clone(),
documents: input.documents.clone(),
entities: vec![],
edges: vec![],
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
});
}
let dlt_doc_ids: HashSet<Uuid> = input
.documents
.iter()
.filter(|d| d.document_type == "dlt_row")
.map(|d| d.base.id)
.collect();
let batch_size = config.chunks_per_batch;
let semaphore = Arc::new(Semaphore::new(config.max_parallel_extractions));
let mut updated_chunks = input.chunks.clone();
let non_dlt_indices: Vec<usize> = updated_chunks
.iter()
.enumerate()
.filter(|(_, c)| !dlt_doc_ids.contains(&c.document_id))
.map(|(i, _)| i)
.collect();
if non_dlt_indices.is_empty() {
return Ok(ExtractedGraphData {
chunks: updated_chunks,
documents: input.documents.clone(),
entities: vec![],
edges: vec![],
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
});
}
let total_batches = non_dlt_indices.len().div_ceil(batch_size);
for (batch_idx, batch_indices) in non_dlt_indices.chunks(batch_size).enumerate() {
let mut extract_tasks = Vec::new();
for &idx in batch_indices {
let extractor = FactExtractor::new(Arc::clone(&llm));
let text = updated_chunks[idx].text.clone();
let sem = Arc::clone(&semaphore);
let prompt = config.custom_extraction_prompt.clone();
extract_tasks.push(tokio::spawn(async move {
#[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
let _permit = sem
.acquire()
.await
.expect("semaphore is never closed; created locally in this function");
extractor.extract::<M>(&text, prompt.as_deref()).await
}));
}
let batch_results = futures::future::join_all(extract_tasks).await;
let batch_len = batch_indices.len();
for (i, result) in batch_results.into_iter().enumerate() {
let model: M =
result.map_err(|e| CognifyError::FactExtractionError(e.to_string()))??;
let value = serde_json::to_value(&model)
.map_err(|e| CognifyError::SerializationError(e.to_string()))?;
updated_chunks[batch_indices[i]].contains = vec![value];
}
info!(
"Processed custom graph extraction batch {}/{} ({} chunks)",
batch_idx + 1,
total_batches,
batch_len
);
}
Ok(ExtractedGraphData {
chunks: updated_chunks,
documents: input.documents.clone(),
entities: vec![],
edges: vec![],
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
})
}
pub async fn summarize_text(
input: &ExtractedGraphData,
llm: Arc<dyn Llm>,
config: &CognifyConfig,
) -> Result<SummarizedData, CognifyError> {
let dlt_doc_ids: HashSet<Uuid> = input
.documents
.iter()
.filter(|d| d.document_type == "dlt_row")
.map(|d| d.base.id)
.collect();
let non_dlt_chunks: Vec<DocumentChunk> = input
.chunks
.iter()
.filter(|c| !dlt_doc_ids.contains(&c.document_id))
.cloned()
.collect();
if non_dlt_chunks.len() < input.chunks.len() {
info!(
"Skipping {} DLT chunks from summarization ({} non-DLT chunks remain)",
input.chunks.len() - non_dlt_chunks.len(),
non_dlt_chunks.len()
);
}
let summaries = if config.enable_summarization && !non_dlt_chunks.is_empty() {
let summary_extractor =
SummaryExtractor::new_with_schema(llm, config.summary_schema.clone());
let mut all_summaries = Vec::new();
for batch in non_dlt_chunks.chunks(config.summarization_batch_size) {
let batch_summaries = summary_extractor.summarize_chunks(batch, None).await?;
all_summaries.extend(batch_summaries);
}
info!("Generated {} summaries", all_summaries.len());
all_summaries
} else {
if !config.enable_summarization {
info!("Summarization disabled in config");
} else {
info!("No non-DLT chunks to summarize");
}
Vec::new()
};
Ok(SummarizedData {
chunks: input.chunks.clone(),
documents: input.documents.clone(),
entities: input.entities.clone(),
edges: input.edges.clone(),
summaries,
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
})
}
pub async fn add_data_points(
input: &SummarizedData,
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
db: Option<Arc<DatabaseConnection>>,
config: &CognifyConfig,
) -> Result<CognifyResult, CognifyError> {
if !input.chunks.is_empty() {
let chunk_refs: Vec<&DocumentChunk> = input.chunks.iter().collect();
graph_db
.add_nodes(&chunk_refs)
.await
.map_err(CognifyError::from)?;
info!("Stored {} document chunks as graph nodes", chunk_refs.len());
}
if !input.summaries.is_empty() {
let summary_refs: Vec<&TextSummary> = input.summaries.iter().collect();
graph_db
.add_nodes(&summary_refs)
.await
.map_err(CognifyError::from)?;
info!(
"Stored {} text summaries as graph nodes",
summary_refs.len()
);
}
if !input.entities.is_empty() {
let entity_type_refs: Vec<&cognee_models::EntityType> = input
.entities
.iter()
.map(|pair| &pair.entity_type)
.collect();
graph_db
.add_nodes(&entity_type_refs)
.await
.map_err(CognifyError::from)?;
info!(
"Stored {} entity types as graph nodes",
entity_type_refs.len()
);
}
if !input.documents.is_empty() {
let doc_refs: Vec<&Document> = input.documents.iter().collect();
graph_db
.add_nodes(&doc_refs)
.await
.map_err(CognifyError::from)?;
info!("Stored {} documents as graph nodes", doc_refs.len());
}
let mut edge_type_counts: HashMap<String, i32> = HashMap::new();
for edge_pair in &input.edges {
let edge_text = edge_retrieval_text(edge_pair);
if edge_text.is_empty() {
continue;
}
*edge_type_counts.entry(edge_text).or_insert(0) += 1;
}
let mut edge_types: Vec<EdgeType> = edge_type_counts
.into_iter()
.map(|(text, count)| {
let mut et = EdgeType::new_deterministic(&text, Some(input.dataset_id));
et.set_count(count);
et
})
.collect();
{
let user_label = input.user_id.as_ref().map(|id| id.to_string());
let mut local_visited: HashSet<Uuid> = HashSet::new();
for et in &mut edge_types {
crate::graph_integration::expansion::pre_stamp_extraction(
et,
user_label.as_deref(),
&mut local_visited,
);
}
}
let mut extractable_items: Vec<&dyn crate::graph_extraction::GraphExtractable> = Vec::new();
for chunk in &input.chunks {
extractable_items.push(chunk as &dyn crate::graph_extraction::GraphExtractable);
}
for summary in &input.summaries {
extractable_items.push(summary as &dyn crate::graph_extraction::GraphExtractable);
}
for pair in &input.entities {
extractable_items.push(&pair.entity as &dyn crate::graph_extraction::GraphExtractable);
extractable_items.push(&pair.entity_type as &dyn crate::graph_extraction::GraphExtractable);
}
let structural_edges = crate::graph_extraction::get_graph_from_model(&extractable_items);
if !structural_edges.is_empty() {
graph_db
.add_edges(&structural_edges)
.await
.map_err(CognifyError::from)?;
info!("Created {} structural edges", structural_edges.len());
}
let embeddings = generate_embeddings(
&input.chunks,
&input.entities,
&input.summaries,
embedding_engine.clone(),
)
.await?;
let indexed_fields = index_data_points(
&input.chunks,
&input.entities,
&input.summaries,
&input.documents,
&input.edges,
&edge_types,
input.dataset_id,
input.user_id,
input.tenant_id,
embedding_engine,
vector_db,
config,
&embeddings,
)
.await?;
if let (Some(db), Some(user_id)) = (&db, input.user_id) {
upsert_provenance(
db,
input.tenant_id,
user_id,
input.dataset_id,
&input.chunks,
&input.entities,
&input.edges,
&input.summaries,
&input.documents,
&structural_edges,
)
.await?;
}
Ok(CognifyResult {
chunks: input.chunks.clone(),
entities: input.entities.clone(),
edges: input.edges.clone(),
summaries: input.summaries.clone(),
edge_types,
embeddings,
indexed_fields,
documents_for_dlt: input.documents.clone(),
already_completed: false,
prior_pipeline_run_id: None,
})
}
pub async fn extract_temporal_events(
input: &ExtractedChunks,
llm: Arc<dyn Llm>,
config: &CognifyConfig,
) -> Result<ExtractedTemporalEvents, CognifyError> {
if input.chunks.is_empty() {
return Ok(ExtractedTemporalEvents {
events: vec![],
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
});
}
let dlt_doc_ids: HashSet<Uuid> = input
.documents
.iter()
.filter(|d| d.document_type == "dlt_row")
.map(|d| d.base.id)
.collect();
let non_dlt_chunks: Vec<&DocumentChunk> = input
.chunks
.iter()
.filter(|c| !dlt_doc_ids.contains(&c.document_id))
.collect();
if non_dlt_chunks.is_empty() {
return Ok(ExtractedTemporalEvents {
events: vec![],
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
});
}
let batch_size = config.data_per_batch;
let semaphore = Arc::new(Semaphore::new(config.max_parallel_extractions));
let extractor = Arc::new(TemporalEventExtractor::new(Arc::clone(&llm)));
let enricher = TemporalEntityEnricher::new(Arc::clone(&llm));
let mut all_events: Vec<TemporalEvent> = Vec::new();
for (batch_idx, batch) in non_dlt_chunks.chunks(batch_size).enumerate() {
let mut extract_tasks = Vec::new();
for chunk in batch {
let ext = Arc::clone(&extractor);
let text = chunk.text.clone();
let sem = Arc::clone(&semaphore);
extract_tasks.push(tokio::spawn(async move {
#[allow(clippy::expect_used, reason = "invariant is upheld by construction")]
let _permit = sem
.acquire()
.await
.expect("semaphore is never closed; created locally in this function");
ext.extract_events(&text).await
}));
}
let batch_results = futures::future::join_all(extract_tasks).await;
let mut batch_events: Vec<TemporalEvent> = Vec::new();
for result in batch_results {
let events = result.map_err(|e| CognifyError::FactExtractionError(e.to_string()))??;
batch_events.extend(events);
}
info!(
"Temporal extraction batch {}/{}: {} events extracted",
batch_idx + 1,
non_dlt_chunks.len().div_ceil(batch_size),
batch_events.len()
);
let enriched = enricher.enrich(batch_events).await?;
all_events.extend(enriched);
}
info!(
"Temporal event extraction complete: {} total events",
all_events.len()
);
Ok(ExtractedTemporalEvents {
events: all_events,
dataset_id: input.dataset_id,
user_id: input.user_id,
tenant_id: input.tenant_id,
})
}
pub async fn add_temporal_data_points(
events: &ExtractedTemporalEvents,
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
) -> Result<CognifyResult, CognifyError> {
if events.events.is_empty() {
info!("No temporal events to persist.");
return Ok(CognifyResult::empty());
}
let mut graph_nodes: Vec<serde_json::Value> = Vec::new();
let mut graph_edges: Vec<EdgeData> = Vec::new();
let mut seen_entity_ids: HashSet<Uuid> = HashSet::new();
let mut seen_edge_keys: HashSet<(String, String, String)> = HashSet::new();
let mut event_ids: Vec<Uuid> = Vec::new();
let mut event_names: Vec<String> = Vec::new();
for event in &events.events {
let event_id = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("event:{}", event.name).as_bytes(),
);
event_ids.push(event_id);
event_names.push(event.name.clone());
let mut event_node = json!({
"id": event_id.to_string(),
"data_type": "Event",
"name": event.name,
});
if let Some(desc) = &event.description {
event_node["description"] = json!(desc);
}
if let Some(loc) = &event.location {
event_node["location"] = json!(loc);
}
graph_nodes.push(event_node);
if let Some(ts) = &event.at {
let ts_id = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("timestamp:{}", ts.time_at).as_bytes(),
);
graph_nodes.push(json!({
"id": ts_id.to_string(),
"data_type": "Timestamp",
"time_at": ts.time_at,
"timestamp_str": ts.timestamp_str,
"year": ts.year,
"month": ts.month,
"day": ts.day,
"hour": ts.hour,
"minute": ts.minute,
"second": ts.second,
}));
let edge_key = (event_id.to_string(), ts_id.to_string(), "at".to_string());
if seen_edge_keys.insert(edge_key) {
graph_edges.push((
event_id.to_string(),
ts_id.to_string(),
"at".to_string(),
build_edge_props(&event_id.to_string(), &ts_id.to_string(), "at"),
));
}
}
if let Some(interval) = &event.during {
let ts_from = &interval.time_from;
let ts_to = &interval.time_to;
let ts_from_id = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("timestamp:{}", ts_from.time_at).as_bytes(),
);
let ts_to_id = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("timestamp:{}", ts_to.time_at).as_bytes(),
);
let interval_id = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("interval:{}:{}", ts_from.time_at, ts_to.time_at).as_bytes(),
);
graph_nodes.push(json!({
"id": ts_from_id.to_string(),
"data_type": "Timestamp",
"time_at": ts_from.time_at,
"timestamp_str": ts_from.timestamp_str,
"year": ts_from.year,
"month": ts_from.month,
"day": ts_from.day,
"hour": ts_from.hour,
"minute": ts_from.minute,
"second": ts_from.second,
}));
graph_nodes.push(json!({
"id": ts_to_id.to_string(),
"data_type": "Timestamp",
"time_at": ts_to.time_at,
"timestamp_str": ts_to.timestamp_str,
"year": ts_to.year,
"month": ts_to.month,
"day": ts_to.day,
"hour": ts_to.hour,
"minute": ts_to.minute,
"second": ts_to.second,
}));
graph_nodes.push(json!({
"id": interval_id.to_string(),
"data_type": "Interval",
}));
let during_key = (
event_id.to_string(),
interval_id.to_string(),
"during".to_string(),
);
if seen_edge_keys.insert(during_key) {
graph_edges.push((
event_id.to_string(),
interval_id.to_string(),
"during".to_string(),
build_edge_props(&event_id.to_string(), &interval_id.to_string(), "during"),
));
}
let from_key = (
interval_id.to_string(),
ts_from_id.to_string(),
"time_from".to_string(),
);
if seen_edge_keys.insert(from_key) {
graph_edges.push((
interval_id.to_string(),
ts_from_id.to_string(),
"time_from".to_string(),
build_edge_props(
&interval_id.to_string(),
&ts_from_id.to_string(),
"time_from",
),
));
}
let to_key = (
interval_id.to_string(),
ts_to_id.to_string(),
"time_to".to_string(),
);
if seen_edge_keys.insert(to_key) {
graph_edges.push((
interval_id.to_string(),
ts_to_id.to_string(),
"time_to".to_string(),
build_edge_props(&interval_id.to_string(), &ts_to_id.to_string(), "time_to"),
));
}
}
for attr in &event.attributes {
let entity_id = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("entity:{}", attr.entity).as_bytes(),
);
if seen_entity_ids.insert(entity_id) {
graph_nodes.push(json!({
"id": entity_id.to_string(),
"data_type": attr.entity_type,
"name": attr.entity,
}));
}
let rel_key = (
event_id.to_string(),
entity_id.to_string(),
attr.relationship.clone(),
);
if seen_edge_keys.insert(rel_key) {
graph_edges.push((
event_id.to_string(),
entity_id.to_string(),
attr.relationship.clone(),
build_edge_props(
&event_id.to_string(),
&entity_id.to_string(),
&attr.relationship,
),
));
}
}
}
if !graph_nodes.is_empty() {
let node_count = graph_nodes.len();
graph_db
.add_nodes_raw(graph_nodes)
.await
.map_err(CognifyError::from)?;
info!("Stored {} temporal graph nodes", node_count);
}
if !graph_edges.is_empty() {
let edge_count = graph_edges.len();
graph_db
.add_edges(&graph_edges)
.await
.map_err(CognifyError::from)?;
info!("Stored {} temporal graph edges", edge_count);
}
let mut indexed_fields = IndexedFieldsStats::default();
if !event_ids.is_empty() {
let dimension = embedding_engine.dimension();
if !vector_db
.has_collection("Event", "name")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection("Event", "name", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let name_strs: Vec<&str> = event_names.iter().map(String::as_str).collect();
let vectors = embedding_engine
.embed(&name_strs)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
let points: Vec<VectorPoint> = event_ids
.iter()
.zip(event_names.iter())
.zip(vectors.iter())
.map(|((id, name), vector)| {
let mut point = VectorPoint::new(*id, vector.clone())
.with_metadata("type", json!("Event"))
.with_metadata("field", json!("name"))
.with_metadata("name", json!(name))
.with_metadata("dataset_id", json!(events.dataset_id.to_string()));
if let Some(uid) = events.user_id {
point = point.with_metadata("user_id", json!(uid.to_string()));
}
if let Some(tid) = events.tenant_id {
point = point.with_metadata("tenant_id", json!(tid.to_string()));
}
point
})
.collect();
vector_db
.index_points("Event", "name", &points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
indexed_fields.record("Event", "name", event_ids.len());
info!("Indexed {} event names in vector DB", event_ids.len());
}
Ok(CognifyResult {
chunks: vec![],
entities: vec![],
edges: vec![],
summaries: vec![],
edge_types: vec![],
embeddings: vec![],
indexed_fields,
documents_for_dlt: vec![],
already_completed: false,
prior_pipeline_run_id: None,
})
}
fn edge_retrieval_text(edge_pair: &GraphEdgePair) -> String {
let from_edge_text = edge_pair
.properties
.get("edge_text")
.map(|s| s.trim())
.filter(|s| !s.is_empty());
if let Some(text) = from_edge_text {
return text.to_string();
}
let rel = edge_pair.relationship_name.trim();
rel.to_string()
}
fn build_edge_props(
source_id: &str,
target_id: &str,
relationship_name: &str,
) -> HashMap<std::borrow::Cow<'static, str>, serde_json::Value> {
let mut props = HashMap::new();
props.insert(
std::borrow::Cow::Borrowed("source_node_id"),
json!(source_id),
);
props.insert(
std::borrow::Cow::Borrowed("target_node_id"),
json!(target_id),
);
props.insert(
std::borrow::Cow::Borrowed("relationship_name"),
json!(relationship_name),
);
props
}
pub async fn extract_dlt_fk_edges(
_chunks: &[DocumentChunk],
documents: &[Document],
graph_db: Arc<dyn GraphDBTrait>,
) -> Result<(), CognifyError> {
let dlt_docs: Vec<&Document> = documents
.iter()
.filter(|d| d.document_type == "dlt_row")
.collect();
if dlt_docs.is_empty() {
return Ok(());
}
info!(
"Processing {} DLT documents for FK edge extraction",
dlt_docs.len()
);
let mut tables_seen: HashMap<String, DltTableMeta> = HashMap::new();
let mut dlt_doc_meta: HashMap<Uuid, serde_json::Value> = HashMap::new();
let mut fk_defs_seen: HashSet<(String, String, String, String)> = HashSet::new();
for doc in &dlt_docs {
let ext_metadata = match &doc.external_metadata {
Some(m) => match serde_json::from_str::<serde_json::Value>(m) {
Ok(v) if v.get("source").and_then(|s| s.as_str()) == Some("dlt") => v,
_ => continue,
},
None => continue,
};
dlt_doc_meta.insert(doc.base.id, ext_metadata.clone());
let table_name = ext_metadata
.get("table_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if !table_name.is_empty() && !tables_seen.contains_key(&table_name) {
tables_seen.insert(
table_name.clone(),
DltTableMeta {
schema_info: ext_metadata.get("schema_info").cloned(),
foreign_keys: ext_metadata
.get("foreign_keys")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default(),
dlt_db_name: ext_metadata
.get("dlt_db_name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string(),
},
);
}
}
if dlt_doc_meta.is_empty() {
return Ok(());
}
let mut all_edges: Vec<cognee_graph::EdgeData> = Vec::new();
let mut table_node_ids: HashMap<String, Uuid> = HashMap::new();
let mut schema_nodes: Vec<serde_json::Value> = Vec::new();
for (table_name, table_meta) in &tables_seen {
let id = Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("dlt:{table_name}").as_bytes());
table_node_ids.insert(table_name.clone(), id);
let columns_str = table_meta
.schema_info
.as_ref()
.map(|v| v.to_string())
.unwrap_or_else(|| "[]".to_string());
let fk_str =
serde_json::to_string(&table_meta.foreign_keys).unwrap_or_else(|_| "[]".to_string());
let table_node = SchemaTableNode {
id: id.to_string(),
name: table_name.clone(),
columns: columns_str,
primary_key: None,
foreign_keys: fk_str,
sample_rows: "[]".to_string(),
row_count_estimate: None,
description: format!(
"DLT-ingested relational table '{}' from database '{}'.",
table_name, table_meta.dlt_db_name
),
data_type: "SchemaTable".to_string(),
};
if let Ok(val) = serde_json::to_value(&table_node) {
schema_nodes.push(val);
}
}
for (table_name, table_meta) in &tables_seen {
for fk in &table_meta.foreign_keys {
let fk_col = fk
.get("column")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let ref_table = fk
.get("ref_table")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let ref_col = fk
.get("ref_column")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if fk_col.is_empty() || ref_table.is_empty() {
continue;
}
let fk_key = (
table_name.clone(),
fk_col.clone(),
ref_table.clone(),
ref_col.clone(),
);
if fk_defs_seen.contains(&fk_key) {
continue;
}
fk_defs_seen.insert(fk_key);
let rel_name = format!("{table_name}:{fk_col}->{ref_table}:{ref_col}");
let rel_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("dlt:{rel_name}").as_bytes());
let rel_node = SchemaRelationshipNode {
id: rel_id.to_string(),
name: rel_name.clone(),
source_table: table_name.clone(),
target_table: ref_table.clone(),
relationship_type: "foreign_key".to_string(),
source_column: fk_col.clone(),
target_column: ref_col.clone(),
description: format!("Foreign key: {table_name}.{fk_col} -> {ref_table}.{ref_col}"),
data_type: "SchemaRelationship".to_string(),
};
if let Ok(val) = serde_json::to_value(&rel_node) {
schema_nodes.push(val);
}
if let Some(&source_table_id) = table_node_ids.get(table_name.as_str()) {
let mut props = HashMap::new();
props.insert(
std::borrow::Cow::Borrowed("source_node_id"),
json!(source_table_id.to_string()),
);
props.insert(
std::borrow::Cow::Borrowed("target_node_id"),
json!(rel_id.to_string()),
);
props.insert(
std::borrow::Cow::Borrowed("relationship_name"),
json!("has_foreign_key"),
);
all_edges.push((
source_table_id.to_string(),
rel_id.to_string(),
"has_foreign_key".to_string(),
props,
));
}
if let Some(&target_table_id) = table_node_ids.get(ref_table.as_str()) {
let mut props = HashMap::new();
props.insert(
std::borrow::Cow::Borrowed("source_node_id"),
json!(rel_id.to_string()),
);
props.insert(
std::borrow::Cow::Borrowed("target_node_id"),
json!(target_table_id.to_string()),
);
props.insert(
std::borrow::Cow::Borrowed("relationship_name"),
json!("references_table"),
);
all_edges.push((
rel_id.to_string(),
target_table_id.to_string(),
"references_table".to_string(),
props,
));
}
}
}
let mut seen_row_edges: HashSet<(String, String, String)> = HashSet::new();
for (doc_id, ext_metadata) in &dlt_doc_meta {
let table_name = ext_metadata
.get("table_name")
.and_then(|v| v.as_str())
.unwrap_or("");
if let Some(&table_node_id) = table_node_ids.get(table_name) {
let mut props = HashMap::new();
props.insert(
std::borrow::Cow::Borrowed("source_node_id"),
json!(doc_id.to_string()),
);
props.insert(
std::borrow::Cow::Borrowed("target_node_id"),
json!(table_node_id.to_string()),
);
props.insert(
std::borrow::Cow::Borrowed("relationship_name"),
json!("is_row_of"),
);
all_edges.push((
doc_id.to_string(),
table_node_id.to_string(),
"is_row_of".to_string(),
props,
));
}
let fk_references = ext_metadata
.get("fk_references")
.and_then(|v| v.as_array())
.cloned()
.unwrap_or_default();
for fk_ref in &fk_references {
let target_data_id = match fk_ref.get("target_data_id").and_then(|v| v.as_str()) {
Some(id) => id.to_string(),
None => continue,
};
let relationship_name = fk_ref
.get("relationship_name")
.and_then(|v| v.as_str())
.unwrap_or("references")
.to_string();
let edge_key = (
doc_id.to_string(),
target_data_id.clone(),
relationship_name.clone(),
);
if seen_row_edges.contains(&edge_key) {
continue;
}
seen_row_edges.insert(edge_key);
let mut props = HashMap::new();
props.insert(
std::borrow::Cow::Borrowed("source_node_id"),
json!(doc_id.to_string()),
);
props.insert(
std::borrow::Cow::Borrowed("target_node_id"),
json!(target_data_id.clone()),
);
props.insert(
std::borrow::Cow::Borrowed("relationship_name"),
json!(relationship_name.clone()),
);
props.insert(
std::borrow::Cow::Borrowed("edge_text"),
json!(relationship_name.replace('_', " ")),
);
props.insert(
std::borrow::Cow::Borrowed("source_table"),
json!(table_name),
);
props.insert(
std::borrow::Cow::Borrowed("target_table"),
json!(
fk_ref
.get("target_table")
.and_then(|v| v.as_str())
.unwrap_or("")
),
);
props.insert(
std::borrow::Cow::Borrowed("fk_column"),
json!(fk_ref.get("column").and_then(|v| v.as_str()).unwrap_or("")),
);
all_edges.push((doc_id.to_string(), target_data_id, relationship_name, props));
}
}
if !schema_nodes.is_empty() {
let node_count = schema_nodes.len();
graph_db
.add_nodes_raw(schema_nodes)
.await
.map_err(CognifyError::from)?;
info!("Added {} DLT schema nodes to graph", node_count);
}
if !all_edges.is_empty() {
graph_db
.add_edges(&all_edges)
.await
.map_err(CognifyError::from)?;
info!(
"Added {} DLT FK edges to graph ({} tables, {} FK definitions)",
all_edges.len(),
table_node_ids.len(),
fk_defs_seen.len()
);
}
Ok(())
}
#[derive(Debug, Serialize)]
struct SchemaTableNode {
id: String,
name: String,
columns: String,
primary_key: Option<String>,
foreign_keys: String,
sample_rows: String,
row_count_estimate: Option<i64>,
description: String,
data_type: String,
}
#[derive(Debug, Serialize)]
struct SchemaRelationshipNode {
id: String,
name: String,
source_table: String,
target_table: String,
relationship_type: String,
source_column: String,
target_column: String,
description: String,
data_type: String,
}
#[derive(Debug)]
struct DltTableMeta {
schema_info: Option<serde_json::Value>,
foreign_keys: Vec<serde_json::Value>,
dlt_db_name: String,
}
fn stamp_provenance(dp: &mut DataPoint, pipeline: &str, task: &str, user: Option<&str>) {
if dp.source_pipeline.is_none() {
dp.source_pipeline = Some(pipeline.to_string());
}
if dp.source_task.is_none() {
dp.source_task = Some(task.to_string());
}
if dp.source_user.is_none() {
dp.source_user = user.map(String::from);
}
}
#[allow(clippy::too_many_arguments)]
pub async fn cognify(
data_items: Vec<Data>,
dataset_id: Uuid,
user_id: Option<Uuid>,
user_email: Option<String>,
tenant_id: Option<Uuid>,
llm: Arc<dyn Llm>,
storage: Arc<dyn StorageTrait>,
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
database: Arc<DatabaseConnection>,
pipeline_run_repo: Arc<dyn PipelineRunRepository>,
thread_pool: Arc<dyn CpuPool>,
ontology_resolver: Arc<dyn OntologyResolver>,
config: &CognifyConfig,
) -> Result<CognifyResult, CognifyError> {
config
.validate()
.map_err(|e| CognifyError::ConfigError(e.to_string()))?;
let effective_config = if config.max_chunk_size == CognifyConfig::default().max_chunk_size {
let cfg = config
.clone()
.with_auto_chunk_size(embedding_engine.as_ref(), llm.as_ref());
info!("Auto-calculated max_chunk_size: {}", cfg.max_chunk_size);
cfg
} else {
config.clone()
};
info!(
"Starting cognify pipeline with config: chunks_per_batch={}, max_chunk_size={}",
effective_config.chunks_per_batch, effective_config.max_chunk_size
);
let pipeline_name: &str = if effective_config.temporal_cognify {
"temporal-cognify"
} else {
"cognify"
};
match check_pipeline_run_qualification(pipeline_run_repo.as_ref(), dataset_id, pipeline_name)
.await
.map_err(|e| CognifyError::DatabaseError(e.to_string()))?
{
Qualification::AlreadyCompleted(prior) => {
info!(
dataset_id = %dataset_id,
pipeline_run_id = %prior.pipeline_run_id,
"cognify: dataset already completed; short-circuiting (Python parity)"
);
return Ok(CognifyResult::already_completed(prior.pipeline_run_id));
}
Qualification::AlreadyRunning(_prior) => {
return Err(CognifyError::PipelineAlreadyRunning {
pipeline_name: pipeline_name.to_string(),
dataset_id,
});
}
Qualification::Proceed => {}
}
if data_items.is_empty() {
return Ok(CognifyResult::empty());
}
let is_temporal = effective_config.temporal_cognify;
let pipeline = if is_temporal {
build_temporal_cognify_pipeline(
Arc::clone(&storage),
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&embedding_engine),
Arc::clone(&llm),
Some(Arc::clone(&database)),
effective_config.clone(),
)
} else {
build_cognify_pipeline(
Arc::clone(&storage),
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&embedding_engine),
Arc::clone(&llm),
Some(Arc::clone(&database)),
Arc::clone(&ontology_resolver),
effective_config.clone(),
)
};
let pipeline_ctx = PipelineContext {
pipeline_id: pipeline.id,
pipeline_name: pipeline.name.clone().unwrap_or_default(),
user_id,
tenant_id,
dataset_id: Some(dataset_id),
current_data: None,
run_id: None,
user_email: user_email.clone(),
provenance_visited: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
};
let (_cancel_handle, ctx) = TaskContextBuilder::new()
.thread_pool(thread_pool)
.database(Arc::clone(&database))
.graph_db(Arc::clone(&graph_db))
.vector_db(Arc::clone(&vector_db))
.pipeline_context(pipeline_ctx)
.build()
.map_err(|e| CognifyError::ContextBuild(e.to_string()))?;
let ctx = Arc::new(ctx);
let input = CognifyInput {
data_items,
dataset_id,
user_id,
tenant_id,
};
let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(input) as Arc<dyn Value>];
let watcher = DbPipelineWatcher::new(pipeline_run_repo);
let outputs = cognee_core::pipeline::execute(&pipeline, inputs, ctx, &watcher)
.await
.map_err(|e| CognifyError::Execute(e.to_string()))?;
let result = extract_cognify_outputs(outputs)?;
if !is_temporal {
extract_dlt_fk_edges(
&result.chunks,
&result.documents_for_dlt,
Arc::clone(&graph_db),
)
.await?;
}
Ok(result)
}
fn extract_cognify_outputs(outputs: Vec<Arc<dyn Value>>) -> Result<CognifyResult, CognifyError> {
let first = outputs
.into_iter()
.next()
.ok_or(CognifyError::OutputTypeMismatch {
expected: "CognifyResult",
actual: "empty",
})?;
(*first)
.as_any()
.downcast_ref::<CognifyResult>()
.cloned()
.ok_or(CognifyError::OutputTypeMismatch {
expected: "CognifyResult",
actual: "unknown",
})
}
fn provenance_node_id(
tenant_id: Option<Uuid>,
user_id: Uuid,
dataset_id: Uuid,
data_id: Uuid,
node_id: Uuid,
) -> Uuid {
let tid = tenant_id.map_or("None".to_string(), |t| t.to_string());
let raw = format!("{tid}{user_id}{dataset_id}{data_id}{node_id}");
Uuid::new_v5(&Uuid::NAMESPACE_OID, raw.as_bytes())
}
fn provenance_edge_id(
tenant_id: Option<Uuid>,
user_id: Uuid,
dataset_id: Uuid,
source_id: Uuid,
edge_text: &str,
target_id: Uuid,
) -> Uuid {
let tid = tenant_id.map_or("None".to_string(), |t| t.to_string());
let raw = format!("{tid}{user_id}{dataset_id}{source_id}{edge_text}{target_id}");
Uuid::new_v5(&Uuid::NAMESPACE_OID, raw.as_bytes())
}
fn edge_slug(edge_text: &str) -> Uuid {
let normalized = edge_text.to_lowercase().replace(' ', "_").replace('\'', "");
Uuid::new_v5(&Uuid::NAMESPACE_OID, normalized.as_bytes())
}
fn triplet_slug(source_id: Uuid, relationship_name: &str, target_id: Uuid) -> Uuid {
let raw = format!("{source_id}{relationship_name}{target_id}");
let normalized = raw.to_lowercase().replace(' ', "_").replace('\'', "");
Uuid::new_v5(&Uuid::NAMESPACE_OID, normalized.as_bytes())
}
#[allow(clippy::too_many_arguments)]
async fn upsert_provenance(
db: &DatabaseConnection,
tenant_id: Option<Uuid>,
user_id: Uuid,
dataset_id: Uuid,
chunks: &[DocumentChunk],
entities: &[GraphNodePair],
edges: &[GraphEdgePair],
summaries: &[TextSummary],
documents: &[Document],
structural_edges: &[EdgeData],
) -> Result<(), CognifyError> {
use cognee_database::ops::graph_storage;
use cognee_database::{GraphEdge, GraphNode};
let chunk_data_map: HashMap<Uuid, Uuid> =
chunks.iter().map(|c| (c.base.id, c.document_id)).collect();
let entity_data_map: HashMap<Uuid, Uuid> = entities
.iter()
.filter_map(|pair| {
pair.entity
.base
.get_metadata("chunk_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok())
.and_then(|chunk_id| chunk_data_map.get(&chunk_id).copied())
.map(|data_id| (pair.entity.base.id, data_id))
})
.collect();
let mut prov_nodes: Vec<GraphNode> = Vec::new();
for pair in entities {
let entity = &pair.entity;
let data_id = entity
.base
.get_metadata("chunk_id")
.and_then(|v| v.as_str())
.and_then(|s| Uuid::parse_str(s).ok())
.and_then(|chunk_id| chunk_data_map.get(&chunk_id).copied())
.unwrap_or(Uuid::nil());
let indexed_fields = entity
.base
.get_metadata("index_fields")
.cloned()
.unwrap_or(json!(["name"]));
let label = if entity.name.is_empty() {
entity.base.id.to_string()
} else {
entity.name.clone()
};
prov_nodes.push(GraphNode {
id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, entity.base.id),
slug: entity.base.id,
user_id,
data_id,
dataset_id,
label: Some(label),
node_type: entity.base.data_type.clone(),
indexed_fields,
attributes: serde_json::to_value(entity).ok(),
created_at: Utc::now(),
});
}
for chunk in chunks {
let data_id = chunk.document_id;
let indexed_fields = chunk
.base
.get_metadata("index_fields")
.cloned()
.unwrap_or(json!(["text"]));
prov_nodes.push(GraphNode {
id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, chunk.base.id),
slug: chunk.base.id,
user_id,
data_id,
dataset_id,
label: Some(format!("chunk_{}", chunk.chunk_index)),
node_type: chunk.base.data_type.clone(),
indexed_fields,
attributes: serde_json::to_value(chunk).ok(),
created_at: Utc::now(),
});
}
for summary in summaries {
let data_id = summary
.made_from
.and_then(|chunk_id| chunk_data_map.get(&chunk_id).copied())
.unwrap_or(Uuid::nil());
let indexed_fields = summary
.base
.get_metadata("index_fields")
.cloned()
.unwrap_or(json!(["text"]));
prov_nodes.push(GraphNode {
id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, summary.base.id),
slug: summary.base.id,
user_id,
data_id,
dataset_id,
label: Some(format!("summary_{}", summary.base.id)),
node_type: summary.base.data_type.clone(),
indexed_fields,
attributes: serde_json::to_value(summary).ok(),
created_at: Utc::now(),
});
}
for pair in entities {
let et = &pair.entity_type;
prov_nodes.push(GraphNode {
id: provenance_node_id(tenant_id, user_id, dataset_id, Uuid::nil(), et.base.id),
slug: et.base.id,
user_id,
data_id: Uuid::nil(),
dataset_id,
label: Some(et.name.clone()),
node_type: et.base.data_type.clone(),
indexed_fields: et
.base
.get_metadata("index_fields")
.cloned()
.unwrap_or(json!(["name"])),
attributes: serde_json::to_value(et).ok(),
created_at: Utc::now(),
});
}
for document in documents {
let data_id = document.base.id;
let indexed_fields = document
.base
.get_metadata("index_fields")
.cloned()
.unwrap_or(json!(["name"]));
let label = if document.name.is_empty() {
document.base.id.to_string()
} else {
document.name.clone()
};
prov_nodes.push(GraphNode {
id: provenance_node_id(tenant_id, user_id, dataset_id, data_id, document.base.id),
slug: document.base.id,
user_id,
data_id,
dataset_id,
label: Some(label),
node_type: document.base.data_type.clone(),
indexed_fields,
attributes: serde_json::to_value(document).ok(),
created_at: Utc::now(),
});
}
if !prov_nodes.is_empty() {
graph_storage::upsert_nodes(db, &prov_nodes).await?;
info!("Upserted {} provenance node records", prov_nodes.len());
}
let mut prov_edges: Vec<GraphEdge> = Vec::new();
for edge_pair in edges {
let edge_text = if edge_pair.relationship_name == "contains" {
edge_pair
.properties
.get("edge_text")
.cloned()
.unwrap_or_else(|| edge_pair.relationship_name.clone())
} else {
edge_pair.relationship_name.clone()
};
let source_data_id = entity_data_map.get(&edge_pair.source_entity_id).copied();
let target_data_id = entity_data_map.get(&edge_pair.target_entity_id).copied();
let data_id = match (source_data_id, target_data_id) {
(Some(source), Some(target)) if source == target => source,
_ => Uuid::nil(),
};
prov_edges.push(GraphEdge {
id: provenance_edge_id(
tenant_id,
user_id,
dataset_id,
edge_pair.source_entity_id,
&edge_text,
edge_pair.target_entity_id,
),
slug: triplet_slug(
edge_pair.source_entity_id,
&edge_text,
edge_pair.target_entity_id,
),
user_id,
data_id,
dataset_id,
source_node_id: edge_pair.source_entity_id,
destination_node_id: edge_pair.target_entity_id,
relationship_name: edge_text,
label: Some(edge_pair.relationship_name.clone()),
attributes: serde_json::to_value(&edge_pair.properties).ok(),
created_at: Utc::now(),
});
}
for (source_id_str, target_id_str, rel_name, properties) in structural_edges {
let source_id = Uuid::parse_str(source_id_str).unwrap_or(Uuid::nil());
let target_id = Uuid::parse_str(target_id_str).unwrap_or(Uuid::nil());
let attrs = if properties.is_empty() {
None
} else {
let map: serde_json::Map<String, serde_json::Value> = properties
.iter()
.map(|(k, v)| (k.to_string(), v.clone()))
.collect();
Some(serde_json::Value::Object(map))
};
prov_edges.push(GraphEdge {
id: provenance_edge_id(
tenant_id, user_id, dataset_id, source_id, rel_name, target_id,
),
slug: edge_slug(rel_name),
user_id,
data_id: Uuid::nil(), dataset_id,
source_node_id: source_id,
destination_node_id: target_id,
relationship_name: rel_name.clone(),
label: None,
attributes: attrs,
created_at: Utc::now(),
});
}
if !prov_edges.is_empty() {
graph_storage::upsert_edges(db, &prov_edges).await?;
info!("Upserted {} provenance edge records", prov_edges.len());
}
Ok(())
}
async fn generate_embeddings(
chunks: &[DocumentChunk],
entities: &[GraphNodePair],
summaries: &[TextSummary],
engine: Arc<dyn EmbeddingEngine>,
) -> Result<Vec<Embedding>, CognifyError> {
let mut embeddings = Vec::new();
if !chunks.is_empty() {
let chunk_texts: Vec<_> = chunks.iter().map(|c| c.text.as_str()).collect();
let chunk_vectors = engine
.embed(&chunk_texts)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
for (chunk, vector) in chunks.iter().zip(chunk_vectors) {
embeddings.push(Embedding::new(
chunk.base.id,
"DocumentChunk",
"text",
vector,
));
}
}
if !entities.is_empty() {
let entity_names: Vec<_> = entities.iter().map(|e| e.entity.name.as_str()).collect();
let entity_vectors = engine
.embed(&entity_names)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
for (entity, vector) in entities.iter().zip(entity_vectors) {
embeddings.push(Embedding::new(
entity.entity.base.id,
"Entity",
"name",
vector,
));
}
}
if !summaries.is_empty() {
let summary_texts: Vec<_> = summaries.iter().map(|s| s.text.as_str()).collect();
let summary_vectors = engine
.embed(&summary_texts)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
for (summary, vector) in summaries.iter().zip(summary_vectors) {
embeddings.push(Embedding::new(
summary.base.id,
"TextSummary",
"text",
vector,
));
}
}
Ok(embeddings)
}
async fn reuse_or_embed(
engine: &Arc<dyn EmbeddingEngine>,
precomputed: &std::collections::HashMap<Uuid, Vec<f32>>,
ids: &[Uuid],
texts: &[&str],
) -> Result<Vec<Vec<f32>>, CognifyError> {
debug_assert_eq!(ids.len(), texts.len(), "ids and texts must be parallel");
let missing_texts: Vec<&str> = ids
.iter()
.zip(texts)
.filter(|(id, _)| !precomputed.contains_key(*id))
.map(|(_, text)| *text)
.collect();
let fresh = if missing_texts.is_empty() {
Vec::new()
} else {
engine
.embed(&missing_texts)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?
};
let mut fresh = fresh.into_iter();
ids.iter()
.map(|id| match precomputed.get(id) {
Some(vector) => Ok(vector.clone()),
None => fresh
.next()
.ok_or_else(|| CognifyError::EmbeddingError("missing fresh embedding".into())),
})
.collect()
}
#[allow(clippy::too_many_arguments)]
async fn index_data_points(
chunks: &[DocumentChunk],
entities: &[GraphNodePair],
summaries: &[TextSummary],
documents: &[Document],
edges: &[GraphEdgePair],
edge_types: &[EdgeType],
dataset_id: Uuid,
user_id: Option<Uuid>,
tenant_id: Option<Uuid>,
engine: Arc<dyn EmbeddingEngine>,
vector_db: Arc<dyn VectorDB>,
config: &CognifyConfig,
precomputed_embeddings: &[Embedding],
) -> Result<IndexedFieldsStats, CognifyError> {
let mut stats = IndexedFieldsStats::default();
let dimension = engine.dimension();
let precomputed: std::collections::HashMap<Uuid, Vec<f32>> = precomputed_embeddings
.iter()
.map(|e| (e.data_point_id, e.vector.clone()))
.collect();
if !chunks.is_empty() {
if !vector_db
.has_collection("DocumentChunk", "text")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection("DocumentChunk", "text", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let ids: Vec<Uuid> = chunks.iter().map(|c| c.base.id).collect();
let texts: Vec<_> = chunks.iter().map(|c| c.text.as_str()).collect();
let vectors = reuse_or_embed(&engine, &precomputed, &ids, &texts).await?;
let points: Vec<VectorPoint> = chunks
.iter()
.zip(vectors)
.map(|(chunk, vector)| {
let mut point = VectorPoint::new(chunk.base.id, vector);
for (k, v) in chunk.base.vector_metadata() {
point = point.with_metadata(k, v);
}
point = point
.with_metadata("field", json!("text"))
.with_metadata("text", json!(chunk.text.clone()))
.with_metadata("dataset_id", json!(dataset_id.to_string()))
.with_metadata("document_id", json!(chunk.document_id.to_string()))
.with_metadata("chunk_index", json!(chunk.chunk_index));
if let Some(uid) = user_id {
point = point.with_metadata("user_id", json!(uid.to_string()));
}
if let Some(tid) = tenant_id {
point = point.with_metadata("tenant_id", json!(tid.to_string()));
}
point
})
.collect();
vector_db
.index_points("DocumentChunk", "text", &points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
stats.record("DocumentChunk", "text", chunks.len());
info!("Indexed {} document chunks", chunks.len());
}
if !entities.is_empty() {
if !vector_db
.has_collection("Entity", "name")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection("Entity", "name", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let ids: Vec<Uuid> = entities.iter().map(|e| e.entity.base.id).collect();
let names: Vec<_> = entities.iter().map(|e| e.entity.name.as_str()).collect();
let vectors = reuse_or_embed(&engine, &precomputed, &ids, &names).await?;
let points: Vec<VectorPoint> = entities
.iter()
.zip(vectors)
.map(|(entity, vector)| {
let mut point = VectorPoint::new(entity.entity.base.id, vector);
for (k, v) in entity.entity.base.vector_metadata() {
point = point.with_metadata(k, v);
}
point = point
.with_metadata("field", json!("name"))
.with_metadata("dataset_id", json!(dataset_id.to_string()))
.with_metadata("entity_type", json!(entity.entity_type.name.clone()));
if let Some(uid) = user_id {
point = point.with_metadata("user_id", json!(uid.to_string()));
}
if let Some(tid) = tenant_id {
point = point.with_metadata("tenant_id", json!(tid.to_string()));
}
point
})
.collect();
vector_db
.index_points("Entity", "name", &points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
stats.record("Entity", "name", entities.len());
info!("Indexed {} entity names", entities.len());
}
{
let mut seen_ids = std::collections::HashSet::new();
let unique_entity_types: Vec<&cognee_models::EntityType> = entities
.iter()
.map(|pair| &pair.entity_type)
.filter(|et| seen_ids.insert(et.base.id))
.collect();
if !unique_entity_types.is_empty() {
if !vector_db
.has_collection("EntityType", "name")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection("EntityType", "name", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let type_names: Vec<_> = unique_entity_types
.iter()
.map(|et| et.name.as_str())
.collect();
let vectors = engine
.embed(&type_names)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
let points: Vec<VectorPoint> = unique_entity_types
.iter()
.zip(vectors)
.map(|(et, vector)| {
let mut point = VectorPoint::new(et.base.id, vector);
for (k, v) in et.base.vector_metadata() {
point = point.with_metadata(k, v);
}
point = point
.with_metadata("field", json!("name"))
.with_metadata("dataset_id", json!(dataset_id.to_string()));
if let Some(uid) = user_id {
point = point.with_metadata("user_id", json!(uid.to_string()));
}
if let Some(tid) = tenant_id {
point = point.with_metadata("tenant_id", json!(tid.to_string()));
}
point
})
.collect();
vector_db
.index_points("EntityType", "name", &points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
stats.record("EntityType", "name", unique_entity_types.len());
info!("Indexed {} entity type names", unique_entity_types.len());
}
}
if !summaries.is_empty() {
if !vector_db
.has_collection("TextSummary", "text")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection("TextSummary", "text", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let ids: Vec<Uuid> = summaries.iter().map(|s| s.base.id).collect();
let texts: Vec<_> = summaries.iter().map(|s| s.text.as_str()).collect();
let vectors = reuse_or_embed(&engine, &precomputed, &ids, &texts).await?;
let points: Vec<VectorPoint> = summaries
.iter()
.zip(vectors)
.map(|(summary, vector)| {
let mut point = VectorPoint::new(summary.base.id, vector);
for (k, v) in summary.base.vector_metadata() {
point = point.with_metadata(k, v);
}
point = point
.with_metadata("field", json!("text"))
.with_metadata("text", json!(summary.text.clone()))
.with_metadata("dataset_id", json!(dataset_id.to_string()));
if let Some(made_from) = summary.made_from {
point = point.with_metadata("chunk_id", json!(made_from.to_string()));
}
if let Some(uid) = user_id {
point = point.with_metadata("user_id", json!(uid.to_string()));
}
if let Some(tid) = tenant_id {
point = point.with_metadata("tenant_id", json!(tid.to_string()));
}
point
})
.collect();
vector_db
.index_points("TextSummary", "text", &points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
stats.record("TextSummary", "text", summaries.len());
info!("Indexed {} summaries", summaries.len());
}
if config.embed_triplets && !edges.is_empty() && !entities.is_empty() {
use crate::triplet_creation::create_triplets_from_graph;
let triplets = create_triplets_from_graph(entities, edges);
if !triplets.is_empty() {
if !vector_db
.has_collection("Triplet", "text")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection("Triplet", "text", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let triplet_texts: Vec<_> = triplets.iter().map(|t| t.text.as_str()).collect();
let triplet_vectors = engine
.embed(&triplet_texts)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
let edge_type_by_text: std::collections::HashMap<&str, &EdgeType> = edge_types
.iter()
.map(|et| (et.relationship_name.as_str(), et))
.collect();
let edge_text_by_triple: std::collections::HashMap<(Uuid, Uuid, &str), String> = edges
.iter()
.map(|e| {
(
(
e.source_entity_id,
e.target_entity_id,
e.relationship_name.as_str(),
),
edge_retrieval_text(e),
)
})
.collect();
let triplet_points: Vec<VectorPoint> = triplets
.iter()
.zip(triplet_vectors)
.map(|(triplet, vector)| {
let mut point = VectorPoint::new(triplet.id, vector)
.with_metadata("type", json!("Triplet"))
.with_metadata("field", json!("text"))
.with_metadata("source_id", json!(triplet.source_entity_id.to_string()))
.with_metadata("target_id", json!(triplet.target_entity_id.to_string()))
.with_metadata("relationship", json!(triplet.relationship_name.clone()));
let edge_type = edge_text_by_triple
.get(&(
triplet.source_entity_id,
triplet.target_entity_id,
triplet.relationship_name.as_str(),
))
.and_then(|text| edge_type_by_text.get(text.as_str()));
if let Some(edge_type) = edge_type {
for (k, v) in edge_type.base.vector_metadata() {
if matches!(
k.as_str(),
"source_pipeline"
| "source_task"
| "source_user"
| "source_node_set"
| "source_content_hash"
) {
point = point.with_metadata(k, v);
}
}
}
point
})
.collect();
vector_db
.index_points("Triplet", "text", &triplet_points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
stats.triplet_count = triplets.len();
info!("Indexed {} triplets", triplets.len());
}
} else if config.embed_triplets {
info!("Triplet embedding enabled but no edges/entities to index");
}
if !edge_types.is_empty() {
if !vector_db
.has_collection("EdgeType", "relationship_name")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection("EdgeType", "relationship_name", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let names: Vec<&str> = edge_types
.iter()
.map(|et| et.relationship_name.as_str())
.collect();
let vectors = engine
.embed(&names)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
let points: Vec<VectorPoint> = edge_types
.iter()
.zip(vectors)
.map(|(et, vector)| {
let mut point = VectorPoint::new(et.base.id, vector);
for (k, v) in et.base.vector_metadata() {
point = point.with_metadata(k, v);
}
point = point
.with_metadata("field", json!("relationship_name"))
.with_metadata("relationship_name", json!(et.relationship_name.clone()))
.with_metadata("number_of_edges", json!(et.number_of_edges))
.with_metadata("dataset_id", json!(dataset_id.to_string()));
if let Some(uid) = user_id {
point = point.with_metadata("user_id", json!(uid.to_string()));
}
if let Some(tid) = tenant_id {
point = point.with_metadata("tenant_id", json!(tid.to_string()));
}
point
})
.collect();
vector_db
.index_points("EdgeType", "relationship_name", &points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
stats.record("EdgeType", "relationship_name", edge_types.len());
info!("Indexed {} edge types", edge_types.len());
}
if !documents.is_empty() {
let mut by_type: std::collections::BTreeMap<&str, Vec<&Document>> =
std::collections::BTreeMap::new();
for d in documents {
by_type
.entry(d.base.data_type.as_str())
.or_default()
.push(d);
}
for (type_name, docs) in by_type {
if !vector_db
.has_collection(type_name, "name")
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?
{
vector_db
.create_collection(type_name, "name", dimension)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
}
let names: Vec<&str> = docs.iter().map(|d| d.name.as_str()).collect();
let vectors = engine
.embed(&names)
.await
.map_err(|e| CognifyError::EmbeddingError(e.to_string()))?;
let points: Vec<VectorPoint> = docs
.iter()
.zip(vectors)
.map(|(doc, vector)| {
let mut point = VectorPoint::new(doc.base.id, vector);
for (k, v) in doc.base.vector_metadata() {
point = point.with_metadata(k, v);
}
point = point
.with_metadata("field", json!("name"))
.with_metadata("name", json!(doc.name.clone()))
.with_metadata("dataset_id", json!(dataset_id.to_string()));
if let Some(uid) = user_id {
point = point.with_metadata("user_id", json!(uid.to_string()));
}
if let Some(tid) = tenant_id {
point = point.with_metadata("tenant_id", json!(tid.to_string()));
}
point
})
.collect();
vector_db
.index_points(type_name, "name", &points)
.await
.map_err(|e| CognifyError::VectorDBError(e.to_string()))?;
stats.record(type_name, "name", docs.len());
info!("Indexed {} {}", docs.len(), type_name);
}
}
Ok(stats)
}
pub const CLASSIFY_DOCUMENTS_TASK_NAME: &str = "classify_documents";
pub const EXTRACT_CHUNKS_TASK_NAME: &str = "extract_chunks_from_documents";
pub const EXTRACT_GRAPH_TASK_NAME: &str = "extract_graph_from_data";
pub const SUMMARIZE_TEXT_TASK_NAME: &str = "summarize_text";
pub const ADD_DATA_POINTS_TASK_NAME: &str = "add_data_points";
const COGNIFY_PIPELINE_STAMP_NAME: &str = "cognify";
fn user_label_from_ctx(ctx: &Arc<cognee_core::TaskContext>) -> Option<String> {
ctx.pipeline_ctx.as_ref().and_then(|p| p.user_label())
}
pub fn make_classify_documents_task() -> TypedTask<CognifyInput, ClassifiedDocuments> {
TypedTask::sync(|input: &CognifyInput, ctx| {
let mut classified = classify_documents(input).map_err(|e| format!("{e}"))?;
let user_label = user_label_from_ctx(&ctx);
for doc in &mut classified.documents {
stamp_provenance(
&mut doc.base,
COGNIFY_PIPELINE_STAMP_NAME,
CLASSIFY_DOCUMENTS_TASK_NAME,
user_label.as_deref(),
);
}
Ok(Box::new(classified))
})
}
pub fn make_extract_chunks_task(
storage: Arc<dyn StorageTrait>,
max_chunk_size: usize,
token_counter_kind: TokenCounterKind,
db: Option<Arc<DatabaseConnection>>,
loader_registry: Arc<LoaderRegistry>,
) -> TypedTask<ClassifiedDocuments, ExtractedChunks> {
TypedTask::async_fn(move |input: &ClassifiedDocuments, ctx| {
let input = input.clone();
let storage = Arc::clone(&storage);
let db = db.clone();
let token_counter_kind = token_counter_kind.clone();
let loader_registry = Arc::clone(&loader_registry);
let user_label = user_label_from_ctx(&ctx);
Box::pin(async move {
let mut extracted = extract_chunks_from_documents(
&input,
&*storage,
max_chunk_size,
token_counter_kind,
db.as_deref(),
&loader_registry,
)
.await
.map_err(|e| format!("{e}"))?;
for chunk in &mut extracted.chunks {
stamp_provenance(
&mut chunk.base,
COGNIFY_PIPELINE_STAMP_NAME,
EXTRACT_CHUNKS_TASK_NAME,
user_label.as_deref(),
);
}
for doc in &mut extracted.documents {
stamp_provenance(
&mut doc.base,
COGNIFY_PIPELINE_STAMP_NAME,
EXTRACT_CHUNKS_TASK_NAME,
user_label.as_deref(),
);
}
Ok(Box::new(extracted))
})
})
}
pub fn make_extract_graph_task(
llm: Arc<dyn Llm>,
graph_db: Arc<dyn GraphDBTrait>,
ontology_resolver: Arc<dyn OntologyResolver>,
config: CognifyConfig,
) -> TypedTask<ExtractedChunks, ExtractedGraphData> {
TypedTask::async_fn(move |input: &ExtractedChunks, ctx| {
let input = input.clone();
let llm = Arc::clone(&llm);
let graph_db = Arc::clone(&graph_db);
let ontology_resolver = Arc::clone(&ontology_resolver);
let config = config.clone();
let user_label = user_label_from_ctx(&ctx);
Box::pin(async move {
let mut graph_data = extract_graph_from_data(
&input,
llm,
Arc::clone(&graph_db),
ontology_resolver,
&config,
user_label.as_deref(),
)
.await
.map_err(|e| format!("{e}"))?;
if config.create_web_page_nodes {
create_web_page_nodes(&graph_data.documents, &graph_data.chunks, graph_db)
.await
.map_err(|e| format!("{e}"))?;
}
for pair in &mut graph_data.entities {
stamp_provenance(
&mut pair.entity.base,
COGNIFY_PIPELINE_STAMP_NAME,
EXTRACT_GRAPH_TASK_NAME,
user_label.as_deref(),
);
stamp_provenance(
&mut pair.entity_type.base,
COGNIFY_PIPELINE_STAMP_NAME,
EXTRACT_GRAPH_TASK_NAME,
user_label.as_deref(),
);
}
for chunk in &mut graph_data.chunks {
stamp_provenance(
&mut chunk.base,
COGNIFY_PIPELINE_STAMP_NAME,
EXTRACT_GRAPH_TASK_NAME,
user_label.as_deref(),
);
}
for doc in &mut graph_data.documents {
stamp_provenance(
&mut doc.base,
COGNIFY_PIPELINE_STAMP_NAME,
EXTRACT_GRAPH_TASK_NAME,
user_label.as_deref(),
);
}
Ok(Box::new(graph_data))
})
})
}
pub fn make_summarize_text_task(
llm: Arc<dyn Llm>,
config: CognifyConfig,
) -> TypedTask<ExtractedGraphData, SummarizedData> {
TypedTask::async_fn(move |input: &ExtractedGraphData, ctx| {
let input = input.clone();
let llm = Arc::clone(&llm);
let config = config.clone();
let user_label = user_label_from_ctx(&ctx);
Box::pin(async move {
let mut summarized = summarize_text(&input, llm, &config)
.await
.map_err(|e| format!("{e}"))?;
for summary in &mut summarized.summaries {
stamp_provenance(
&mut summary.base,
COGNIFY_PIPELINE_STAMP_NAME,
SUMMARIZE_TEXT_TASK_NAME,
user_label.as_deref(),
);
}
for chunk in &mut summarized.chunks {
stamp_provenance(
&mut chunk.base,
COGNIFY_PIPELINE_STAMP_NAME,
SUMMARIZE_TEXT_TASK_NAME,
user_label.as_deref(),
);
}
for doc in &mut summarized.documents {
stamp_provenance(
&mut doc.base,
COGNIFY_PIPELINE_STAMP_NAME,
SUMMARIZE_TEXT_TASK_NAME,
user_label.as_deref(),
);
}
for pair in &mut summarized.entities {
stamp_provenance(
&mut pair.entity.base,
COGNIFY_PIPELINE_STAMP_NAME,
SUMMARIZE_TEXT_TASK_NAME,
user_label.as_deref(),
);
stamp_provenance(
&mut pair.entity_type.base,
COGNIFY_PIPELINE_STAMP_NAME,
SUMMARIZE_TEXT_TASK_NAME,
user_label.as_deref(),
);
}
Ok(Box::new(summarized))
})
})
}
pub fn make_add_data_points_task(
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
db: Option<Arc<DatabaseConnection>>,
config: CognifyConfig,
) -> TypedTask<SummarizedData, CognifyResult> {
TypedTask::async_fn(move |input: &SummarizedData, ctx| {
let input = input.clone();
let graph_db = Arc::clone(&graph_db);
let vector_db = Arc::clone(&vector_db);
let embedding_engine = Arc::clone(&embedding_engine);
let db = db.clone();
let config = config.clone();
let user_label = user_label_from_ctx(&ctx);
Box::pin(async move {
let mut result =
add_data_points(&input, graph_db, vector_db, embedding_engine, db, &config)
.await
.map_err(|e| format!("{e}"))?;
for chunk in &mut result.chunks {
stamp_provenance(
&mut chunk.base,
COGNIFY_PIPELINE_STAMP_NAME,
ADD_DATA_POINTS_TASK_NAME,
user_label.as_deref(),
);
}
for pair in &mut result.entities {
stamp_provenance(
&mut pair.entity.base,
COGNIFY_PIPELINE_STAMP_NAME,
ADD_DATA_POINTS_TASK_NAME,
user_label.as_deref(),
);
stamp_provenance(
&mut pair.entity_type.base,
COGNIFY_PIPELINE_STAMP_NAME,
ADD_DATA_POINTS_TASK_NAME,
user_label.as_deref(),
);
}
for summary in &mut result.summaries {
stamp_provenance(
&mut summary.base,
COGNIFY_PIPELINE_STAMP_NAME,
ADD_DATA_POINTS_TASK_NAME,
user_label.as_deref(),
);
}
for edge_type in &mut result.edge_types {
stamp_provenance(
&mut edge_type.base,
COGNIFY_PIPELINE_STAMP_NAME,
ADD_DATA_POINTS_TASK_NAME,
user_label.as_deref(),
);
}
for doc in &mut result.documents_for_dlt {
stamp_provenance(
&mut doc.base,
COGNIFY_PIPELINE_STAMP_NAME,
ADD_DATA_POINTS_TASK_NAME,
user_label.as_deref(),
);
}
Ok(Box::new(result))
})
})
}
#[cfg_attr(
not(any(feature = "image-loader", feature = "audio-loader")),
allow(unused_variables)
)]
fn build_loader_registry(llm: &Arc<dyn Llm>, config: &CognifyConfig) -> LoaderRegistry {
#[allow(unused_mut)]
let mut registry = LoaderRegistry::default_registry();
#[cfg(feature = "image-loader")]
registry.register("image", Arc::new(ImageLoader::new(Arc::clone(llm))));
#[cfg(feature = "audio-loader")]
if let Some(ref transcriber_handle) = config.transcriber {
registry.register(
"audio",
Arc::new(AudioLoader::new(Arc::clone(&transcriber_handle.0))),
);
}
registry
}
#[allow(clippy::too_many_arguments)]
pub fn build_cognify_pipeline(
storage: Arc<dyn StorageTrait>,
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
llm: Arc<dyn Llm>,
db: Option<Arc<DatabaseConnection>>,
ontology_resolver: Arc<dyn OntologyResolver>,
config: CognifyConfig,
) -> Pipeline {
let loader_registry = Arc::new(build_loader_registry(&llm, &config));
PipelineBuilder::new_with_task("cognify", make_classify_documents_task())
.with_first_task_name(CLASSIFY_DOCUMENTS_TASK_NAME)
.add_task_named(
make_extract_chunks_task(
storage,
config.max_chunk_size,
config.token_counter_kind.clone(),
db.clone(),
loader_registry,
),
EXTRACT_CHUNKS_TASK_NAME,
)
.add_task_named(
make_extract_graph_task(
Arc::clone(&llm),
Arc::clone(&graph_db),
ontology_resolver,
config.clone(),
),
EXTRACT_GRAPH_TASK_NAME,
)
.add_task_named(
make_summarize_text_task(llm, config.clone()),
SUMMARIZE_TEXT_TASK_NAME,
)
.add_task_named(
make_add_data_points_task(graph_db, vector_db, embedding_engine, db, config),
ADD_DATA_POINTS_TASK_NAME,
)
.with_name("cognify")
.build()
}
pub fn make_extract_temporal_events_task(
llm: Arc<dyn Llm>,
config: CognifyConfig,
) -> TypedTask<ExtractedChunks, ExtractedTemporalEvents> {
TypedTask::async_fn(move |input: &ExtractedChunks, _ctx| {
let input = input.clone();
let llm = Arc::clone(&llm);
let config = config.clone();
Box::pin(async move {
extract_temporal_events(&input, llm, &config)
.await
.map(Box::new)
.map_err(|e| format!("{e}").into())
})
})
}
pub fn make_add_temporal_data_points_task(
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
) -> TypedTask<ExtractedTemporalEvents, CognifyResult> {
TypedTask::async_fn(move |input: &ExtractedTemporalEvents, _ctx| {
let input = input.clone();
let graph_db = Arc::clone(&graph_db);
let vector_db = Arc::clone(&vector_db);
let embedding_engine = Arc::clone(&embedding_engine);
Box::pin(async move {
add_temporal_data_points(&input, graph_db, vector_db, embedding_engine)
.await
.map(Box::new)
.map_err(|e| format!("{e}").into())
})
})
}
pub fn build_temporal_cognify_pipeline(
storage: Arc<dyn StorageTrait>,
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
llm: Arc<dyn Llm>,
db: Option<Arc<DatabaseConnection>>,
config: CognifyConfig,
) -> Pipeline {
let loader_registry = Arc::new(build_loader_registry(&llm, &config));
PipelineBuilder::new_with_task("temporal-cognify", make_classify_documents_task())
.with_first_task_name(CLASSIFY_DOCUMENTS_TASK_NAME)
.add_task_named(
make_extract_chunks_task(
storage,
config.max_chunk_size,
config.token_counter_kind.clone(),
db,
loader_registry,
),
EXTRACT_CHUNKS_TASK_NAME,
)
.add_task_named(
make_extract_temporal_events_task(llm, config),
"extract_temporal_events",
)
.add_task_named(
make_add_temporal_data_points_task(graph_db, vector_db, embedding_engine),
"add_temporal_data_points",
)
.with_name("temporal-cognify")
.build()
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
use cognee_models::DataPoint;
use cognee_storage::MockStorage;
#[test]
fn test_classify_documents_empty() {
let input = CognifyInput {
data_items: vec![],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let result = classify_documents(&input).unwrap();
assert!(result.documents.is_empty());
}
#[test]
fn test_classify_documents_text_data() {
let data = Data::builder(
Uuid::new_v4(),
"test.txt",
"/storage/test.txt",
"text://test",
"txt",
"text/plain",
"hash123",
Uuid::new_v4(),
)
.build();
let input = CognifyInput {
data_items: vec![data],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let result = classify_documents(&input).unwrap();
assert_eq!(result.documents.len(), 1);
}
#[test]
fn test_classify_documents_skips_unknown_extension() {
let data = Data::builder(
Uuid::new_v4(),
"data.xyz",
"/storage/data.xyz",
"file://data.xyz",
"xyz",
"application/octet-stream",
"hash456",
Uuid::new_v4(),
)
.build();
let input = CognifyInput {
data_items: vec![data],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let result = classify_documents(&input).unwrap();
assert!(result.documents.is_empty());
}
#[tokio::test]
async fn test_extract_chunks_from_documents() {
let storage = Arc::new(MockStorage::new());
let location = storage
.store(b"Hello world. This is a test.", "test.txt")
.await
.unwrap();
let doc_id = Uuid::new_v4();
let mut base = DataPoint::new("TextDocument", None);
base.id = doc_id;
base.set_metadata("index_fields", serde_json::json!(["name"]));
let doc = Document {
base,
document_type: "text".to_string(),
name: "test.txt".to_string(),
raw_data_location: location,
mime_type: "text/plain".to_string(),
extension: "txt".to_string(),
data_id: doc_id,
external_metadata: None,
};
let input = ClassifiedDocuments {
documents: vec![doc],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let registry = LoaderRegistry::default();
let result = extract_chunks_from_documents(
&input,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await
.unwrap();
assert!(!result.chunks.is_empty());
}
#[tokio::test]
async fn test_extract_chunks_empty_documents() {
let storage = Arc::new(MockStorage::new());
let input = ClassifiedDocuments {
documents: vec![],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let registry = LoaderRegistry::default();
let result = extract_chunks_from_documents(
&input,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await
.unwrap();
assert!(result.chunks.is_empty());
}
#[tokio::test]
async fn test_dlt_short_circuit() {
let storage = Arc::new(MockStorage::new());
let location = storage
.store(b" some dlt row content ", "dlt.txt")
.await
.unwrap();
let doc_id = Uuid::new_v4();
let mut base = DataPoint::new("DltRowDocument", None);
base.id = doc_id;
base.set_metadata("index_fields", serde_json::json!(["text"]));
let doc = Document {
base,
document_type: "dlt_row".to_string(),
name: "dlt.txt".to_string(),
raw_data_location: location,
mime_type: "text/plain".to_string(),
extension: "txt".to_string(),
data_id: doc_id,
external_metadata: None,
};
let input = ClassifiedDocuments {
documents: vec![doc],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let registry = LoaderRegistry::default();
let result = extract_chunks_from_documents(
&input,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await
.unwrap();
assert_eq!(result.chunks.len(), 1);
let chunk = &result.chunks[0];
assert_eq!(chunk.text, "some dlt row content");
assert_eq!(chunk.cut_type, "dlt_row");
assert_eq!(chunk.chunk_index, 0);
assert_eq!(chunk.document_id, doc_id);
}
#[tokio::test]
async fn test_unsupported_document_type() {
const UNSUPPORTED: &str = "no_such_loader_type_for_test";
let storage = Arc::new(MockStorage::new());
let location = storage.store(b"some content", "test.bin").await.unwrap();
let doc_id = Uuid::new_v4();
let mut base = DataPoint::new("UnknownDocument", None);
base.id = doc_id;
base.set_metadata("index_fields", serde_json::json!(["text"]));
let doc = Document {
base,
document_type: UNSUPPORTED.to_string(),
name: "test.bin".to_string(),
raw_data_location: location,
mime_type: "application/octet-stream".to_string(),
extension: "bin".to_string(),
data_id: doc_id,
external_metadata: None,
};
let input = ClassifiedDocuments {
documents: vec![doc],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let registry = LoaderRegistry::default();
let result = extract_chunks_from_documents(
&input,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err();
assert!(
matches!(err, CognifyError::UnsupportedDocumentType(ref t) if t == UNSUPPORTED),
"expected UnsupportedDocumentType({UNSUPPORTED:?}), got: {err:?}"
);
}
#[test]
fn test_classify_documents_preserves_dataset_id() {
let dataset_id = Uuid::new_v4();
let input = CognifyInput {
data_items: vec![],
dataset_id,
user_id: None,
tenant_id: None,
};
let result = classify_documents(&input).unwrap();
assert_eq!(result.dataset_id, dataset_id);
}
#[test]
fn provenance_node_id_works_with_none_tenant() {
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
let dataset_id = Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap();
let data_id = Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap();
let node_id = Uuid::parse_str("00000000-0000-0000-0000-000000000004").unwrap();
let id = provenance_node_id(None, user_id, dataset_id, data_id, node_id);
let expected_input = format!("None{user_id}{dataset_id}{data_id}{node_id}");
let expected = Uuid::new_v5(&Uuid::NAMESPACE_OID, expected_input.as_bytes());
assert_eq!(id, expected);
}
#[test]
fn provenance_node_id_with_real_tenant_differs_from_none() {
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
let dataset_id = Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap();
let data_id = Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap();
let node_id = Uuid::parse_str("00000000-0000-0000-0000-000000000004").unwrap();
let tenant_id = Uuid::parse_str("00000000-0000-0000-0000-000000000005").unwrap();
let id_none = provenance_node_id(None, user_id, dataset_id, data_id, node_id);
let id_real = provenance_node_id(Some(tenant_id), user_id, dataset_id, data_id, node_id);
assert_ne!(id_none, id_real);
}
#[test]
fn provenance_edge_id_works_with_none_tenant() {
let user_id = Uuid::parse_str("00000000-0000-0000-0000-000000000001").unwrap();
let dataset_id = Uuid::parse_str("00000000-0000-0000-0000-000000000002").unwrap();
let source_id = Uuid::parse_str("00000000-0000-0000-0000-000000000003").unwrap();
let target_id = Uuid::parse_str("00000000-0000-0000-0000-000000000004").unwrap();
let id = provenance_edge_id(
None,
user_id,
dataset_id,
source_id,
"relates_to",
target_id,
);
let expected_input = format!("None{user_id}{dataset_id}{source_id}relates_to{target_id}");
let expected = Uuid::new_v5(&Uuid::NAMESPACE_OID, expected_input.as_bytes());
assert_eq!(id, expected);
}
#[test]
fn dlt_fk_rel_name_always_includes_ref_col_separator() {
let table_name = "orders";
let fk_col = "customer_id";
let ref_table = "customers";
let ref_col = "id";
let rel_name = format!("{table_name}:{fk_col}->{ref_table}:{ref_col}");
assert_eq!(rel_name, "orders:customer_id->customers:id");
let rel_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, format!("dlt:{rel_name}").as_bytes());
let expected_id = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
b"dlt:orders:customer_id->customers:id",
);
assert_eq!(rel_id, expected_id);
let ref_col_empty = "";
let rel_name_empty = format!("{table_name}:{fk_col}->{ref_table}:{ref_col_empty}");
assert_eq!(
rel_name_empty, "orders:customer_id->customers:",
"rel_name must include trailing colon even when ref_col is empty"
);
let rel_id_empty = Uuid::new_v5(
&Uuid::NAMESPACE_OID,
format!("dlt:{rel_name_empty}").as_bytes(),
);
let expected_id_empty =
Uuid::new_v5(&Uuid::NAMESPACE_OID, b"dlt:orders:customer_id->customers:");
assert_eq!(rel_id_empty, expected_id_empty);
assert_ne!(
rel_id, rel_id_empty,
"non-empty and empty ref_col must produce different UUIDs"
);
}
#[test]
fn provenance_guard_does_not_require_tenant_id() {
let db: Option<u8> = Some(1); let user_id: Option<Uuid> = Some(Uuid::new_v4());
let tenant_id: Option<Uuid> = None;
let guard_fires = matches!((&db, user_id), (Some(_), Some(_)));
assert!(
guard_fires,
"Provenance guard must fire when db + user_id are present, regardless of tenant_id"
);
let old_guard_fires = matches!((&db, user_id, tenant_id), (Some(_), Some(_), Some(_)));
assert!(
!old_guard_fires,
"The old 3-way guard should NOT fire when tenant_id is None"
);
}
fn test_document_with_metadata(doc_id: Uuid, external_metadata: Option<String>) -> Document {
let mut base = DataPoint::new("TextDocument", None);
base.id = doc_id;
Document {
base,
document_type: "text".to_string(),
name: "test.txt".to_string(),
raw_data_location: "file:///tmp/test.txt".to_string(),
mime_type: "text/plain".to_string(),
extension: "txt".to_string(),
data_id: doc_id,
external_metadata,
}
}
fn test_chunk(chunk_id: Uuid, doc_id: Uuid, text: &str) -> DocumentChunk {
DocumentChunk::new(
chunk_id,
text.to_string(),
text.split_whitespace().count(),
0,
"paragraph_end".to_string(),
doc_id,
)
}
fn test_entity(name: &str, entity_type_id: Uuid) -> GraphNodePair {
let mut entity_base = DataPoint::new("Entity", None);
entity_base.id = Uuid::new_v4();
let entity = cognee_models::Entity {
base: entity_base,
name: name.to_string(),
is_a: None,
description: format!("description of {name}"),
};
let mut type_base = DataPoint::new("EntityType", None);
type_base.id = entity_type_id;
let entity_type = cognee_models::EntityType {
base: type_base,
name: "Generic".to_string(),
description: "Generic type".to_string(),
};
GraphNodePair {
entity,
entity_type,
}
}
#[tokio::test]
async fn embedding_reuse_avoids_double_pass() {
use cognee_embedding::MockEmbeddingEngine;
use cognee_vector::MockVectorDB;
let engine = Arc::new(MockEmbeddingEngine::new(8));
let engine_dyn: Arc<dyn EmbeddingEngine> = engine.clone();
let vector: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let doc_id = Uuid::new_v4();
let chunks = vec![
test_chunk(Uuid::new_v4(), doc_id, "first chunk text"),
test_chunk(Uuid::new_v4(), doc_id, "second chunk text"),
];
let shared_type_id = Uuid::new_v4();
let entities = vec![
test_entity("Alice", shared_type_id),
test_entity("Bob", shared_type_id),
];
let summaries = vec![TextSummary::new(
chunks[0].base.id,
"a summary".to_string(),
None,
"mock-model".to_string(),
)];
let dataset_id = Uuid::new_v4();
let config = CognifyConfig::default();
let embeddings = generate_embeddings(&chunks, &entities, &summaries, engine_dyn.clone())
.await
.unwrap();
assert_eq!(embeddings.len(), 5);
assert_eq!(engine.embedded_text_count(), 5);
index_data_points(
&chunks,
&entities,
&summaries,
&[],
&[],
&[],
dataset_id,
None,
None,
engine_dyn,
vector,
&config,
&embeddings,
)
.await
.unwrap();
assert_eq!(engine.embedded_text_count(), 6);
}
#[tokio::test]
async fn report_embedding_reuse_savings() {
use cognee_embedding::MockEmbeddingEngine;
use cognee_vector::MockVectorDB;
let doc_id = Uuid::new_v4();
let chunks: Vec<DocumentChunk> = (0..24)
.map(|i| test_chunk(Uuid::new_v4(), doc_id, &format!("chunk text number {i}")))
.collect();
let type_ids: Vec<Uuid> = (0..4).map(|_| Uuid::new_v4()).collect();
let entities: Vec<GraphNodePair> = (0..16)
.map(|i| test_entity(&format!("Entity {i}"), type_ids[i % 4]))
.collect();
let summaries: Vec<TextSummary> = (0..10)
.map(|i| {
TextSummary::new(
Uuid::new_v4(),
format!("summary number {i}"),
None,
"mock-model".to_string(),
)
})
.collect();
let overlap = chunks.len() + entities.len() + summaries.len();
let dataset_id = Uuid::new_v4();
let config = CognifyConfig::default();
async fn measure(
reuse: bool,
chunks: &[DocumentChunk],
entities: &[GraphNodePair],
summaries: &[TextSummary],
dataset_id: Uuid,
config: &CognifyConfig,
) -> (usize, usize) {
let engine = Arc::new(MockEmbeddingEngine::new(8));
let engine_dyn: Arc<dyn EmbeddingEngine> = engine.clone();
let vector: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let embeddings = generate_embeddings(chunks, entities, summaries, engine_dyn.clone())
.await
.unwrap();
let precomputed: &[Embedding] = if reuse { &embeddings } else { &[] };
index_data_points(
chunks,
entities,
summaries,
&[],
&[],
&[],
dataset_id,
None,
None,
engine_dyn,
vector,
config,
precomputed,
)
.await
.unwrap();
(engine.call_count(), engine.embedded_text_count())
}
let (before_calls, before_texts) =
measure(false, &chunks, &entities, &summaries, dataset_id, &config).await;
let (after_calls, after_texts) =
measure(true, &chunks, &entities, &summaries, dataset_id, &config).await;
println!(
"\n Embedding work per cognify ({} chunks / {} entities / {} summaries):",
chunks.len(),
entities.len(),
summaries.len()
);
println!(" BEFORE (double pass): {before_calls} embed() calls, {before_texts} texts");
println!(" AFTER (reuse) : {after_calls} embed() calls, {after_texts} texts");
println!(
" Saved: {} texts ({:.0}% fewer)\n",
before_texts - after_texts,
100.0 * (before_texts - after_texts) as f64 / before_texts as f64,
);
assert_eq!(before_texts - after_texts, overlap);
}
fn url_metadata(url: &str, final_url: &str, title: &str) -> String {
json!({
"source": "url",
"url": url,
"final_url": final_url,
"content_type": "text/html",
"title": title,
})
.to_string()
}
#[tokio::test]
async fn add_data_points_stores_document_node_and_indexes_document_name() {
use cognee_embedding::MockEmbeddingEngine;
use cognee_vector::MockVectorDB;
let graph: Arc<dyn GraphDBTrait> = Arc::new(cognee_graph::MockGraphDB::new());
let vector: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
let doc_id = Uuid::parse_str("00000000-0000-0000-0000-0000000000a1").unwrap();
let chunk_id = Uuid::parse_str("00000000-0000-0000-0000-0000000000b1").unwrap();
let document = test_document_with_metadata(doc_id, None);
let chunk = test_chunk(chunk_id, doc_id, "Hello world");
let input = SummarizedData {
chunks: vec![chunk],
documents: vec![document],
entities: vec![],
edges: vec![],
summaries: vec![],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let config = CognifyConfig::default();
add_data_points(
&input,
Arc::clone(&graph),
Arc::clone(&vector),
Arc::clone(&engine),
None,
&config,
)
.await
.unwrap();
let node = graph
.get_node(&doc_id.to_string())
.await
.unwrap()
.expect("document node should exist");
assert_eq!(
node.get("type").and_then(|v| v.as_str()),
Some("TextDocument")
);
assert!(vector.has_collection("TextDocument", "name").await.unwrap());
assert_eq!(
vector
.collection_size("TextDocument", "name")
.await
.unwrap(),
1
);
}
#[tokio::test]
async fn extracted_edge_description_persists_as_edge_text_property() {
use crate::fact_extraction::{Edge, KnowledgeGraph, Node};
use cognee_ontology::NoOpOntologyResolver;
let graph = KnowledgeGraph {
nodes: vec![
Node {
id: "alice".to_string(),
name: "Alice".to_string(),
node_type: "PERSON".to_string(),
description: "A person".to_string(),
},
Node {
id: "acme".to_string(),
name: "Acme".to_string(),
node_type: "ORGANIZATION".to_string(),
description: "A company".to_string(),
},
],
edges: vec![Edge {
source_node_id: "alice".to_string(),
target_node_id: "acme".to_string(),
relationship_name: "founded".to_string(),
description: Some(" Alice founded Acme ".to_string()),
}],
};
let chunk_id = Uuid::new_v4();
let dataset_id = Uuid::new_v4();
let resolver = NoOpOntologyResolver::new();
let (_nodes, edges) = expand_with_nodes_and_edges(
vec![(chunk_id, graph)],
dataset_id,
&HashSet::new(),
&resolver,
None,
)
.await;
assert_eq!(edges.len(), 1);
let edge_text = edges[0]
.properties
.get("edge_text")
.expect("edge_text property should be set");
assert_eq!(edge_text, "Alice founded Acme");
}
#[test]
fn cognify_config_creates_web_page_nodes_by_default() {
assert!(CognifyConfig::default().create_web_page_nodes);
assert!(
!CognifyConfig::default()
.with_web_page_nodes(false)
.create_web_page_nodes
);
}
#[tokio::test]
async fn create_web_page_nodes_creates_deterministic_page_site_and_edges() {
let graph = Arc::new(cognee_graph::MockGraphDB::new());
let doc_id = Uuid::parse_str("00000000-0000-0000-0000-000000000101").unwrap();
let chunk_id = Uuid::parse_str("00000000-0000-0000-0000-000000000201").unwrap();
let final_url = "https://Example.com/path?q=1";
let documents = vec![test_document_with_metadata(
doc_id,
Some(url_metadata(
"https://example.com/start",
final_url,
"Example title",
)),
)];
let chunks = vec![test_chunk(chunk_id, doc_id, "Visible page content")];
create_web_page_nodes(&documents, &chunks, graph.clone())
.await
.unwrap();
let page_id = web_page_id("https://example.com/path?q=1").to_string();
let site_id = web_site_id("example.com").to_string();
let (nodes, edges) = graph.get_graph_data().await.unwrap();
assert_eq!(nodes.len(), 2);
let page = graph.get_node(&page_id).await.unwrap().unwrap();
assert_eq!(page.get("type").and_then(|v| v.as_str()), Some("WebPage"));
assert_eq!(
page.get("url").and_then(|v| v.as_str()),
Some("https://example.com/path?q=1")
);
assert_eq!(
page.get("title").and_then(|v| v.as_str()),
Some("Example title")
);
assert_eq!(
page.get("content").and_then(|v| v.as_str()),
Some("Visible page content")
);
assert!(
!page.contains_key("created_at"),
"WebPage node payload should be deterministic"
);
let site = graph.get_node(&site_id).await.unwrap().unwrap();
assert_eq!(site.get("type").and_then(|v| v.as_str()), Some("WebSite"));
assert_eq!(
site.get("domain").and_then(|v| v.as_str()),
Some("example.com")
);
assert_eq!(edges.len(), 2);
assert!(edges.iter().any(|(source, target, rel, _)| {
source == &page_id && target == &site_id && rel == "PART_OF"
}));
assert!(edges.iter().any(|(source, target, rel, _)| {
source == &chunk_id.to_string() && target == &page_id && rel == "SOURCED_FROM"
}));
}
#[tokio::test]
async fn create_web_page_nodes_truncates_content_to_500_chars() {
let graph = Arc::new(cognee_graph::MockGraphDB::new());
let doc_id = Uuid::new_v4();
let long_text = "a".repeat(650);
let documents = vec![test_document_with_metadata(
doc_id,
Some(url_metadata(
"https://example.com/long",
"https://example.com/long",
"Long",
)),
)];
let chunks = vec![test_chunk(Uuid::new_v4(), doc_id, &long_text)];
create_web_page_nodes(&documents, &chunks, graph.clone())
.await
.unwrap();
let page_id = web_page_id("https://example.com/long").to_string();
let page = graph.get_node(&page_id).await.unwrap().unwrap();
assert_eq!(
page.get("content")
.and_then(|v| v.as_str())
.unwrap()
.chars()
.count(),
500
);
}
#[tokio::test]
async fn create_web_page_nodes_skips_invalid_and_non_url_metadata() {
let graph = Arc::new(cognee_graph::MockGraphDB::new());
let doc_with_invalid_json =
test_document_with_metadata(Uuid::new_v4(), Some("{not valid json".to_string()));
let non_url_doc = test_document_with_metadata(
Uuid::new_v4(),
Some(json!({"source": "dlt", "url": "https://example.com"}).to_string()),
);
let bad_url_doc = test_document_with_metadata(
Uuid::new_v4(),
Some(json!({"source": "url", "final_url": "not a url"}).to_string()),
);
let chunks = vec![
test_chunk(Uuid::new_v4(), doc_with_invalid_json.base.id, "a"),
test_chunk(Uuid::new_v4(), non_url_doc.base.id, "b"),
test_chunk(Uuid::new_v4(), bad_url_doc.base.id, "c"),
];
create_web_page_nodes(
&[doc_with_invalid_json, non_url_doc, bad_url_doc],
&chunks,
graph.clone(),
)
.await
.unwrap();
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
}
#[tokio::test]
async fn create_web_page_nodes_is_idempotent_for_edges() {
let graph = Arc::new(cognee_graph::MockGraphDB::new());
let doc_id = Uuid::new_v4();
let documents = vec![test_document_with_metadata(
doc_id,
Some(url_metadata(
"https://example.com/idempotent",
"https://example.com/idempotent",
"Idempotent",
)),
)];
let chunks = vec![test_chunk(Uuid::new_v4(), doc_id, "content")];
create_web_page_nodes(&documents, &chunks, graph.clone())
.await
.unwrap();
create_web_page_nodes(&documents, &chunks, graph.clone())
.await
.unwrap();
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 2);
}
#[tokio::test]
async fn make_extract_graph_task_wires_web_page_nodes_and_respects_opt_out() {
use cognee_ontology::NoOpOntologyResolver;
use cognee_test_utils::{MockLlm, test_task_context};
let doc_id = Uuid::new_v4();
let input = ExtractedChunks {
chunks: vec![test_chunk(Uuid::new_v4(), doc_id, "content")],
documents: vec![test_document_with_metadata(
doc_id,
Some(url_metadata(
"https://example.com/wired",
"https://example.com/wired",
"Wired",
)),
)],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let graph = Arc::new(cognee_graph::MockGraphDB::new());
let (_, ctx, _) = test_task_context().await;
let task = make_extract_graph_task(
Arc::new(MockLlm::empty()),
graph.clone(),
Arc::new(NoOpOntologyResolver::new()),
CognifyConfig::default(),
);
let TypedTask::Async(run) = task else {
panic!("extract graph task should be async");
};
run(&input, ctx.clone()).await.unwrap();
assert_eq!(graph.node_count(), 2);
assert_eq!(graph.edge_count(), 2);
let graph = Arc::new(cognee_graph::MockGraphDB::new());
let task = make_extract_graph_task(
Arc::new(MockLlm::empty()),
graph.clone(),
Arc::new(NoOpOntologyResolver::new()),
CognifyConfig::default().with_web_page_nodes(false),
);
let TypedTask::Async(run) = task else {
panic!("extract graph task should be async");
};
run(&input, ctx).await.unwrap();
assert_eq!(graph.node_count(), 0);
assert_eq!(graph.edge_count(), 0);
}
#[tokio::test]
async fn test_summarize_text_skips_dlt_chunks() {
use cognee_test_utils::MockLlm;
let doc_id_text = Uuid::new_v4();
let doc_id_dlt = Uuid::new_v4();
let mut base_text = DataPoint::new("TextDocument", None);
base_text.id = doc_id_text;
let text_doc = Document {
base: base_text,
document_type: "text".to_string(),
name: "test.txt".to_string(),
raw_data_location: "file:///tmp/test.txt".to_string(),
mime_type: "text/plain".to_string(),
extension: "txt".to_string(),
data_id: doc_id_text,
external_metadata: None,
};
let mut base_dlt = DataPoint::new("DltRowDocument", None);
base_dlt.id = doc_id_dlt;
let dlt_doc = Document {
base: base_dlt,
document_type: "dlt_row".to_string(),
name: "dlt_row.json".to_string(),
raw_data_location: "file:///tmp/dlt_row.json".to_string(),
mime_type: "application/json".to_string(),
extension: "json".to_string(),
data_id: doc_id_dlt,
external_metadata: None,
};
let text_chunk = DocumentChunk::new(
Uuid::new_v4(),
"Some meaningful text to summarize.".to_string(),
5,
0,
"paragraph_end".to_string(),
doc_id_text,
);
let dlt_chunk = DocumentChunk::new(
Uuid::new_v4(),
r#"{"id": 1, "name": "row"}"#.to_string(),
3,
0,
"paragraph_end".to_string(),
doc_id_dlt,
);
let input = ExtractedGraphData {
chunks: vec![text_chunk, dlt_chunk],
documents: vec![text_doc, dlt_doc],
entities: vec![],
edges: vec![],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let config = CognifyConfig::default().with_summarization(false);
let llm: Arc<dyn Llm> = Arc::new(MockLlm::empty());
let result = summarize_text(&input, llm, &config).await.unwrap();
assert!(result.summaries.is_empty());
assert_eq!(result.chunks.len(), 2);
}
#[cfg(feature = "image-loader")]
#[tokio::test]
async fn test_image_document_produces_chunks() {
use cognee_ingestion::loaders::image::ImageLoader;
use cognee_test_utils::MockLlm;
let storage = Arc::new(MockStorage::new());
let location = storage
.store(b"fake-image-bytes", "test.jpg")
.await
.expect("MockStorage store should succeed");
let doc_id = Uuid::new_v4();
let mut base = DataPoint::new("ImageDocument", None);
base.id = doc_id;
base.set_metadata("index_fields", serde_json::json!(["name"]));
let doc = Document {
base,
document_type: "image".to_string(),
name: "test.jpg".to_string(),
raw_data_location: location,
mime_type: "image/jpeg".to_string(),
extension: "jpg".to_string(),
data_id: doc_id,
external_metadata: None,
};
let input = ClassifiedDocuments {
documents: vec![doc],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let mock_llm = Arc::new(
MockLlm::new(vec![])
.with_vision_responses(vec!["An image description for testing.".to_string()]),
);
let mut registry = LoaderRegistry::default();
registry.register("image", Arc::new(ImageLoader::new(mock_llm)));
let result = extract_chunks_from_documents(
&input,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await;
assert!(
!matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
"image document must not produce UnsupportedDocumentType"
);
let chunks = result.expect("extract_chunks_from_documents should succeed for image docs");
assert!(
!chunks.chunks.is_empty(),
"image document should produce at least one chunk"
);
}
#[cfg(feature = "audio-loader")]
#[tokio::test]
async fn test_audio_document_produces_chunks() {
use cognee_ingestion::loaders::audio::AudioLoader;
use cognee_llm::TranscriptionOutput;
use cognee_test_utils::MockTranscriber;
let storage = Arc::new(MockStorage::new());
let location = storage
.store(b"fake-audio-bytes", "test.mp3")
.await
.expect("MockStorage store should succeed");
let doc_id = Uuid::new_v4();
let mut base = DataPoint::new("AudioDocument", None);
base.id = doc_id;
base.set_metadata("index_fields", serde_json::json!(["name"]));
let doc = Document {
base,
document_type: "audio".to_string(),
name: "test.mp3".to_string(),
raw_data_location: location,
mime_type: "audio/mpeg".to_string(),
extension: "mp3".to_string(),
data_id: doc_id,
external_metadata: None,
};
let input = ClassifiedDocuments {
documents: vec![doc],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let mock_transcriber = Arc::new(MockTranscriber::new(
"mock-whisper",
vec![TranscriptionOutput {
text: "Test transcript.".to_string(),
language: None,
duration: None,
}],
));
let mut registry = LoaderRegistry::default();
registry.register("audio", Arc::new(AudioLoader::new(mock_transcriber)));
let result = extract_chunks_from_documents(
&input,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await;
assert!(
!matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
"audio document must not produce UnsupportedDocumentType"
);
let chunks = result.expect("extract_chunks_from_documents should succeed for audio docs");
assert!(
!chunks.chunks.is_empty(),
"audio document should produce at least one chunk"
);
}
#[test]
fn classify_html_extension_not_dropped() {
for ext in ["html", "htm"] {
let data = Data::builder(
Uuid::new_v4(),
format!("page.{ext}"),
format!("/storage/page.{ext}"),
format!("file:///page.{ext}"),
ext,
"text/html",
"hash_html",
Uuid::new_v4(),
)
.build();
let input = CognifyInput {
data_items: vec![data],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let result = classify_documents(&input).expect("classify should not error");
assert_eq!(
result.documents.len(),
1,
".{ext} file must not be dropped by classify_documents"
);
assert_eq!(
result.documents[0].document_type, "html",
".{ext} must classify as document_type=\"html\""
);
assert_eq!(
result.documents[0].base.data_type, "TextDocument",
".{ext} must carry data_type=\"TextDocument\" for Python DB parity"
);
}
}
#[cfg(feature = "html-loader")]
#[tokio::test]
async fn classify_then_chunk_html_end_to_end() {
let storage = Arc::new(MockStorage::new());
let html = b"<html><head><title>Guide</title></head><body><p>The quick brown fox.</p></body></html>";
let location = storage
.store(html, "guide.html")
.await
.expect("MockStorage store should succeed");
let data = Data::builder(
Uuid::new_v4(),
"guide.html",
&location, "file:///guide.html",
"html",
"text/html",
"hash_guide_html",
Uuid::new_v4(),
)
.build();
let input = CognifyInput {
data_items: vec![data],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let classified =
classify_documents(&input).expect("classify_documents must succeed for html");
assert_eq!(
classified.documents.len(),
1,
"classify_documents must not drop the .html file"
);
assert_eq!(classified.documents[0].document_type, "html");
let registry = LoaderRegistry::default();
let result = extract_chunks_from_documents(
&classified,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await;
assert!(
!matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
"html loader must be registered (UnsupportedDocumentType must not occur)"
);
let chunks = result.expect("extract_chunks_from_documents must succeed for html");
assert!(
!chunks.chunks.is_empty(),
"html file must produce at least one chunk"
);
assert!(
chunks
.chunks
.iter()
.any(|c| c.text.contains("quick brown fox")),
"extracted text must appear in chunks (HTML tags must be stripped)"
);
}
#[cfg(feature = "html-loader")]
#[tokio::test]
async fn test_html_document_produces_chunks() {
let storage = Arc::new(MockStorage::new());
let html =
b"<html><head><title>T</title></head><body><h1>Heading</h1><p>Body text here.</p></body></html>";
let location = storage
.store(html, "test.html")
.await
.expect("MockStorage store should succeed");
let doc_id = Uuid::new_v4();
let mut base = DataPoint::new("TextDocument", None);
base.id = doc_id;
base.set_metadata("index_fields", serde_json::json!(["name"]));
let doc = Document {
base,
document_type: "html".to_string(),
name: "test.html".to_string(),
raw_data_location: location,
mime_type: "text/html".to_string(),
extension: "html".to_string(),
data_id: doc_id,
external_metadata: None,
};
let input = ClassifiedDocuments {
documents: vec![doc],
dataset_id: Uuid::new_v4(),
user_id: None,
tenant_id: None,
};
let registry = LoaderRegistry::default();
let result = extract_chunks_from_documents(
&input,
&*storage,
100,
TokenCounterKind::Word,
None,
®istry,
)
.await;
assert!(
!matches!(result, Err(CognifyError::UnsupportedDocumentType(_))),
"html document must not produce UnsupportedDocumentType"
);
let chunks = result.expect("extract_chunks_from_documents should succeed for html docs");
assert!(
!chunks.chunks.is_empty(),
"html document should produce at least one chunk"
);
assert!(
chunks.chunks.iter().any(|c| c.text.contains("Body text")),
"extracted HTML text should appear in chunks"
);
}
#[cfg(feature = "image-loader")]
#[test]
fn test_build_loader_registry_includes_image() {
use cognee_test_utils::MockLlm;
let llm: Arc<dyn Llm> = Arc::new(MockLlm::empty());
let config = CognifyConfig::default();
let registry = build_loader_registry(&llm, &config);
assert!(
registry.get("image").is_some(),
"build_loader_registry must include \"image\" loader when image-loader feature is on"
);
}
#[cfg(feature = "audio-loader")]
#[test]
fn test_build_loader_registry_includes_audio_when_transcriber_set() {
use cognee_llm::TranscriptionOutput;
use cognee_test_utils::MockTranscriber;
let llm: Arc<dyn Llm> = Arc::new(cognee_test_utils::MockLlm::empty());
let transcriber: Arc<dyn cognee_llm::Transcriber> = Arc::new(MockTranscriber::new(
"mock",
vec![TranscriptionOutput {
text: "hi".to_string(),
language: None,
duration: None,
}],
));
let config = CognifyConfig::default().with_transcriber(transcriber);
let registry = build_loader_registry(&llm, &config);
assert!(
registry.get("audio").is_some(),
"build_loader_registry must include \"audio\" loader when transcriber is set"
);
}
#[cfg(feature = "audio-loader")]
#[test]
fn test_build_loader_registry_no_audio_without_transcriber() {
let llm: Arc<dyn Llm> = Arc::new(cognee_test_utils::MockLlm::empty());
let config = CognifyConfig::default(); let registry = build_loader_registry(&llm, &config);
assert!(
registry.get("audio").is_none(),
"build_loader_registry must NOT include \"audio\" loader when transcriber is None"
);
}
}