agent_kernel/
recovery.rs

1//! State recovery and checkpoint persistence for graceful shutdown and restart.
2//!
3//! This module provides mechanisms to persist critical agent state before termination
4//! and recover that state on startup, enabling agents to resume operation without
5//! losing in-flight work or configuration.
6
7use std::time::SystemTime;
8
9use agent_memory::Journal;
10use agent_primitives::AgentId;
11use bytes::Bytes;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15use tracing::{debug, warn};
16
17use crate::lifecycle::AgentState;
18
19/// Errors from state recovery operations.
20#[derive(Debug, Error)]
21pub enum RecoveryError {
22    /// Journal operation failed.
23    #[error("journal error: {0}")]
24    JournalError(String),
25
26    /// Serialization error.
27    #[error("serialization error: {0}")]
28    SerializationError(#[from] serde_json::Error),
29
30    /// No checkpoint found for recovery.
31    #[error("no checkpoint found")]
32    NoCheckpoint,
33
34    /// Checkpoint is corrupted or invalid.
35    #[error("checkpoint corrupted: {0}")]
36    CorruptedCheckpoint(String),
37
38    /// I/O error.
39    #[error("io error: {0}")]
40    IoError(#[from] std::io::Error),
41}
42
43/// Result type for recovery operations.
44pub type RecoveryResult<T> = Result<T, RecoveryError>;
45
46/// Checkpoint containing critical agent state for recovery.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StateCheckpoint {
49    /// Agent identifier.
50    pub agent_id: AgentId,
51
52    /// Agent lifecycle state at checkpoint time.
53    pub agent_state: AgentState,
54
55    /// Timestamp when checkpoint was created.
56    pub timestamp: SystemTime,
57
58    /// Custom application state (opaque to recovery layer).
59    pub app_state: Value,
60
61    /// Metadata about the checkpoint.
62    pub metadata: CheckpointMetadata,
63}
64
65/// Metadata about a checkpoint.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CheckpointMetadata {
68    /// Version of checkpoint format.
69    pub version: u32,
70
71    /// Number of in-flight tasks at checkpoint time.
72    pub in_flight_tasks: u32,
73
74    /// Optional reason for checkpoint (e.g., "shutdown", "periodic").
75    pub reason: Option<String>,
76
77    /// Custom metadata fields.
78    #[serde(default)]
79    pub custom: Value,
80}
81
82impl Default for CheckpointMetadata {
83    fn default() -> Self {
84        Self {
85            version: 1,
86            in_flight_tasks: 0,
87            reason: None,
88            custom: Value::Null,
89        }
90    }
91}
92
93/// Manages checkpoint persistence and recovery.
94pub struct StateRecovery {
95    journal: std::sync::Arc<dyn Journal>,
96    checkpoint_tag: String,
97}
98
99impl std::fmt::Debug for StateRecovery {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("StateRecovery")
102            .field("checkpoint_tag", &self.checkpoint_tag)
103            .finish()
104    }
105}
106
107impl StateRecovery {
108    /// Creates a new state recovery manager.
109    pub fn new(journal: std::sync::Arc<dyn Journal>) -> Self {
110        Self {
111            journal,
112            checkpoint_tag: "checkpoint".to_string(),
113        }
114    }
115
116    /// Persists a checkpoint to the journal.
117    ///
118    /// # Arguments
119    ///
120    /// * `checkpoint` - The state checkpoint to persist
121    ///
122    /// # Errors
123    ///
124    /// Returns `RecoveryError` if journal write fails or serialization fails.
125    pub async fn persist_checkpoint(&self, checkpoint: &StateCheckpoint) -> RecoveryResult<()> {
126        let payload = serde_json::to_vec(checkpoint)?;
127        let record = agent_memory::MemoryRecord::builder(
128            agent_memory::MemoryChannel::System,
129            Bytes::from(payload),
130        )
131        .tag(&self.checkpoint_tag)
132        .map_err(|e| RecoveryError::JournalError(e.to_string()))?
133        .metadata("checkpoint_version", json!(checkpoint.metadata.version))
134        .metadata("agent_id", json!(checkpoint.agent_id.to_string()))
135        .metadata(
136            "agent_state",
137            json!(format!("{:?}", checkpoint.agent_state)),
138        )
139        .build()
140        .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
141
142        self.journal
143            .append(&record)
144            .await
145            .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
146
147        debug!(
148            agent_id = %checkpoint.agent_id,
149            state = ?checkpoint.agent_state,
150            "checkpoint persisted"
151        );
152
153        Ok(())
154    }
155
156    /// Recovers the most recent checkpoint from the journal.
157    ///
158    /// # Errors
159    ///
160    /// Returns `RecoveryError::NoCheckpoint` if no checkpoint exists.
161    /// Returns `RecoveryError::CorruptedCheckpoint` if checkpoint data is invalid.
162    pub async fn recover_checkpoint(&self) -> RecoveryResult<StateCheckpoint> {
163        // Retrieve recent records from journal
164        let records = self
165            .journal
166            .tail(100)
167            .await
168            .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
169
170        // Find the most recent checkpoint record
171        for record in records.iter().rev() {
172            if record.tags().contains(&self.checkpoint_tag) {
173                let checkpoint: StateCheckpoint = serde_json::from_slice(record.payload())
174                    .map_err(|e| {
175                        RecoveryError::CorruptedCheckpoint(format!("deserialization failed: {}", e))
176                    })?;
177
178                debug!(
179                    agent_id = %checkpoint.agent_id,
180                    state = ?checkpoint.agent_state,
181                    "checkpoint recovered"
182                );
183
184                return Ok(checkpoint);
185            }
186        }
187
188        warn!("no checkpoint found in journal");
189        Err(RecoveryError::NoCheckpoint)
190    }
191
192    /// Creates a checkpoint from current agent state.
193    pub fn create_checkpoint(
194        agent_id: AgentId,
195        agent_state: AgentState,
196        app_state: Value,
197        in_flight_tasks: u32,
198    ) -> StateCheckpoint {
199        StateCheckpoint {
200            agent_id,
201            agent_state,
202            timestamp: SystemTime::now(),
203            app_state,
204            metadata: CheckpointMetadata {
205                version: 1,
206                in_flight_tasks,
207                reason: Some("shutdown".to_string()),
208                custom: Value::Null,
209            },
210        }
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use agent_memory::FileJournal;
218    use std::sync::Arc;
219
220    fn temp_path() -> std::path::PathBuf {
221        let mut path = std::env::temp_dir();
222        path.push(format!(
223            "recovery-test-{}-{}.log",
224            std::process::id(),
225            std::time::SystemTime::now()
226                .duration_since(std::time::UNIX_EPOCH)
227                .unwrap()
228                .as_nanos()
229        ));
230        path
231    }
232
233    #[tokio::test]
234    async fn checkpoint_roundtrip() {
235        let path = temp_path();
236        let journal = Arc::new(FileJournal::open(&path).await.unwrap());
237        let recovery = StateRecovery::new(journal);
238
239        let agent_id = AgentId::random();
240        let checkpoint = StateCheckpoint {
241            agent_id,
242            agent_state: crate::lifecycle::AgentState::Active,
243            timestamp: SystemTime::now(),
244            app_state: json!({"key": "value"}),
245            metadata: CheckpointMetadata {
246                version: 1,
247                in_flight_tasks: 5,
248                reason: Some("shutdown".to_string()),
249                custom: Value::Null,
250            },
251        };
252
253        recovery.persist_checkpoint(&checkpoint).await.unwrap();
254        let recovered = recovery.recover_checkpoint().await.unwrap();
255
256        assert_eq!(recovered.agent_id, checkpoint.agent_id);
257        assert_eq!(recovered.agent_state, checkpoint.agent_state);
258        assert_eq!(recovered.app_state, checkpoint.app_state);
259        assert_eq!(recovered.metadata.in_flight_tasks, 5);
260
261        if path.exists() {
262            let _ = std::fs::remove_file(path);
263        }
264    }
265
266    #[tokio::test]
267    async fn no_checkpoint_returns_error() {
268        let path = temp_path();
269        let journal = Arc::new(FileJournal::open(&path).await.unwrap());
270        let recovery = StateRecovery::new(journal);
271
272        let result = recovery.recover_checkpoint().await;
273        assert!(matches!(result, Err(RecoveryError::NoCheckpoint)));
274
275        if path.exists() {
276            let _ = std::fs::remove_file(path);
277        }
278    }
279
280    #[tokio::test]
281    async fn multiple_checkpoints_recovers_latest() {
282        let path = temp_path();
283        let journal = Arc::new(FileJournal::open(&path).await.unwrap());
284        let recovery = StateRecovery::new(journal);
285
286        let agent_id = AgentId::random();
287
288        // Persist first checkpoint
289        let checkpoint1 = StateCheckpoint {
290            agent_id,
291            agent_state: crate::lifecycle::AgentState::Ready,
292            timestamp: SystemTime::now(),
293            app_state: json!({"version": 1}),
294            metadata: CheckpointMetadata::default(),
295        };
296        recovery.persist_checkpoint(&checkpoint1).await.unwrap();
297
298        // Small delay to ensure different timestamps
299        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
300
301        // Persist second checkpoint
302        let checkpoint2 = StateCheckpoint {
303            agent_id,
304            agent_state: crate::lifecycle::AgentState::Active,
305            timestamp: SystemTime::now(),
306            app_state: json!({"version": 2}),
307            metadata: CheckpointMetadata::default(),
308        };
309        recovery.persist_checkpoint(&checkpoint2).await.unwrap();
310
311        // Should recover the latest
312        let recovered = recovery.recover_checkpoint().await.unwrap();
313        assert_eq!(recovered.app_state, json!({"version": 2}));
314
315        if path.exists() {
316            let _ = std::fs::remove_file(path);
317        }
318    }
319}