use std::collections::HashMap;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use crate::graph::TaskGraph;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "state", rename_all = "snake_case")]
pub enum StepState {
Pending,
Ready,
Running { started_at: DateTime<Utc> },
AwaitingConfirmation { nonce: String, since: DateTime<Utc> },
Completed {
outcome: StepOutcome,
completed_at: DateTime<Utc>,
},
Failed {
error: String,
retryable: bool,
failed_at: DateTime<Utc>,
},
Skipped { reason: String },
Cancelled,
}
impl StepState {
pub fn is_terminal(&self) -> bool {
matches!(
self,
StepState::Completed { .. }
| StepState::Failed { .. }
| StepState::Skipped { .. }
| StepState::Cancelled
)
}
pub fn is_success(&self) -> bool {
matches!(self, StepState::Completed { .. })
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StepOutcome {
pub stdout: String,
pub stderr: String,
pub exit_code: Option<i32>,
pub artifacts: Vec<String>,
pub summary: String,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum TaskPhase {
Planning,
AwaitingApproval,
Executing,
Reconciling,
Completed,
Failed,
Cancelled,
}
impl TaskPhase {
pub fn as_str(self) -> &'static str {
match self {
TaskPhase::Planning => "planning",
TaskPhase::AwaitingApproval => "awaiting_approval",
TaskPhase::Executing => "executing",
TaskPhase::Reconciling => "reconciling",
TaskPhase::Completed => "completed",
TaskPhase::Failed => "failed",
TaskPhase::Cancelled => "cancelled",
}
}
pub fn is_terminal(self) -> bool {
matches!(
self,
TaskPhase::Completed | TaskPhase::Failed | TaskPhase::Cancelled
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskState {
pub id: String,
pub request: String,
pub graph: TaskGraph,
pub step_states: HashMap<String, StepState>,
pub created_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub phase: TaskPhase,
#[serde(default)]
pub replan_attempts: u32,
}
impl TaskState {
pub fn new(id: String, request: String, graph: TaskGraph) -> Self {
let step_states: HashMap<String, StepState> = graph
.steps
.keys()
.map(|id| (id.clone(), StepState::Pending))
.collect();
Self {
id,
request,
graph,
step_states,
created_at: Utc::now(),
completed_at: None,
phase: TaskPhase::Planning,
replan_attempts: 0,
}
}
pub fn set_step_state(&mut self, step_id: &str, state: StepState) {
self.step_states.insert(step_id.to_string(), state);
}
pub fn is_complete(&self) -> bool {
self.step_states.values().all(|s| s.is_terminal())
}
pub fn all_succeeded(&self) -> bool {
self.step_states.values().all(|s| s.is_success())
}
pub fn counts(&self) -> TaskCounts {
let mut c = TaskCounts::default();
for state in self.step_states.values() {
match state {
StepState::Pending => c.pending += 1,
StepState::Ready => c.ready += 1,
StepState::Running { .. } => c.running += 1,
StepState::AwaitingConfirmation { .. } => c.awaiting += 1,
StepState::Completed { .. } => c.completed += 1,
StepState::Failed { .. } => c.failed += 1,
StepState::Skipped { .. } => c.skipped += 1,
StepState::Cancelled => c.cancelled += 1,
}
}
c
}
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct TaskCounts {
pub pending: usize,
pub ready: usize,
pub running: usize,
pub awaiting: usize,
pub completed: usize,
pub failed: usize,
pub skipped: usize,
pub cancelled: usize,
}
impl TaskCounts {
pub fn total(&self) -> usize {
self.pending
+ self.ready
+ self.running
+ self.awaiting
+ self.completed
+ self.failed
+ self.skipped
+ self.cancelled
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::TaskGraph;
use crate::step::{StepAction, TaskStep};
fn simple_graph() -> TaskGraph {
let steps = vec![
TaskStep {
id: "s1".to_string(),
description: "Step 1".to_string(),
action: StepAction::Plan {
output: "plan".to_string(),
},
depends_on: vec![],
tier: audit::ActionTier::Execute,
estimated_tokens: 0,
},
TaskStep {
id: "s2".to_string(),
description: "Step 2".to_string(),
action: StepAction::Test {
command: "cargo test".to_string(),
workdir: "/tmp".into(),
},
depends_on: vec!["s1".to_string()],
tier: audit::ActionTier::Execute,
estimated_tokens: 0,
},
];
TaskGraph::from_steps(steps).unwrap()
}
#[test]
fn test_new_task_state() {
let graph = simple_graph();
let state = TaskState::new("t1".to_string(), "build it".to_string(), graph);
assert_eq!(state.phase, TaskPhase::Planning);
assert!(!state.is_complete());
let counts = state.counts();
assert_eq!(counts.pending, 2);
assert_eq!(counts.total(), 2);
}
#[test]
fn test_step_transitions() {
let graph = simple_graph();
let mut state = TaskState::new("t1".to_string(), "build it".to_string(), graph);
state.set_step_state(
"s1",
StepState::Completed {
outcome: StepOutcome {
stdout: String::new(),
stderr: String::new(),
exit_code: Some(0),
artifacts: vec![],
summary: "done".to_string(),
},
completed_at: Utc::now(),
},
);
assert!(!state.is_complete());
state.set_step_state(
"s2",
StepState::Completed {
outcome: StepOutcome {
stdout: String::new(),
stderr: String::new(),
exit_code: Some(0),
artifacts: vec![],
summary: "done".to_string(),
},
completed_at: Utc::now(),
},
);
assert!(state.is_complete());
assert!(state.all_succeeded());
}
}