Skip to main content

cognee_cognify/memify/
sync_graph_session.rs

1//! Stage 4 of `improve()` — incrementally sync recent graph edges into a
2//! session's `graph_context`.
3//!
4//! Ported from `/tmp/cognee-python/cognee/tasks/memify/sync_graph_to_session.py`.
5//!
6//! Behavior:
7//! 1. Load the checkpoint timestamp for
8//!    `graph_sync_checkpoint:{user_id}:{dataset_id}:{session_id}`.
9//! 2. Paginate through edges in the relational DB with `created_at > since`,
10//!    in batches of [`BATCH_SIZE`]. Stop when a partial batch is returned.
11//! 3. Resolve edge endpoints to node records so each line can include
12//!    `label`/`type`/`description`.
13//! 4. Emit each edge as a JSON-line
14//!    `{"source": ..., "relationship": ..., "target": ...}`.
15//! 5. Merge with existing `graph_context` and cap at [`DEFAULT_MAX_LINES`]
16//!    lines (drop oldest when over).
17//! 6. Persist the merged context and advance the checkpoint.
18
19use std::collections::HashMap;
20
21use chrono::{DateTime, Utc};
22use cognee_database::ops::graph_storage::{get_edges_since, get_nodes_by_ids};
23use cognee_database::uuid_hex;
24use cognee_database::{CheckpointStore, DatabaseConnection, DatabaseError, GraphEdge, GraphNode};
25use cognee_session::SessionManager;
26use thiserror::Error;
27use tracing::info;
28use uuid::Uuid;
29
30/// Pagination batch size. Matches Python `sync_graph_to_session.py:34`.
31pub const BATCH_SIZE: u64 = 500;
32
33/// Default cap on the number of JSON-lines stored in a session's graph
34/// context. Matches Python `DEFAULT_MAX_LINES = 500`.
35pub const DEFAULT_MAX_LINES: usize = 500;
36
37/// Error type for Stage 4 (`sync_graph_to_session`).
38#[derive(Debug, Error)]
39pub enum SyncError {
40    #[error("Database error: {0}")]
41    Database(#[from] DatabaseError),
42
43    #[error("Session error: {0}")]
44    Session(#[from] cognee_session::SessionError),
45}
46
47/// Summary of a Stage 4 run.
48#[derive(Debug, Clone, Default)]
49pub struct SyncResult {
50    /// Number of newly synced edges.
51    pub synced: usize,
52    /// Total number of JSON-lines in the session graph_context after merge.
53    pub total: usize,
54}
55
56/// Build the checkpoint key. Matches Python `_checkpoint_key()`.
57pub fn checkpoint_key(user_id: &str, dataset_id: Uuid, session_id: &str) -> String {
58    format!("graph_sync_checkpoint:{user_id}:{dataset_id}:{session_id}")
59}
60
61/// Render an edge as a JSON-line using node metadata from `node_map`.
62///
63/// Returns `None` if either endpoint is missing from `node_map`. Matches
64/// Python `_edge_to_text()` semantics (keys: `label`, `type`, `description`).
65pub fn edge_to_json_line(edge: &GraphEdge, node_map: &HashMap<Uuid, GraphNode>) -> Option<String> {
66    let src = node_map.get(&edge.source_node_id)?;
67    let dst = node_map.get(&edge.destination_node_id)?;
68
69    let mut src_obj = serde_json::Map::new();
70    src_obj.insert(
71        "label".to_string(),
72        serde_json::Value::String(src.label.clone().unwrap_or_else(|| {
73            if !src.node_type.is_empty() {
74                src.node_type.clone()
75            } else {
76                src.id.to_string()
77            }
78        })),
79    );
80    if !src.node_type.is_empty() {
81        src_obj.insert(
82            "type".to_string(),
83            serde_json::Value::String(src.node_type.clone()),
84        );
85    }
86    if let Some(attrs) = src
87        .attributes
88        .as_ref()
89        .and_then(|v| v.get("description"))
90        .and_then(|v| v.as_str())
91    {
92        src_obj.insert(
93            "description".to_string(),
94            serde_json::Value::String(attrs.to_string()),
95        );
96    }
97
98    let mut dst_obj = serde_json::Map::new();
99    dst_obj.insert(
100        "label".to_string(),
101        serde_json::Value::String(dst.label.clone().unwrap_or_else(|| {
102            if !dst.node_type.is_empty() {
103                dst.node_type.clone()
104            } else {
105                dst.id.to_string()
106            }
107        })),
108    );
109    if !dst.node_type.is_empty() {
110        dst_obj.insert(
111            "type".to_string(),
112            serde_json::Value::String(dst.node_type.clone()),
113        );
114    }
115    if let Some(attrs) = dst
116        .attributes
117        .as_ref()
118        .and_then(|v| v.get("description"))
119        .and_then(|v| v.as_str())
120    {
121        dst_obj.insert(
122            "description".to_string(),
123            serde_json::Value::String(attrs.to_string()),
124        );
125    }
126
127    let relationship = if edge.relationship_name.is_empty() {
128        "related_to".to_string()
129    } else {
130        edge.relationship_name.clone()
131    };
132
133    let mut line = serde_json::Map::new();
134    line.insert("source".to_string(), serde_json::Value::Object(src_obj));
135    line.insert(
136        "relationship".to_string(),
137        serde_json::Value::String(relationship),
138    );
139    line.insert("target".to_string(), serde_json::Value::Object(dst_obj));
140    Some(serde_json::Value::Object(line).to_string())
141}
142
143/// Sync graph→session for a single session.
144#[allow(clippy::too_many_arguments)]
145pub async fn sync_graph_to_session(
146    user_id: &str,
147    session_id: &str,
148    dataset_id: Uuid,
149    db: &DatabaseConnection,
150    session_manager: &SessionManager,
151    checkpoint_store: &dyn CheckpointStore,
152    max_lines: usize,
153) -> Result<SyncResult, SyncError> {
154    let ck = checkpoint_key(user_id, dataset_id, session_id);
155    let since: Option<DateTime<Utc>> = checkpoint_store.load(&ck).await?;
156
157    let mut new_lines: Vec<String> = Vec::new();
158    let mut latest: Option<DateTime<Utc>> = since;
159
160    loop {
161        let edges = get_edges_since(db, dataset_id, latest, BATCH_SIZE).await?;
162        if edges.is_empty() {
163            break;
164        }
165
166        // Collect endpoint hex ids for batch node fetch.
167        let mut id_hex_set: std::collections::HashSet<String> = std::collections::HashSet::new();
168        for e in &edges {
169            id_hex_set.insert(uuid_hex::to_hex(e.source_node_id));
170            id_hex_set.insert(uuid_hex::to_hex(e.destination_node_id));
171        }
172        let id_hex_vec: Vec<String> = id_hex_set.into_iter().collect();
173        let nodes = get_nodes_by_ids(db, &id_hex_vec).await?;
174        let node_map: HashMap<Uuid, GraphNode> = nodes.into_iter().map(|n| (n.id, n)).collect();
175
176        for e in &edges {
177            if let Some(line) = edge_to_json_line(e, &node_map) {
178                new_lines.push(line);
179            }
180            if latest.map(|t| e.created_at > t).unwrap_or(true) {
181                latest = Some(e.created_at);
182            }
183        }
184
185        if (edges.len() as u64) < BATCH_SIZE {
186            break;
187        }
188    }
189
190    if new_lines.is_empty() {
191        info!(
192            session_id = session_id,
193            "sync_graph_to_session: no new edges"
194        );
195        return Ok(SyncResult::default());
196    }
197
198    let existing = session_manager
199        .get_graph_context(Some(session_id), Some(user_id))
200        .await?;
201    let mut merged: Vec<String> = existing
202        .as_deref()
203        .map(|s| {
204            s.split('\n')
205                .filter(|l| !l.is_empty())
206                .map(|s| s.to_string())
207                .collect()
208        })
209        .unwrap_or_default();
210    merged.extend(new_lines.iter().cloned());
211    if merged.len() > max_lines {
212        let drop = merged.len() - max_lines;
213        info!(
214            session_id = session_id,
215            dropped = drop,
216            cap = max_lines,
217            "sync_graph_to_session: capping, dropping oldest"
218        );
219        merged.drain(0..drop);
220    }
221
222    let merged_str = merged.join("\n");
223    session_manager
224        .set_graph_context(Some(session_id), Some(user_id), &merged_str)
225        .await?;
226
227    if let Some(ts) = latest
228        && Some(ts) != since
229    {
230        checkpoint_store.save(&ck, ts).await?;
231    }
232
233    info!(
234        session_id = session_id,
235        synced = new_lines.len(),
236        total = merged.len(),
237        max_lines = max_lines,
238        "sync_graph_to_session: complete"
239    );
240
241    Ok(SyncResult {
242        synced: new_lines.len(),
243        total: merged.len(),
244    })
245}
246
247#[cfg(test)]
248#[allow(
249    clippy::unwrap_used,
250    clippy::expect_used,
251    reason = "test code — panics are acceptable failures"
252)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn checkpoint_key_format() {
258        let u = Uuid::nil();
259        let k = checkpoint_key("user-1", u, "sess-1");
260        assert_eq!(k, format!("graph_sync_checkpoint:user-1:{u}:sess-1"));
261    }
262
263    #[test]
264    fn edge_to_json_line_full() {
265        let src_id = Uuid::new_v4();
266        let dst_id = Uuid::new_v4();
267        let edge = GraphEdge {
268            id: Uuid::new_v4(),
269            slug: Uuid::new_v4(),
270            user_id: Uuid::new_v4(),
271            data_id: Uuid::new_v4(),
272            dataset_id: Uuid::new_v4(),
273            source_node_id: src_id,
274            destination_node_id: dst_id,
275            relationship_name: "knows".to_string(),
276            label: None,
277            attributes: None,
278            created_at: chrono::Utc::now(),
279        };
280        let mut node_map = HashMap::new();
281        node_map.insert(
282            src_id,
283            GraphNode {
284                id: src_id,
285                slug: Uuid::new_v4(),
286                user_id: Uuid::new_v4(),
287                data_id: Uuid::new_v4(),
288                dataset_id: Uuid::new_v4(),
289                label: Some("Alice".to_string()),
290                node_type: "Person".to_string(),
291                indexed_fields: serde_json::json!({}),
292                attributes: Some(serde_json::json!({"description": "An engineer"})),
293                created_at: chrono::Utc::now(),
294            },
295        );
296        node_map.insert(
297            dst_id,
298            GraphNode {
299                id: dst_id,
300                slug: Uuid::new_v4(),
301                user_id: Uuid::new_v4(),
302                data_id: Uuid::new_v4(),
303                dataset_id: Uuid::new_v4(),
304                label: Some("Bob".to_string()),
305                node_type: "Person".to_string(),
306                indexed_fields: serde_json::json!({}),
307                attributes: None,
308                created_at: chrono::Utc::now(),
309            },
310        );
311        let line = edge_to_json_line(&edge, &node_map).unwrap();
312        let parsed: serde_json::Value = serde_json::from_str(&line).unwrap();
313        assert_eq!(parsed["relationship"], serde_json::json!("knows"));
314        assert_eq!(parsed["source"]["label"], serde_json::json!("Alice"));
315        assert_eq!(parsed["source"]["type"], serde_json::json!("Person"));
316        assert_eq!(
317            parsed["source"]["description"],
318            serde_json::json!("An engineer")
319        );
320        assert_eq!(parsed["target"]["label"], serde_json::json!("Bob"));
321        // dst has no description
322        assert!(parsed["target"].get("description").is_none());
323    }
324
325    #[test]
326    fn edge_to_json_line_missing_endpoint() {
327        let src_id = Uuid::new_v4();
328        let dst_id = Uuid::new_v4();
329        let edge = GraphEdge {
330            id: Uuid::new_v4(),
331            slug: Uuid::new_v4(),
332            user_id: Uuid::new_v4(),
333            data_id: Uuid::new_v4(),
334            dataset_id: Uuid::new_v4(),
335            source_node_id: src_id,
336            destination_node_id: dst_id,
337            relationship_name: "r".to_string(),
338            label: None,
339            attributes: None,
340            created_at: chrono::Utc::now(),
341        };
342        let empty = HashMap::new();
343        assert!(edge_to_json_line(&edge, &empty).is_none());
344    }
345}