Skip to main content

adk_managed/
checkpoint.rs

1//! Checkpoint management for durable sessions.
2//!
3//! The [`CheckpointManager`] provides atomic checkpoint persistence so that
4//! a crash cannot leave an event emitted but un-checkpointed (or vice versa).
5//! For the initial implementation, storage is in-memory (`Vec<SessionEvent>`).
6//! The real integration with `SessionService` for persistent storage is a
7//! platform concern.
8//!
9//! # Responsibilities
10//!
11//! 1. **Atomicity guarantee**: event + state saved together in one operation
12//! 2. **Load/resume interface**: retrieve all events and last run state
13//! 3. **Event log maintenance**: ordered log for replay
14
15use serde::{Deserialize, Serialize};
16
17use crate::types::{SessionEvent, SessionStatus};
18
19/// Run-state persisted with each checkpoint.
20///
21/// Contains everything needed to resume a session after a crash:
22/// the current sequence counter value, which tool calls are parked,
23/// and the session's lifecycle status.
24///
25/// # Example
26///
27/// ```rust
28/// use adk_managed::checkpoint::RunState;
29/// use adk_managed::types::SessionStatus;
30///
31/// let state = RunState {
32///     seq: 5,
33///     pending_tool_ids: vec!["ctu_001".to_string()],
34///     status: SessionStatus::Running,
35/// };
36/// assert_eq!(state.seq, 5);
37/// assert!(!state.pending_tool_ids.is_empty());
38/// ```
39#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
40pub struct RunState {
41    /// Current sequence counter value.
42    pub seq: u64,
43    /// IDs of custom tool calls that are currently parked (awaiting client response).
44    pub pending_tool_ids: Vec<String>,
45    /// Current session status.
46    pub status: SessionStatus,
47}
48
49impl RunState {
50    /// Create a new initial run state (seq=0, no pending tools, queued status).
51    pub fn initial() -> Self {
52        Self { seq: 0, pending_tool_ids: Vec::new(), status: SessionStatus::Queued }
53    }
54}
55
56/// Manages atomic checkpoint persistence for durable sessions.
57///
58/// Each checkpoint atomically stores an event and the updated run-state so that
59/// a crash cannot leave an event emitted but un-checkpointed (or vice versa).
60///
61/// # Example
62///
63/// ```rust
64/// use adk_managed::checkpoint::{CheckpointManager, RunState};
65/// use adk_managed::types::{SessionEvent, SessionStatus, ContentBlock};
66///
67/// let mut mgr = CheckpointManager::new("session_001".to_string());
68///
69/// let event = SessionEvent::StatusRunning { seq: 0 };
70/// let state = RunState { seq: 1, pending_tool_ids: vec![], status: SessionStatus::Running };
71/// mgr.checkpoint(event, state.clone());
72///
73/// assert_eq!(mgr.events().len(), 1);
74/// assert_eq!(mgr.run_state(), &state);
75/// ```
76pub struct CheckpointManager {
77    /// The session ID this manager is checkpointing for.
78    session_id: String,
79    /// The event log (in-memory implementation).
80    events: Vec<SessionEvent>,
81    /// Current run state.
82    run_state: RunState,
83}
84
85impl CheckpointManager {
86    /// Create a new checkpoint manager for the given session.
87    ///
88    /// Initializes with an empty event log and the initial run state
89    /// (seq=0, no pending tools, queued status).
90    pub fn new(session_id: String) -> Self {
91        Self { session_id, events: Vec::new(), run_state: RunState::initial() }
92    }
93
94    /// Atomically persist an event and updated run-state.
95    ///
96    /// Both the event and the new state are stored together in one operation,
97    /// guaranteeing that replay will see a consistent view after any crash.
98    pub fn checkpoint(&mut self, event: SessionEvent, run_state: RunState) {
99        self.events.push(event);
100        self.run_state = run_state;
101    }
102
103    /// Load the last checkpoint for resume.
104    ///
105    /// Returns all stored events and the current run state, providing
106    /// everything needed to reconstruct a session after a restart.
107    pub fn load_checkpoint(&self) -> (Vec<SessionEvent>, RunState) {
108        (self.events.clone(), self.run_state.clone())
109    }
110
111    /// Get all events stored in the checkpoint log.
112    pub fn events(&self) -> &[SessionEvent] {
113        &self.events
114    }
115
116    /// Get current run state.
117    pub fn run_state(&self) -> &RunState {
118        &self.run_state
119    }
120
121    /// Get the session ID this manager is checkpointing for.
122    pub fn session_id(&self) -> &str {
123        &self.session_id
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    use super::*;
130    use crate::types::ContentBlock;
131    use serde_json::json;
132
133    #[test]
134    fn test_run_state_initial() {
135        let state = RunState::initial();
136        assert_eq!(state.seq, 0);
137        assert!(state.pending_tool_ids.is_empty());
138        assert_eq!(state.status, SessionStatus::Queued);
139    }
140
141    #[test]
142    fn test_run_state_serialization_round_trip() {
143        let state = RunState {
144            seq: 42,
145            pending_tool_ids: vec!["ctu_001".to_string(), "ctu_002".to_string()],
146            status: SessionStatus::Running,
147        };
148        let json = serde_json::to_string(&state).unwrap();
149        let deserialized: RunState = serde_json::from_str(&json).unwrap();
150        assert_eq!(state, deserialized);
151    }
152
153    #[test]
154    fn test_checkpoint_manager_new() {
155        let mgr = CheckpointManager::new("sess_123".to_string());
156        assert_eq!(mgr.session_id(), "sess_123");
157        assert!(mgr.events().is_empty());
158        assert_eq!(mgr.run_state(), &RunState::initial());
159    }
160
161    #[test]
162    fn test_checkpoint_stores_event_and_state_atomically() {
163        let mut mgr = CheckpointManager::new("sess_001".to_string());
164
165        let event = SessionEvent::StatusRunning { seq: 0 };
166        let state = RunState { seq: 1, pending_tool_ids: vec![], status: SessionStatus::Running };
167
168        mgr.checkpoint(event, state.clone());
169
170        // Both event and state should be stored together
171        assert_eq!(mgr.events().len(), 1);
172        assert_eq!(mgr.run_state(), &state);
173    }
174
175    #[test]
176    fn test_checkpoint_multiple_events() {
177        let mut mgr = CheckpointManager::new("sess_002".to_string());
178
179        // First checkpoint
180        let event1 = SessionEvent::StatusRunning { seq: 0 };
181        let state1 = RunState { seq: 1, pending_tool_ids: vec![], status: SessionStatus::Running };
182        mgr.checkpoint(event1, state1);
183
184        // Second checkpoint
185        let event2 = SessionEvent::Message {
186            content: vec![ContentBlock::Text { text: "Hello".to_string() }],
187            seq: 1,
188        };
189        let state2 = RunState { seq: 2, pending_tool_ids: vec![], status: SessionStatus::Running };
190        mgr.checkpoint(event2, state2.clone());
191
192        // Third checkpoint — idle with pending tool
193        let event3 = SessionEvent::CustomToolUse {
194            custom_tool_use_id: "ctu_001".to_string(),
195            name: "deploy".to_string(),
196            input: json!({"target": "staging"}),
197            seq: 2,
198        };
199        let state3 = RunState {
200            seq: 3,
201            pending_tool_ids: vec!["ctu_001".to_string()],
202            status: SessionStatus::Idle,
203        };
204        mgr.checkpoint(event3, state3.clone());
205
206        assert_eq!(mgr.events().len(), 3);
207        // Run state should reflect the LAST checkpoint
208        assert_eq!(mgr.run_state(), &state3);
209    }
210
211    #[test]
212    fn test_load_checkpoint_returns_all_events_and_current_state() {
213        let mut mgr = CheckpointManager::new("sess_003".to_string());
214
215        let event1 = SessionEvent::StatusRunning { seq: 0 };
216        let state1 = RunState { seq: 1, pending_tool_ids: vec![], status: SessionStatus::Running };
217        mgr.checkpoint(event1, state1);
218
219        let event2 = SessionEvent::StatusIdle { seq: 1, stop_reason: None, usage: None };
220        let state2 = RunState { seq: 2, pending_tool_ids: vec![], status: SessionStatus::Idle };
221        mgr.checkpoint(event2, state2.clone());
222
223        let (events, run_state) = mgr.load_checkpoint();
224        assert_eq!(events.len(), 2);
225        assert_eq!(run_state, state2);
226    }
227
228    #[test]
229    fn test_load_checkpoint_empty_manager() {
230        let mgr = CheckpointManager::new("sess_empty".to_string());
231        let (events, run_state) = mgr.load_checkpoint();
232        assert!(events.is_empty());
233        assert_eq!(run_state, RunState::initial());
234    }
235
236    #[test]
237    fn test_run_state_updates_atomically_with_event() {
238        let mut mgr = CheckpointManager::new("sess_atomic".to_string());
239
240        // Simulate a custom tool use that parks
241        let event = SessionEvent::CustomToolUse {
242            custom_tool_use_id: "ctu_park".to_string(),
243            name: "user_action".to_string(),
244            input: json!({}),
245            seq: 0,
246        };
247        let state = RunState {
248            seq: 1,
249            pending_tool_ids: vec!["ctu_park".to_string()],
250            status: SessionStatus::Idle,
251        };
252        mgr.checkpoint(event, state.clone());
253
254        // Verify the state reflects the parked tool
255        assert_eq!(mgr.run_state().pending_tool_ids, vec!["ctu_park"]);
256        assert_eq!(mgr.run_state().status, SessionStatus::Idle);
257
258        // Simulate the tool result arriving and session resuming
259        let event2 = SessionEvent::StatusRunning { seq: 1 };
260        let state2 = RunState { seq: 2, pending_tool_ids: vec![], status: SessionStatus::Running };
261        mgr.checkpoint(event2, state2.clone());
262
263        // Pending tools should be cleared
264        assert!(mgr.run_state().pending_tool_ids.is_empty());
265        assert_eq!(mgr.run_state().status, SessionStatus::Running);
266    }
267
268    #[test]
269    fn test_run_state_with_multiple_pending_tools() {
270        let state = RunState {
271            seq: 10,
272            pending_tool_ids: vec![
273                "ctu_001".to_string(),
274                "ctu_002".to_string(),
275                "ctu_003".to_string(),
276            ],
277            status: SessionStatus::Idle,
278        };
279
280        let json = serde_json::to_string(&state).unwrap();
281        let deserialized: RunState = serde_json::from_str(&json).unwrap();
282        assert_eq!(deserialized.pending_tool_ids.len(), 3);
283        assert_eq!(deserialized, state);
284    }
285}