Skip to main content

brainos_orchestrate/
state.rs

1//! Task and step state machine.
2
3use std::collections::HashMap;
4
5use chrono::{DateTime, Utc};
6use serde::{Deserialize, Serialize};
7
8use crate::graph::TaskGraph;
9
10/// State of a single step in the task plan.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(tag = "state", rename_all = "snake_case")]
13pub enum StepState {
14    /// Waiting for dependencies.
15    Pending,
16    /// Dependencies met, awaiting execution slot.
17    Ready,
18    /// Currently executing.
19    Running { started_at: DateTime<Utc> },
20    /// Waiting for human approval.
21    AwaitingConfirmation { nonce: String, since: DateTime<Utc> },
22    /// Completed successfully.
23    Completed {
24        outcome: StepOutcome,
25        completed_at: DateTime<Utc>,
26    },
27    /// Failed.
28    Failed {
29        error: String,
30        retryable: bool,
31        failed_at: DateTime<Utc>,
32    },
33    /// Skipped (dependency failed, user chose to skip).
34    Skipped { reason: String },
35    /// Cancelled by user or timeout.
36    Cancelled,
37}
38
39impl StepState {
40    pub fn is_terminal(&self) -> bool {
41        matches!(
42            self,
43            StepState::Completed { .. }
44                | StepState::Failed { .. }
45                | StepState::Skipped { .. }
46                | StepState::Cancelled
47        )
48    }
49
50    pub fn is_success(&self) -> bool {
51        matches!(self, StepState::Completed { .. })
52    }
53}
54
55/// Outcome of a completed step.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct StepOutcome {
58    pub stdout: String,
59    pub stderr: String,
60    pub exit_code: Option<i32>,
61    pub artifacts: Vec<String>,
62    pub summary: String,
63}
64
65/// Overall phase of the task.
66///
67/// The canonical state machine is
68/// `Planning → Executing → Reconciling → (Completed | Failed)`.
69/// `AwaitingApproval` is a side-channel state entered before `Executing`
70/// when the plan needs human consent; `Cancelled` is a terminal state
71/// reachable from anywhere via [`crate::TaskOrchestrator::cancel`].
72#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
73#[serde(rename_all = "snake_case")]
74pub enum TaskPhase {
75    /// Plan is being assembled.
76    Planning,
77    /// Waiting for user approval of the plan.
78    AwaitingApproval,
79    /// Steps are being executed.
80    Executing,
81    /// All steps complete; verifying the world matches what was planned
82    /// before committing to a terminal phase. The explicit "world-drift
83    /// detection" hook; today it's a brief no-op transition that lands
84    /// on `Completed` or `Failed` based on step outcomes.
85    Reconciling,
86    /// All steps completed successfully.
87    Completed,
88    /// At least one step failed and the orchestrator gave up
89    /// (replan budget exhausted or non-retryable error).
90    Failed,
91    /// Task was cancelled.
92    Cancelled,
93}
94
95impl TaskPhase {
96    /// Stable wire string used by [`observe::BrainEvent::TaskStateChange`]
97    /// and the `task_states` audit table. Snake-case matches the serde
98    /// representation; keep them in sync if either changes.
99    pub fn as_str(self) -> &'static str {
100        match self {
101            TaskPhase::Planning => "planning",
102            TaskPhase::AwaitingApproval => "awaiting_approval",
103            TaskPhase::Executing => "executing",
104            TaskPhase::Reconciling => "reconciling",
105            TaskPhase::Completed => "completed",
106            TaskPhase::Failed => "failed",
107            TaskPhase::Cancelled => "cancelled",
108        }
109    }
110
111    /// True for phases the orchestrator never transitions out of.
112    pub fn is_terminal(self) -> bool {
113        matches!(
114            self,
115            TaskPhase::Completed | TaskPhase::Failed | TaskPhase::Cancelled
116        )
117    }
118}
119
120/// Complete state of a task — graph + per-step states.
121#[derive(Debug, Clone, Serialize, Deserialize)]
122pub struct TaskState {
123    /// Unique task ID.
124    pub id: String,
125    /// Original user request.
126    pub request: String,
127    /// The task graph (steps + dependencies).
128    pub graph: TaskGraph,
129    /// Per-step state.
130    pub step_states: HashMap<String, StepState>,
131    /// When the task was created.
132    pub created_at: DateTime<Utc>,
133    /// When the task completed (if it has).
134    pub completed_at: Option<DateTime<Utc>>,
135    /// Current phase.
136    pub phase: TaskPhase,
137    /// How many times the orchestrator has invoked the
138    /// replan-on-failure loop for this task. Capped to bound LLM cost.
139    #[serde(default)]
140    pub replan_attempts: u32,
141}
142
143impl TaskState {
144    /// Create a new task in the Planning phase.
145    pub fn new(id: String, request: String, graph: TaskGraph) -> Self {
146        let step_states: HashMap<String, StepState> = graph
147            .steps
148            .keys()
149            .map(|id| (id.clone(), StepState::Pending))
150            .collect();
151
152        Self {
153            id,
154            request,
155            graph,
156            step_states,
157            created_at: Utc::now(),
158            completed_at: None,
159            phase: TaskPhase::Planning,
160            replan_attempts: 0,
161        }
162    }
163
164    /// Transition a step to a new state.
165    pub fn set_step_state(&mut self, step_id: &str, state: StepState) {
166        self.step_states.insert(step_id.to_string(), state);
167    }
168
169    /// Check whether all steps are in a terminal state.
170    pub fn is_complete(&self) -> bool {
171        self.step_states.values().all(|s| s.is_terminal())
172    }
173
174    /// Check whether all steps succeeded.
175    pub fn all_succeeded(&self) -> bool {
176        self.step_states.values().all(|s| s.is_success())
177    }
178
179    /// Count steps by state category.
180    pub fn counts(&self) -> TaskCounts {
181        let mut c = TaskCounts::default();
182        for state in self.step_states.values() {
183            match state {
184                StepState::Pending => c.pending += 1,
185                StepState::Ready => c.ready += 1,
186                StepState::Running { .. } => c.running += 1,
187                StepState::AwaitingConfirmation { .. } => c.awaiting += 1,
188                StepState::Completed { .. } => c.completed += 1,
189                StepState::Failed { .. } => c.failed += 1,
190                StepState::Skipped { .. } => c.skipped += 1,
191                StepState::Cancelled => c.cancelled += 1,
192            }
193        }
194        c
195    }
196}
197
198/// Step counts by state.
199#[derive(Debug, Default, Clone, Serialize, Deserialize)]
200pub struct TaskCounts {
201    pub pending: usize,
202    pub ready: usize,
203    pub running: usize,
204    pub awaiting: usize,
205    pub completed: usize,
206    pub failed: usize,
207    pub skipped: usize,
208    pub cancelled: usize,
209}
210
211impl TaskCounts {
212    pub fn total(&self) -> usize {
213        self.pending
214            + self.ready
215            + self.running
216            + self.awaiting
217            + self.completed
218            + self.failed
219            + self.skipped
220            + self.cancelled
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227    use crate::graph::TaskGraph;
228    use crate::step::{StepAction, TaskStep};
229
230    fn simple_graph() -> TaskGraph {
231        let steps = vec![
232            TaskStep {
233                id: "s1".to_string(),
234                description: "Step 1".to_string(),
235                action: StepAction::Plan {
236                    output: "plan".to_string(),
237                },
238                depends_on: vec![],
239                tier: audit::ActionTier::Execute,
240                estimated_tokens: 0,
241            },
242            TaskStep {
243                id: "s2".to_string(),
244                description: "Step 2".to_string(),
245                action: StepAction::Test {
246                    command: "cargo test".to_string(),
247                    workdir: "/tmp".into(),
248                },
249                depends_on: vec!["s1".to_string()],
250                tier: audit::ActionTier::Execute,
251                estimated_tokens: 0,
252            },
253        ];
254        TaskGraph::from_steps(steps).unwrap()
255    }
256
257    #[test]
258    fn test_new_task_state() {
259        let graph = simple_graph();
260        let state = TaskState::new("t1".to_string(), "build it".to_string(), graph);
261        assert_eq!(state.phase, TaskPhase::Planning);
262        assert!(!state.is_complete());
263        let counts = state.counts();
264        assert_eq!(counts.pending, 2);
265        assert_eq!(counts.total(), 2);
266    }
267
268    #[test]
269    fn test_step_transitions() {
270        let graph = simple_graph();
271        let mut state = TaskState::new("t1".to_string(), "build it".to_string(), graph);
272
273        state.set_step_state(
274            "s1",
275            StepState::Completed {
276                outcome: StepOutcome {
277                    stdout: String::new(),
278                    stderr: String::new(),
279                    exit_code: Some(0),
280                    artifacts: vec![],
281                    summary: "done".to_string(),
282                },
283                completed_at: Utc::now(),
284            },
285        );
286        assert!(!state.is_complete());
287
288        state.set_step_state(
289            "s2",
290            StepState::Completed {
291                outcome: StepOutcome {
292                    stdout: String::new(),
293                    stderr: String::new(),
294                    exit_code: Some(0),
295                    artifacts: vec![],
296                    summary: "done".to_string(),
297                },
298                completed_at: Utc::now(),
299            },
300        );
301        assert!(state.is_complete());
302        assert!(state.all_succeeded());
303    }
304}