use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum AgentHookPoint {
BeforeSessionSetup,
AfterSessionSetup,
BeforeFinalize,
BeforeRound,
AfterRound,
BeforePromptAssembly,
AfterPromptAssembly,
BeforeLlmCall,
AfterLlmCall,
BeforeToolExecution,
AfterToolExecution,
BeforeMemoryRecall,
AfterMemoryRecall,
BeforeCompression,
AfterCompression,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case", tag = "type")]
pub enum HookResult {
#[default]
Continue,
Mutated,
Suspend { reason: String },
Abort { reason: String },
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn hook_point_serialization_round_trip() {
let points = [
AgentHookPoint::BeforeSessionSetup,
AgentHookPoint::AfterSessionSetup,
AgentHookPoint::BeforeFinalize,
AgentHookPoint::BeforeRound,
AgentHookPoint::AfterRound,
AgentHookPoint::BeforePromptAssembly,
AgentHookPoint::AfterPromptAssembly,
AgentHookPoint::BeforeLlmCall,
AgentHookPoint::AfterLlmCall,
AgentHookPoint::BeforeToolExecution,
AgentHookPoint::AfterToolExecution,
AgentHookPoint::BeforeMemoryRecall,
AgentHookPoint::AfterMemoryRecall,
AgentHookPoint::BeforeCompression,
AgentHookPoint::AfterCompression,
];
for point in &points {
let json = serde_json::to_string(point).unwrap();
let restored: AgentHookPoint = serde_json::from_str(&json).unwrap();
assert_eq!(point, &restored);
}
}
#[test]
fn hook_result_default_is_continue() {
assert_eq!(HookResult::default(), HookResult::Continue);
}
#[test]
fn hook_result_variants_serialize() {
let variants = [
HookResult::Continue,
HookResult::Mutated,
HookResult::Suspend {
reason: "waiting".to_string(),
},
HookResult::Abort {
reason: "error".to_string(),
},
];
for variant in &variants {
let json = serde_json::to_string(variant).unwrap();
let restored: HookResult = serde_json::from_str(&json).unwrap();
assert_eq!(variant, &restored);
}
}
}