use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Phase {
RunStart,
StepStart,
BeforeInference,
AfterInference,
ToolGate,
BeforeToolExecute,
AfterToolExecute,
StepEnd,
RunEnd,
}
impl Phase {
pub const ALL: [Phase; 9] = [
Phase::RunStart,
Phase::StepStart,
Phase::BeforeInference,
Phase::AfterInference,
Phase::ToolGate,
Phase::BeforeToolExecute,
Phase::AfterToolExecute,
Phase::StepEnd,
Phase::RunEnd,
];
pub fn is_run_level(self) -> bool {
matches!(self, Phase::RunStart | Phase::RunEnd)
}
pub fn is_step_level(self) -> bool {
!self.is_run_level()
}
}
impl std::fmt::Display for Phase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Phase::RunStart => write!(f, "RunStart"),
Phase::StepStart => write!(f, "StepStart"),
Phase::BeforeInference => write!(f, "BeforeInference"),
Phase::AfterInference => write!(f, "AfterInference"),
Phase::ToolGate => write!(f, "ToolGate"),
Phase::BeforeToolExecute => write!(f, "BeforeToolExecute"),
Phase::AfterToolExecute => write!(f, "AfterToolExecute"),
Phase::StepEnd => write!(f, "StepEnd"),
Phase::RunEnd => write!(f, "RunEnd"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn phase_all_has_9_variants() {
assert_eq!(Phase::ALL.len(), 9);
}
#[test]
fn phase_all_order_matches_lifecycle() {
let order = Phase::ALL;
assert_eq!(order[0], Phase::RunStart);
assert_eq!(order[1], Phase::StepStart);
assert_eq!(order[2], Phase::BeforeInference);
assert_eq!(order[3], Phase::AfterInference);
assert_eq!(order[4], Phase::ToolGate);
assert_eq!(order[5], Phase::BeforeToolExecute);
assert_eq!(order[6], Phase::AfterToolExecute);
assert_eq!(order[7], Phase::StepEnd);
assert_eq!(order[8], Phase::RunEnd);
}
#[test]
fn phase_run_level_vs_step_level() {
assert!(Phase::RunStart.is_run_level());
assert!(Phase::RunEnd.is_run_level());
assert!(!Phase::RunStart.is_step_level());
for phase in [
Phase::StepStart,
Phase::BeforeInference,
Phase::AfterInference,
Phase::ToolGate,
Phase::BeforeToolExecute,
Phase::AfterToolExecute,
Phase::StepEnd,
] {
assert!(phase.is_step_level(), "{phase} should be step-level");
assert!(!phase.is_run_level(), "{phase} should not be run-level");
}
}
#[test]
fn phase_serde_roundtrip() {
for phase in Phase::ALL {
let json = serde_json::to_string(&phase).unwrap();
let parsed: Phase = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, phase);
}
}
#[test]
fn phase_display() {
assert_eq!(Phase::StepStart.to_string(), "StepStart");
assert_eq!(Phase::BeforeInference.to_string(), "BeforeInference");
}
#[test]
fn phase_serde_snake_case() {
assert_eq!(
serde_json::to_string(&Phase::StepStart).unwrap(),
"\"step_start\""
);
assert_eq!(
serde_json::to_string(&Phase::BeforeToolExecute).unwrap(),
"\"before_tool_execute\""
);
assert_eq!(
serde_json::to_string(&Phase::ToolGate).unwrap(),
"\"tool_gate\""
);
}
#[test]
fn phase_display_all_variants() {
let expected = [
"RunStart",
"StepStart",
"BeforeInference",
"AfterInference",
"ToolGate",
"BeforeToolExecute",
"AfterToolExecute",
"StepEnd",
"RunEnd",
];
for (phase, name) in Phase::ALL.iter().zip(expected.iter()) {
assert_eq!(phase.to_string(), *name);
}
}
#[test]
fn phase_equality_and_hash() {
assert_eq!(Phase::RunStart, Phase::RunStart);
assert_ne!(Phase::RunStart, Phase::RunEnd);
let mut set = std::collections::HashSet::new();
for phase in Phase::ALL {
assert!(set.insert(phase), "duplicate phase: {phase}");
}
assert_eq!(set.len(), 9);
}
#[test]
fn phase_clone() {
let phase = Phase::BeforeInference;
let cloned = phase;
assert_eq!(phase, cloned);
}
}