1use std::fmt::{Debug, Display};
8use std::hash::Hash;
9
10use crate::types::{SessionStatus, TaskStatus};
11
12#[derive(Debug, Clone, PartialEq, Eq)]
14pub struct TransitionError<S> {
15 pub from: S,
16 pub to: S,
17 pub reason: String,
18}
19
20impl<S: Display> Display for TransitionError<S> {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 write!(
23 f,
24 "invalid transition from '{}' to '{}': {}",
25 self.from, self.to, self.reason
26 )
27 }
28}
29
30impl<S: Debug + Display> std::error::Error for TransitionError<S> {}
31
32#[derive(Debug, Clone, PartialEq, Eq)]
34pub struct Transition<S> {
35 pub from: S,
36 pub to: S,
37 pub forced: bool,
38}
39
40pub trait State: Copy + Clone + PartialEq + Eq + Hash + Debug + Display {
42 fn is_terminal(&self) -> bool;
44
45 fn is_universal_target(&self) -> bool;
47
48 fn valid_targets(&self) -> &'static [Self];
50}
51
52#[derive(Debug, Clone)]
54pub struct StateMachine<S: State> {
55 _marker: std::marker::PhantomData<S>,
56}
57
58impl<S: State + 'static> StateMachine<S> {
59 pub fn new() -> Self {
60 Self {
61 _marker: std::marker::PhantomData,
62 }
63 }
64
65 pub fn validate(
70 &self,
71 from: S,
72 to: S,
73 force: bool,
74 ) -> Result<Transition<S>, TransitionError<S>> {
75 if force {
77 return Ok(Transition {
78 from,
79 to,
80 forced: true,
81 });
82 }
83
84 if from == to {
86 return Ok(Transition {
87 from,
88 to,
89 forced: false,
90 });
91 }
92
93 if from.is_terminal() {
95 return Err(TransitionError {
96 from,
97 to,
98 reason: format!("'{}' is a terminal state", from),
99 });
100 }
101
102 if to.is_universal_target() {
104 return Ok(Transition {
105 from,
106 to,
107 forced: false,
108 });
109 }
110
111 if from.valid_targets().contains(&to) {
113 return Ok(Transition {
114 from,
115 to,
116 forced: false,
117 });
118 }
119
120 Err(TransitionError {
121 from,
122 to,
123 reason: format!(
124 "valid targets from '{}' are: {}",
125 from,
126 format_targets(from.valid_targets())
127 ),
128 })
129 }
130}
131
132impl<S: State + 'static> Default for StateMachine<S> {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138fn format_targets<S: Display>(targets: &[S]) -> String {
140 if targets.is_empty() {
141 "none (terminal state)".to_string()
142 } else {
143 targets
144 .iter()
145 .map(|t| format!("'{}'", t))
146 .collect::<Vec<_>>()
147 .join(", ")
148 }
149}
150
151impl State for TaskStatus {
156 fn is_terminal(&self) -> bool {
157 matches!(self, TaskStatus::Merged | TaskStatus::Dropped)
158 }
159
160 fn is_universal_target(&self) -> bool {
161 matches!(self, TaskStatus::Dropped)
162 }
163
164 fn valid_targets(&self) -> &'static [Self] {
165 match self {
166 TaskStatus::Queued => &[TaskStatus::Running],
168
169 TaskStatus::Running => &[TaskStatus::Blocked, TaskStatus::NeedsReview],
172
173 TaskStatus::Blocked => &[TaskStatus::Running],
175
176 TaskStatus::NeedsReview => &[TaskStatus::Merged, TaskStatus::Blocked],
179
180 TaskStatus::Merged => &[],
182 TaskStatus::Dropped => &[],
183 }
184 }
185}
186
187impl Display for TaskStatus {
188 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189 match self {
190 TaskStatus::Queued => write!(f, "queued"),
191 TaskStatus::Running => write!(f, "running"),
192 TaskStatus::Blocked => write!(f, "blocked"),
193 TaskStatus::NeedsReview => write!(f, "needs-review"),
194 TaskStatus::Merged => write!(f, "merged"),
195 TaskStatus::Dropped => write!(f, "dropped"),
196 }
197 }
198}
199
200impl State for SessionStatus {
205 fn is_terminal(&self) -> bool {
206 matches!(self, SessionStatus::Exit)
207 }
208
209 fn is_universal_target(&self) -> bool {
210 matches!(self, SessionStatus::Exit)
212 }
213
214 fn valid_targets(&self) -> &'static [Self] {
215 match self {
216 SessionStatus::Spawned => &[SessionStatus::Ready],
218
219 SessionStatus::Ready => &[SessionStatus::Running],
221
222 SessionStatus::Running => &[SessionStatus::Handoff],
224
225 SessionStatus::Handoff => &[],
227
228 SessionStatus::Exit => &[],
230 }
231 }
232}
233
234impl Display for SessionStatus {
235 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
236 match self {
237 SessionStatus::Spawned => write!(f, "spawned"),
238 SessionStatus::Ready => write!(f, "ready"),
239 SessionStatus::Running => write!(f, "running"),
240 SessionStatus::Handoff => write!(f, "handoff"),
241 SessionStatus::Exit => write!(f, "exit"),
242 }
243 }
244}
245
246#[cfg(test)]
251mod tests {
252 use super::*;
253
254 #[test]
259 fn test_valid_task_transitions() {
260 let sm = StateMachine::<TaskStatus>::new();
261
262 assert!(sm.validate(TaskStatus::Queued, TaskStatus::Running, false).is_ok());
264 assert!(sm.validate(TaskStatus::Running, TaskStatus::Blocked, false).is_ok());
265 assert!(sm.validate(TaskStatus::Running, TaskStatus::NeedsReview, false).is_ok());
266 assert!(sm.validate(TaskStatus::Blocked, TaskStatus::Running, false).is_ok());
267 assert!(sm.validate(TaskStatus::NeedsReview, TaskStatus::Merged, false).is_ok());
268 assert!(sm.validate(TaskStatus::NeedsReview, TaskStatus::Blocked, false).is_ok());
269 }
270
271 #[test]
272 fn test_invalid_task_transitions() {
273 let sm = StateMachine::<TaskStatus>::new();
274
275 assert!(sm.validate(TaskStatus::Queued, TaskStatus::NeedsReview, false).is_err());
277 assert!(sm.validate(TaskStatus::Queued, TaskStatus::Merged, false).is_err());
278
279 assert!(sm.validate(TaskStatus::Running, TaskStatus::Queued, false).is_err());
281 assert!(sm.validate(TaskStatus::NeedsReview, TaskStatus::Running, false).is_err());
282 }
283
284 #[test]
285 fn test_dropped_from_any_state() {
286 let sm = StateMachine::<TaskStatus>::new();
287
288 for status in [
290 TaskStatus::Queued,
291 TaskStatus::Running,
292 TaskStatus::Blocked,
293 TaskStatus::NeedsReview,
294 ] {
295 assert!(sm.validate(status, TaskStatus::Dropped, false).is_ok());
296 }
297 }
298
299 #[test]
300 fn test_terminal_states_cannot_transition() {
301 let sm = StateMachine::<TaskStatus>::new();
302
303 let err = sm.validate(TaskStatus::Merged, TaskStatus::Running, false).unwrap_err();
305 assert!(err.reason.contains("terminal state"));
306
307 let err = sm.validate(TaskStatus::Dropped, TaskStatus::Running, false).unwrap_err();
309 assert!(err.reason.contains("terminal state"));
310 }
311
312 #[test]
313 fn test_force_bypasses_validation() {
314 let sm = StateMachine::<TaskStatus>::new();
315
316 let result = sm.validate(TaskStatus::Merged, TaskStatus::Running, true);
318 assert!(result.is_ok());
319 assert!(result.unwrap().forced);
320 }
321
322 #[test]
323 fn test_noop_transition_always_valid() {
324 let sm = StateMachine::<TaskStatus>::new();
325
326 for status in [
327 TaskStatus::Queued,
328 TaskStatus::Running,
329 TaskStatus::Blocked,
330 TaskStatus::NeedsReview,
331 TaskStatus::Merged,
332 TaskStatus::Dropped,
333 ] {
334 let result = sm.validate(status, status, false);
335 assert!(result.is_ok());
336 assert!(!result.unwrap().forced);
337 }
338 }
339
340 #[test]
345 fn test_valid_session_transitions() {
346 let sm = StateMachine::<SessionStatus>::new();
347
348 assert!(sm.validate(SessionStatus::Spawned, SessionStatus::Ready, false).is_ok());
349 assert!(sm.validate(SessionStatus::Ready, SessionStatus::Running, false).is_ok());
350 assert!(sm.validate(SessionStatus::Running, SessionStatus::Handoff, false).is_ok());
351 }
352
353 #[test]
354 fn test_exit_from_any_session_state() {
355 let sm = StateMachine::<SessionStatus>::new();
356
357 for status in [
359 SessionStatus::Spawned,
360 SessionStatus::Ready,
361 SessionStatus::Running,
362 SessionStatus::Handoff,
363 ] {
364 assert!(sm.validate(status, SessionStatus::Exit, false).is_ok());
365 }
366 }
367
368 #[test]
369 fn test_session_terminal_state() {
370 let sm = StateMachine::<SessionStatus>::new();
371
372 let err = sm.validate(SessionStatus::Exit, SessionStatus::Running, false).unwrap_err();
374 assert!(err.reason.contains("terminal state"));
375 }
376
377 #[test]
378 fn test_invalid_session_transitions() {
379 let sm = StateMachine::<SessionStatus>::new();
380
381 assert!(sm.validate(SessionStatus::Spawned, SessionStatus::Running, false).is_err());
383 assert!(sm.validate(SessionStatus::Ready, SessionStatus::Handoff, false).is_err());
384
385 assert!(sm.validate(SessionStatus::Running, SessionStatus::Ready, false).is_err());
387 assert!(sm.validate(SessionStatus::Handoff, SessionStatus::Running, false).is_err());
388 }
389
390 #[test]
391 fn test_transition_error_display() {
392 let sm = StateMachine::<TaskStatus>::new();
393 let err = sm.validate(TaskStatus::Queued, TaskStatus::Merged, false).unwrap_err();
394 let msg = err.to_string();
395 assert!(msg.contains("queued"));
396 assert!(msg.contains("merged"));
397 assert!(msg.contains("valid targets"));
398 }
399}