1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
16#[serde(rename_all = "snake_case")]
17pub enum Phase {
18 RunStart,
19 StepStart,
20 BeforeInference,
21 AfterInference,
22 ToolGate,
23 BeforeToolExecute,
24 AfterToolExecute,
25 StepEnd,
26 RunEnd,
27}
28
29impl Phase {
30 pub const ALL: [Phase; 9] = [
32 Phase::RunStart,
33 Phase::StepStart,
34 Phase::BeforeInference,
35 Phase::AfterInference,
36 Phase::ToolGate,
37 Phase::BeforeToolExecute,
38 Phase::AfterToolExecute,
39 Phase::StepEnd,
40 Phase::RunEnd,
41 ];
42
43 pub fn is_run_level(self) -> bool {
45 matches!(self, Phase::RunStart | Phase::RunEnd)
46 }
47
48 pub fn is_step_level(self) -> bool {
50 !self.is_run_level()
51 }
52}
53
54impl std::fmt::Display for Phase {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 Phase::RunStart => write!(f, "RunStart"),
58 Phase::StepStart => write!(f, "StepStart"),
59 Phase::BeforeInference => write!(f, "BeforeInference"),
60 Phase::AfterInference => write!(f, "AfterInference"),
61 Phase::ToolGate => write!(f, "ToolGate"),
62 Phase::BeforeToolExecute => write!(f, "BeforeToolExecute"),
63 Phase::AfterToolExecute => write!(f, "AfterToolExecute"),
64 Phase::StepEnd => write!(f, "StepEnd"),
65 Phase::RunEnd => write!(f, "RunEnd"),
66 }
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73
74 #[test]
75 fn phase_all_has_9_variants() {
76 assert_eq!(Phase::ALL.len(), 9);
77 }
78
79 #[test]
80 fn phase_all_order_matches_lifecycle() {
81 let order = Phase::ALL;
82 assert_eq!(order[0], Phase::RunStart);
83 assert_eq!(order[1], Phase::StepStart);
84 assert_eq!(order[2], Phase::BeforeInference);
85 assert_eq!(order[3], Phase::AfterInference);
86 assert_eq!(order[4], Phase::ToolGate);
87 assert_eq!(order[5], Phase::BeforeToolExecute);
88 assert_eq!(order[6], Phase::AfterToolExecute);
89 assert_eq!(order[7], Phase::StepEnd);
90 assert_eq!(order[8], Phase::RunEnd);
91 }
92
93 #[test]
94 fn phase_run_level_vs_step_level() {
95 assert!(Phase::RunStart.is_run_level());
96 assert!(Phase::RunEnd.is_run_level());
97 assert!(!Phase::RunStart.is_step_level());
98
99 for phase in [
100 Phase::StepStart,
101 Phase::BeforeInference,
102 Phase::AfterInference,
103 Phase::ToolGate,
104 Phase::BeforeToolExecute,
105 Phase::AfterToolExecute,
106 Phase::StepEnd,
107 ] {
108 assert!(phase.is_step_level(), "{phase} should be step-level");
109 assert!(!phase.is_run_level(), "{phase} should not be run-level");
110 }
111 }
112
113 #[test]
114 fn phase_serde_roundtrip() {
115 for phase in Phase::ALL {
116 let json = serde_json::to_string(&phase).unwrap();
117 let parsed: Phase = serde_json::from_str(&json).unwrap();
118 assert_eq!(parsed, phase);
119 }
120 }
121
122 #[test]
123 fn phase_display() {
124 assert_eq!(Phase::StepStart.to_string(), "StepStart");
125 assert_eq!(Phase::BeforeInference.to_string(), "BeforeInference");
126 }
127
128 #[test]
129 fn phase_serde_snake_case() {
130 assert_eq!(
131 serde_json::to_string(&Phase::StepStart).unwrap(),
132 "\"step_start\""
133 );
134 assert_eq!(
135 serde_json::to_string(&Phase::BeforeToolExecute).unwrap(),
136 "\"before_tool_execute\""
137 );
138 assert_eq!(
139 serde_json::to_string(&Phase::ToolGate).unwrap(),
140 "\"tool_gate\""
141 );
142 }
143
144 #[test]
145 fn phase_display_all_variants() {
146 let expected = [
147 "RunStart",
148 "StepStart",
149 "BeforeInference",
150 "AfterInference",
151 "ToolGate",
152 "BeforeToolExecute",
153 "AfterToolExecute",
154 "StepEnd",
155 "RunEnd",
156 ];
157 for (phase, name) in Phase::ALL.iter().zip(expected.iter()) {
158 assert_eq!(phase.to_string(), *name);
159 }
160 }
161
162 #[test]
163 fn phase_equality_and_hash() {
164 assert_eq!(Phase::RunStart, Phase::RunStart);
165 assert_ne!(Phase::RunStart, Phase::RunEnd);
166
167 let mut set = std::collections::HashSet::new();
168 for phase in Phase::ALL {
169 assert!(set.insert(phase), "duplicate phase: {phase}");
170 }
171 assert_eq!(set.len(), 9);
172 }
173
174 #[test]
175 fn phase_clone() {
176 let phase = Phase::BeforeInference;
177 let cloned = phase;
178 assert_eq!(phase, cloned);
179 }
180}