Skip to main content

cognee_graph/
traits.rs

1//! Graph database trait interface.
2//!
3//! Defines the complete async API for graph database operations.
4
5use async_trait::async_trait;
6use serde::Serialize;
7use serde_json::Value;
8use std::borrow::Cow;
9use std::collections::{HashMap, HashSet};
10
11use crate::{EdgeData, GraphDBResult, GraphNode, NodeData};
12
13/// Composite key uniquely identifying an edge in the graph:
14/// `(source_id, target_id, relationship_name)`.
15pub type EdgeKey = (String, String, String);
16
17/// Graph database interface trait.
18///
19/// This trait defines the complete set of operations for graph database interaction,
20/// providing a consistent API for any graph database backend.
21///
22/// # Methods
23///
24/// ## Core Operations
25/// - `initialize()` - Set up database schema
26/// - `is_empty()` - Check if database is empty
27/// - `query()` - Execute raw query
28/// - `delete_graph()` - Remove all data
29///
30/// ## Node Operations
31/// - `add_node()` - Add single node
32/// - `add_nodes()` - Add multiple nodes
33/// - `delete_node()` - Delete single node
34/// - `delete_nodes()` - Delete multiple nodes
35/// - `get_node()` - Get single node
36/// - `get_nodes()` - Get multiple nodes
37/// - `has_node()` - Check node existence
38///
39/// ## Edge Operations
40/// - `add_edge()` - Add single edge
41/// - `add_edges()` - Add multiple edges
42/// - `has_edge()` - Check edge existence
43/// - `has_edges()` - Check multiple edges existence
44/// - `get_edges()` - Get all edges for a node
45///
46/// ## Graph Queries
47/// - `get_neighbors()` - Get neighboring nodes
48/// - `get_connections()` - Get all connections (nodes + edges)
49/// - `get_graph_data()` - Get all nodes and edges
50/// - `get_graph_metrics()` - Get graph statistics
51/// - `get_filtered_graph_data()` - Get filtered subgraph
52/// - `get_nodeset_subgraph()` - Get subgraph for specific nodes
53#[async_trait]
54pub trait GraphDBTrait: Send + Sync {
55    /// Initialize the database schema.
56    ///
57    /// Creates necessary tables, indexes, and constraints.
58    ///
59    async fn initialize(&self) -> GraphDBResult<()>;
60
61    /// Check if the database is empty (no nodes).
62    ///
63    async fn is_empty(&self) -> GraphDBResult<bool>;
64
65    /// Execute a raw database query.
66    ///
67    /// # Arguments
68    /// * `query` - Query string (Cypher-like for Ladybug)
69    /// * `params` - Query parameters
70    ///
71    async fn query(
72        &self,
73        query: &str,
74        params: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
75    ) -> GraphDBResult<Vec<Vec<serde_json::Value>>>;
76
77    /// Delete the entire graph (all nodes and edges).
78    ///
79    async fn delete_graph(&self) -> GraphDBResult<()>;
80
81    /// Check if a node exists by ID.
82    ///
83    async fn has_node(&self, node_id: &str) -> GraphDBResult<bool>;
84
85    /// Add a single node (type-erased). Takes a pre-serialized JSON value.
86    /// Prefer [`GraphDBTraitExt::add_node`] for typed access.
87    async fn add_node_raw(&self, node: Value) -> GraphDBResult<()>;
88
89    /// Add multiple nodes (type-erased). Takes pre-serialized JSON values.
90    /// Prefer [`GraphDBTraitExt::add_nodes`] for typed access.
91    async fn add_nodes_raw(&self, nodes: Vec<Value>) -> GraphDBResult<()>;
92
93    /// Delete a node by ID.
94    ///
95    async fn delete_node(&self, node_id: &str) -> GraphDBResult<()>;
96
97    /// Delete multiple nodes by IDs.
98    ///
99    async fn delete_nodes(&self, node_ids: &[String]) -> GraphDBResult<()>;
100
101    /// Get a single node by ID.
102    ///
103    /// Returns None if node doesn't exist.
104    ///
105    async fn get_node(&self, node_id: &str) -> GraphDBResult<Option<NodeData>>;
106
107    /// Get multiple nodes by IDs.
108    ///
109    async fn get_nodes(&self, node_ids: &[String]) -> GraphDBResult<Vec<NodeData>>;
110
111    /// Check if an edge exists between two nodes.
112    ///
113    /// # Arguments
114    /// * `source_id` - Source node ID
115    /// * `target_id` - Target node ID
116    /// * `relationship_name` - Edge label/relationship type
117    ///
118    async fn has_edge(
119        &self,
120        source_id: &str,
121        target_id: &str,
122        relationship_name: &str,
123    ) -> GraphDBResult<bool>;
124
125    /// Check which edges exist from a list.
126    ///
127    /// Returns only edges that exist in the database.
128    ///
129    async fn has_edges(&self, edges: &[EdgeData]) -> GraphDBResult<Vec<EdgeData>>;
130
131    /// Add a single edge between two nodes.
132    ///
133    /// # Arguments
134    /// * `source_id` - Source node ID
135    /// * `target_id` - Target node ID
136    /// * `relationship_name` - Edge label/relationship type
137    /// * `properties` - Optional edge properties
138    ///
139    async fn add_edge(
140        &self,
141        source_id: &str,
142        target_id: &str,
143        relationship_name: &str,
144        properties: Option<HashMap<Cow<'static, str>, serde_json::Value>>,
145    ) -> GraphDBResult<()>;
146
147    /// Add multiple edges in a batch operation.
148    ///
149    /// # Arguments
150    /// * `edges` - Vector of EdgeData tuples
151    ///
152    async fn add_edges(&self, edges: &[EdgeData]) -> GraphDBResult<()>;
153
154    /// Get all edges connected to a node.
155    ///
156    /// Returns edges in format: (source_id, target_id, relationship_name, properties)
157    ///
158    async fn get_edges(&self, node_id: &str) -> GraphDBResult<Vec<EdgeData>>;
159
160    /// Get all neighboring nodes (directly connected).
161    ///
162    async fn get_neighbors(&self, node_id: &str) -> GraphDBResult<Vec<NodeData>>;
163
164    /// Get all connections (nodes + edges) for a node.
165    ///
166    /// Returns: Vec<(source_node, edge_properties, target_node)>
167    ///
168    async fn get_connections(
169        &self,
170        node_id: &str,
171    ) -> GraphDBResult<
172        Vec<(
173            NodeData,
174            HashMap<Cow<'static, str>, serde_json::Value>,
175            NodeData,
176        )>,
177    >;
178
179    /// Get all nodes and edges in the graph.
180    ///
181    /// Returns: (nodes, edges) where:
182    /// - nodes: Vec<(node_id, properties)>
183    /// - edges: Vec<(source_id, target_id, relationship_name, properties)>
184    ///
185    async fn get_graph_data(&self) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)>;
186
187    /// Get graph metrics and statistics.
188    ///
189    /// Returns metrics like node count, edge count, density, etc.
190    ///
191    async fn get_graph_metrics(
192        &self,
193        include_optional: bool,
194    ) -> GraphDBResult<HashMap<Cow<'static, str>, serde_json::Value>>;
195
196    /// Get a filtered subgraph based on attribute filters.
197    ///
198    /// # Arguments
199    /// * `attribute_filters` - Filters as key-value pairs
200    ///
201    async fn get_filtered_graph_data(
202        &self,
203        attribute_filters: &HashMap<Cow<'static, str>, Vec<serde_json::Value>>,
204    ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)>;
205
206    /// Get subgraph for a specific set of nodes.
207    ///
208    /// # Arguments
209    /// * `node_type` - Type name of nodes to retrieve
210    /// * `node_names` - Names of specific nodes
211    /// * `node_name_filter_operator` - "OR" to include neighbors of ANY named node,
212    ///   "AND" to include only neighbors connected to ALL named nodes
213    ///
214    /// Returns nodes and edges connecting them.
215    ///
216    async fn get_nodeset_subgraph(
217        &self,
218        node_type: &str,
219        node_names: &[String],
220        node_name_filter_operator: &str,
221    ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)>;
222
223    /// Find nodes of the given type that have exactly one edge (any direction).
224    ///
225    /// Used by hard-delete mode to locate orphaned Entity/EntityType nodes that
226    /// are no longer meaningfully connected after a soft deletion.
227    ///
228    /// Default implementation fetches the full graph and computes degree in
229    /// memory (O(N+E)).  Backends may override with an efficient Cypher/SQL query.
230    async fn get_degree_one_nodes(&self, node_type: &str) -> GraphDBResult<Vec<crate::GraphNode>> {
231        let (nodes, edges) = self.get_graph_data().await?;
232
233        // Build a degree map from edges (count both endpoints)
234        let mut degree: HashMap<String, usize> = HashMap::new();
235        for (src, tgt, _, _) in &edges {
236            *degree.entry(src.clone()).or_default() += 1;
237            *degree.entry(tgt.clone()).or_default() += 1;
238        }
239
240        Ok(nodes
241            .into_iter()
242            .filter(|(id, props)| {
243                let type_matches = props
244                    .get("type")
245                    .and_then(|v| v.as_str())
246                    .is_some_and(|t| t == node_type);
247                let deg = degree.get(id).copied().unwrap_or(0);
248                type_matches && deg == 1
249            })
250            .collect())
251    }
252
253    /// Return the set of all unique relationship names from edges in the graph.
254    ///
255    /// Used by orphan cleanup to determine which EdgeType nodes still have
256    /// corresponding edges. Default implementation fetches the full graph via
257    /// `get_graph_data()` and collects distinct relationship names.
258    /// Backends may override with a more efficient query.
259    async fn get_all_relationship_names(&self) -> GraphDBResult<HashSet<String>> {
260        let (_, edges) = self.get_graph_data().await?;
261        Ok(edges.into_iter().map(|(_, _, rel, _)| rel).collect())
262    }
263
264    /// Find EdgeType nodes in the graph that have zero edges (degree 0).
265    ///
266    /// Used by hard-delete orphan sweep to find EdgeType nodes whose
267    /// relationship name no longer appears in any edge.
268    ///
269    /// Default implementation fetches the full graph and filters in memory.
270    /// Backends may override with a more efficient query.
271    async fn get_zero_degree_edge_type_nodes(&self) -> GraphDBResult<Vec<crate::GraphNode>> {
272        let (nodes, edges) = self.get_graph_data().await?;
273
274        // Collect all relationship names still in use
275        let active_rel_names: HashSet<&str> =
276            edges.iter().map(|(_, _, rel, _)| rel.as_str()).collect();
277
278        // Build a degree map from edges
279        let mut degree: HashMap<String, usize> = HashMap::new();
280        for (src, tgt, _, _) in &edges {
281            *degree.entry(src.clone()).or_default() += 1;
282            *degree.entry(tgt.clone()).or_default() += 1;
283        }
284
285        Ok(nodes
286            .into_iter()
287            .filter(|(id, props)| {
288                let is_edge_type = props
289                    .get("type")
290                    .and_then(|v| v.as_str())
291                    .is_some_and(|t| t == "EdgeType");
292                if !is_edge_type {
293                    return false;
294                }
295                // Check degree is 0 (no edges at all)
296                let deg = degree.get(id).copied().unwrap_or(0);
297                if deg > 0 {
298                    return false;
299                }
300                // Also check that the relationship_name is not in any edge
301                let rel_name = props
302                    .get("relationship_name")
303                    .and_then(|v| v.as_str())
304                    .unwrap_or("");
305                !active_rel_names.contains(rel_name)
306            })
307            .collect())
308    }
309
310    /// Update a single property on a node.
311    ///
312    /// # Arguments
313    /// * `node_id` - The node identifier
314    /// * `key` - Property name
315    /// * `value` - New property value
316    ///
317    /// Default implementation fetches the node and its edges, modifies the
318    /// property, removes the old node (which may cascade-delete edges), re-adds
319    /// the node, and restores the edges. Backends should override with an
320    /// in-place `SET` operation for better performance and atomicity.
321    async fn update_node_property(
322        &self,
323        node_id: &str,
324        key: &str,
325        value: serde_json::Value,
326    ) -> GraphDBResult<()> {
327        let node = self
328            .get_node(node_id)
329            .await?
330            .ok_or_else(|| crate::GraphDBError::NodeError(format!("Node not found: {node_id}")))?;
331
332        // Save edges before deleting the node, since delete_node may cascade.
333        let edges = self.get_edges(node_id).await.unwrap_or_default();
334
335        let mut props = serde_json::Map::new();
336        for (k, v) in node {
337            props.insert(k.into_owned(), v);
338        }
339        props.insert(key.to_string(), value);
340
341        self.delete_node(node_id).await?;
342        self.add_node_raw(Value::Object(props)).await?;
343
344        // Restore edges that were removed by the cascade delete.
345        if !edges.is_empty() {
346            self.add_edges(&edges).await?;
347        }
348
349        Ok(())
350    }
351
352    /// Update a single property on an edge.
353    ///
354    /// # Arguments
355    /// * `source_id` - Source node ID
356    /// * `target_id` - Target node ID
357    /// * `relationship_name` - Edge label/relationship type
358    /// * `key` - Property name
359    /// * `value` - New property value
360    ///
361    /// Default implementation is a no-op that logs a warning. Backends that
362    /// support in-place edge property updates should override this method.
363    async fn update_edge_property(
364        &self,
365        source_id: &str,
366        target_id: &str,
367        relationship_name: &str,
368        key: &str,
369        value: serde_json::Value,
370    ) -> GraphDBResult<()> {
371        let _ = (source_id, target_id, relationship_name, key, value);
372        tracing::warn!(
373            "update_edge_property not implemented for this backend; \
374             edge {source_id} -> {target_id} ({relationship_name}) property {key} not updated"
375        );
376        Ok(())
377    }
378
379    /// Batch-fetch `feedback_weight` values for the given node IDs.
380    ///
381    /// Returns only IDs that exist and have a numeric `feedback_weight`
382    /// property. IDs missing from the graph or missing the property are
383    /// omitted from the result map.
384    ///
385    /// Default implementation calls [`get_node`] per id; backends should
386    /// override with a single batch query for efficiency.
387    async fn get_node_feedback_weights(
388        &self,
389        node_ids: &[String],
390    ) -> GraphDBResult<HashMap<String, f64>> {
391        let mut out = HashMap::with_capacity(node_ids.len());
392        for id in node_ids {
393            if let Some(node) = self.get_node(id).await?
394                && let Some(v) = node.get("feedback_weight").and_then(|v| v.as_f64())
395            {
396                out.insert(id.clone(), v);
397            }
398        }
399        Ok(out)
400    }
401
402    /// Batch-write `feedback_weight` values on the given nodes.
403    ///
404    /// Returns a map `node_id -> success` indicating whether each update
405    /// succeeded. Default implementation delegates to `update_node_property`
406    /// for each id; backends should override with a single batch query.
407    async fn set_node_feedback_weights(
408        &self,
409        updates: &HashMap<String, f64>,
410    ) -> GraphDBResult<HashMap<String, bool>> {
411        let mut out = HashMap::with_capacity(updates.len());
412        for (id, w) in updates {
413            let ok = self
414                .update_node_property(id, "feedback_weight", serde_json::json!(w))
415                .await
416                .is_ok();
417            out.insert(id.clone(), ok);
418        }
419        Ok(out)
420    }
421
422    /// Batch-fetch `feedback_weight` values for the given edges.
423    ///
424    /// Default implementation returns an empty map and logs a warning,
425    /// because the generic `GraphDBTrait` does not expose a per-edge
426    /// property read. Backends that support edge-property queries should
427    /// override this method.
428    async fn get_edge_feedback_weights(
429        &self,
430        edge_keys: &[EdgeKey],
431    ) -> GraphDBResult<HashMap<EdgeKey, f64>> {
432        if !edge_keys.is_empty() {
433            tracing::warn!(
434                "get_edge_feedback_weights not implemented for this backend; \
435                 returning empty map for {} edge(s)",
436                edge_keys.len()
437            );
438        }
439        Ok(HashMap::new())
440    }
441
442    /// Batch-write `feedback_weight` values on the given edges.
443    ///
444    /// Default implementation delegates to [`update_edge_property`] per
445    /// edge. Backends with no edge-update support will silently succeed
446    /// (because the default `update_edge_property` returns `Ok(())` with
447    /// a warning).
448    async fn set_edge_feedback_weights(
449        &self,
450        updates: &HashMap<EdgeKey, f64>,
451    ) -> GraphDBResult<HashMap<EdgeKey, bool>> {
452        let mut out = HashMap::with_capacity(updates.len());
453        for (key, w) in updates {
454            let ok = self
455                .update_edge_property(
456                    &key.0,
457                    &key.1,
458                    &key.2,
459                    "feedback_weight",
460                    serde_json::json!(w),
461                )
462                .await
463                .is_ok();
464            out.insert(key.clone(), ok);
465        }
466        Ok(out)
467    }
468
469    /// Retrieve a subgraph containing only the specified nodes and edges between them.
470    ///
471    /// Default implementation fetches the full graph and filters in memory.
472    /// Backends may override this with a more efficient query.
473    async fn get_id_filtered_graph_data(
474        &self,
475        node_ids: &[String],
476    ) -> GraphDBResult<(Vec<GraphNode>, Vec<EdgeData>)> {
477        if node_ids.is_empty() {
478            return Ok((vec![], vec![]));
479        }
480        let (all_nodes, all_edges) = self.get_graph_data().await?;
481        let id_set: std::collections::HashSet<&str> = node_ids.iter().map(String::as_str).collect();
482
483        let filtered_nodes: Vec<GraphNode> = all_nodes
484            .into_iter()
485            .filter(|(id, _)| id_set.contains(id.as_str()))
486            .collect();
487
488        let filtered_edges: Vec<EdgeData> = all_edges
489            .into_iter()
490            .filter(|(src, tgt, _, _)| {
491                id_set.contains(src.as_str()) && id_set.contains(tgt.as_str())
492            })
493            .collect();
494
495        Ok((filtered_nodes, filtered_edges))
496    }
497}
498
499/// Extension trait providing generic convenience methods on top of [`GraphDBTrait`].
500/// Auto-implemented for all types that implement `GraphDBTrait`.
501#[async_trait]
502pub trait GraphDBTraitExt: GraphDBTrait {
503    /// Add a single node to the graph.
504    async fn add_node<T: Serialize + Sync>(&self, node: &T) -> GraphDBResult<()> {
505        let value = serde_json::to_value(node).map_err(|e| {
506            crate::GraphDBError::QueryError(format!("Failed to serialize node: {e}"))
507        })?;
508        self.add_node_raw(value).await
509    }
510
511    /// Add multiple nodes in a batch operation.
512    async fn add_nodes<T: Serialize + Sync>(&self, nodes: &[&T]) -> GraphDBResult<()> {
513        let values: Vec<Value> = nodes
514            .iter()
515            .map(serde_json::to_value)
516            .collect::<Result<_, _>>()
517            .map_err(|e| {
518                crate::GraphDBError::QueryError(format!("Failed to serialize nodes: {e}"))
519            })?;
520        self.add_nodes_raw(values).await
521    }
522}
523
524impl<T: GraphDBTrait + ?Sized> GraphDBTraitExt for T {}