#![allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
#![cfg(feature = "testing")]
use cognee_cognify::graph_integration::{GraphEdgePair, GraphNodePair};
use cognee_cognify::memify::extract_triplets::extract_triplets_from_graph_db;
use cognee_cognify::memify::{MemifyConfig, memify};
use cognee_cognify::triplet_creation::create_triplets_from_graph;
use cognee_core::{CpuPool, RayonThreadPool};
use cognee_database::{DatabaseConnection, connect, initialize};
use cognee_embedding::{EmbeddingEngine, MockEmbeddingEngine};
use cognee_graph::{GraphDBTrait, MockGraphDB};
use cognee_models::{Entity, EntityType};
use cognee_vector::{MockVectorDB, VectorDB};
use serde_json::json;
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
async fn make_ctx_handles() -> (Arc<dyn CpuPool>, Arc<DatabaseConnection>) {
let db = connect("sqlite::memory:")
.await
.expect("connect in-memory DB");
initialize(&db).await.expect("initialize schema");
let pool: Arc<dyn CpuPool> =
Arc::new(RayonThreadPool::with_default_threads().expect("rayon pool"));
(pool, Arc::new(db))
}
async fn add_node(db: &dyn GraphDBTrait, id: Uuid, name: &str, description: &str) {
let mut node_json = serde_json::Map::new();
node_json.insert("id".to_string(), json!(id.to_string()));
node_json.insert("name".to_string(), json!(name));
if !description.is_empty() {
node_json.insert("description".to_string(), json!(description));
}
db.add_node_raw(serde_json::Value::Object(node_json))
.await
.unwrap();
}
async fn add_edge(db: &dyn GraphDBTrait, source: Uuid, target: Uuid, relationship: &str) {
db.add_edge(&source.to_string(), &target.to_string(), relationship, None)
.await
.unwrap();
}
async fn seed_graph(db: &dyn GraphDBTrait) -> (Uuid, Uuid, Uuid) {
let id_a = Uuid::new_v4();
let id_b = Uuid::new_v4();
let id_c = Uuid::new_v4();
add_node(db, id_a, "Alice", "Software engineer").await;
add_node(db, id_b, "TechCorp", "Technology company").await;
add_node(db, id_c, "Bob", "Product manager").await;
add_edge(db, id_a, id_b, "works_at").await;
add_edge(db, id_a, id_c, "knows").await;
(id_a, id_b, id_c)
}
#[tokio::test]
async fn test_memify_end_to_end() {
let graph_db: Arc<dyn GraphDBTrait> = Arc::new(MockGraphDB::new());
let vector_db: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
let (pool, database) = make_ctx_handles().await;
let config = MemifyConfig::default();
let (_a, _b, _c) = seed_graph(&*graph_db).await;
let result = memify(
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&engine),
pool,
database,
Arc::new(cognee_database::NoopPipelineRunRepository::new())
as Arc<dyn cognee_database::PipelineRunRepository>,
Some(Uuid::new_v4()),
Some(Uuid::new_v4()),
None,
&config,
)
.await
.unwrap();
assert_eq!(result.triplet_count, 2);
assert_eq!(result.index_result.indexed_count, 2);
assert!(result.index_result.batch_count >= 1);
assert!(vector_db.has_collection("Triplet", "text").await.unwrap());
assert_eq!(
vector_db.collection_size("Triplet", "text").await.unwrap(),
2
);
}
#[tokio::test]
async fn test_memify_idempotent() {
let graph_db: Arc<dyn GraphDBTrait> = Arc::new(MockGraphDB::new());
let vector_db: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
let (pool, database) = make_ctx_handles().await;
let config = MemifyConfig::default();
seed_graph(&*graph_db).await;
let r1 = memify(
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&engine),
Arc::clone(&pool),
Arc::clone(&database),
Arc::new(cognee_database::NoopPipelineRunRepository::new())
as Arc<dyn cognee_database::PipelineRunRepository>,
None,
None,
None,
&config,
)
.await
.unwrap();
let r2 = memify(
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&engine),
Arc::clone(&pool),
Arc::clone(&database),
Arc::new(cognee_database::NoopPipelineRunRepository::new())
as Arc<dyn cognee_database::PipelineRunRepository>,
None,
None,
None,
&config,
)
.await
.unwrap();
assert_eq!(r1.triplet_count, r2.triplet_count);
assert_eq!(r1.index_result.indexed_count, r2.index_result.indexed_count);
assert_eq!(
vector_db.collection_size("Triplet", "text").await.unwrap(),
2,
"idempotent upsert should not duplicate points"
);
}
#[tokio::test]
async fn test_memify_empty_graph() {
let graph_db: Arc<dyn GraphDBTrait> = Arc::new(MockGraphDB::new());
let vector_db: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
let (pool, database) = make_ctx_handles().await;
let config = MemifyConfig::default();
let result = memify(
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&engine),
pool,
database,
Arc::new(cognee_database::NoopPipelineRunRepository::new())
as Arc<dyn cognee_database::PipelineRunRepository>,
None,
None,
None,
&config,
)
.await
.unwrap();
assert_eq!(result.triplet_count, 0);
assert_eq!(result.index_result.indexed_count, 0);
assert_eq!(result.index_result.batch_count, 0);
assert!(
!vector_db.has_collection("Triplet", "text").await.unwrap(),
"no collection should be created for empty graph"
);
}
#[tokio::test]
async fn test_memify_idempotent_ids_match_cognify() {
let graph_db = MockGraphDB::new();
let id_a = Uuid::new_v4();
let id_b = Uuid::new_v4();
let id_c = Uuid::new_v4();
add_node(&graph_db, id_a, "Alice", "Software engineer").await;
add_node(&graph_db, id_b, "TechCorp", "Technology company").await;
add_node(&graph_db, id_c, "Bob", "Product manager").await;
add_edge(&graph_db, id_a, id_b, "works_at").await;
add_edge(&graph_db, id_a, id_c, "knows").await;
add_edge(&graph_db, id_b, id_c, "employs").await;
fn make_node(id: Uuid, name: &str, description: &str) -> GraphNodePair {
let mut entity = Entity::new(name, None, description, None);
entity.base.id = id;
let entity_type = EntityType::new("Generic", "Generic type", None);
GraphNodePair {
entity,
entity_type,
}
}
let nodes = vec![
make_node(id_a, "Alice", "Software engineer"),
make_node(id_b, "TechCorp", "Technology company"),
make_node(id_c, "Bob", "Product manager"),
];
let edges = vec![
GraphEdgePair::new(id_a, id_b, "works_at"),
GraphEdgePair::new(id_a, id_c, "knows"),
GraphEdgePair::new(id_b, id_c, "employs"),
];
let memify_config = MemifyConfig::default();
let memify_triplets = extract_triplets_from_graph_db(&graph_db, &memify_config)
.await
.expect("memify extract should succeed on seeded mock graph");
let cognify_triplets = create_triplets_from_graph(&nodes, &edges);
assert_eq!(
memify_triplets.len(),
cognify_triplets.len(),
"memify and cognify should produce the same number of triplets for \
the same logical graph state (memify={}, cognify={})",
memify_triplets.len(),
cognify_triplets.len(),
);
assert_eq!(
memify_triplets.len(),
3,
"sanity: all three seeded edges should yield triplets"
);
let memify_map: HashMap<(Uuid, String, Uuid), Uuid> = memify_triplets
.iter()
.map(|t| {
(
(
t.source_entity_id,
t.relationship_name.clone(),
t.target_entity_id,
),
t.id,
)
})
.collect();
let cognify_map: HashMap<(Uuid, String, Uuid), Uuid> = cognify_triplets
.iter()
.map(|t| {
(
(
t.source_entity_id,
t.relationship_name.clone(),
t.target_entity_id,
),
t.id,
)
})
.collect();
assert_eq!(
memify_map.len(),
memify_triplets.len(),
"memify triplets must have unique (source, rel, target) tuples"
);
assert_eq!(
cognify_map.len(),
cognify_triplets.len(),
"cognify triplets must have unique (source, rel, target) tuples"
);
let memify_keys: std::collections::HashSet<_> = memify_map.keys().collect();
let cognify_keys: std::collections::HashSet<_> = cognify_map.keys().collect();
assert_eq!(
memify_keys, cognify_keys,
"memify and cognify must cover the same (source, rel, target) tuples"
);
for (key, memify_id) in &memify_map {
let cognify_id = cognify_map
.get(key)
.expect("key presence already asserted by set equality above");
assert_eq!(
memify_id, cognify_id,
"Triplet.id diverges between memify and cognify for \
(source={}, rel={}, target={}): memify={}, cognify={}. \
This would cause duplicate vector points instead of upsert.",
key.0, key.1, key.2, memify_id, cognify_id,
);
}
}
#[tokio::test]
async fn test_memify_rejects_invalid_config() {
let graph_db: Arc<dyn GraphDBTrait> = Arc::new(MockGraphDB::new());
let vector_db: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
let (pool, database) = make_ctx_handles().await;
let config = MemifyConfig::default().with_triplet_batch_size(0);
let err = memify(
graph_db,
vector_db,
engine,
pool,
database,
Arc::new(cognee_database::NoopPipelineRunRepository::new())
as Arc<dyn cognee_database::PipelineRunRepository>,
None,
None,
None,
&config,
)
.await
.unwrap_err();
assert!(
err.to_string().contains("triplet_batch_size"),
"expected config validation error, got: {err}"
);
}
async fn add_typed_node(
db: &dyn GraphDBTrait,
id: Uuid,
name: &str,
node_type: &str,
description: &str,
) {
let mut node_json = serde_json::Map::new();
node_json.insert("id".to_string(), json!(id.to_string()));
node_json.insert("name".to_string(), json!(name));
node_json.insert("type".to_string(), json!(node_type));
if !description.is_empty() {
node_json.insert("description".to_string(), json!(description));
}
db.add_node_raw(serde_json::Value::Object(node_json))
.await
.unwrap();
}
async fn seed_filter_graph(db: &dyn GraphDBTrait) -> (Uuid, Uuid, Uuid, Uuid) {
let alice = Uuid::new_v4();
let bob = Uuid::new_v4();
let carol = Uuid::new_v4();
let idea1 = Uuid::new_v4();
add_typed_node(db, alice, "Alice", "Entity", "Person A").await;
add_typed_node(db, bob, "Bob", "Entity", "Person B").await;
add_typed_node(db, carol, "Carol", "Entity", "Person C").await;
add_typed_node(db, idea1, "Idea1", "Concept", "An idea").await;
add_edge(db, alice, bob, "knows").await;
add_edge(db, bob, carol, "knows").await;
add_edge(db, alice, idea1, "likes").await;
(alice, bob, carol, idea1)
}
#[tokio::test]
async fn test_memify_with_type_and_names_filter_or() {
let graph_db: Arc<dyn GraphDBTrait> = Arc::new(MockGraphDB::new());
let vector_db: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
let (pool, database) = make_ctx_handles().await;
let (_alice, _bob, _carol, _idea1) = seed_filter_graph(&*graph_db).await;
let config = MemifyConfig::default()
.with_node_type_filter("Entity".to_string())
.with_node_name_filter(vec!["Alice".to_string(), "Bob".to_string()])
.with_node_name_filter_operator("OR".to_string());
let result = memify(
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&engine),
pool,
database,
Arc::new(cognee_database::NoopPipelineRunRepository::new())
as Arc<dyn cognee_database::PipelineRunRepository>,
Some(Uuid::new_v4()),
Some(Uuid::new_v4()),
None,
&config,
)
.await
.unwrap();
assert_eq!(
result.triplet_count, 3,
"OR filter should keep all 3 edges between the included primaries and their neighbors"
);
assert_eq!(result.index_result.indexed_count, 3);
assert_eq!(
vector_db.collection_size("Triplet", "text").await.unwrap(),
3
);
}
#[tokio::test]
async fn test_memify_with_type_and_names_filter_and() {
let graph_db: Arc<dyn GraphDBTrait> = Arc::new(MockGraphDB::new());
let vector_db: Arc<dyn VectorDB> = Arc::new(MockVectorDB::new());
let engine: Arc<dyn EmbeddingEngine> = Arc::new(MockEmbeddingEngine::new(8));
let (pool, database) = make_ctx_handles().await;
let (_alice, _bob, _carol, _idea1) = seed_filter_graph(&*graph_db).await;
let config = MemifyConfig::default()
.with_node_type_filter("Entity".to_string())
.with_node_name_filter(vec!["Alice".to_string(), "Bob".to_string()])
.with_node_name_filter_operator("AND".to_string());
let result = memify(
Arc::clone(&graph_db),
Arc::clone(&vector_db),
Arc::clone(&engine),
pool,
database,
Arc::new(cognee_database::NoopPipelineRunRepository::new())
as Arc<dyn cognee_database::PipelineRunRepository>,
Some(Uuid::new_v4()),
Some(Uuid::new_v4()),
None,
&config,
)
.await
.unwrap();
assert_eq!(
result.triplet_count, 1,
"AND filter should keep only the Alice-knows-Bob edge (both endpoints are primaries)"
);
assert_eq!(result.index_result.indexed_count, 1);
assert_eq!(
vector_db.collection_size("Triplet", "text").await.unwrap(),
1
);
}