1use std::time::SystemTime;
8
9use agent_memory::Journal;
10use agent_primitives::AgentId;
11use bytes::Bytes;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15use tracing::{debug, warn};
16
17use crate::lifecycle::AgentState;
18
19#[derive(Debug, Error)]
21pub enum RecoveryError {
22 #[error("journal error: {0}")]
24 JournalError(String),
25
26 #[error("serialization error: {0}")]
28 SerializationError(#[from] serde_json::Error),
29
30 #[error("no checkpoint found")]
32 NoCheckpoint,
33
34 #[error("checkpoint corrupted: {0}")]
36 CorruptedCheckpoint(String),
37
38 #[error("io error: {0}")]
40 IoError(#[from] std::io::Error),
41}
42
43pub type RecoveryResult<T> = Result<T, RecoveryError>;
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StateCheckpoint {
49 pub agent_id: AgentId,
51
52 pub agent_state: AgentState,
54
55 pub timestamp: SystemTime,
57
58 pub app_state: Value,
60
61 pub metadata: CheckpointMetadata,
63}
64
65#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CheckpointMetadata {
68 pub version: u32,
70
71 pub in_flight_tasks: u32,
73
74 pub reason: Option<String>,
76
77 #[serde(default)]
79 pub custom: Value,
80}
81
82impl Default for CheckpointMetadata {
83 fn default() -> Self {
84 Self {
85 version: 1,
86 in_flight_tasks: 0,
87 reason: None,
88 custom: Value::Null,
89 }
90 }
91}
92
93pub struct StateRecovery {
95 journal: std::sync::Arc<dyn Journal>,
96 checkpoint_tag: String,
97}
98
99impl std::fmt::Debug for StateRecovery {
100 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101 f.debug_struct("StateRecovery")
102 .field("checkpoint_tag", &self.checkpoint_tag)
103 .finish()
104 }
105}
106
107impl StateRecovery {
108 pub fn new(journal: std::sync::Arc<dyn Journal>) -> Self {
110 Self {
111 journal,
112 checkpoint_tag: "checkpoint".to_string(),
113 }
114 }
115
116 pub async fn persist_checkpoint(&self, checkpoint: &StateCheckpoint) -> RecoveryResult<()> {
126 let payload = serde_json::to_vec(checkpoint)?;
127 let record = agent_memory::MemoryRecord::builder(
128 agent_memory::MemoryChannel::System,
129 Bytes::from(payload),
130 )
131 .tag(&self.checkpoint_tag)
132 .map_err(|e| RecoveryError::JournalError(e.to_string()))?
133 .metadata("checkpoint_version", json!(checkpoint.metadata.version))
134 .metadata("agent_id", json!(checkpoint.agent_id.to_string()))
135 .metadata(
136 "agent_state",
137 json!(format!("{:?}", checkpoint.agent_state)),
138 )
139 .build()
140 .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
141
142 self.journal
143 .append(&record)
144 .await
145 .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
146
147 debug!(
148 agent_id = %checkpoint.agent_id,
149 state = ?checkpoint.agent_state,
150 "checkpoint persisted"
151 );
152
153 Ok(())
154 }
155
156 pub async fn recover_checkpoint(&self) -> RecoveryResult<StateCheckpoint> {
163 let records = self
165 .journal
166 .tail(100)
167 .await
168 .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
169
170 for record in records.iter().rev() {
172 if record.tags().contains(&self.checkpoint_tag) {
173 let checkpoint: StateCheckpoint = serde_json::from_slice(record.payload())
174 .map_err(|e| {
175 RecoveryError::CorruptedCheckpoint(format!("deserialization failed: {}", e))
176 })?;
177
178 debug!(
179 agent_id = %checkpoint.agent_id,
180 state = ?checkpoint.agent_state,
181 "checkpoint recovered"
182 );
183
184 return Ok(checkpoint);
185 }
186 }
187
188 warn!("no checkpoint found in journal");
189 Err(RecoveryError::NoCheckpoint)
190 }
191
192 pub fn create_checkpoint(
194 agent_id: AgentId,
195 agent_state: AgentState,
196 app_state: Value,
197 in_flight_tasks: u32,
198 ) -> StateCheckpoint {
199 StateCheckpoint {
200 agent_id,
201 agent_state,
202 timestamp: SystemTime::now(),
203 app_state,
204 metadata: CheckpointMetadata {
205 version: 1,
206 in_flight_tasks,
207 reason: Some("shutdown".to_string()),
208 custom: Value::Null,
209 },
210 }
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use agent_memory::FileJournal;
218 use std::sync::Arc;
219
220 fn temp_path() -> std::path::PathBuf {
221 let mut path = std::env::temp_dir();
222 path.push(format!(
223 "recovery-test-{}-{}.log",
224 std::process::id(),
225 std::time::SystemTime::now()
226 .duration_since(std::time::UNIX_EPOCH)
227 .unwrap()
228 .as_nanos()
229 ));
230 path
231 }
232
233 #[tokio::test]
234 async fn checkpoint_roundtrip() {
235 let path = temp_path();
236 let journal = Arc::new(FileJournal::open(&path).await.unwrap());
237 let recovery = StateRecovery::new(journal);
238
239 let agent_id = AgentId::random();
240 let checkpoint = StateCheckpoint {
241 agent_id,
242 agent_state: crate::lifecycle::AgentState::Active,
243 timestamp: SystemTime::now(),
244 app_state: json!({"key": "value"}),
245 metadata: CheckpointMetadata {
246 version: 1,
247 in_flight_tasks: 5,
248 reason: Some("shutdown".to_string()),
249 custom: Value::Null,
250 },
251 };
252
253 recovery.persist_checkpoint(&checkpoint).await.unwrap();
254 let recovered = recovery.recover_checkpoint().await.unwrap();
255
256 assert_eq!(recovered.agent_id, checkpoint.agent_id);
257 assert_eq!(recovered.agent_state, checkpoint.agent_state);
258 assert_eq!(recovered.app_state, checkpoint.app_state);
259 assert_eq!(recovered.metadata.in_flight_tasks, 5);
260
261 if path.exists() {
262 let _ = std::fs::remove_file(path);
263 }
264 }
265
266 #[tokio::test]
267 async fn no_checkpoint_returns_error() {
268 let path = temp_path();
269 let journal = Arc::new(FileJournal::open(&path).await.unwrap());
270 let recovery = StateRecovery::new(journal);
271
272 let result = recovery.recover_checkpoint().await;
273 assert!(matches!(result, Err(RecoveryError::NoCheckpoint)));
274
275 if path.exists() {
276 let _ = std::fs::remove_file(path);
277 }
278 }
279
280 #[tokio::test]
281 async fn multiple_checkpoints_recovers_latest() {
282 let path = temp_path();
283 let journal = Arc::new(FileJournal::open(&path).await.unwrap());
284 let recovery = StateRecovery::new(journal);
285
286 let agent_id = AgentId::random();
287
288 let checkpoint1 = StateCheckpoint {
290 agent_id,
291 agent_state: crate::lifecycle::AgentState::Ready,
292 timestamp: SystemTime::now(),
293 app_state: json!({"version": 1}),
294 metadata: CheckpointMetadata::default(),
295 };
296 recovery.persist_checkpoint(&checkpoint1).await.unwrap();
297
298 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
300
301 let checkpoint2 = StateCheckpoint {
303 agent_id,
304 agent_state: crate::lifecycle::AgentState::Active,
305 timestamp: SystemTime::now(),
306 app_state: json!({"version": 2}),
307 metadata: CheckpointMetadata::default(),
308 };
309 recovery.persist_checkpoint(&checkpoint2).await.unwrap();
310
311 let recovered = recovery.recover_checkpoint().await.unwrap();
313 assert_eq!(recovered.app_state, json!({"version": 2}));
314
315 if path.exists() {
316 let _ = std::fs::remove_file(path);
317 }
318 }
319}