awaken-contract 0.4.0

Core types, traits, and state model for the Awaken AI agent runtime
Documentation
use serde::{Deserialize, Serialize};

/// Lifecycle phase within an agent run.
///
/// # Examples
///
/// ```
/// use awaken_contract::Phase;
///
/// assert!(Phase::RunStart.is_run_level());
/// assert!(!Phase::RunStart.is_step_level());
/// assert!(Phase::BeforeInference.is_step_level());
/// assert_eq!(Phase::ALL.len(), 9);
/// ```
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Phase {
    RunStart,
    StepStart,
    BeforeInference,
    AfterInference,
    ToolGate,
    BeforeToolExecute,
    AfterToolExecute,
    StepEnd,
    RunEnd,
}

impl Phase {
    /// All phases in execution order.
    pub const ALL: [Phase; 9] = [
        Phase::RunStart,
        Phase::StepStart,
        Phase::BeforeInference,
        Phase::AfterInference,
        Phase::ToolGate,
        Phase::BeforeToolExecute,
        Phase::AfterToolExecute,
        Phase::StepEnd,
        Phase::RunEnd,
    ];

    /// Whether this phase runs once per run (not per step).
    pub fn is_run_level(self) -> bool {
        matches!(self, Phase::RunStart | Phase::RunEnd)
    }

    /// Whether this phase runs within the step loop.
    pub fn is_step_level(self) -> bool {
        !self.is_run_level()
    }
}

impl std::fmt::Display for Phase {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        match self {
            Phase::RunStart => write!(f, "RunStart"),
            Phase::StepStart => write!(f, "StepStart"),
            Phase::BeforeInference => write!(f, "BeforeInference"),
            Phase::AfterInference => write!(f, "AfterInference"),
            Phase::ToolGate => write!(f, "ToolGate"),
            Phase::BeforeToolExecute => write!(f, "BeforeToolExecute"),
            Phase::AfterToolExecute => write!(f, "AfterToolExecute"),
            Phase::StepEnd => write!(f, "StepEnd"),
            Phase::RunEnd => write!(f, "RunEnd"),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn phase_all_has_9_variants() {
        assert_eq!(Phase::ALL.len(), 9);
    }

    #[test]
    fn phase_all_order_matches_lifecycle() {
        let order = Phase::ALL;
        assert_eq!(order[0], Phase::RunStart);
        assert_eq!(order[1], Phase::StepStart);
        assert_eq!(order[2], Phase::BeforeInference);
        assert_eq!(order[3], Phase::AfterInference);
        assert_eq!(order[4], Phase::ToolGate);
        assert_eq!(order[5], Phase::BeforeToolExecute);
        assert_eq!(order[6], Phase::AfterToolExecute);
        assert_eq!(order[7], Phase::StepEnd);
        assert_eq!(order[8], Phase::RunEnd);
    }

    #[test]
    fn phase_run_level_vs_step_level() {
        assert!(Phase::RunStart.is_run_level());
        assert!(Phase::RunEnd.is_run_level());
        assert!(!Phase::RunStart.is_step_level());

        for phase in [
            Phase::StepStart,
            Phase::BeforeInference,
            Phase::AfterInference,
            Phase::ToolGate,
            Phase::BeforeToolExecute,
            Phase::AfterToolExecute,
            Phase::StepEnd,
        ] {
            assert!(phase.is_step_level(), "{phase} should be step-level");
            assert!(!phase.is_run_level(), "{phase} should not be run-level");
        }
    }

    #[test]
    fn phase_serde_roundtrip() {
        for phase in Phase::ALL {
            let json = serde_json::to_string(&phase).unwrap();
            let parsed: Phase = serde_json::from_str(&json).unwrap();
            assert_eq!(parsed, phase);
        }
    }

    #[test]
    fn phase_display() {
        assert_eq!(Phase::StepStart.to_string(), "StepStart");
        assert_eq!(Phase::BeforeInference.to_string(), "BeforeInference");
    }

    #[test]
    fn phase_serde_snake_case() {
        assert_eq!(
            serde_json::to_string(&Phase::StepStart).unwrap(),
            "\"step_start\""
        );
        assert_eq!(
            serde_json::to_string(&Phase::BeforeToolExecute).unwrap(),
            "\"before_tool_execute\""
        );
        assert_eq!(
            serde_json::to_string(&Phase::ToolGate).unwrap(),
            "\"tool_gate\""
        );
    }

    #[test]
    fn phase_display_all_variants() {
        let expected = [
            "RunStart",
            "StepStart",
            "BeforeInference",
            "AfterInference",
            "ToolGate",
            "BeforeToolExecute",
            "AfterToolExecute",
            "StepEnd",
            "RunEnd",
        ];
        for (phase, name) in Phase::ALL.iter().zip(expected.iter()) {
            assert_eq!(phase.to_string(), *name);
        }
    }

    #[test]
    fn phase_equality_and_hash() {
        assert_eq!(Phase::RunStart, Phase::RunStart);
        assert_ne!(Phase::RunStart, Phase::RunEnd);

        let mut set = std::collections::HashSet::new();
        for phase in Phase::ALL {
            assert!(set.insert(phase), "duplicate phase: {phase}");
        }
        assert_eq!(set.len(), 9);
    }

    #[test]
    fn phase_clone() {
        let phase = Phase::BeforeInference;
        let cloned = phase;
        assert_eq!(phase, cloned);
    }
}