use std::collections::HashMap;
use std::path::Path;
use std::sync::Arc;
use tracing::{info, instrument};
use uuid::Uuid;
use cognee_core::CpuPool;
#[cfg(test)]
use cognee_core::RayonThreadPool;
use cognee_core::pipeline::DataIdFn;
use cognee_core::pipeline_run_registry::DbPipelineWatcher;
use cognee_core::task::Value;
use cognee_core::{Pipeline, PipelineBuilder, PipelineContext, TaskContextBuilder, TypedTask};
use cognee_database::{AclDb, DatabaseConnection, IngestDb, PipelineRunRepository};
use cognee_graph::GraphDBTrait;
use cognee_models::{Data, DataInput, DataPoint, Dataset, Document};
use cognee_storage::StorageTrait;
use cognee_vector::VectorDB;
use crate::content_hasher::HashAlgorithm;
use crate::id_generation::{generate_data_id, generate_dataset_id};
use crate::loader_registry::get_loader_name;
use crate::loaders::{LoaderOutput, LoaderRegistry};
use crate::url_resolver::UrlMetadata;
#[cfg(feature = "html-loader")]
use crate::url_resolver::resolve_url_input;
#[derive(Debug, Clone, Default)]
pub struct AddParams {
pub node_set: Option<Vec<String>>,
pub dataset_id: Option<Uuid>,
pub preferred_loaders: Option<HashMap<String, String>>,
pub importance_weight: Option<f64>,
pub incremental_loading: bool,
}
#[derive(Debug, Clone)]
pub struct ProcessedInput {
pub content_hash: String,
pub raw_content_hash: String,
pub data_id: Uuid,
pub storage_location: String,
pub label: Option<String>,
pub stored_extension: String,
pub stored_mime_type: String,
pub original_extension: String,
pub original_mime_type: String,
pub loader_engine: String,
pub data_size: i64,
pub name: String,
pub raw_data_uri: String,
pub original_location: String,
pub raw_source_uri: Option<String>,
pub owner_id: Uuid,
pub tenant_id: Option<Uuid>,
pub external_metadata: Option<String>,
pub node_set: Option<String>,
pub importance_weight: Option<f64>,
}
#[instrument(name = "ingestion.process_input", skip(input, storage))]
pub async fn process_input(
input: &DataInput,
storage: &dyn StorageTrait,
hash_algorithm: HashAlgorithm,
owner_id: Uuid,
tenant_id: Option<Uuid>,
) -> Result<ProcessedInput, Box<dyn std::error::Error>> {
use tokio::sync::Mutex;
let (effective_input, resolved_url_metadata, resolved_label, data_item_metadata) =
resolve_input_for_processing(input).await?;
let (
file_name,
stored_extension,
stored_mime_type,
original_extension,
original_mime_type,
label,
loader_engine,
) = if let Some(metadata) = resolved_url_metadata.as_ref() {
let fname = format!("text_placeholder.{}", metadata.stored_extension);
(
fname,
metadata.stored_extension.clone(),
metadata.stored_mime_type.clone(),
metadata.source_extension.clone(),
metadata.source_mime_type.clone(),
resolved_label,
metadata.loader_engine.clone(),
)
} else {
let (fname, ext, mime, lbl) = extract_file_metadata(input);
let loader = get_loader_name(&ext).to_string();
(fname, ext.clone(), mime.clone(), ext, mime, lbl, loader)
};
if matches!(unwrap_data_item(input), DataInput::S3Path(_)) {
return Err(Box::new(IngestionError::S3IngestionUnavailable));
}
let raw_source_uri = if let Some(metadata) = resolved_url_metadata.as_ref()
&& is_html_url_metadata(metadata)
{
let raw_file_name = format!("source_placeholder.{}", metadata.source_extension);
let raw_location = storage.store(&metadata.raw_bytes, &raw_file_name).await?;
Some(storage_location_to_uri(storage.base_path(), &raw_location))
} else {
None
};
let mut stored_extension = stored_extension;
let mut stored_mime_type = stored_mime_type;
let mut loader_engine = loader_engine;
let size_counter: Arc<Mutex<i64>> = Arc::new(Mutex::new(0i64));
let raw_bytes: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
let size_clone = size_counter.clone();
let raw_bytes_clone = raw_bytes.clone();
effective_input
.process_by_chunks(move |chunk| {
let size = size_clone.clone();
let bytes = raw_bytes_clone.clone();
let chunk_owned = chunk.to_vec();
async move {
*size.lock().await += chunk_owned.len() as i64;
bytes.lock().await.extend_from_slice(&chunk_owned);
Ok::<(), Box<dyn std::error::Error>>(())
}
})
.await?;
let collected = Arc::try_unwrap(raw_bytes)
.map_err(|_| "Failed to unwrap bytes")?
.into_inner();
let content_hash =
crate::content_hasher::ContentHasher::hash_content(&collected, hash_algorithm);
let data_size = Arc::try_unwrap(size_counter)
.map_err(|_| "Failed to unwrap size counter")?
.into_inner();
let data_id = generate_data_id(&content_hash, owner_id, tenant_id);
let (storage_location, raw_content_hash) = if resolved_url_metadata.is_some() {
let location = storage.store(&collected, &file_name).await?;
(location, content_hash.clone())
} else {
let registry = LoaderRegistry::default_registry();
let doc_type = cognee_models::doc_type_for_extension(&original_extension)
.unwrap_or("text")
.to_string();
let is_text_doc_type = doc_type == "text";
if is_text_doc_type && std::str::from_utf8(&collected).is_err() {
let stored_name = format!("text_{content_hash}.bin");
let location = storage.store(&collected, &stored_name).await?;
(location, content_hash.clone())
} else {
let loader = registry.get(&doc_type).ok_or_else(|| {
Box::new(IngestionError::UnsupportedDocumentType {
document_type: doc_type.clone(),
}) as Box<dyn std::error::Error>
})?;
let descriptor = build_loader_descriptor(
data_id,
&extract_name(input, &content_hash),
&original_extension,
&original_mime_type,
);
let extracted_text = match loader.extract(&collected, &descriptor).await? {
LoaderOutput::Text(t) => t,
LoaderOutput::Rows(rows) => rows.join("\n\n"),
LoaderOutput::SingleChunk { text, .. } => text,
};
let extracted_bytes = extracted_text.into_bytes();
let stored_name = format!("text_{content_hash}.txt");
let location = storage.store(&extracted_bytes, &stored_name).await?;
stored_extension = "txt".to_string();
stored_mime_type = "text/plain".to_string();
loader_engine = loader.engine_name().to_string();
let raw_content_hash = crate::content_hasher::ContentHasher::hash_content(
&extracted_bytes,
hash_algorithm,
);
(location, raw_content_hash)
}
};
let raw_data_uri = storage_location_to_uri(storage.base_path(), &storage_location);
let name = extract_name(input, &content_hash);
let original_location = if let Some(uri) = raw_source_uri.clone() {
uri
} else {
match input {
DataInput::Text(_) => raw_data_uri.clone(),
_ => extract_original_location(input),
}
};
let external_metadata =
merge_external_metadata(data_item_metadata, resolved_url_metadata.as_ref())?;
Ok(ProcessedInput {
content_hash,
raw_content_hash,
data_id,
storage_location,
label,
stored_extension,
stored_mime_type,
original_extension,
original_mime_type,
loader_engine: loader_engine.to_string(),
data_size,
name,
raw_data_uri,
original_location,
raw_source_uri,
owner_id,
tenant_id,
external_metadata,
node_set: None,
importance_weight: None,
})
}
#[instrument(
name = "ingestion.persist_data",
skip(processed, database),
fields(data_id = %processed.data_id)
)]
pub async fn persist_data(
processed: &ProcessedInput,
database: &dyn IngestDb,
dataset_name: &str,
owner_id: Uuid,
tenant_id: Option<Uuid>,
) -> Result<Data, Box<dyn std::error::Error>> {
persist_data_with_acl(
processed,
database,
dataset_name,
owner_id,
tenant_id,
None,
None,
)
.await
}
#[instrument(
name = "ingestion.persist_data_with_acl",
skip(processed, database, acl_db),
fields(data_id = %processed.data_id)
)]
pub async fn persist_data_with_acl(
processed: &ProcessedInput,
database: &dyn IngestDb,
dataset_name: &str,
owner_id: Uuid,
tenant_id: Option<Uuid>,
acl_db: Option<&dyn AclDb>,
target_dataset_id: Option<Uuid>,
) -> Result<Data, Box<dyn std::error::Error>> {
let is_new_dataset;
let dataset = if let Some(ds_id) = target_dataset_id {
match database.get_dataset(ds_id).await? {
Some(ds) => {
is_new_dataset = false;
ds
}
None => {
return Err(format!("Dataset with id {ds_id} not found").into());
}
}
} else {
let generated_id = generate_dataset_id(dataset_name, owner_id, tenant_id);
match database
.get_dataset_by_name(dataset_name, owner_id, tenant_id)
.await?
{
Some(ds) => {
is_new_dataset = false;
ds
}
None => {
is_new_dataset = true;
let new_dataset =
Dataset::new(dataset_name.to_string(), owner_id, tenant_id, generated_id);
database.create_dataset(new_dataset).await?
}
}
};
info!(dataset_id = %dataset.id, "dataset resolved");
if is_new_dataset && let Some(acl) = acl_db {
cognee_database::ops::acl::grant_all_permissions_on_dataset_via_trait(
acl, owner_id, dataset.id,
)
.await?;
info!(
dataset_id = %dataset.id,
owner_id = %owner_id,
"ACL permissions granted on new dataset"
);
}
let data_id = processed.data_id;
if let Some(existing_data) = database.get_data(data_id).await? {
database.attach_data_to_dataset(dataset.id, data_id).await?;
info!(data_id = %data_id, is_duplicate = true, "input processed");
return Ok(existing_data);
}
let mut data_builder = Data::builder(
data_id,
processed.name.clone(),
processed.raw_data_uri.clone(),
processed.original_location.clone(),
processed.stored_extension.clone(),
processed.stored_mime_type.clone(),
processed.content_hash.clone(),
processed.owner_id,
)
.original_extension(processed.original_extension.clone())
.original_mime_type(processed.original_mime_type.clone())
.loader_engine(processed.loader_engine.clone())
.raw_content_hash(processed.raw_content_hash.clone())
.data_size(processed.data_size);
if let Some(tid) = processed.tenant_id {
data_builder = data_builder.tenant_id(tid);
}
if let Some(ref lbl) = processed.label {
data_builder = data_builder.label(lbl.clone());
}
if let Some(ref meta) = processed.external_metadata {
data_builder = data_builder.external_metadata(meta.clone());
}
if let Some(ref ns) = processed.node_set {
data_builder = data_builder.node_set(ns.clone());
}
if let Some(w) = processed.importance_weight {
data_builder = data_builder.importance_weight(w);
}
let data = data_builder.build();
let saved_data = database.create_data(data).await?;
database.attach_data_to_dataset(dataset.id, data_id).await?;
info!(data_id = %data_id, is_duplicate = false, "input processed");
Ok(saved_data)
}
fn resolve_mime(extension: &str, path_for_guess: &str) -> String {
if get_loader_name(extension) == "text_loader" {
"text/plain".to_string()
} else {
mime_guess::from_path(path_for_guess)
.first_or_octet_stream()
.to_string()
}
}
fn unwrap_data_item(input: &DataInput) -> &DataInput {
match input {
DataInput::DataItem { data, .. } => unwrap_data_item(data),
other => other,
}
}
fn build_loader_descriptor(
data_id: Uuid,
name: &str,
extension: &str,
mime_type: &str,
) -> Document {
let doc_type = cognee_models::doc_type_for_extension(extension)
.unwrap_or("text")
.to_string();
let mut base = DataPoint::new("Document", None);
base.id = data_id;
Document {
base,
document_type: doc_type,
name: name.to_string(),
raw_data_location: String::new(),
mime_type: mime_type.to_string(),
extension: extension.to_string(),
data_id,
external_metadata: None,
}
}
fn extract_file_metadata(input: &DataInput) -> (String, String, String, Option<String>) {
match input {
DataInput::FilePath(path) => {
let clean_path = path.strip_prefix("file://").unwrap_or(path);
let p = Path::new(clean_path);
let file_name = p
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("file.bin")
.to_string();
let extension = p
.extension()
.and_then(|e| e.to_str())
.unwrap_or("")
.to_string();
let mime = resolve_mime(&extension, clean_path);
(file_name, extension, mime, None)
}
DataInput::Text(_) => {
(
"text_placeholder.txt".to_string(),
"txt".to_string(),
"text/plain".to_string(),
None,
)
}
DataInput::Url(_url) => {
(
"text_placeholder.txt".to_string(),
"html".to_string(),
"text/html".to_string(),
None,
)
}
DataInput::S3Path(_) => (
"s3_file.bin".to_string(),
"bin".to_string(),
"application/octet-stream".to_string(),
None,
),
DataInput::Binary { name, .. } => {
let ext = Path::new(name)
.extension()
.and_then(|e| e.to_str())
.unwrap_or("bin")
.to_string();
let mime = resolve_mime(&ext, name);
(name.clone(), ext, mime, None)
}
DataInput::DataItem { data, label, .. } => {
let (file_name, ext, mime, _) = extract_file_metadata(data);
(file_name, ext, mime, Some(label.clone()))
}
}
}
fn storage_location_to_uri(base_path: &str, location: &str) -> String {
if base_path.is_empty() {
location.to_string()
} else {
let joined = Path::new(base_path).join(location);
let abs = if joined.is_absolute() {
joined
} else {
std::env::current_dir().unwrap_or_default().join(&joined)
};
format!("file://{}", abs.display())
}
}
async fn resolve_input_for_processing(
input: &DataInput,
) -> Result<
(
DataInput,
Option<UrlMetadata>,
Option<String>,
Option<String>,
),
Box<dyn std::error::Error>,
> {
match input {
DataInput::Url(url) => {
#[cfg(feature = "html-loader")]
{
let resolved = resolve_url_input(url).await?;
Ok((resolved.input, Some(resolved.metadata), None, None))
}
#[cfg(not(feature = "html-loader"))]
{
let _ = url;
Err(Box::new(IngestionError::UrlIngestionUnavailable))
}
}
DataInput::DataItem {
data,
label,
external_metadata,
} => {
if let DataInput::Url(_url) = data.as_ref() {
#[cfg(feature = "html-loader")]
{
let resolved = resolve_url_input(_url).await?;
Ok((
resolved.input,
Some(resolved.metadata),
Some(label.clone()),
external_metadata.clone(),
))
}
#[cfg(not(feature = "html-loader"))]
{
Err(Box::new(IngestionError::UrlIngestionUnavailable))
}
} else {
Ok((
input.clone(),
None,
Some(label.clone()),
external_metadata.clone(),
))
}
}
_ => Ok((input.clone(), None, None, None)),
}
}
fn is_html_url_metadata(metadata: &UrlMetadata) -> bool {
metadata.essence == "text/html" || metadata.essence == "application/xhtml+xml"
}
fn merge_external_metadata(
data_item_metadata: Option<String>,
url_metadata: Option<&UrlMetadata>,
) -> Result<Option<String>, serde_json::Error> {
let Some(metadata) = url_metadata else {
return Ok(data_item_metadata);
};
let mut merged = serde_json::Map::new();
let mut user_metadata_object = None;
let mut has_conflict = false;
if let Some(user_metadata) = data_item_metadata {
match serde_json::from_str::<serde_json::Value>(&user_metadata) {
Ok(serde_json::Value::Object(user_object)) => {
user_metadata_object = Some(serde_json::Value::Object(user_object.clone()));
for (key, value) in user_object {
merged.insert(key, value);
}
}
_ => {
merged.insert(
"data_item_external_metadata_raw".to_string(),
serde_json::Value::String(user_metadata),
);
}
}
}
let url_fields = [
("source", serde_json::json!("url")),
("url", serde_json::json!(metadata.requested_url.clone())),
("final_url", serde_json::json!(metadata.final_url.clone())),
(
"content_type",
serde_json::json!(metadata.content_type.clone()),
),
];
for (key, value) in url_fields {
if merged.contains_key(key) {
has_conflict = true;
}
merged.insert(key.to_string(), value);
}
if let Some(title) = &metadata.title {
if merged.contains_key("title") {
has_conflict = true;
}
merged.insert("title".to_string(), serde_json::json!(title));
}
if has_conflict && let Some(user_metadata) = user_metadata_object {
merged.insert("data_item_external_metadata".to_string(), user_metadata);
}
serde_json::to_string(&serde_json::Value::Object(merged)).map(Some)
}
fn extract_name(input: &DataInput, content_hash: &str) -> String {
match input {
DataInput::Text(_) => format!("text_{content_hash}"),
DataInput::FilePath(path) => {
let clean_path = path.strip_prefix("file://").unwrap_or(path);
Path::new(clean_path)
.file_stem()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string()
}
DataInput::Url(_) => format!("text_{content_hash}"),
DataInput::S3Path(path) => path
.split('/')
.next_back()
.unwrap_or("s3_content")
.to_string(),
DataInput::Binary { name, .. } => name.clone(),
DataInput::DataItem { data, .. } => extract_name(data, content_hash),
}
}
fn extract_original_location(input: &DataInput) -> String {
match input {
DataInput::Text(_) => "text://inline".to_string(),
DataInput::FilePath(path) => {
if path.starts_with("file://") {
path.clone()
} else {
format!("file://{path}")
}
}
DataInput::Url(url) => url.clone(),
DataInput::S3Path(path) => path.clone(),
DataInput::Binary { name, .. } => format!("binary://{name}"),
DataInput::DataItem { data, .. } => extract_original_location(data),
}
}
pub fn make_process_input_task(
storage: Arc<dyn StorageTrait>,
hash_algorithm: HashAlgorithm,
owner_id: Uuid,
tenant_id: Option<Uuid>,
) -> TypedTask<DataInput, ProcessedInput> {
TypedTask::async_fn(move |input: &DataInput, _ctx| {
let input = input.clone();
let storage = Arc::clone(&storage);
Box::pin(async move {
process_input(&input, &*storage, hash_algorithm, owner_id, tenant_id)
.await
.map(Box::new)
.map_err(|e| format!("{e}").into())
})
})
}
pub fn make_persist_data_task(
database: Arc<dyn IngestDb>,
dataset_name: String,
owner_id: Uuid,
tenant_id: Option<Uuid>,
) -> TypedTask<ProcessedInput, Data> {
make_persist_data_task_with_acl(database, dataset_name, owner_id, tenant_id, None)
}
pub fn make_persist_data_task_with_acl(
database: Arc<dyn IngestDb>,
dataset_name: String,
owner_id: Uuid,
tenant_id: Option<Uuid>,
acl_db: Option<Arc<dyn AclDb>>,
) -> TypedTask<ProcessedInput, Data> {
make_persist_data_task_with_acl_and_params(
database,
dataset_name,
owner_id,
tenant_id,
acl_db,
AddParamsInjection::default(),
)
}
#[derive(Debug, Clone, Default)]
struct AddParamsInjection {
node_set_json: Option<String>,
importance_weight: Option<f64>,
target_dataset_id: Option<Uuid>,
}
fn make_persist_data_task_with_acl_and_params(
database: Arc<dyn IngestDb>,
dataset_name: String,
owner_id: Uuid,
tenant_id: Option<Uuid>,
acl_db: Option<Arc<dyn AclDb>>,
add_params: AddParamsInjection,
) -> TypedTask<ProcessedInput, Data> {
TypedTask::async_fn(move |processed: &ProcessedInput, _ctx| {
let mut processed = processed.clone();
if let Some(ref ns) = add_params.node_set_json {
processed.node_set = Some(ns.clone());
}
if let Some(w) = add_params.importance_weight {
processed.importance_weight = Some(w);
}
let override_ds = add_params.target_dataset_id;
let database = Arc::clone(&database);
let dataset_name = dataset_name.clone();
let acl_db = acl_db.clone();
Box::pin(async move {
persist_data_with_acl(
&processed,
&*database,
&dataset_name,
owner_id,
tenant_id,
acl_db.as_deref(),
override_ds,
)
.await
.map(Box::new)
.map_err(|e| format!("{e}").into())
})
})
}
pub fn build_add_pipeline(
storage: Arc<dyn StorageTrait>,
database: Arc<dyn IngestDb>,
hash_algorithm: HashAlgorithm,
dataset_name: &str,
owner_id: Uuid,
tenant_id: Option<Uuid>,
) -> Pipeline {
build_add_pipeline_with_acl(
storage,
database,
hash_algorithm,
dataset_name,
owner_id,
tenant_id,
None,
)
}
pub fn build_add_pipeline_with_acl(
storage: Arc<dyn StorageTrait>,
database: Arc<dyn IngestDb>,
hash_algorithm: HashAlgorithm,
dataset_name: &str,
owner_id: Uuid,
tenant_id: Option<Uuid>,
acl_db: Option<Arc<dyn AclDb>>,
) -> Pipeline {
build_add_pipeline_internal(
storage,
database,
hash_algorithm,
dataset_name,
owner_id,
tenant_id,
acl_db,
AddParamsInjection::default(),
)
}
#[allow(clippy::too_many_arguments)]
fn build_add_pipeline_internal(
storage: Arc<dyn StorageTrait>,
database: Arc<dyn IngestDb>,
hash_algorithm: HashAlgorithm,
dataset_name: &str,
owner_id: Uuid,
tenant_id: Option<Uuid>,
acl_db: Option<Arc<dyn AclDb>>,
add_params: AddParamsInjection,
) -> Pipeline {
let data_id_fn: DataIdFn = Arc::new(|_v: Arc<dyn Value>| None);
PipelineBuilder::new_with_task(
"ingestion.add",
make_process_input_task(Arc::clone(&storage), hash_algorithm, owner_id, tenant_id),
)
.add_task(make_persist_data_task_with_acl_and_params(
database,
dataset_name.to_string(),
owner_id,
tenant_id,
acl_db,
add_params,
))
.with_name("add_pipeline")
.with_data_id(data_id_fn)
.build()
}
pub struct AddPipeline {
storage: Arc<dyn StorageTrait>,
database: Arc<dyn IngestDb>,
hash_algorithm: HashAlgorithm,
acl_db: Option<Arc<dyn AclDb>>,
thread_pool: Option<Arc<dyn CpuPool>>,
graph_db: Option<Arc<dyn GraphDBTrait>>,
vector_db: Option<Arc<dyn VectorDB>>,
db_connection: Option<Arc<DatabaseConnection>>,
pipeline_run_repo: Option<Arc<dyn PipelineRunRepository>>,
}
impl AddPipeline {
pub fn new(storage: Arc<dyn StorageTrait>, database: Arc<dyn IngestDb>) -> Self {
Self {
storage,
database,
hash_algorithm: HashAlgorithm::default(),
acl_db: None,
thread_pool: None,
graph_db: None,
vector_db: None,
db_connection: None,
pipeline_run_repo: None,
}
}
pub fn new_with_algorithm(
storage: Arc<dyn StorageTrait>,
database: Arc<dyn IngestDb>,
hash_algorithm: HashAlgorithm,
) -> Self {
Self {
storage,
database,
hash_algorithm,
acl_db: None,
thread_pool: None,
graph_db: None,
vector_db: None,
db_connection: None,
pipeline_run_repo: None,
}
}
pub fn with_acl_db(mut self, acl_db: Arc<dyn AclDb>) -> Self {
self.acl_db = Some(acl_db);
self
}
pub fn with_thread_pool(mut self, pool: Arc<dyn CpuPool>) -> Self {
self.thread_pool = Some(pool);
self
}
pub fn with_graph_db(mut self, graph: Arc<dyn GraphDBTrait>) -> Self {
self.graph_db = Some(graph);
self
}
pub fn with_vector_db(mut self, vectors: Arc<dyn VectorDB>) -> Self {
self.vector_db = Some(vectors);
self
}
pub fn with_database(mut self, db: Arc<DatabaseConnection>) -> Self {
self.db_connection = Some(db);
self
}
pub fn with_pipeline_run_repo(mut self, repo: Arc<dyn PipelineRunRepository>) -> Self {
self.pipeline_run_repo = Some(repo);
self
}
#[instrument(
name = "ingestion.add",
skip(self, inputs),
fields(dataset_name, owner_id = %owner_id, inputs_count = inputs.len())
)]
pub async fn add(
&self,
inputs: Vec<DataInput>,
dataset_name: &str,
owner_id: Uuid,
tenant_id: Option<Uuid>,
) -> Result<Vec<Data>, Box<dyn std::error::Error>> {
self.add_with_params(
inputs,
dataset_name,
owner_id,
tenant_id,
&AddParams::default(),
)
.await
}
#[instrument(
name = "ingestion.add_with_params",
skip(self, inputs, params),
fields(dataset_name, owner_id = %owner_id, inputs_count = inputs.len())
)]
pub async fn add_with_params(
&self,
inputs: Vec<DataInput>,
dataset_name: &str,
owner_id: Uuid,
tenant_id: Option<Uuid>,
params: &AddParams,
) -> Result<Vec<Data>, Box<dyn std::error::Error>> {
let thread_pool = self
.thread_pool
.clone()
.ok_or(IngestionError::MissingBackend {
which: "thread_pool",
})?;
let graph_db = self
.graph_db
.clone()
.ok_or(IngestionError::MissingBackend { which: "graph_db" })?;
let vector_db = self
.vector_db
.clone()
.ok_or(IngestionError::MissingBackend { which: "vector_db" })?;
let db_connection = self
.db_connection
.clone()
.ok_or(IngestionError::MissingBackend { which: "database" })?;
let node_set_json = params
.node_set
.as_ref()
.map(serde_json::to_string)
.transpose()
.map_err(|e| format!("Failed to serialize node_set: {e}"))?;
let add_params_inj = AddParamsInjection {
node_set_json,
importance_weight: params.importance_weight,
target_dataset_id: params.dataset_id,
};
let pipeline = build_add_pipeline_internal(
Arc::clone(&self.storage),
Arc::clone(&self.database),
self.hash_algorithm,
dataset_name,
owner_id,
tenant_id,
self.acl_db.clone(),
add_params_inj,
);
let pipeline_ctx = PipelineContext {
pipeline_id: pipeline.id,
pipeline_name: pipeline.name.clone().unwrap_or_default(),
user_id: Some(owner_id),
tenant_id,
dataset_id: params.dataset_id,
current_data: None,
run_id: None,
user_email: None,
provenance_visited: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
};
let (_cancel_handle, ctx) = TaskContextBuilder::new()
.thread_pool(thread_pool)
.database(db_connection)
.graph_db(graph_db)
.vector_db(vector_db)
.pipeline_context(pipeline_ctx)
.build()
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
let ctx = Arc::new(ctx);
let typed_inputs: Vec<Arc<dyn Value>> = inputs
.into_iter()
.map(|i| Arc::new(i) as Arc<dyn Value>)
.collect();
let pipeline_run_repo = self
.pipeline_run_repo
.clone()
.unwrap_or_else(cognee_database::NoopPipelineRunRepository::arc);
let watcher = DbPipelineWatcher::new(pipeline_run_repo);
let outputs = cognee_core::pipeline::execute(&pipeline, typed_inputs, ctx, &watcher)
.await
.map_err(|e| Box::new(e) as Box<dyn std::error::Error>)?;
extract_data_outputs(outputs)
}
}
fn extract_data_outputs(
outputs: Vec<Arc<dyn Value>>,
) -> Result<Vec<Data>, Box<dyn std::error::Error>> {
let mut data_vec = Vec::with_capacity(outputs.len());
for o in outputs {
let d = (*o).as_any().downcast_ref::<Data>().cloned().ok_or(
IngestionError::OutputTypeMismatch {
expected: "Data",
actual: "unknown",
},
)?;
data_vec.push(d);
}
Ok(data_vec)
}
#[derive(Debug, thiserror::Error)]
pub enum IngestionError {
#[error("AddPipeline missing required backend: {which}")]
MissingBackend { which: &'static str },
#[error("AddPipeline output type mismatch: expected {expected}, actual {actual}")]
OutputTypeMismatch {
expected: &'static str,
actual: &'static str,
},
#[error("URL ingestion requires the `html-loader` feature to be enabled")]
UrlIngestionUnavailable,
#[error("S3 ingestion is not yet implemented")]
S3IngestionUnavailable,
#[error("Unsupported document type at ingest: {document_type}")]
UnsupportedDocumentType { document_type: String },
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
use cognee_database::{connect, initialize, ops};
use cognee_graph::MockGraphDB;
#[cfg(feature = "html-loader")]
use cognee_storage::LocalStorage;
use cognee_storage::MockStorage;
use cognee_vector::MockVectorDB;
#[cfg(feature = "html-loader")]
use mockito::{Server, ServerGuard};
use std::io::Write;
use tempfile::NamedTempFile;
async fn make_pipeline() -> (AddPipeline, Arc<cognee_database::DatabaseConnection>) {
let db = connect("sqlite::memory:").await.unwrap();
initialize(&db).await.unwrap();
let db = Arc::new(db);
let storage: Arc<dyn StorageTrait> = Arc::new(MockStorage::new());
let pipeline = AddPipeline::new(storage, db.clone() as Arc<dyn IngestDb>)
.with_thread_pool(Arc::new(RayonThreadPool::with_default_threads().unwrap()))
.with_graph_db(Arc::new(MockGraphDB::new()))
.with_vector_db(Arc::new(MockVectorDB::new()))
.with_database(Arc::clone(&db));
(pipeline, db)
}
#[cfg(feature = "html-loader")]
async fn server_with_robots() -> ServerGuard {
let mut server = Server::new_async().await;
server
.mock("GET", "/robots.txt")
.with_status(404)
.create_async()
.await;
server
}
#[cfg(feature = "html-loader")]
#[tokio::test]
async fn test_process_input_url_html_stores_text_with_source_metadata() {
let mut server = server_with_robots().await;
let html = "<html><head><title>Example</title><style>.x{display:none}</style></head><body><h1>Visible text</h1><script>hidden()</script></body></html>";
let url = format!("{}/page.html", server.url());
let _mock = server
.mock("GET", "/page.html")
.with_header("content-type", "text/html; charset=utf-8")
.with_body(html)
.create_async()
.await;
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let processed = process_input(
&DataInput::Url(url.clone()),
&storage,
HashAlgorithm::Md5,
Uuid::new_v4(),
None,
)
.await
.unwrap();
let raw_source_uri = processed.raw_source_uri.as_ref().unwrap();
let raw_html = storage.retrieve(raw_source_uri).await.unwrap();
assert_eq!(raw_html, html.as_bytes());
assert!(raw_source_uri.ends_with(".html"));
assert!(processed.raw_data_uri.ends_with(".txt"));
assert_ne!(processed.raw_data_uri, *raw_source_uri);
let stored = storage.retrieve(&processed.raw_data_uri).await.unwrap();
let stored_text = String::from_utf8(stored).unwrap();
assert!(stored_text.contains("Visible text"));
assert!(!stored_text.contains("<html>"));
assert!(!stored_text.contains("hidden()"));
assert!(!stored_text.contains("display:none"));
assert_eq!(processed.stored_extension, "txt");
assert_eq!(processed.stored_mime_type, "text/plain");
assert_eq!(processed.original_extension, "html");
assert_eq!(processed.original_mime_type, "text/html");
assert_eq!(processed.loader_engine, "beautiful_soup_loader");
assert_eq!(processed.original_location, *raw_source_uri);
let metadata: serde_json::Value =
serde_json::from_str(processed.external_metadata.as_ref().unwrap()).unwrap();
assert_eq!(metadata["source"], "url");
assert_eq!(metadata["url"], url);
assert_eq!(metadata["final_url"], url);
assert_eq!(metadata["content_type"], "text/html; charset=utf-8");
assert_eq!(metadata["title"], "Example");
}
#[cfg(feature = "html-loader")]
#[tokio::test]
async fn test_persist_data_url_html_uses_text_payload_and_raw_html_original_location() {
let mut server = server_with_robots().await;
let url = format!("{}/page", server.url());
let html = "<html><head><title>XHTML</title></head><body>XHTML body</body></html>";
let _mock = server
.mock("GET", "/page")
.with_header("content-type", "application/xhtml+xml")
.with_body(html)
.create_async()
.await;
let db = connect("sqlite::memory:").await.unwrap();
initialize(&db).await.unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let owner_id = Uuid::new_v4();
let processed = process_input(
&DataInput::Url(url),
&storage,
HashAlgorithm::Md5,
owner_id,
None,
)
.await
.unwrap();
let data = persist_data(&processed, &db, "url-html", owner_id, None)
.await
.unwrap();
assert_eq!(data.extension, "txt");
assert_eq!(data.mime_type, "text/plain");
assert!(data.raw_data_location.ends_with(".txt"));
assert!(data.original_data_location.ends_with(".html"));
assert_ne!(data.raw_data_location, data.original_data_location);
assert_eq!(
storage
.retrieve(&data.original_data_location)
.await
.unwrap(),
html.as_bytes()
);
assert_eq!(data.original_extension.as_deref(), Some("html"));
assert_eq!(
data.original_mime_type.as_deref(),
Some("application/xhtml+xml")
);
assert_eq!(data.loader_engine.as_deref(), Some("beautiful_soup_loader"));
}
#[cfg(feature = "html-loader")]
#[tokio::test]
async fn test_data_item_url_merges_metadata_and_preserves_label() {
let mut server = server_with_robots().await;
let url = format!("{}/wrapped", server.url());
let _mock = server
.mock("GET", "/wrapped")
.with_header("content-type", "text/html")
.with_body("<html><head><title>Wrapped</title></head><body>Wrapped body</body></html>")
.create_async()
.await;
let db = connect("sqlite::memory:").await.unwrap();
initialize(&db).await.unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let owner_id = Uuid::new_v4();
let processed = process_input(
&DataInput::DataItem {
data: Box::new(DataInput::Url(url.clone())),
label: "wrapped-label".to_string(),
external_metadata: Some(r#"{"custom":"keep","rank":7}"#.to_string()),
},
&storage,
HashAlgorithm::Md5,
owner_id,
None,
)
.await
.unwrap();
let data = persist_data(&processed, &db, "wrapped-url", owner_id, None)
.await
.unwrap();
assert_eq!(data.label.as_deref(), Some("wrapped-label"));
assert_eq!(data.extension, "txt");
assert!(data.original_data_location.ends_with(".html"));
let metadata: serde_json::Value =
serde_json::from_str(data.external_metadata.as_ref().unwrap()).unwrap();
assert_eq!(metadata["custom"], "keep");
assert_eq!(metadata["rank"], 7);
assert_eq!(metadata["source"], "url");
assert_eq!(metadata["url"], url);
assert_eq!(metadata["final_url"], url);
assert_eq!(metadata["content_type"], "text/html");
assert_eq!(metadata["title"], "Wrapped");
}
#[cfg(feature = "html-loader")]
#[tokio::test]
async fn test_data_item_url_invalid_or_non_object_metadata_preserved_under_raw_field() {
let mut server = server_with_robots().await;
let cases = [
("/invalid-meta", "not-json"),
("/non-object-meta", r#"["not","an","object"]"#),
];
for (path, user_metadata) in cases {
let url = format!("{}{}", server.url(), path);
let _mock = server
.mock("GET", path)
.with_header("content-type", "text/html")
.with_body("<html><body>Metadata body</body></html>")
.create_async()
.await;
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let owner_id = Uuid::new_v4();
let processed = process_input(
&DataInput::DataItem {
data: Box::new(DataInput::Url(url.clone())),
label: "invalid-meta".to_string(),
external_metadata: Some(user_metadata.to_string()),
},
&storage,
HashAlgorithm::Md5,
owner_id,
None,
)
.await
.unwrap();
let metadata: serde_json::Value =
serde_json::from_str(processed.external_metadata.as_ref().unwrap()).unwrap();
assert_eq!(metadata["data_item_external_metadata_raw"], user_metadata);
assert_eq!(metadata["source"], "url");
assert_eq!(metadata["url"], url);
}
}
#[cfg(feature = "html-loader")]
#[tokio::test]
async fn test_non_html_url_inputs_do_not_store_raw_source_copy() {
let mut server = server_with_robots().await;
let cases = [
("/plain", "text/plain", "plain body", "txt", "text/plain"),
(
"/json",
"application/json",
r#"{"hello":"world"}"#,
"json",
"application/json",
),
("/csv", "text/csv", "a,b\n1,2\n", "csv", "text/csv"),
(
"/pdf",
"application/pdf",
"%PDF-1.7\n",
"pdf",
"application/pdf",
),
];
for (path, content_type, body, expected_ext, expected_mime) in cases {
let url = format!("{}{}", server.url(), path);
let _mock = server
.mock("GET", path)
.with_header("content-type", content_type)
.with_body(body)
.create_async()
.await;
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let processed = process_input(
&DataInput::Url(url.clone()),
&storage,
HashAlgorithm::Md5,
Uuid::new_v4(),
None,
)
.await
.unwrap();
assert_eq!(processed.raw_source_uri, None);
assert_eq!(processed.original_location, url);
assert_eq!(processed.stored_extension, expected_ext);
assert_eq!(processed.stored_mime_type, expected_mime);
assert_eq!(
storage.retrieve(&processed.raw_data_uri).await.unwrap(),
body.as_bytes()
);
}
}
#[tokio::test]
async fn test_add_text_input() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let inputs = vec![DataInput::Text("Hello, world!".to_string())];
let result = pipeline.add(inputs, "test_dataset", owner_id, None).await;
assert!(result.is_ok(), "add should succeed: {:?}", result.err());
let data = result.unwrap();
assert_eq!(data.len(), 1);
assert!(
data[0].name.starts_with("text_"),
"name should start with text_"
);
assert_eq!(data[0].mime_type, "text/plain");
assert_eq!(data[0].extension, "txt");
let datasets = ops::datasets::list_datasets_by_owner(&db, owner_id)
.await
.unwrap();
assert_eq!(datasets.len(), 1);
let ds_data = ops::datasets::get_dataset_data(&db, datasets[0].id)
.await
.unwrap();
assert_eq!(ds_data.len(), 1);
}
#[tokio::test]
async fn test_add_file_input() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "Test file content").unwrap();
let file_path = temp_file.path().to_str().unwrap().to_string();
let inputs = vec![DataInput::FilePath(file_path)];
let result = pipeline.add(inputs, "test_dataset", owner_id, None).await;
assert!(result.is_ok());
let data = result.unwrap();
assert_eq!(data.len(), 1);
assert!(!data[0].name.is_empty());
let datasets = ops::datasets::list_datasets_by_owner(&db, owner_id)
.await
.unwrap();
assert_eq!(datasets.len(), 1);
}
#[tokio::test]
async fn test_add_multiple_inputs() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let inputs = vec![
DataInput::Text("First text".to_string()),
DataInput::Text("Second text".to_string()),
];
let result = pipeline.add(inputs, "test_dataset", owner_id, None).await;
assert!(result.is_ok());
let data = result.unwrap();
assert_eq!(data.len(), 2);
let datasets = ops::datasets::list_datasets_by_owner(&db, owner_id)
.await
.unwrap();
assert_eq!(datasets.len(), 1);
let ds_data = ops::datasets::get_dataset_data(&db, datasets[0].id)
.await
.unwrap();
assert_eq!(ds_data.len(), 2);
}
#[tokio::test]
async fn test_deduplication_same_content() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let content = "Duplicate content";
let result1 = pipeline
.add(
vec![DataInput::Text(content.to_string())],
"test_dataset",
owner_id,
None,
)
.await
.unwrap();
let result2 = pipeline
.add(
vec![DataInput::Text(content.to_string())],
"test_dataset",
owner_id,
None,
)
.await
.unwrap();
assert_eq!(result1[0].id, result2[0].id);
assert_eq!(result1[0].content_hash, result2[0].content_hash);
let dataset = ops::datasets::get_dataset_by_name(&db, "test_dataset", owner_id, None)
.await
.unwrap()
.unwrap();
let ds_data = ops::datasets::get_dataset_data(&db, dataset.id)
.await
.unwrap();
assert_eq!(ds_data.len(), 1);
}
#[tokio::test]
async fn test_different_owners_same_hash_different_ids() {
let (pipeline, _db) = make_pipeline().await;
let owner1 = Uuid::new_v4();
let owner2 = Uuid::new_v4();
let result1 = pipeline
.add(
vec![DataInput::Text("Same content".to_string())],
"ds1",
owner1,
None,
)
.await
.unwrap();
let result2 = pipeline
.add(
vec![DataInput::Text("Same content".to_string())],
"ds2",
owner2,
None,
)
.await
.unwrap();
assert_eq!(
result1[0].content_hash, result2[0].content_hash,
"content hash is owner-independent"
);
assert_ne!(result1[0].id, result2[0].id, "data_id must differ by owner");
}
#[tokio::test]
async fn test_multiple_datasets() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
pipeline
.add(
vec![DataInput::Text("Content 1".to_string())],
"dataset1",
owner_id,
None,
)
.await
.unwrap();
pipeline
.add(
vec![DataInput::Text("Content 2".to_string())],
"dataset2",
owner_id,
None,
)
.await
.unwrap();
let datasets = ops::datasets::list_datasets_by_owner(&db, owner_id)
.await
.unwrap();
assert_eq!(datasets.len(), 2);
}
#[tokio::test]
async fn test_reuse_dataset() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
pipeline
.add(
vec![DataInput::Text("Content 1".to_string())],
"same_dataset",
owner_id,
None,
)
.await
.unwrap();
pipeline
.add(
vec![DataInput::Text("Content 2".to_string())],
"same_dataset",
owner_id,
None,
)
.await
.unwrap();
let datasets = ops::datasets::list_datasets_by_owner(&db, owner_id)
.await
.unwrap();
assert_eq!(datasets.len(), 1);
let ds_data = ops::datasets::get_dataset_data(&db, datasets[0].id)
.await
.unwrap();
assert_eq!(ds_data.len(), 2);
}
#[tokio::test]
async fn test_content_hash_deterministic() {
let (pipeline, _db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let result1 = pipeline
.add(
vec![DataInput::Text("Test content".to_string())],
"dataset1",
owner_id,
None,
)
.await
.unwrap();
let result2 = pipeline
.add(
vec![DataInput::Text("Test content".to_string())],
"dataset1",
owner_id,
None,
)
.await
.unwrap();
assert_eq!(result1[0].content_hash, result2[0].content_hash);
assert_eq!(result1[0].id, result2[0].id);
}
#[tokio::test]
async fn test_content_hash_non_empty_across_variants_and_db_roundtrip() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let text_data = pipeline
.add(
vec![DataInput::Text("Provenance audit text".to_string())],
"audit_text",
owner_id,
None,
)
.await
.unwrap();
assert!(
!text_data[0].content_hash.is_empty(),
"Text input must populate content_hash"
);
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "Provenance audit file").unwrap();
let file_path = temp_file.path().to_str().unwrap().to_string();
let file_data = pipeline
.add(
vec![DataInput::FilePath(file_path)],
"audit_file",
owner_id,
None,
)
.await
.unwrap();
assert!(
!file_data[0].content_hash.is_empty(),
"FilePath input must populate content_hash"
);
let binary_data = pipeline
.add(
vec![DataInput::Binary {
name: "audit.bin".to_string(),
data: b"provenance audit binary".to_vec(),
}],
"audit_binary",
owner_id,
None,
)
.await
.unwrap();
assert!(
!binary_data[0].content_hash.is_empty(),
"Binary input must populate content_hash"
);
let wrapped = pipeline
.add(
vec![DataInput::DataItem {
data: Box::new(DataInput::Text("Wrapped audit text".to_string())),
label: "wrapped".to_string(),
external_metadata: None,
}],
"audit_wrapped",
owner_id,
None,
)
.await
.unwrap();
assert!(
!wrapped[0].content_hash.is_empty(),
"DataItem(Text) must populate content_hash"
);
let reread = ops::data::get_data(&db, text_data[0].id)
.await
.unwrap()
.expect("data row exists immediately after add()");
assert_eq!(
reread.content_hash, text_data[0].content_hash,
"content_hash must round-trip through SeaORM <-> Data conversion"
);
}
#[tokio::test]
async fn test_dataset_id_is_deterministic() {
let (pipeline, db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
pipeline
.add(
vec![DataInput::Text("any content".to_string())],
"my_dataset",
owner_id,
None,
)
.await
.unwrap();
pipeline
.add(
vec![DataInput::Text("other content".to_string())],
"my_dataset",
owner_id,
None,
)
.await
.unwrap();
let datasets = ops::datasets::list_datasets_by_owner(&db, owner_id)
.await
.unwrap();
assert_eq!(
datasets.len(),
1,
"deterministic dataset ID should prevent duplicate creation"
);
}
#[tokio::test]
async fn test_loader_engine_populated() {
let (pipeline, _db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "content").unwrap();
let txt_path = temp_file.path().with_extension("txt");
std::fs::copy(temp_file.path(), &txt_path).unwrap();
let result = pipeline
.add(
vec![DataInput::FilePath(txt_path.to_str().unwrap().to_string())],
"ds",
owner_id,
None,
)
.await
.unwrap();
assert_eq!(result[0].loader_engine.as_deref(), Some("text_loader"));
assert_eq!(result[0].extension, "txt");
let _ = std::fs::remove_file(&txt_path);
}
#[cfg(not(any(feature = "pdf-pdfium", feature = "pdf-pure-rust")))]
#[tokio::test]
async fn test_unsupported_loader_type_errors_at_add() {
let (pipeline, _db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let mut temp_file = NamedTempFile::new().unwrap();
writeln!(temp_file, "%PDF-1.7").unwrap();
let pdf_path = temp_file.path().with_extension("pdf");
std::fs::copy(temp_file.path(), &pdf_path).unwrap();
let result = pipeline
.add(
vec![DataInput::FilePath(pdf_path.to_str().unwrap().to_string())],
"ds",
owner_id,
None,
)
.await;
assert!(
result.is_err(),
"pdf input must error when the pdf loader feature is off"
);
let _ = std::fs::remove_file(&pdf_path);
}
#[tokio::test]
async fn test_text_path_no_regression_hashes_and_stored_bytes() {
use cognee_storage::LocalStorage;
const HELLO_WORLD: &str = "hello world";
const HELLO_WORLD_MD5: &str = "5eb63bbbe01eeed093cb22bb8f5acdc3";
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let owner_id = Uuid::new_v4();
let processed = process_input(
&DataInput::Text(HELLO_WORLD.to_string()),
&storage,
HashAlgorithm::Md5,
owner_id,
None,
)
.await
.unwrap();
assert_eq!(processed.content_hash, HELLO_WORLD_MD5);
assert_eq!(
processed.data_id,
generate_data_id(HELLO_WORLD_MD5, owner_id, None),
"data_id must remain uuid5(content_hash + owner + tenant)"
);
assert_eq!(
processed.raw_content_hash, processed.content_hash,
"plain-text raw_content_hash must equal content_hash"
);
assert_eq!(processed.stored_extension, "txt");
assert_eq!(processed.stored_mime_type, "text/plain");
assert_eq!(processed.loader_engine, "text_loader");
let stored = storage.retrieve(&processed.raw_data_uri).await.unwrap();
assert_eq!(stored, HELLO_WORLD.as_bytes());
}
#[cfg(feature = "csv-loader")]
#[tokio::test]
async fn test_csv_path_stores_extracted_text() {
use cognee_storage::LocalStorage;
let csv = "name,age\nAlice,30\nBob,25\n";
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let owner_id = Uuid::new_v4();
let mut temp_file = NamedTempFile::new().unwrap();
write!(temp_file, "{csv}").unwrap();
let csv_path = temp_file.path().with_extension("csv");
std::fs::copy(temp_file.path(), &csv_path).unwrap();
let processed = process_input(
&DataInput::FilePath(csv_path.to_str().unwrap().to_string()),
&storage,
HashAlgorithm::Md5,
owner_id,
None,
)
.await
.unwrap();
assert_eq!(
processed.content_hash,
crate::content_hasher::ContentHasher::hash_content(csv.as_bytes(), HashAlgorithm::Md5)
);
assert_eq!(processed.stored_extension, "txt");
assert_eq!(processed.stored_mime_type, "text/plain");
assert_eq!(processed.original_extension, "csv");
assert_eq!(processed.loader_engine, "csv_loader");
assert_ne!(
processed.raw_content_hash, processed.content_hash,
"extracted csv text differs from raw bytes"
);
let stored = storage.retrieve(&processed.raw_data_uri).await.unwrap();
let stored_text = String::from_utf8(stored).unwrap();
assert!(stored_text.contains("name: Alice, age: 30"));
assert!(!stored_text.contains("name,age"));
assert_eq!(
processed.raw_content_hash,
crate::content_hasher::ContentHasher::hash_content(
stored_text.as_bytes(),
HashAlgorithm::Md5
)
);
let _ = std::fs::remove_file(&csv_path);
}
#[tokio::test]
async fn test_s3_input_errors_at_add() {
use cognee_storage::LocalStorage;
let temp_dir = tempfile::tempdir().unwrap();
let storage = LocalStorage::new(temp_dir.path().to_path_buf());
let result = process_input(
&DataInput::S3Path("s3://bucket/key.txt".to_string()),
&storage,
HashAlgorithm::Md5,
Uuid::new_v4(),
None,
)
.await;
assert!(result.is_err(), "S3 input must error at ingest");
}
#[tokio::test]
async fn test_tenant_id_stored() {
let (pipeline, _db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let tenant_id = Uuid::new_v4();
let result = pipeline
.add(
vec![DataInput::Text("tenant content".to_string())],
"ds",
owner_id,
Some(tenant_id),
)
.await
.unwrap();
assert_eq!(result[0].tenant_id, Some(tenant_id));
}
#[tokio::test]
async fn test_data_item_label_stored() {
let (pipeline, _db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let result = pipeline
.add(
vec![DataInput::DataItem {
data: Box::new(DataInput::Text("labelled content".to_string())),
label: "my-label".to_string(),
external_metadata: None,
}],
"ds",
owner_id,
None,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result[0].label.as_deref(),
Some("my-label"),
"DataItem label must be stored in the Data record"
);
}
#[test]
fn extract_name_file_path_uses_stem_not_full_name() {
let input = DataInput::FilePath("documents/report.txt".into());
let name = super::extract_name(&input, "abc123");
assert_eq!(
name, "report",
"file path name should be stem (no extension)"
);
}
#[test]
fn extract_name_file_path_with_file_uri() {
let input = DataInput::FilePath("file:///tmp/data/notes.pdf".into());
let name = super::extract_name(&input, "abc123");
assert_eq!(name, "notes");
}
#[test]
fn extract_name_text_input_uses_hash() {
let input = DataInput::Text("hello world".into());
let name = super::extract_name(&input, "5eb63bbbe01eeed093cb22bb8f5acdc3");
assert_eq!(name, "text_5eb63bbbe01eeed093cb22bb8f5acdc3");
}
#[test]
fn binary_md_file_gets_text_plain_mime() {
let input = DataInput::Binary {
name: "notes.md".to_string(),
data: b"# Heading\nSome markdown".to_vec(),
};
let (_name, _ext, mime, _label) = super::extract_file_metadata(&input);
assert_eq!(
mime, "text/plain",
".md binary should produce text/plain, not text/markdown"
);
}
#[test]
fn file_path_md_gets_text_plain_mime() {
let input = DataInput::FilePath("/tmp/notes.md".to_string());
let (_name, _ext, mime, _label) = super::extract_file_metadata(&input);
assert_eq!(
mime, "text/plain",
".md file path should produce text/plain, not text/markdown"
);
}
#[test]
fn file_path_json_gets_text_plain_mime() {
let input = DataInput::FilePath("/tmp/data.json".to_string());
let (_name, _ext, mime, _label) = super::extract_file_metadata(&input);
assert_eq!(
mime, "text/plain",
".json file path should produce text/plain, not application/json"
);
}
#[test]
fn file_path_pdf_keeps_original_mime() {
let input = DataInput::FilePath("/tmp/doc.pdf".to_string());
let (_name, _ext, mime, _label) = super::extract_file_metadata(&input);
assert_ne!(
mime, "text/plain",
".pdf should NOT be overridden to text/plain"
);
}
#[tokio::test]
async fn test_data_item_external_metadata_stored() {
let (pipeline, _db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let meta_json = r#"{"source":"dlt","table_name":"orders"}"#.to_string();
let result = pipeline
.add(
vec![DataInput::DataItem {
data: Box::new(DataInput::Text("dlt content".to_string())),
label: "dlt-label".to_string(),
external_metadata: Some(meta_json.clone()),
}],
"ds",
owner_id,
None,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result[0].external_metadata.as_deref(),
Some(meta_json.as_str()),
"DataItem external_metadata must be stored in the Data record"
);
}
#[tokio::test]
async fn test_data_item_without_metadata() {
let (pipeline, _db) = make_pipeline().await;
let owner_id = Uuid::new_v4();
let result = pipeline
.add(
vec![DataInput::DataItem {
data: Box::new(DataInput::Text("no metadata".to_string())),
label: "plain-label".to_string(),
external_metadata: None,
}],
"ds",
owner_id,
None,
)
.await
.unwrap();
assert_eq!(result.len(), 1);
assert_eq!(
result[0].external_metadata, None,
"DataItem with no external_metadata should produce None on Data"
);
}
}