Skip to main content

mnemo_core/query/
checkpoint.rs

1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::error::Result;
5use crate::model::checkpoint::Checkpoint;
6use crate::model::event::EventType;
7use crate::query::MnemoEngine;
8use crate::storage::MemoryFilter;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub struct CheckpointRequest {
12    pub thread_id: String,
13    pub agent_id: Option<String>,
14    pub branch_name: Option<String>,
15    pub state_snapshot: serde_json::Value,
16    pub label: Option<String>,
17    pub metadata: Option<serde_json::Value>,
18}
19
20impl CheckpointRequest {
21    pub fn new(thread_id: String, state_snapshot: serde_json::Value) -> Self {
22        Self {
23            thread_id,
24            agent_id: None,
25            branch_name: None,
26            state_snapshot,
27            label: None,
28            metadata: None,
29        }
30    }
31}
32
33#[non_exhaustive]
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CheckpointResponse {
36    pub id: Uuid,
37    pub parent_id: Option<Uuid>,
38    pub branch_name: String,
39}
40
41impl CheckpointResponse {
42    pub fn new(id: Uuid, parent_id: Option<Uuid>, branch_name: String) -> Self {
43        Self {
44            id,
45            parent_id,
46            branch_name,
47        }
48    }
49}
50
51pub async fn execute(
52    engine: &MnemoEngine,
53    request: CheckpointRequest,
54) -> Result<CheckpointResponse> {
55    let agent_id = request
56        .agent_id
57        .unwrap_or_else(|| engine.default_agent_id.clone());
58    let branch_name = request.branch_name.unwrap_or_else(|| "main".to_string());
59    let now = chrono::Utc::now().to_rfc3339();
60
61    // Get latest checkpoint on this branch as parent
62    let parent = engine
63        .storage
64        .get_latest_checkpoint(&request.thread_id, &branch_name)
65        .await?;
66
67    let parent_id = parent.as_ref().map(|p| p.id);
68
69    // Compute state_diff from parent
70    let state_diff = parent.as_ref().map(|p| {
71        serde_json::json!({
72            "from": p.state_snapshot,
73            "to": request.state_snapshot,
74        })
75    });
76
77    // Collect memory_refs — active memories for this agent
78    let filter = MemoryFilter {
79        agent_id: Some(agent_id.clone()),
80        ..Default::default()
81    };
82    let memories = engine.storage.list_memories(&filter, 1000, 0).await?;
83    let memory_refs: Vec<Uuid> = memories.iter().map(|m| m.id).collect();
84
85    // Get latest event as cursor
86    let events = engine.storage.list_events(&agent_id, 1, 0).await?;
87    let event_cursor = events.first().map(|e| e.id);
88
89    let id = Uuid::now_v7();
90    let cp = Checkpoint {
91        id,
92        thread_id: request.thread_id.clone(),
93        agent_id: agent_id.clone(),
94        parent_id,
95        branch_name: branch_name.clone(),
96        state_snapshot: request.state_snapshot,
97        state_diff,
98        memory_refs,
99        event_cursor,
100        label: request.label,
101        created_at: now.clone(),
102        metadata: request
103            .metadata
104            .unwrap_or(serde_json::Value::Object(serde_json::Map::new())),
105    };
106
107    engine.storage.insert_checkpoint(&cp).await?;
108
109    // Emit Checkpoint event
110    let event = super::event_builder::build_event(
111        engine,
112        &agent_id,
113        EventType::Checkpoint,
114        serde_json::json!({"checkpoint_id": id.to_string(), "branch": branch_name}),
115        &id.to_string(),
116        Some(request.thread_id),
117    )
118    .await;
119    if let Err(e) = engine.storage.insert_event(&event).await {
120        tracing::error!(event_id = %event.id, error = %e, "failed to insert audit event");
121    }
122
123    Ok(CheckpointResponse {
124        id,
125        parent_id,
126        branch_name,
127    })
128}