Skip to main content

libbrat_grite/
state_machine.rs

1//! State machine validation for Brat entity lifecycles.
2//!
3//! This module provides generic state machine validation for tasks, sessions,
4//! convoys, and roles. Transitions are validated against defined rules before
5//! being persisted to Grite.
6
7use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10use crate::types::{SessionStatus, TaskStatus};
11
12/// Error returned when a state transition is invalid.
13#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct TransitionError<S> {
15    pub from: S,
16    pub to: S,
17    pub reason: String,
18}
19
20impl<S: Display> Display for TransitionError<S> {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        write!(
23            f,
24            "invalid transition from '{}' to '{}': {}",
25            self.from, self.to, self.reason
26        )
27    }
28}
29
30impl<S: Debug + Display> std::error::Error for TransitionError<S> {}
31
32/// A validated state transition.
33#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct Transition<S> {
35    pub from: S,
36    pub to: S,
37    pub forced: bool,
38}
39
40/// Trait for states that can be validated by a state machine.
41pub trait State: Copy + Clone + PartialEq + Eq + Hash + Debug + Display {
42    /// Returns true if this is a terminal state (no outgoing transitions allowed).
43    fn is_terminal(&self) -> bool;
44
45    /// Returns true if any state can transition to this state.
46    fn is_universal_target(&self) -> bool;
47
48    /// Returns the valid states that can be transitioned to from this state.
49    fn valid_targets(&self) -> &'static [Self];
50}
51
52/// Generic state machine for validating transitions.
53#[derive(Debug, Clone)]
54pub struct StateMachine<S: State> {
55    _marker: std::marker::PhantomData<S>,
56}
57
58impl<S: State + 'static> StateMachine<S> {
59    pub fn new() -> Self {
60        Self {
61            _marker: std::marker::PhantomData,
62        }
63    }
64
65    /// Validate a state transition.
66    ///
67    /// Returns Ok(Transition) if valid, Err(TransitionError) if invalid.
68    /// If `force` is true, any transition is allowed (for admin overrides).
69    pub fn validate(
70        &self,
71        from: S,
72        to: S,
73        force: bool,
74    ) -> Result<Transition<S>, TransitionError<S>> {
75        // Force flag bypasses all validation
76        if force {
77            return Ok(Transition {
78                from,
79                to,
80                forced: true,
81            });
82        }
83
84        // No-op transitions are always valid
85        if from == to {
86            return Ok(Transition {
87                from,
88                to,
89                forced: false,
90            });
91        }
92
93        // Cannot transition out of terminal states
94        if from.is_terminal() {
95            return Err(TransitionError {
96                from,
97                to,
98                reason: format!("'{}' is a terminal state", from),
99            });
100        }
101
102        // Universal targets (like Dropped) are always reachable
103        if to.is_universal_target() {
104            return Ok(Transition {
105                from,
106                to,
107                forced: false,
108            });
109        }
110
111        // Check if target is in valid transitions
112        if from.valid_targets().contains(&to) {
113            return Ok(Transition {
114                from,
115                to,
116                forced: false,
117            });
118        }
119
120        Err(TransitionError {
121            from,
122            to,
123            reason: format!(
124                "valid targets from '{}' are: {}",
125                from,
126                format_targets(from.valid_targets())
127            ),
128        })
129    }
130}
131
132impl<S: State + 'static> Default for StateMachine<S> {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// Format a list of valid targets for error messages.
139fn format_targets<S: Display>(targets: &[S]) -> String {
140    if targets.is_empty() {
141        "none (terminal state)".to_string()
142    } else {
143        targets
144            .iter()
145            .map(|t| format!("'{}'", t))
146            .collect::<Vec<_>>()
147            .join(", ")
148    }
149}
150
151// =============================================================================
152// TaskStatus State Implementation
153// =============================================================================
154
155impl State for TaskStatus {
156    fn is_terminal(&self) -> bool {
157        matches!(self, TaskStatus::Merged | TaskStatus::Dropped)
158    }
159
160    fn is_universal_target(&self) -> bool {
161        matches!(self, TaskStatus::Dropped)
162    }
163
164    fn valid_targets(&self) -> &'static [Self] {
165        match self {
166            // queued -> running (session picks up task)
167            TaskStatus::Queued => &[TaskStatus::Running],
168
169            // running -> blocked (resource/dependency constraint)
170            // running -> needs-review (work complete)
171            TaskStatus::Running => &[TaskStatus::Blocked, TaskStatus::NeedsReview],
172
173            // blocked -> running (constraint resolved)
174            TaskStatus::Blocked => &[TaskStatus::Running],
175
176            // needs-review -> merged (approved)
177            // needs-review -> blocked (merge conflicts or check failures)
178            TaskStatus::NeedsReview => &[TaskStatus::Merged, TaskStatus::Blocked],
179
180            // Terminal states have no valid outgoing transitions
181            TaskStatus::Merged => &[],
182            TaskStatus::Dropped => &[],
183        }
184    }
185}
186
187impl Display for TaskStatus {
188    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        match self {
190            TaskStatus::Queued => write!(f, "queued"),
191            TaskStatus::Running => write!(f, "running"),
192            TaskStatus::Blocked => write!(f, "blocked"),
193            TaskStatus::NeedsReview => write!(f, "needs-review"),
194            TaskStatus::Merged => write!(f, "merged"),
195            TaskStatus::Dropped => write!(f, "dropped"),
196        }
197    }
198}
199
200// =============================================================================
201// SessionStatus State Implementation
202// =============================================================================
203
204impl State for SessionStatus {
205    fn is_terminal(&self) -> bool {
206        matches!(self, SessionStatus::Exit)
207    }
208
209    fn is_universal_target(&self) -> bool {
210        // Exit can be reached from any state (failure, timeout, user stop)
211        matches!(self, SessionStatus::Exit)
212    }
213
214    fn valid_targets(&self) -> &'static [Self] {
215        match self {
216            // spawned -> ready (engine health check passes)
217            SessionStatus::Spawned => &[SessionStatus::Ready],
218
219            // ready -> running (first task action begins)
220            SessionStatus::Ready => &[SessionStatus::Running],
221
222            // running -> handoff (task ready for review/merge)
223            SessionStatus::Running => &[SessionStatus::Handoff],
224
225            // handoff -> (exit only via universal target)
226            SessionStatus::Handoff => &[],
227
228            // Terminal state
229            SessionStatus::Exit => &[],
230        }
231    }
232}
233
234impl Display for SessionStatus {
235    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236        match self {
237            SessionStatus::Spawned => write!(f, "spawned"),
238            SessionStatus::Ready => write!(f, "ready"),
239            SessionStatus::Running => write!(f, "running"),
240            SessionStatus::Handoff => write!(f, "handoff"),
241            SessionStatus::Exit => write!(f, "exit"),
242        }
243    }
244}
245
246// =============================================================================
247// Tests
248// =============================================================================
249
250#[cfg(test)]
251mod tests {
252    use super::*;
253
254    // -------------------------------------------------------------------------
255    // TaskStatus Tests
256    // -------------------------------------------------------------------------
257
258    #[test]
259    fn test_valid_task_transitions() {
260        let sm = StateMachine::<TaskStatus>::new();
261
262        // Valid transitions
263        assert!(sm.validate(TaskStatus::Queued, TaskStatus::Running, false).is_ok());
264        assert!(sm.validate(TaskStatus::Running, TaskStatus::Blocked, false).is_ok());
265        assert!(sm.validate(TaskStatus::Running, TaskStatus::NeedsReview, false).is_ok());
266        assert!(sm.validate(TaskStatus::Blocked, TaskStatus::Running, false).is_ok());
267        assert!(sm.validate(TaskStatus::NeedsReview, TaskStatus::Merged, false).is_ok());
268        assert!(sm.validate(TaskStatus::NeedsReview, TaskStatus::Blocked, false).is_ok());
269    }
270
271    #[test]
272    fn test_invalid_task_transitions() {
273        let sm = StateMachine::<TaskStatus>::new();
274
275        // Invalid: cannot skip states
276        assert!(sm.validate(TaskStatus::Queued, TaskStatus::NeedsReview, false).is_err());
277        assert!(sm.validate(TaskStatus::Queued, TaskStatus::Merged, false).is_err());
278
279        // Invalid: cannot go backward (except blocked -> running)
280        assert!(sm.validate(TaskStatus::Running, TaskStatus::Queued, false).is_err());
281        assert!(sm.validate(TaskStatus::NeedsReview, TaskStatus::Running, false).is_err());
282    }
283
284    #[test]
285    fn test_dropped_from_any_state() {
286        let sm = StateMachine::<TaskStatus>::new();
287
288        // Dropped is reachable from any non-terminal state
289        for status in [
290            TaskStatus::Queued,
291            TaskStatus::Running,
292            TaskStatus::Blocked,
293            TaskStatus::NeedsReview,
294        ] {
295            assert!(sm.validate(status, TaskStatus::Dropped, false).is_ok());
296        }
297    }
298
299    #[test]
300    fn test_terminal_states_cannot_transition() {
301        let sm = StateMachine::<TaskStatus>::new();
302
303        // Cannot transition out of Merged
304        let err = sm.validate(TaskStatus::Merged, TaskStatus::Running, false).unwrap_err();
305        assert!(err.reason.contains("terminal state"));
306
307        // Cannot transition out of Dropped
308        let err = sm.validate(TaskStatus::Dropped, TaskStatus::Running, false).unwrap_err();
309        assert!(err.reason.contains("terminal state"));
310    }
311
312    #[test]
313    fn test_force_bypasses_validation() {
314        let sm = StateMachine::<TaskStatus>::new();
315
316        // Force allows any transition, even from terminal states
317        let result = sm.validate(TaskStatus::Merged, TaskStatus::Running, true);
318        assert!(result.is_ok());
319        assert!(result.unwrap().forced);
320    }
321
322    #[test]
323    fn test_noop_transition_always_valid() {
324        let sm = StateMachine::<TaskStatus>::new();
325
326        for status in [
327            TaskStatus::Queued,
328            TaskStatus::Running,
329            TaskStatus::Blocked,
330            TaskStatus::NeedsReview,
331            TaskStatus::Merged,
332            TaskStatus::Dropped,
333        ] {
334            let result = sm.validate(status, status, false);
335            assert!(result.is_ok());
336            assert!(!result.unwrap().forced);
337        }
338    }
339
340    // -------------------------------------------------------------------------
341    // SessionStatus Tests
342    // -------------------------------------------------------------------------
343
344    #[test]
345    fn test_valid_session_transitions() {
346        let sm = StateMachine::<SessionStatus>::new();
347
348        assert!(sm.validate(SessionStatus::Spawned, SessionStatus::Ready, false).is_ok());
349        assert!(sm.validate(SessionStatus::Ready, SessionStatus::Running, false).is_ok());
350        assert!(sm.validate(SessionStatus::Running, SessionStatus::Handoff, false).is_ok());
351    }
352
353    #[test]
354    fn test_exit_from_any_session_state() {
355        let sm = StateMachine::<SessionStatus>::new();
356
357        // Exit is reachable from any non-terminal state
358        for status in [
359            SessionStatus::Spawned,
360            SessionStatus::Ready,
361            SessionStatus::Running,
362            SessionStatus::Handoff,
363        ] {
364            assert!(sm.validate(status, SessionStatus::Exit, false).is_ok());
365        }
366    }
367
368    #[test]
369    fn test_session_terminal_state() {
370        let sm = StateMachine::<SessionStatus>::new();
371
372        // Cannot transition out of Exit
373        let err = sm.validate(SessionStatus::Exit, SessionStatus::Running, false).unwrap_err();
374        assert!(err.reason.contains("terminal state"));
375    }
376
377    #[test]
378    fn test_invalid_session_transitions() {
379        let sm = StateMachine::<SessionStatus>::new();
380
381        // Cannot skip states
382        assert!(sm.validate(SessionStatus::Spawned, SessionStatus::Running, false).is_err());
383        assert!(sm.validate(SessionStatus::Ready, SessionStatus::Handoff, false).is_err());
384
385        // Cannot go backward
386        assert!(sm.validate(SessionStatus::Running, SessionStatus::Ready, false).is_err());
387        assert!(sm.validate(SessionStatus::Handoff, SessionStatus::Running, false).is_err());
388    }
389
390    #[test]
391    fn test_transition_error_display() {
392        let sm = StateMachine::<TaskStatus>::new();
393        let err = sm.validate(TaskStatus::Queued, TaskStatus::Merged, false).unwrap_err();
394        let msg = err.to_string();
395        assert!(msg.contains("queued"));
396        assert!(msg.contains("merged"));
397        assert!(msg.contains("valid targets"));
398    }
399}