1use std::time::SystemTime;
8
9use agent_memory::Journal;
10use agent_primitives::AgentId;
11use bytes::Bytes;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15use tracing::{debug, warn};
16
17use crate::lifecycle::AgentState;
18
19#[derive(Debug, Error)]
21pub enum RecoveryError {
22 #[error("journal error: {0}")]
24 JournalError(String),
25
26 #[error("serialization error: {0}")]
28 SerializationError(#[from] serde_json::Error),
29
30 #[error("no checkpoint found")]
32 NoCheckpoint,
33
34 #[error("checkpoint corrupted: {0}")]
36 CorruptedCheckpoint(String),
37
38 #[error("io error: {0}")]
40 IoError(#[from] std::io::Error),
41}
42
43pub type RecoveryResult<T> = Result<T, RecoveryError>;
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StateCheckpoint {
49 pub agent_id: AgentId,
51
52 pub agent_state: AgentState,
54
55 pub timestamp: SystemTime,
57
58 pub app_state: Value,
60
61 pub metadata: CheckpointMetadata,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CheckpointMetadata {
68 pub version: u32,
70
71 pub in_flight_tasks: u32,
73
74 pub reason: Option<String>,
76
77 #[serde(default)]
79 pub custom: Value,
80}
81
82impl Default for CheckpointMetadata {
83 fn default() -> Self {
84 Self {
85 version: 1,
86 in_flight_tasks: 0,
87 reason: None,
88 custom: Value::Null,
89 }
90 }
91}
92
93pub struct StateRecovery {
95 journal: std::sync::Arc<dyn Journal>,
96 checkpoint_tag: String,
97}
98
99impl std::fmt::Debug for StateRecovery {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("StateRecovery")
102 .field("checkpoint_tag", &self.checkpoint_tag)
103 .finish_non_exhaustive()
104 }
105}
106
107impl StateRecovery {
108 pub fn new(journal: std::sync::Arc<dyn Journal>) -> Self {
110 Self {
111 journal,
112 checkpoint_tag: "checkpoint".to_string(),
113 }
114 }
115
116 pub async fn persist_checkpoint(&self, checkpoint: &StateCheckpoint) -> RecoveryResult<()> {
126 let payload = serde_json::to_vec(checkpoint)?;
127 let record = agent_memory::MemoryRecord::builder(
128 agent_memory::MemoryChannel::System,
129 Bytes::from(payload),
130 )
131 .tag(&self.checkpoint_tag)
132 .map_err(|e| RecoveryError::JournalError(e.to_string()))?
133 .metadata("checkpoint_version", json!(checkpoint.metadata.version))
134 .metadata("agent_id", json!(checkpoint.agent_id.to_string()))
135 .metadata(
136 "agent_state",
137 json!(format!("{:?}", checkpoint.agent_state)),
138 )
139 .build()
140 .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
141
142 self.journal
143 .append(&record)
144 .await
145 .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
146
147 debug!(
148 agent_id = %checkpoint.agent_id,
149 state = ?checkpoint.agent_state,
150 "checkpoint persisted"
151 );
152
153 Ok(())
154 }
155
156 pub async fn recover_checkpoint(&self) -> RecoveryResult<StateCheckpoint> {
163 let records = self
165 .journal
166 .tail(100)
167 .await
168 .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
169
170 for record in records.iter().rev() {
172 if record.tags().contains(&self.checkpoint_tag) {
173 let checkpoint: StateCheckpoint = serde_json::from_slice(record.payload())
174 .map_err(|e| {
175 RecoveryError::CorruptedCheckpoint(format!("deserialization failed: {e}"))
176 })?;
177
178 debug!(
179 agent_id = %checkpoint.agent_id,
180 state = ?checkpoint.agent_state,
181 "checkpoint recovered"
182 );
183
184 return Ok(checkpoint);
185 }
186 }
187
188 warn!("no checkpoint found in journal");
189 Err(RecoveryError::NoCheckpoint)
190 }
191
192 #[must_use]
194 pub fn create_checkpoint(
195 agent_id: AgentId,
196 agent_state: AgentState,
197 app_state: Value,
198 in_flight_tasks: u32,
199 ) -> StateCheckpoint {
200 StateCheckpoint {
201 agent_id,
202 agent_state,
203 timestamp: SystemTime::now(),
204 app_state,
205 metadata: CheckpointMetadata {
206 version: 1,
207 in_flight_tasks,
208 reason: Some("shutdown".to_string()),
209 custom: Value::Null,
210 },
211 }
212 }
213}
214
215#[cfg(test)]
216mod tests {
217 use super::*;
218 use agent_memory::FileJournal;
219 use std::sync::Arc;
220
221 fn temp_path() -> std::path::PathBuf {
222 let mut path = std::env::temp_dir();
223 path.push(format!(
224 "recovery-test-{}-{}.log",
225 std::process::id(),
226 std::time::SystemTime::now()
227 .duration_since(std::time::UNIX_EPOCH)
228 .unwrap()
229 .as_nanos()
230 ));
231 path
232 }
233
234 #[tokio::test]
235 async fn checkpoint_roundtrip() {
236 let path = temp_path();
237 let journal = Arc::new(FileJournal::open(&path).await.unwrap());
238 let recovery = StateRecovery::new(journal);
239
240 let agent_id = AgentId::random();
241 let checkpoint = StateCheckpoint {
242 agent_id,
243 agent_state: crate::lifecycle::AgentState::Active,
244 timestamp: SystemTime::now(),
245 app_state: json!({"key": "value"}),
246 metadata: CheckpointMetadata {
247 version: 1,
248 in_flight_tasks: 5,
249 reason: Some("shutdown".to_string()),
250 custom: Value::Null,
251 },
252 };
253
254 recovery.persist_checkpoint(&checkpoint).await.unwrap();
255 let recovered = recovery.recover_checkpoint().await.unwrap();
256
257 assert_eq!(recovered.agent_id, checkpoint.agent_id);
258 assert_eq!(recovered.agent_state, checkpoint.agent_state);
259 assert_eq!(recovered.app_state, checkpoint.app_state);
260 assert_eq!(recovered.metadata.in_flight_tasks, 5);
261
262 if path.exists() {
263 let _ = std::fs::remove_file(path);
264 }
265 }
266
267 #[tokio::test]
268 async fn no_checkpoint_returns_error() {
269 let path = temp_path();
270 let journal = Arc::new(FileJournal::open(&path).await.unwrap());
271 let recovery = StateRecovery::new(journal);
272
273 let result = recovery.recover_checkpoint().await;
274 assert!(matches!(result, Err(RecoveryError::NoCheckpoint)));
275
276 if path.exists() {
277 let _ = std::fs::remove_file(path);
278 }
279 }
280
281 #[tokio::test]
282 async fn multiple_checkpoints_recovers_latest() {
283 let path = temp_path();
284 let journal = Arc::new(FileJournal::open(&path).await.unwrap());
285 let recovery = StateRecovery::new(journal);
286
287 let agent_id = AgentId::random();
288
289 let checkpoint1 = StateCheckpoint {
291 agent_id,
292 agent_state: crate::lifecycle::AgentState::Ready,
293 timestamp: SystemTime::now(),
294 app_state: json!({"version": 1}),
295 metadata: CheckpointMetadata::default(),
296 };
297 recovery.persist_checkpoint(&checkpoint1).await.unwrap();
298
299 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
301
302 let checkpoint2 = StateCheckpoint {
304 agent_id,
305 agent_state: crate::lifecycle::AgentState::Active,
306 timestamp: SystemTime::now(),
307 app_state: json!({"version": 2}),
308 metadata: CheckpointMetadata::default(),
309 };
310 recovery.persist_checkpoint(&checkpoint2).await.unwrap();
311
312 let recovered = recovery.recover_checkpoint().await.unwrap();
314 assert_eq!(recovered.app_state, json!({"version": 2}));
315
316 if path.exists() {
317 let _ = std::fs::remove_file(path);
318 }
319 }
320}