use std::time::Duration;
use async_trait::async_trait;
use crate::agent::AgentOutput;
use crate::types::completion::{CompletionRequest, CompletionResponse};
#[derive(Debug, Clone)]
pub enum HookAction {
Continue,
Block(String),
}
#[async_trait]
pub trait AgentHook: Send + Sync + 'static {
async fn on_agent_start(&self, _input: &str) {}
async fn on_agent_end(&self, _output: &AgentOutput, _duration: Duration) {}
async fn on_provider_start(&self, _request: &CompletionRequest) {}
async fn on_provider_end(&self, _response: &CompletionResponse, _duration: Duration) {}
async fn before_tool_execute(&self, _name: &str, _args: &serde_json::Value) -> HookAction {
HookAction::Continue
}
async fn after_tool_execute(&self, _name: &str, _result: &str, _duration: Duration) {}
async fn on_stream_chunk(&self, _chunk: &str) {}
async fn on_error(&self, _error: &crate::Error) {}
}
#[async_trait]
impl<T: AgentHook> AgentHook for std::sync::Arc<T> {
async fn on_agent_start(&self, input: &str) {
(**self).on_agent_start(input).await;
}
async fn on_agent_end(&self, output: &AgentOutput, duration: Duration) {
(**self).on_agent_end(output, duration).await;
}
async fn on_provider_start(&self, request: &CompletionRequest) {
(**self).on_provider_start(request).await;
}
async fn on_provider_end(&self, response: &CompletionResponse, duration: Duration) {
(**self).on_provider_end(response, duration).await;
}
async fn before_tool_execute(&self, name: &str, args: &serde_json::Value) -> HookAction {
(**self).before_tool_execute(name, args).await
}
async fn after_tool_execute(&self, name: &str, result: &str, duration: Duration) {
(**self).after_tool_execute(name, result, duration).await;
}
async fn on_stream_chunk(&self, chunk: &str) {
(**self).on_stream_chunk(chunk).await;
}
async fn on_error(&self, error: &crate::Error) {
(**self).on_error(error).await;
}
}
pub struct LoggingHook {
level: tracing::Level,
}
impl LoggingHook {
#[must_use]
pub fn new(level: tracing::Level) -> Self {
Self { level }
}
}
#[async_trait]
impl AgentHook for LoggingHook {
async fn on_agent_start(&self, input: &str) {
match self.level {
tracing::Level::TRACE => tracing::trace!(input_len = input.len(), "Agent starting"),
tracing::Level::DEBUG => tracing::debug!(input_len = input.len(), "Agent starting"),
_ => tracing::info!(input_len = input.len(), "Agent starting"),
}
}
async fn on_agent_end(&self, _output: &AgentOutput, duration: Duration) {
#[allow(clippy::cast_possible_truncation)]
let ms = duration.as_millis() as u64;
match self.level {
tracing::Level::TRACE => tracing::trace!(duration_ms = ms, "Agent completed"),
tracing::Level::DEBUG => tracing::debug!(duration_ms = ms, "Agent completed"),
_ => tracing::info!(duration_ms = ms, "Agent completed"),
}
}
async fn on_provider_start(&self, _request: &CompletionRequest) {
match self.level {
tracing::Level::TRACE => tracing::trace!("LLM call starting"),
tracing::Level::DEBUG => tracing::debug!("LLM call starting"),
_ => tracing::info!("LLM call starting"),
}
}
async fn on_provider_end(&self, response: &CompletionResponse, duration: Duration) {
#[allow(clippy::cast_possible_truncation)]
let ms = duration.as_millis() as u64;
let tokens = response.usage.total_tokens;
match self.level {
tracing::Level::TRACE => {
tracing::trace!(duration_ms = ms, tokens, "LLM call completed")
}
tracing::Level::DEBUG => {
tracing::debug!(duration_ms = ms, tokens, "LLM call completed")
}
_ => tracing::info!(duration_ms = ms, tokens, "LLM call completed"),
}
}
async fn before_tool_execute(&self, name: &str, _args: &serde_json::Value) -> HookAction {
match self.level {
tracing::Level::TRACE => tracing::trace!(tool = name, "Tool executing"),
tracing::Level::DEBUG => tracing::debug!(tool = name, "Tool executing"),
_ => tracing::info!(tool = name, "Tool executing"),
}
HookAction::Continue
}
async fn after_tool_execute(&self, name: &str, _result: &str, duration: Duration) {
#[allow(clippy::cast_possible_truncation)]
let ms = duration.as_millis() as u64;
match self.level {
tracing::Level::TRACE => tracing::trace!(tool = name, duration_ms = ms, "Tool done"),
tracing::Level::DEBUG => tracing::debug!(tool = name, duration_ms = ms, "Tool done"),
_ => tracing::info!(tool = name, duration_ms = ms, "Tool done"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn _assert_object_safe(_: &dyn AgentHook) {}
#[test]
fn test_hook_action_variants() {
let cont = HookAction::Continue;
assert!(matches!(cont, HookAction::Continue));
let block = HookAction::Block("reason".into());
assert!(matches!(block, HookAction::Block(r) if r == "reason"));
}
#[test]
fn test_logging_hook_creation() {
let hook = LoggingHook::new(tracing::Level::INFO);
assert_eq!(hook.level, tracing::Level::INFO);
}
}