agent_kernel/
recovery.rs

1//! State recovery and checkpoint persistence for graceful shutdown and restart.
2//!
3//! This module provides mechanisms to persist critical agent state before termination
4//! and recover that state on startup, enabling agents to resume operation without
5//! losing in-flight work or configuration.
6
7use std::time::SystemTime;
8
9use agent_memory::Journal;
10use agent_primitives::AgentId;
11use bytes::Bytes;
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15use tracing::{debug, warn};
16
17use crate::lifecycle::AgentState;
18
19/// Errors from state recovery operations.
20#[derive(Debug, Error)]
21pub enum RecoveryError {
22    /// Journal operation failed.
23    #[error("journal error: {0}")]
24    JournalError(String),
25
26    /// Serialization error.
27    #[error("serialization error: {0}")]
28    SerializationError(#[from] serde_json::Error),
29
30    /// No checkpoint found for recovery.
31    #[error("no checkpoint found")]
32    NoCheckpoint,
33
34    /// Checkpoint is corrupted or invalid.
35    #[error("checkpoint corrupted: {0}")]
36    CorruptedCheckpoint(String),
37
38    /// I/O error.
39    #[error("io error: {0}")]
40    IoError(#[from] std::io::Error),
41}
42
43/// Result type for recovery operations.
44pub type RecoveryResult<T> = Result<T, RecoveryError>;
45
46/// Checkpoint containing critical agent state for recovery.
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct StateCheckpoint {
49    /// Agent identifier.
50    pub agent_id: AgentId,
51
52    /// Agent lifecycle state at checkpoint time.
53    pub agent_state: AgentState,
54
55    /// Timestamp when checkpoint was created.
56    pub timestamp: SystemTime,
57
58    /// Custom application state (opaque to recovery layer).
59    pub app_state: Value,
60
61    /// Metadata about the checkpoint.
62    pub metadata: CheckpointMetadata,
63}
64
65/// Metadata about a checkpoint.
66#[derive(Debug, Clone, Serialize, Deserialize)]
67pub struct CheckpointMetadata {
68    /// Version of checkpoint format.
69    pub version: u32,
70
71    /// Number of in-flight tasks at checkpoint time.
72    pub in_flight_tasks: u32,
73
74    /// Optional reason for checkpoint (e.g., "shutdown", "periodic").
75    pub reason: Option<String>,
76
77    /// Custom metadata fields.
78    #[serde(default)]
79    pub custom: Value,
80}
81
82impl Default for CheckpointMetadata {
83    fn default() -> Self {
84        Self {
85            version: 1,
86            in_flight_tasks: 0,
87            reason: None,
88            custom: Value::Null,
89        }
90    }
91}
92
93/// Manages checkpoint persistence and recovery.
94pub struct StateRecovery {
95    journal: std::sync::Arc<dyn Journal>,
96    checkpoint_tag: String,
97}
98
99impl std::fmt::Debug for StateRecovery {
100    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
101        f.debug_struct("StateRecovery")
102            .field("checkpoint_tag", &self.checkpoint_tag)
103            .finish_non_exhaustive()
104    }
105}
106
107impl StateRecovery {
108    /// Creates a new state recovery manager.
109    pub fn new(journal: std::sync::Arc<dyn Journal>) -> Self {
110        Self {
111            journal,
112            checkpoint_tag: "checkpoint".to_string(),
113        }
114    }
115
116    /// Persists a checkpoint to the journal.
117    ///
118    /// # Arguments
119    ///
120    /// * `checkpoint` - The state checkpoint to persist
121    ///
122    /// # Errors
123    ///
124    /// Returns `RecoveryError` if journal write fails or serialization fails.
125    pub async fn persist_checkpoint(&self, checkpoint: &StateCheckpoint) -> RecoveryResult<()> {
126        let payload = serde_json::to_vec(checkpoint)?;
127        let record = agent_memory::MemoryRecord::builder(
128            agent_memory::MemoryChannel::System,
129            Bytes::from(payload),
130        )
131        .tag(&self.checkpoint_tag)
132        .map_err(|e| RecoveryError::JournalError(e.to_string()))?
133        .metadata("checkpoint_version", json!(checkpoint.metadata.version))
134        .metadata("agent_id", json!(checkpoint.agent_id.to_string()))
135        .metadata(
136            "agent_state",
137            json!(format!("{:?}", checkpoint.agent_state)),
138        )
139        .build()
140        .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
141
142        self.journal
143            .append(&record)
144            .await
145            .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
146
147        debug!(
148            agent_id = %checkpoint.agent_id,
149            state = ?checkpoint.agent_state,
150            "checkpoint persisted"
151        );
152
153        Ok(())
154    }
155
156    /// Recovers the most recent checkpoint from the journal.
157    ///
158    /// # Errors
159    ///
160    /// Returns `RecoveryError::NoCheckpoint` if no checkpoint exists.
161    /// Returns `RecoveryError::CorruptedCheckpoint` if checkpoint data is invalid.
162    pub async fn recover_checkpoint(&self) -> RecoveryResult<StateCheckpoint> {
163        // Retrieve recent records from journal
164        let records = self
165            .journal
166            .tail(100)
167            .await
168            .map_err(|e| RecoveryError::JournalError(e.to_string()))?;
169
170        // Find the most recent checkpoint record
171        for record in records.iter().rev() {
172            if record.tags().contains(&self.checkpoint_tag) {
173                let checkpoint: StateCheckpoint = serde_json::from_slice(record.payload())
174                    .map_err(|e| {
175                        RecoveryError::CorruptedCheckpoint(format!("deserialization failed: {e}"))
176                    })?;
177
178                debug!(
179                    agent_id = %checkpoint.agent_id,
180                    state = ?checkpoint.agent_state,
181                    "checkpoint recovered"
182                );
183
184                return Ok(checkpoint);
185            }
186        }
187
188        warn!("no checkpoint found in journal");
189        Err(RecoveryError::NoCheckpoint)
190    }
191
192    /// Creates a checkpoint from current agent state.
193    #[must_use]
194    pub fn create_checkpoint(
195        agent_id: AgentId,
196        agent_state: AgentState,
197        app_state: Value,
198        in_flight_tasks: u32,
199    ) -> StateCheckpoint {
200        StateCheckpoint {
201            agent_id,
202            agent_state,
203            timestamp: SystemTime::now(),
204            app_state,
205            metadata: CheckpointMetadata {
206                version: 1,
207                in_flight_tasks,
208                reason: Some("shutdown".to_string()),
209                custom: Value::Null,
210            },
211        }
212    }
213}
214
215#[cfg(test)]
216mod tests {
217    use super::*;
218    use agent_memory::FileJournal;
219    use std::sync::Arc;
220
221    fn temp_path() -> std::path::PathBuf {
222        let mut path = std::env::temp_dir();
223        path.push(format!(
224            "recovery-test-{}-{}.log",
225            std::process::id(),
226            std::time::SystemTime::now()
227                .duration_since(std::time::UNIX_EPOCH)
228                .unwrap()
229                .as_nanos()
230        ));
231        path
232    }
233
234    #[tokio::test]
235    async fn checkpoint_roundtrip() {
236        let path = temp_path();
237        let journal = Arc::new(FileJournal::open(&path).await.unwrap());
238        let recovery = StateRecovery::new(journal);
239
240        let agent_id = AgentId::random();
241        let checkpoint = StateCheckpoint {
242            agent_id,
243            agent_state: crate::lifecycle::AgentState::Active,
244            timestamp: SystemTime::now(),
245            app_state: json!({"key": "value"}),
246            metadata: CheckpointMetadata {
247                version: 1,
248                in_flight_tasks: 5,
249                reason: Some("shutdown".to_string()),
250                custom: Value::Null,
251            },
252        };
253
254        recovery.persist_checkpoint(&checkpoint).await.unwrap();
255        let recovered = recovery.recover_checkpoint().await.unwrap();
256
257        assert_eq!(recovered.agent_id, checkpoint.agent_id);
258        assert_eq!(recovered.agent_state, checkpoint.agent_state);
259        assert_eq!(recovered.app_state, checkpoint.app_state);
260        assert_eq!(recovered.metadata.in_flight_tasks, 5);
261
262        if path.exists() {
263            let _ = std::fs::remove_file(path);
264        }
265    }
266
267    #[tokio::test]
268    async fn no_checkpoint_returns_error() {
269        let path = temp_path();
270        let journal = Arc::new(FileJournal::open(&path).await.unwrap());
271        let recovery = StateRecovery::new(journal);
272
273        let result = recovery.recover_checkpoint().await;
274        assert!(matches!(result, Err(RecoveryError::NoCheckpoint)));
275
276        if path.exists() {
277            let _ = std::fs::remove_file(path);
278        }
279    }
280
281    #[tokio::test]
282    async fn multiple_checkpoints_recovers_latest() {
283        let path = temp_path();
284        let journal = Arc::new(FileJournal::open(&path).await.unwrap());
285        let recovery = StateRecovery::new(journal);
286
287        let agent_id = AgentId::random();
288
289        // Persist first checkpoint
290        let checkpoint1 = StateCheckpoint {
291            agent_id,
292            agent_state: crate::lifecycle::AgentState::Ready,
293            timestamp: SystemTime::now(),
294            app_state: json!({"version": 1}),
295            metadata: CheckpointMetadata::default(),
296        };
297        recovery.persist_checkpoint(&checkpoint1).await.unwrap();
298
299        // Small delay to ensure different timestamps
300        tokio::time::sleep(std::time::Duration::from_millis(10)).await;
301
302        // Persist second checkpoint
303        let checkpoint2 = StateCheckpoint {
304            agent_id,
305            agent_state: crate::lifecycle::AgentState::Active,
306            timestamp: SystemTime::now(),
307            app_state: json!({"version": 2}),
308            metadata: CheckpointMetadata::default(),
309        };
310        recovery.persist_checkpoint(&checkpoint2).await.unwrap();
311
312        // Should recover the latest
313        let recovered = recovery.recover_checkpoint().await.unwrap();
314        assert_eq!(recovered.app_state, json!({"version": 2}));
315
316        if path.exists() {
317            let _ = std::fs::remove_file(path);
318        }
319    }
320}