Skip to main content

awaken_contract/model/
phase.rs

1use serde::{Deserialize, Serialize};
2
3/// Lifecycle phase within an agent run.
4///
5/// # Examples
6///
7/// ```
8/// use awaken_contract::Phase;
9///
10/// assert!(Phase::RunStart.is_run_level());
11/// assert!(!Phase::RunStart.is_step_level());
12/// assert!(Phase::BeforeInference.is_step_level());
13/// assert_eq!(Phase::ALL.len(), 9);
14/// ```
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum Phase {
18    RunStart,
19    StepStart,
20    BeforeInference,
21    AfterInference,
22    ToolGate,
23    BeforeToolExecute,
24    AfterToolExecute,
25    StepEnd,
26    RunEnd,
27}
28
29impl Phase {
30    /// All phases in execution order.
31    pub const ALL: [Phase; 9] = [
32        Phase::RunStart,
33        Phase::StepStart,
34        Phase::BeforeInference,
35        Phase::AfterInference,
36        Phase::ToolGate,
37        Phase::BeforeToolExecute,
38        Phase::AfterToolExecute,
39        Phase::StepEnd,
40        Phase::RunEnd,
41    ];
42
43    /// Whether this phase runs once per run (not per step).
44    pub fn is_run_level(self) -> bool {
45        matches!(self, Phase::RunStart | Phase::RunEnd)
46    }
47
48    /// Whether this phase runs within the step loop.
49    pub fn is_step_level(self) -> bool {
50        !self.is_run_level()
51    }
52}
53
54impl std::fmt::Display for Phase {
55    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56        match self {
57            Phase::RunStart => write!(f, "RunStart"),
58            Phase::StepStart => write!(f, "StepStart"),
59            Phase::BeforeInference => write!(f, "BeforeInference"),
60            Phase::AfterInference => write!(f, "AfterInference"),
61            Phase::ToolGate => write!(f, "ToolGate"),
62            Phase::BeforeToolExecute => write!(f, "BeforeToolExecute"),
63            Phase::AfterToolExecute => write!(f, "AfterToolExecute"),
64            Phase::StepEnd => write!(f, "StepEnd"),
65            Phase::RunEnd => write!(f, "RunEnd"),
66        }
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73
74    #[test]
75    fn phase_all_has_9_variants() {
76        assert_eq!(Phase::ALL.len(), 9);
77    }
78
79    #[test]
80    fn phase_all_order_matches_lifecycle() {
81        let order = Phase::ALL;
82        assert_eq!(order[0], Phase::RunStart);
83        assert_eq!(order[1], Phase::StepStart);
84        assert_eq!(order[2], Phase::BeforeInference);
85        assert_eq!(order[3], Phase::AfterInference);
86        assert_eq!(order[4], Phase::ToolGate);
87        assert_eq!(order[5], Phase::BeforeToolExecute);
88        assert_eq!(order[6], Phase::AfterToolExecute);
89        assert_eq!(order[7], Phase::StepEnd);
90        assert_eq!(order[8], Phase::RunEnd);
91    }
92
93    #[test]
94    fn phase_run_level_vs_step_level() {
95        assert!(Phase::RunStart.is_run_level());
96        assert!(Phase::RunEnd.is_run_level());
97        assert!(!Phase::RunStart.is_step_level());
98
99        for phase in [
100            Phase::StepStart,
101            Phase::BeforeInference,
102            Phase::AfterInference,
103            Phase::ToolGate,
104            Phase::BeforeToolExecute,
105            Phase::AfterToolExecute,
106            Phase::StepEnd,
107        ] {
108            assert!(phase.is_step_level(), "{phase} should be step-level");
109            assert!(!phase.is_run_level(), "{phase} should not be run-level");
110        }
111    }
112
113    #[test]
114    fn phase_serde_roundtrip() {
115        for phase in Phase::ALL {
116            let json = serde_json::to_string(&phase).unwrap();
117            let parsed: Phase = serde_json::from_str(&json).unwrap();
118            assert_eq!(parsed, phase);
119        }
120    }
121
122    #[test]
123    fn phase_display() {
124        assert_eq!(Phase::StepStart.to_string(), "StepStart");
125        assert_eq!(Phase::BeforeInference.to_string(), "BeforeInference");
126    }
127
128    #[test]
129    fn phase_serde_snake_case() {
130        assert_eq!(
131            serde_json::to_string(&Phase::StepStart).unwrap(),
132            "\"step_start\""
133        );
134        assert_eq!(
135            serde_json::to_string(&Phase::BeforeToolExecute).unwrap(),
136            "\"before_tool_execute\""
137        );
138        assert_eq!(
139            serde_json::to_string(&Phase::ToolGate).unwrap(),
140            "\"tool_gate\""
141        );
142    }
143
144    #[test]
145    fn phase_display_all_variants() {
146        let expected = [
147            "RunStart",
148            "StepStart",
149            "BeforeInference",
150            "AfterInference",
151            "ToolGate",
152            "BeforeToolExecute",
153            "AfterToolExecute",
154            "StepEnd",
155            "RunEnd",
156        ];
157        for (phase, name) in Phase::ALL.iter().zip(expected.iter()) {
158            assert_eq!(phase.to_string(), *name);
159        }
160    }
161
162    #[test]
163    fn phase_equality_and_hash() {
164        assert_eq!(Phase::RunStart, Phase::RunStart);
165        assert_ne!(Phase::RunStart, Phase::RunEnd);
166
167        let mut set = std::collections::HashSet::new();
168        for phase in Phase::ALL {
169            assert!(set.insert(phase), "duplicate phase: {phase}");
170        }
171        assert_eq!(set.len(), 9);
172    }
173
174    #[test]
175    fn phase_clone() {
176        let phase = Phase::BeforeInference;
177        let cloned = phase;
178        assert_eq!(phase, cloned);
179    }
180}