use std::sync::Arc;
use cognee_core::pipeline::DataIdFn;
use cognee_core::pipeline_run_registry::DbPipelineWatcher;
use cognee_core::task::Value;
use cognee_core::{
CpuPool, Pipeline, PipelineBuilder, PipelineContext, TaskContextBuilder, TypedTask,
};
use cognee_database::{DatabaseConnection, PipelineRunRepository};
use cognee_embedding::EmbeddingEngine;
use cognee_graph::GraphDBTrait;
use cognee_models::Triplet;
use cognee_vector::VectorDB;
use tracing::info;
use uuid::Uuid;
use super::config::MemifyConfig;
use super::error::MemifyError;
use super::extract_triplets::extract_triplets_from_graph_db;
use super::index_triplets::{IndexResult, index_triplets};
use crate::qualification::{Qualification, check_pipeline_run_qualification};
#[derive(Debug, Clone)]
pub struct MemifyResult {
pub triplet_count: usize,
pub index_result: IndexResult,
pub already_completed: bool,
pub prior_pipeline_run_id: Option<Uuid>,
}
impl MemifyResult {
pub fn empty() -> Self {
Self {
triplet_count: 0,
index_result: IndexResult {
indexed_count: 0,
batch_count: 0,
},
already_completed: false,
prior_pipeline_run_id: None,
}
}
pub fn already_completed(pipeline_run_id: Uuid) -> Self {
Self {
already_completed: true,
prior_pipeline_run_id: Some(pipeline_run_id),
..Self::empty()
}
}
}
fn make_index_triplets_task(
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
dataset_id: Option<Uuid>,
user_id: Option<Uuid>,
tenant_id: Option<Uuid>,
) -> TypedTask<Vec<Triplet>, IndexResult> {
TypedTask::async_fn(move |triplets: &Vec<Triplet>, _ctx| {
let triplets = triplets.clone();
let vector_db = Arc::clone(&vector_db);
let embedding_engine = Arc::clone(&embedding_engine);
Box::pin(async move {
index_triplets(
&triplets,
&*vector_db,
&*embedding_engine,
dataset_id,
user_id,
tenant_id,
)
.await
.map(Box::new)
.map_err(|e| format!("{e}").into())
})
})
}
pub fn build_memify_index_only_pipeline(
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
dataset_id: Option<Uuid>,
user_id: Option<Uuid>,
tenant_id: Option<Uuid>,
) -> Pipeline {
let data_id_fn: DataIdFn = Arc::new(|_v: Arc<dyn Value>| None);
PipelineBuilder::new_with_task(
"memify",
make_index_triplets_task(vector_db, embedding_engine, dataset_id, user_id, tenant_id),
)
.with_name("memify")
.with_data_id(data_id_fn)
.build()
}
#[allow(clippy::too_many_arguments)]
pub async fn memify(
graph_db: Arc<dyn GraphDBTrait>,
vector_db: Arc<dyn VectorDB>,
embedding_engine: Arc<dyn EmbeddingEngine>,
thread_pool: Arc<dyn CpuPool>,
database: Arc<DatabaseConnection>,
pipeline_run_repo: Arc<dyn PipelineRunRepository>,
dataset_id: Option<Uuid>,
user_id: Option<Uuid>,
tenant_id: Option<Uuid>,
config: &MemifyConfig,
) -> Result<MemifyResult, MemifyError> {
config.validate()?;
let pipeline_name = "memify";
if let Some(ds_id) = dataset_id {
match check_pipeline_run_qualification(pipeline_run_repo.as_ref(), ds_id, pipeline_name)
.await
.map_err(|e| MemifyError::Database(e.to_string()))?
{
Qualification::AlreadyCompleted(prior) => {
info!(
dataset_id = %ds_id,
pipeline_run_id = %prior.pipeline_run_id,
"memify: dataset already completed; short-circuiting (Python parity)"
);
return Ok(MemifyResult::already_completed(prior.pipeline_run_id));
}
Qualification::AlreadyRunning(_prior) => {
return Err(MemifyError::PipelineAlreadyRunning {
pipeline_name: pipeline_name.to_string(),
dataset_id: Some(ds_id),
});
}
Qualification::Proceed => {}
}
}
let triplets = if let Some(ref custom_data) = config.custom_data {
let mut custom_triplets = Vec::new();
for value in custom_data {
let source = value
.get("source_node")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let relationship = value
.get("relationship_name")
.and_then(|v| v.as_str())
.unwrap_or("related_to")
.to_string();
let target = value
.get("target_node")
.and_then(|v| v.as_str())
.unwrap_or("unknown")
.to_string();
let source_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, source.to_lowercase().as_bytes());
let target_id = Uuid::new_v5(&Uuid::NAMESPACE_OID, target.to_lowercase().as_bytes());
let text = format!("{source}-\u{203A}{relationship}-\u{203A}{target}");
custom_triplets.push(
Triplet::new(source_id, target_id, relationship, text).with_names(source, target),
);
}
info!(
"Using {} custom triplets instead of graph extraction",
custom_triplets.len()
);
custom_triplets
} else {
extract_triplets_from_graph_db(&*graph_db, config).await?
};
let triplet_count = triplets.len();
if triplets.is_empty() {
info!("No triplets extracted from graph; nothing to index");
return Ok(MemifyResult::empty());
}
let pipeline = build_memify_index_only_pipeline(
Arc::clone(&vector_db),
Arc::clone(&embedding_engine),
dataset_id,
user_id,
tenant_id,
);
let pipeline_ctx = PipelineContext {
pipeline_id: pipeline.id,
pipeline_name: pipeline.name.clone().unwrap_or_default(),
user_id,
tenant_id,
dataset_id,
current_data: None,
run_id: None,
user_email: None,
provenance_visited: Arc::new(std::sync::Mutex::new(std::collections::HashSet::new())),
};
let (_cancel_handle, ctx) = TaskContextBuilder::new()
.thread_pool(thread_pool)
.database(database)
.graph_db(Arc::clone(&graph_db))
.vector_db(Arc::clone(&vector_db))
.pipeline_context(pipeline_ctx)
.build()
.map_err(|e| MemifyError::Context(e.to_string()))?;
let ctx = Arc::new(ctx);
let inputs: Vec<Arc<dyn Value>> = vec![Arc::new(triplets) as Arc<dyn Value>];
let watcher = DbPipelineWatcher::new(pipeline_run_repo);
let outputs = cognee_core::pipeline::execute(&pipeline, inputs, ctx, &watcher)
.await
.map_err(|e| MemifyError::Execute(e.to_string()))?;
let index_result = extract_memify_outputs(outputs)?;
info!(
"Memify complete: {} triplets extracted, {} indexed",
triplet_count, index_result.indexed_count
);
Ok(MemifyResult {
triplet_count,
index_result,
already_completed: false,
prior_pipeline_run_id: None,
})
}
fn extract_memify_outputs(outputs: Vec<Arc<dyn Value>>) -> Result<IndexResult, MemifyError> {
let first = outputs
.into_iter()
.next()
.ok_or(MemifyError::OutputTypeMismatch {
expected: "IndexResult",
actual: "empty",
})?;
(*first)
.as_any()
.downcast_ref::<IndexResult>()
.cloned()
.ok_or(MemifyError::OutputTypeMismatch {
expected: "IndexResult",
actual: "unknown",
})
}