Skip to main content

mnemo_core/query/
replay.rs

1use serde::{Deserialize, Serialize};
2use uuid::Uuid;
3
4use crate::error::{Error, Result};
5use crate::hash::{ChainVerificationResult, verify_chain};
6use crate::model::checkpoint::Checkpoint;
7use crate::model::event::AgentEvent;
8use crate::model::memory::MemoryRecord;
9use crate::query::MnemoEngine;
10use crate::storage::MemoryFilter;
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ReplayRequest {
14    pub thread_id: String,
15    pub agent_id: Option<String>,
16    pub checkpoint_id: Option<Uuid>,
17    pub branch_name: Option<String>,
18    /// Synthesize a virtual checkpoint from the memories and events that
19    /// existed at this RFC3339 timestamp. When set, `checkpoint_id` and
20    /// `branch_name` are ignored.
21    pub as_of: Option<String>,
22}
23
24impl ReplayRequest {
25    pub fn new(thread_id: String) -> Self {
26        Self {
27            thread_id,
28            agent_id: None,
29            checkpoint_id: None,
30            branch_name: None,
31            as_of: None,
32        }
33    }
34}
35
36#[non_exhaustive]
37#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReplayResponse {
39    pub checkpoint: Checkpoint,
40    pub memories: Vec<MemoryRecord>,
41    pub events: Vec<AgentEvent>,
42    pub chain_verification: Option<ChainVerificationResult>,
43}
44
45impl ReplayResponse {
46    pub fn new(
47        checkpoint: Checkpoint,
48        memories: Vec<MemoryRecord>,
49        events: Vec<AgentEvent>,
50        chain_verification: Option<ChainVerificationResult>,
51    ) -> Self {
52        Self {
53            checkpoint,
54            memories,
55            events,
56            chain_verification,
57        }
58    }
59}
60
61pub async fn execute(engine: &MnemoEngine, request: ReplayRequest) -> Result<ReplayResponse> {
62    // Time-travel path: synthesize a virtual checkpoint at `as_of`.
63    if let Some(ref as_of) = request.as_of {
64        return replay_as_of(engine, &request, as_of).await;
65    }
66
67    let branch = request.branch_name.as_deref().unwrap_or("main");
68
69    // Get checkpoint (specified or latest)
70    let checkpoint = if let Some(cp_id) = request.checkpoint_id {
71        engine
72            .storage
73            .get_checkpoint(cp_id)
74            .await?
75            .ok_or_else(|| Error::NotFound(format!("checkpoint {cp_id} not found")))?
76    } else {
77        engine
78            .storage
79            .get_latest_checkpoint(&request.thread_id, branch)
80            .await?
81            .ok_or_else(|| {
82                Error::NotFound(format!(
83                    "no checkpoint found on branch '{branch}' for thread '{}'",
84                    request.thread_id
85                ))
86            })?
87    };
88
89    // Load memories referenced by checkpoint.memory_refs
90    let mut memories = Vec::new();
91    for mem_id in &checkpoint.memory_refs {
92        if let Some(record) = engine.storage.get_memory(*mem_id).await? {
93            memories.push(record);
94        }
95    }
96
97    // Verify hash chain integrity on loaded memories
98    let chain_verification = Some(verify_chain(&memories));
99
100    // Load events up to checkpoint.event_cursor (or all thread events if no cursor)
101    let events = engine
102        .storage
103        .get_events_by_thread(&checkpoint.thread_id, 1000)
104        .await?;
105
106    let events = if let Some(cursor_id) = checkpoint.event_cursor {
107        // Return events up to and including the cursor
108        let mut filtered = Vec::new();
109        for event in events {
110            filtered.push(event.clone());
111            if event.id == cursor_id {
112                break;
113            }
114        }
115        filtered
116    } else {
117        events
118    };
119
120    Ok(ReplayResponse {
121        checkpoint,
122        memories,
123        events,
124        chain_verification,
125    })
126}
127
128/// Build a synthetic `Checkpoint` that describes agent state as it existed at
129/// `as_of_str` — every memory created at or before that instant, excluding
130/// memories already deleted. Events are filtered by timestamp identically so
131/// the returned `ReplayResponse` looks like a real checkpoint from that time.
132async fn replay_as_of(
133    engine: &MnemoEngine,
134    request: &ReplayRequest,
135    as_of_str: &str,
136) -> Result<ReplayResponse> {
137    let as_of = chrono::DateTime::parse_from_rfc3339(as_of_str)
138        .map_err(|e| Error::Validation(format!("invalid as_of timestamp '{as_of_str}': {e}")))?
139        .with_timezone(&chrono::Utc);
140
141    let agent_id = request
142        .agent_id
143        .clone()
144        .unwrap_or_else(|| engine.default_agent_id.clone());
145    super::validate_agent_id(&agent_id)?;
146
147    // Pull all memories for the agent (including soft-deleted ones, so we can
148    // decide per-record whether they existed at `as_of`).
149    let filter = MemoryFilter {
150        agent_id: Some(agent_id.clone()),
151        thread_id: Some(request.thread_id.clone()),
152        include_deleted: true,
153        ..Default::default()
154    };
155    let candidates = engine
156        .storage
157        .list_memories(&filter, super::MAX_BATCH_QUERY_LIMIT, 0)
158        .await?;
159
160    let mut memories: Vec<MemoryRecord> = Vec::new();
161    for record in candidates {
162        let Ok(created) = chrono::DateTime::parse_from_rfc3339(&record.created_at) else {
163            continue;
164        };
165        if created.with_timezone(&chrono::Utc) > as_of {
166            continue;
167        }
168        if let Some(ref deleted_at) = record.deleted_at
169            && let Ok(del) = chrono::DateTime::parse_from_rfc3339(deleted_at)
170            && del.with_timezone(&chrono::Utc) <= as_of
171        {
172            continue;
173        }
174        memories.push(record);
175    }
176
177    let chain_verification = Some(verify_chain(&memories));
178
179    let all_events = engine
180        .storage
181        .get_events_by_thread(&request.thread_id, super::MAX_BATCH_QUERY_LIMIT)
182        .await?;
183    let events: Vec<AgentEvent> = all_events
184        .into_iter()
185        .filter(|e| {
186            chrono::DateTime::parse_from_rfc3339(&e.timestamp)
187                .map(|ts| ts.with_timezone(&chrono::Utc) <= as_of)
188                .unwrap_or(false)
189        })
190        .collect();
191
192    let memory_refs: Vec<Uuid> = memories.iter().map(|m| m.id).collect();
193
194    let virtual_checkpoint = Checkpoint {
195        id: Uuid::nil(),
196        thread_id: request.thread_id.clone(),
197        agent_id,
198        parent_id: None,
199        branch_name: request
200            .branch_name
201            .clone()
202            .unwrap_or_else(|| "main".to_string()),
203        state_snapshot: serde_json::json!({
204            "as_of": as_of_str,
205            "virtual": true,
206        }),
207        state_diff: None,
208        memory_refs,
209        event_cursor: events.last().map(|e| e.id),
210        label: Some(format!("virtual@{as_of_str}")),
211        created_at: as_of_str.to_string(),
212        metadata: serde_json::json!({"synthesized": true}),
213    };
214
215    Ok(ReplayResponse {
216        checkpoint: virtual_checkpoint,
217        memories,
218        events,
219        chain_verification,
220    })
221}