use std::collections::{HashMap, HashSet};
use cognee_graph::{EdgeData, GraphDBTrait};
use cognee_utils::generate_node_id;
use crate::error::CognifyError;
use crate::fact_extraction::KnowledgeGraph;
pub async fn retrieve_existing_edges(
graph_db: &dyn GraphDBTrait,
graphs: &[KnowledgeGraph],
) -> Result<HashSet<String>, CognifyError> {
if graphs.is_empty() {
return Ok(HashSet::new());
}
let mut processed_nodes: HashSet<String> = HashSet::new();
let mut edges_to_check: Vec<EdgeData> = Vec::new();
for graph in graphs {
for edge in &graph.edges {
let source_uuid = generate_node_id(&edge.source_node_id);
let target_uuid = generate_node_id(&edge.target_node_id);
let source_str = edge.source_node_id.as_str();
let target_str = edge.target_node_id.as_str();
if !processed_nodes.contains(source_str) {
processed_nodes.insert(source_str.to_string());
}
if !processed_nodes.contains(target_str) {
processed_nodes.insert(target_str.to_string());
}
let edge_tuple = (
source_uuid.to_string(),
target_uuid.to_string(),
edge.relationship_name.clone(),
HashMap::new(), );
edges_to_check.push(edge_tuple);
}
}
if edges_to_check.is_empty() {
return Ok(HashSet::new());
}
let existing_edges = graph_db
.has_edges(&edges_to_check)
.await
.map_err(|e| CognifyError::GraphDatabaseError(e.to_string()))?;
let mut existing_edges_set = HashSet::new();
for (source_id, target_id, relationship_name, _) in existing_edges {
let edge_key = format!("{source_id}_{target_id}_{relationship_name}");
existing_edges_set.insert(edge_key);
}
Ok(existing_edges_set)
}
#[cfg(test)]
#[allow(
clippy::unwrap_used,
clippy::expect_used,
reason = "test code — panics are acceptable failures"
)]
mod tests {
use super::*;
use crate::fact_extraction::{Edge, Node};
use cognee_graph::MockGraphDB;
fn create_test_graph() -> KnowledgeGraph {
KnowledgeGraph {
nodes: vec![
Node {
id: "alice".to_string(),
name: "Alice".to_string(),
node_type: "Person".to_string(),
description: "A person".to_string(),
},
Node {
id: "techcorp".to_string(),
name: "TechCorp".to_string(),
node_type: "Organization".to_string(),
description: "A company".to_string(),
},
],
edges: vec![Edge {
source_node_id: "alice".to_string(),
target_node_id: "techcorp".to_string(),
relationship_name: "works_at".to_string(),
description: None,
}],
}
}
#[tokio::test]
async fn test_retrieve_existing_edges_empty() {
let graph_db = MockGraphDB::new();
let result = retrieve_existing_edges(&graph_db, &[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_retrieve_existing_edges_no_existing() {
let graph_db = MockGraphDB::new();
let graph = create_test_graph();
let result = retrieve_existing_edges(&graph_db, &[graph]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn test_retrieve_existing_edges_with_existing() {
let graph_db = MockGraphDB::new();
let graph = create_test_graph();
let alice_uuid = generate_node_id("alice");
let techcorp_uuid = generate_node_id("techcorp");
let _ = graph_db
.add_edge(
&alice_uuid.to_string(),
&techcorp_uuid.to_string(),
"works_at",
None,
)
.await;
let result = retrieve_existing_edges(&graph_db, &[graph]).await.unwrap();
let expected_key = format!("{alice_uuid}_{techcorp_uuid}_works_at");
assert!(result.contains(&expected_key));
}
#[tokio::test]
async fn test_retrieve_existing_edges_partial_match() {
let graph_db = MockGraphDB::new();
let graph1 = create_test_graph();
let graph2 = KnowledgeGraph {
nodes: vec![
Node {
id: "bob".to_string(),
name: "Bob".to_string(),
node_type: "Person".to_string(),
description: "Another person".to_string(),
},
Node {
id: "acmecorp".to_string(),
name: "AcmeCorp".to_string(),
node_type: "Organization".to_string(),
description: "Another company".to_string(),
},
],
edges: vec![Edge {
source_node_id: "bob".to_string(),
target_node_id: "acmecorp".to_string(),
relationship_name: "works_at".to_string(),
description: None,
}],
};
let alice_uuid = generate_node_id("alice");
let techcorp_uuid = generate_node_id("techcorp");
let _ = graph_db
.add_edge(
&alice_uuid.to_string(),
&techcorp_uuid.to_string(),
"works_at",
None,
)
.await;
let result = retrieve_existing_edges(&graph_db, &[graph1, graph2])
.await
.unwrap();
let alice_edge_key = format!("{alice_uuid}_{techcorp_uuid}_works_at");
assert!(result.contains(&alice_edge_key));
assert_eq!(result.len(), 1);
}
#[tokio::test]
async fn test_processed_nodes_tracking() {
let graph_db = MockGraphDB::new();
let graph = KnowledgeGraph {
nodes: vec![
Node {
id: "alice".to_string(),
name: "Alice".to_string(),
node_type: "Person".to_string(),
description: "A person".to_string(),
},
Node {
id: "techcorp".to_string(),
name: "TechCorp".to_string(),
node_type: "Organization".to_string(),
description: "A company".to_string(),
},
Node {
id: "london".to_string(),
name: "London".to_string(),
node_type: "Location".to_string(),
description: "A city".to_string(),
},
],
edges: vec![
Edge {
source_node_id: "alice".to_string(),
target_node_id: "techcorp".to_string(),
relationship_name: "works_at".to_string(),
description: None,
},
Edge {
source_node_id: "alice".to_string(),
target_node_id: "london".to_string(),
relationship_name: "lives_in".to_string(),
description: None,
},
],
};
let result = retrieve_existing_edges(&graph_db, &[graph]).await.unwrap();
assert!(result.is_empty());
}
}