use serde::{Deserialize, Serialize};
use crate::types::{SessionEvent, SessionStatus};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct RunState {
pub seq: u64,
pub pending_tool_ids: Vec<String>,
pub status: SessionStatus,
}
impl RunState {
pub fn initial() -> Self {
Self { seq: 0, pending_tool_ids: Vec::new(), status: SessionStatus::Queued }
}
}
pub struct CheckpointManager {
session_id: String,
events: Vec<SessionEvent>,
run_state: RunState,
}
impl CheckpointManager {
pub fn new(session_id: String) -> Self {
Self { session_id, events: Vec::new(), run_state: RunState::initial() }
}
pub fn checkpoint(&mut self, event: SessionEvent, run_state: RunState) {
self.events.push(event);
self.run_state = run_state;
}
pub fn load_checkpoint(&self) -> (Vec<SessionEvent>, RunState) {
(self.events.clone(), self.run_state.clone())
}
pub fn events(&self) -> &[SessionEvent] {
&self.events
}
pub fn run_state(&self) -> &RunState {
&self.run_state
}
pub fn session_id(&self) -> &str {
&self.session_id
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ContentBlock;
use serde_json::json;
#[test]
fn test_run_state_initial() {
let state = RunState::initial();
assert_eq!(state.seq, 0);
assert!(state.pending_tool_ids.is_empty());
assert_eq!(state.status, SessionStatus::Queued);
}
#[test]
fn test_run_state_serialization_round_trip() {
let state = RunState {
seq: 42,
pending_tool_ids: vec!["ctu_001".to_string(), "ctu_002".to_string()],
status: SessionStatus::Running,
};
let json = serde_json::to_string(&state).unwrap();
let deserialized: RunState = serde_json::from_str(&json).unwrap();
assert_eq!(state, deserialized);
}
#[test]
fn test_checkpoint_manager_new() {
let mgr = CheckpointManager::new("sess_123".to_string());
assert_eq!(mgr.session_id(), "sess_123");
assert!(mgr.events().is_empty());
assert_eq!(mgr.run_state(), &RunState::initial());
}
#[test]
fn test_checkpoint_stores_event_and_state_atomically() {
let mut mgr = CheckpointManager::new("sess_001".to_string());
let event = SessionEvent::StatusRunning { seq: 0 };
let state = RunState { seq: 1, pending_tool_ids: vec![], status: SessionStatus::Running };
mgr.checkpoint(event, state.clone());
assert_eq!(mgr.events().len(), 1);
assert_eq!(mgr.run_state(), &state);
}
#[test]
fn test_checkpoint_multiple_events() {
let mut mgr = CheckpointManager::new("sess_002".to_string());
let event1 = SessionEvent::StatusRunning { seq: 0 };
let state1 = RunState { seq: 1, pending_tool_ids: vec![], status: SessionStatus::Running };
mgr.checkpoint(event1, state1);
let event2 = SessionEvent::Message {
content: vec![ContentBlock::Text { text: "Hello".to_string() }],
seq: 1,
};
let state2 = RunState { seq: 2, pending_tool_ids: vec![], status: SessionStatus::Running };
mgr.checkpoint(event2, state2.clone());
let event3 = SessionEvent::CustomToolUse {
custom_tool_use_id: "ctu_001".to_string(),
name: "deploy".to_string(),
input: json!({"target": "staging"}),
seq: 2,
};
let state3 = RunState {
seq: 3,
pending_tool_ids: vec!["ctu_001".to_string()],
status: SessionStatus::Idle,
};
mgr.checkpoint(event3, state3.clone());
assert_eq!(mgr.events().len(), 3);
assert_eq!(mgr.run_state(), &state3);
}
#[test]
fn test_load_checkpoint_returns_all_events_and_current_state() {
let mut mgr = CheckpointManager::new("sess_003".to_string());
let event1 = SessionEvent::StatusRunning { seq: 0 };
let state1 = RunState { seq: 1, pending_tool_ids: vec![], status: SessionStatus::Running };
mgr.checkpoint(event1, state1);
let event2 = SessionEvent::StatusIdle { seq: 1, stop_reason: None, usage: None };
let state2 = RunState { seq: 2, pending_tool_ids: vec![], status: SessionStatus::Idle };
mgr.checkpoint(event2, state2.clone());
let (events, run_state) = mgr.load_checkpoint();
assert_eq!(events.len(), 2);
assert_eq!(run_state, state2);
}
#[test]
fn test_load_checkpoint_empty_manager() {
let mgr = CheckpointManager::new("sess_empty".to_string());
let (events, run_state) = mgr.load_checkpoint();
assert!(events.is_empty());
assert_eq!(run_state, RunState::initial());
}
#[test]
fn test_run_state_updates_atomically_with_event() {
let mut mgr = CheckpointManager::new("sess_atomic".to_string());
let event = SessionEvent::CustomToolUse {
custom_tool_use_id: "ctu_park".to_string(),
name: "user_action".to_string(),
input: json!({}),
seq: 0,
};
let state = RunState {
seq: 1,
pending_tool_ids: vec!["ctu_park".to_string()],
status: SessionStatus::Idle,
};
mgr.checkpoint(event, state.clone());
assert_eq!(mgr.run_state().pending_tool_ids, vec!["ctu_park"]);
assert_eq!(mgr.run_state().status, SessionStatus::Idle);
let event2 = SessionEvent::StatusRunning { seq: 1 };
let state2 = RunState { seq: 2, pending_tool_ids: vec![], status: SessionStatus::Running };
mgr.checkpoint(event2, state2.clone());
assert!(mgr.run_state().pending_tool_ids.is_empty());
assert_eq!(mgr.run_state().status, SessionStatus::Running);
}
#[test]
fn test_run_state_with_multiple_pending_tools() {
let state = RunState {
seq: 10,
pending_tool_ids: vec![
"ctu_001".to_string(),
"ctu_002".to_string(),
"ctu_003".to_string(),
],
status: SessionStatus::Idle,
};
let json = serde_json::to_string(&state).unwrap();
let deserialized: RunState = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.pending_tool_ids.len(), 3);
assert_eq!(deserialized, state);
}
}