use crate::kernel::{ExecutionId, StepId};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use std::time::Duration;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentCallbackContext {
pub execution_id: ExecutionId,
pub step_id: Option<StepId>,
pub agent_name: String,
pub agent_description: Option<String>,
pub input_preview: Option<String>,
pub tenant_id: Option<String>,
pub user_id: Option<String>,
pub trace_id: Option<String>,
pub parent_span_id: Option<String>,
}
impl AgentCallbackContext {
pub fn new(execution_id: ExecutionId, agent_name: impl Into<String>) -> Self {
Self {
execution_id,
step_id: None,
agent_name: agent_name.into(),
agent_description: None,
input_preview: None,
tenant_id: None,
user_id: None,
trace_id: None,
parent_span_id: None,
}
}
pub fn with_step(mut self, step_id: StepId) -> Self {
self.step_id = Some(step_id);
self
}
pub fn with_description(mut self, description: impl Into<String>) -> Self {
self.agent_description = Some(description.into());
self
}
pub fn with_input_preview(mut self, input: impl Into<String>) -> Self {
let input = input.into();
self.input_preview = Some(if input.len() > 500 {
format!("{}...", &input[..497])
} else {
input
});
self
}
pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
self.tenant_id = Some(tenant_id.into());
self
}
pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
self.user_id = Some(user_id.into());
self
}
pub fn with_trace(
mut self,
trace_id: impl Into<String>,
parent_span_id: Option<String>,
) -> Self {
self.trace_id = Some(trace_id.into());
self.parent_span_id = parent_span_id;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AgentCallbackResult {
pub success: bool,
pub duration: Duration,
pub output_preview: Option<String>,
pub error: Option<String>,
pub steps_executed: Option<u32>,
pub tool_calls: Option<u32>,
pub model_calls: Option<u32>,
}
impl AgentCallbackResult {
pub fn success(duration: Duration) -> Self {
Self {
success: true,
duration,
output_preview: None,
error: None,
steps_executed: None,
tool_calls: None,
model_calls: None,
}
}
pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
Self {
success: false,
duration,
output_preview: None,
error: Some(error.into()),
steps_executed: None,
tool_calls: None,
model_calls: None,
}
}
pub fn with_output_preview(mut self, output: impl Into<String>) -> Self {
let output = output.into();
self.output_preview = Some(if output.len() > 500 {
format!("{}...", &output[..497])
} else {
output
});
self
}
pub fn with_stats(mut self, steps: u32, tool_calls: u32, model_calls: u32) -> Self {
self.steps_executed = Some(steps);
self.tool_calls = Some(tool_calls);
self.model_calls = Some(model_calls);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCallbackContext {
pub execution_id: ExecutionId,
pub step_id: Option<StepId>,
pub provider: String,
pub model: String,
pub temperature: Option<f32>,
pub max_tokens: Option<u32>,
pub message_count: usize,
pub streaming: bool,
pub tools_enabled: bool,
pub trace_id: Option<String>,
}
impl ModelCallbackContext {
pub fn new(
execution_id: ExecutionId,
provider: impl Into<String>,
model: impl Into<String>,
) -> Self {
Self {
execution_id,
step_id: None,
provider: provider.into(),
model: model.into(),
temperature: None,
max_tokens: None,
message_count: 0,
streaming: false,
tools_enabled: false,
trace_id: None,
}
}
pub fn with_step(mut self, step_id: StepId) -> Self {
self.step_id = Some(step_id);
self
}
pub fn with_params(mut self, temperature: Option<f32>, max_tokens: Option<u32>) -> Self {
self.temperature = temperature;
self.max_tokens = max_tokens;
self
}
pub fn with_request_info(
mut self,
message_count: usize,
streaming: bool,
tools_enabled: bool,
) -> Self {
self.message_count = message_count;
self.streaming = streaming;
self.tools_enabled = tools_enabled;
self
}
pub fn with_trace(mut self, trace_id: impl Into<String>) -> Self {
self.trace_id = Some(trace_id.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelCallbackResult {
pub success: bool,
pub duration: Duration,
pub input_tokens: Option<u32>,
pub output_tokens: Option<u32>,
pub total_tokens: Option<u32>,
pub finish_reason: Option<String>,
pub tool_calls_count: Option<u32>,
pub error: Option<String>,
pub cached: bool,
}
impl ModelCallbackResult {
pub fn success(duration: Duration) -> Self {
Self {
success: true,
duration,
input_tokens: None,
output_tokens: None,
total_tokens: None,
finish_reason: None,
tool_calls_count: None,
error: None,
cached: false,
}
}
pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
Self {
success: false,
duration,
input_tokens: None,
output_tokens: None,
total_tokens: None,
finish_reason: None,
tool_calls_count: None,
error: Some(error.into()),
cached: false,
}
}
pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
self.input_tokens = Some(input);
self.output_tokens = Some(output);
self.total_tokens = Some(input + output);
self
}
pub fn with_finish_info(mut self, reason: impl Into<String>, tool_calls: u32) -> Self {
self.finish_reason = Some(reason.into());
self.tool_calls_count = Some(tool_calls);
self
}
pub fn cached(mut self) -> Self {
self.cached = true;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallbackContext {
pub execution_id: ExecutionId,
pub step_id: Option<StepId>,
pub tool_name: String,
pub tool_description: Option<String>,
pub requires_network: bool,
pub args_preview: Option<String>,
pub trace_id: Option<String>,
}
impl ToolCallbackContext {
pub fn new(execution_id: ExecutionId, tool_name: impl Into<String>) -> Self {
Self {
execution_id,
step_id: None,
tool_name: tool_name.into(),
tool_description: None,
requires_network: true,
args_preview: None,
trace_id: None,
}
}
pub fn with_step(mut self, step_id: StepId) -> Self {
self.step_id = Some(step_id);
self
}
pub fn with_tool_info(mut self, description: Option<String>, requires_network: bool) -> Self {
self.tool_description = description;
self.requires_network = requires_network;
self
}
pub fn with_args_preview(mut self, args: impl Into<String>) -> Self {
let args = args.into();
self.args_preview = Some(if args.len() > 500 {
format!("{}...", &args[..497])
} else {
args
});
self
}
pub fn with_trace(mut self, trace_id: impl Into<String>) -> Self {
self.trace_id = Some(trace_id.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCallbackResult {
pub success: bool,
pub duration: Duration,
pub output_preview: Option<String>,
pub error: Option<String>,
pub blocked: bool,
pub blocked_reason: Option<String>,
}
impl ToolCallbackResult {
pub fn success(duration: Duration) -> Self {
Self {
success: true,
duration,
output_preview: None,
error: None,
blocked: false,
blocked_reason: None,
}
}
pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
Self {
success: false,
duration,
output_preview: None,
error: Some(error.into()),
blocked: false,
blocked_reason: None,
}
}
pub fn blocked(duration: Duration, reason: impl Into<String>) -> Self {
Self {
success: false,
duration,
output_preview: None,
error: None,
blocked: true,
blocked_reason: Some(reason.into()),
}
}
pub fn with_output_preview(mut self, output: impl Into<String>) -> Self {
let output = output.into();
self.output_preview = Some(if output.len() > 500 {
format!("{}...", &output[..497])
} else {
output
});
self
}
}
pub trait BeforeAgentCallback: Send + Sync {
fn on_before_agent(&self, ctx: &AgentCallbackContext);
}
pub trait AfterAgentCallback: Send + Sync {
fn on_after_agent(&self, ctx: &AgentCallbackContext, result: &AgentCallbackResult);
}
pub trait BeforeModelCallback: Send + Sync {
fn on_before_model(&self, ctx: &ModelCallbackContext);
}
pub trait AfterModelCallback: Send + Sync {
fn on_after_model(&self, ctx: &ModelCallbackContext, result: &ModelCallbackResult);
}
pub trait BeforeToolCallback: Send + Sync {
fn on_before_tool(&self, ctx: &ToolCallbackContext);
}
pub trait AfterToolCallback: Send + Sync {
fn on_after_tool(&self, ctx: &ToolCallbackContext, result: &ToolCallbackResult);
}
pub trait ExecutionCallbacks:
BeforeAgentCallback
+ AfterAgentCallback
+ BeforeModelCallback
+ AfterModelCallback
+ BeforeToolCallback
+ AfterToolCallback
{
}
impl<T> ExecutionCallbacks for T where
T: BeforeAgentCallback
+ AfterAgentCallback
+ BeforeModelCallback
+ AfterModelCallback
+ BeforeToolCallback
+ AfterToolCallback
{
}
#[derive(Default)]
pub struct CallbackRegistry {
before_agent: Vec<Arc<dyn BeforeAgentCallback>>,
after_agent: Vec<Arc<dyn AfterAgentCallback>>,
before_model: Vec<Arc<dyn BeforeModelCallback>>,
after_model: Vec<Arc<dyn AfterModelCallback>>,
before_tool: Vec<Arc<dyn BeforeToolCallback>>,
after_tool: Vec<Arc<dyn AfterToolCallback>>,
}
impl CallbackRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn on_before_agent<C: BeforeAgentCallback + 'static>(&mut self, callback: C) -> &mut Self {
self.before_agent.push(Arc::new(callback));
self
}
pub fn on_after_agent<C: AfterAgentCallback + 'static>(&mut self, callback: C) -> &mut Self {
self.after_agent.push(Arc::new(callback));
self
}
pub fn on_before_model<C: BeforeModelCallback + 'static>(&mut self, callback: C) -> &mut Self {
self.before_model.push(Arc::new(callback));
self
}
pub fn on_after_model<C: AfterModelCallback + 'static>(&mut self, callback: C) -> &mut Self {
self.after_model.push(Arc::new(callback));
self
}
pub fn on_before_tool<C: BeforeToolCallback + 'static>(&mut self, callback: C) -> &mut Self {
self.before_tool.push(Arc::new(callback));
self
}
pub fn on_after_tool<C: AfterToolCallback + 'static>(&mut self, callback: C) -> &mut Self {
self.after_tool.push(Arc::new(callback));
self
}
pub fn register_all<C>(&mut self, callback: Arc<C>) -> &mut Self
where
C: ExecutionCallbacks + 'static,
{
self.before_agent.push(callback.clone());
self.after_agent.push(callback.clone());
self.before_model.push(callback.clone());
self.after_model.push(callback.clone());
self.before_tool.push(callback.clone());
self.after_tool.push(callback);
self
}
pub fn invoke_before_agent(&self, ctx: &AgentCallbackContext) {
for callback in &self.before_agent {
callback.on_before_agent(ctx);
}
}
pub fn invoke_after_agent(&self, ctx: &AgentCallbackContext, result: &AgentCallbackResult) {
for callback in &self.after_agent {
callback.on_after_agent(ctx, result);
}
}
pub fn invoke_before_model(&self, ctx: &ModelCallbackContext) {
for callback in &self.before_model {
callback.on_before_model(ctx);
}
}
pub fn invoke_after_model(&self, ctx: &ModelCallbackContext, result: &ModelCallbackResult) {
for callback in &self.after_model {
callback.on_after_model(ctx, result);
}
}
pub fn invoke_before_tool(&self, ctx: &ToolCallbackContext) {
for callback in &self.before_tool {
callback.on_before_tool(ctx);
}
}
pub fn invoke_after_tool(&self, ctx: &ToolCallbackContext, result: &ToolCallbackResult) {
for callback in &self.after_tool {
callback.on_after_tool(ctx, result);
}
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct NoOpCallbacks;
impl BeforeAgentCallback for NoOpCallbacks {
fn on_before_agent(&self, _ctx: &AgentCallbackContext) {}
}
impl AfterAgentCallback for NoOpCallbacks {
fn on_after_agent(&self, _ctx: &AgentCallbackContext, _result: &AgentCallbackResult) {}
}
impl BeforeModelCallback for NoOpCallbacks {
fn on_before_model(&self, _ctx: &ModelCallbackContext) {}
}
impl AfterModelCallback for NoOpCallbacks {
fn on_after_model(&self, _ctx: &ModelCallbackContext, _result: &ModelCallbackResult) {}
}
impl BeforeToolCallback for NoOpCallbacks {
fn on_before_tool(&self, _ctx: &ToolCallbackContext) {}
}
impl AfterToolCallback for NoOpCallbacks {
fn on_after_tool(&self, _ctx: &ToolCallbackContext, _result: &ToolCallbackResult) {}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicU32, Ordering};
use std::time::Duration;
#[test]
fn test_agent_callback_context_new() {
let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent");
assert_eq!(ctx.agent_name, "test_agent");
assert!(ctx.step_id.is_none());
assert!(ctx.agent_description.is_none());
assert!(ctx.input_preview.is_none());
}
#[test]
fn test_agent_callback_context_with_step() {
let step_id = StepId::new();
let ctx =
AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_step(step_id.clone());
assert!(ctx.step_id.is_some());
assert_eq!(ctx.step_id.unwrap().as_str(), step_id.as_str());
}
#[test]
fn test_agent_callback_context_with_description() {
let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
.with_description("A test agent");
assert_eq!(ctx.agent_description, Some("A test agent".to_string()));
}
#[test]
fn test_agent_callback_context_input_preview_truncation() {
let long_input = "x".repeat(1000);
let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
.with_input_preview(&long_input);
let preview = ctx.input_preview.unwrap();
assert!(preview.len() <= 500);
assert!(preview.ends_with("..."));
}
#[test]
fn test_agent_callback_context_short_input_not_truncated() {
let short_input = "hello world";
let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
.with_input_preview(short_input);
assert_eq!(ctx.input_preview, Some("hello world".to_string()));
}
#[test]
fn test_agent_callback_context_with_tenant() {
let ctx =
AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_tenant("tenant_123");
assert_eq!(ctx.tenant_id, Some("tenant_123".to_string()));
}
#[test]
fn test_agent_callback_context_with_user() {
let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_user("user_456");
assert_eq!(ctx.user_id, Some("user_456".to_string()));
}
#[test]
fn test_agent_callback_context_with_trace() {
let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
.with_trace("trace_abc", Some("span_xyz".to_string()));
assert_eq!(ctx.trace_id, Some("trace_abc".to_string()));
assert_eq!(ctx.parent_span_id, Some("span_xyz".to_string()));
}
#[test]
fn test_agent_callback_context_builder_chain() {
let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
.with_step(StepId::new())
.with_description("Description")
.with_input_preview("Input")
.with_tenant("tenant")
.with_user("user")
.with_trace("trace", None);
assert!(ctx.step_id.is_some());
assert!(ctx.agent_description.is_some());
assert!(ctx.input_preview.is_some());
assert!(ctx.tenant_id.is_some());
assert!(ctx.user_id.is_some());
assert!(ctx.trace_id.is_some());
}
#[test]
fn test_agent_callback_context_serde() {
let ctx = AgentCallbackContext::new(ExecutionId::from_string("exec_test"), "test_agent")
.with_description("Test description");
let json = serde_json::to_string(&ctx).unwrap();
let parsed: AgentCallbackContext = serde_json::from_str(&json).unwrap();
assert_eq!(ctx.agent_name, parsed.agent_name);
assert_eq!(ctx.agent_description, parsed.agent_description);
}
#[test]
fn test_agent_callback_result_success() {
let result = AgentCallbackResult::success(Duration::from_millis(100));
assert!(result.success);
assert_eq!(result.duration, Duration::from_millis(100));
assert!(result.error.is_none());
}
#[test]
fn test_agent_callback_result_failure() {
let result =
AgentCallbackResult::failure(Duration::from_millis(50), "Something went wrong");
assert!(!result.success);
assert_eq!(result.error, Some("Something went wrong".to_string()));
}
#[test]
fn test_agent_callback_result_with_output_preview() {
let result = AgentCallbackResult::success(Duration::from_millis(100))
.with_output_preview("Output here");
assert_eq!(result.output_preview, Some("Output here".to_string()));
}
#[test]
fn test_agent_callback_result_output_truncation() {
let long_output = "y".repeat(1000);
let result = AgentCallbackResult::success(Duration::from_millis(100))
.with_output_preview(&long_output);
let preview = result.output_preview.unwrap();
assert!(preview.len() <= 500);
assert!(preview.ends_with("..."));
}
#[test]
fn test_agent_callback_result_with_stats() {
let result = AgentCallbackResult::success(Duration::from_millis(100)).with_stats(5, 3, 2);
assert_eq!(result.steps_executed, Some(5));
assert_eq!(result.tool_calls, Some(3));
assert_eq!(result.model_calls, Some(2));
}
#[test]
fn test_model_callback_context_new() {
let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
assert_eq!(ctx.provider, "openai");
assert_eq!(ctx.model, "gpt-4");
assert!(ctx.step_id.is_none());
}
#[test]
fn test_model_callback_context_with_params() {
let ctx = ModelCallbackContext::new(ExecutionId::new(), "anthropic", "claude-3-opus")
.with_params(Some(0.7), Some(4096));
assert_eq!(ctx.temperature, Some(0.7));
assert_eq!(ctx.max_tokens, Some(4096));
}
#[test]
fn test_model_callback_context_with_request_info() {
let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4")
.with_request_info(5, true, true);
assert_eq!(ctx.message_count, 5);
assert!(ctx.streaming);
assert!(ctx.tools_enabled);
}
#[test]
fn test_model_callback_context_with_step() {
let step_id = StepId::new();
let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4")
.with_step(step_id.clone());
assert!(ctx.step_id.is_some());
}
#[test]
fn test_model_callback_context_serde() {
let ctx =
ModelCallbackContext::new(ExecutionId::from_string("exec_test"), "openai", "gpt-4")
.with_params(Some(0.5), Some(1000));
let json = serde_json::to_string(&ctx).unwrap();
let parsed: ModelCallbackContext = serde_json::from_str(&json).unwrap();
assert_eq!(ctx.provider, parsed.provider);
assert_eq!(ctx.model, parsed.model);
assert_eq!(ctx.temperature, parsed.temperature);
}
#[test]
fn test_model_callback_result_success() {
let result = ModelCallbackResult::success(Duration::from_millis(500));
assert!(result.success);
assert!(result.error.is_none());
}
#[test]
fn test_model_callback_result_failure() {
let result =
ModelCallbackResult::failure(Duration::from_millis(100), "Rate limit exceeded");
assert!(!result.success);
assert_eq!(result.error, Some("Rate limit exceeded".to_string()));
}
#[test]
fn test_model_callback_result_with_tokens() {
let result =
ModelCallbackResult::success(Duration::from_millis(500)).with_tokens(1000, 500);
assert_eq!(result.input_tokens, Some(1000));
assert_eq!(result.output_tokens, Some(500));
assert_eq!(result.total_tokens, Some(1500));
}
#[test]
fn test_model_callback_result_with_finish_info() {
let result = ModelCallbackResult::success(Duration::from_millis(500))
.with_finish_info("tool_use", 2);
assert_eq!(result.finish_reason, Some("tool_use".to_string()));
assert_eq!(result.tool_calls_count, Some(2));
}
#[test]
fn test_model_callback_result_cached() {
let result = ModelCallbackResult::success(Duration::from_millis(10)).cached();
assert!(result.cached);
}
#[test]
fn test_tool_callback_context_new() {
let ctx = ToolCallbackContext::new(ExecutionId::new(), "read_file");
assert_eq!(ctx.tool_name, "read_file");
assert!(ctx.requires_network); }
#[test]
fn test_tool_callback_context_with_tool_info() {
let ctx = ToolCallbackContext::new(ExecutionId::new(), "calculator")
.with_tool_info(Some("Performs calculations".to_string()), false);
assert_eq!(
ctx.tool_description,
Some("Performs calculations".to_string())
);
assert!(!ctx.requires_network);
}
#[test]
fn test_tool_callback_context_with_args_preview() {
let ctx = ToolCallbackContext::new(ExecutionId::new(), "search")
.with_args_preview(r#"{"query": "rust programming"}"#);
assert!(ctx.args_preview.is_some());
}
#[test]
fn test_tool_callback_context_args_truncation() {
let long_args = "z".repeat(1000);
let ctx =
ToolCallbackContext::new(ExecutionId::new(), "tool").with_args_preview(&long_args);
let preview = ctx.args_preview.unwrap();
assert!(preview.len() <= 500);
assert!(preview.ends_with("..."));
}
#[test]
fn test_tool_callback_context_serde() {
let ctx = ToolCallbackContext::new(ExecutionId::from_string("exec_test"), "my_tool")
.with_tool_info(Some("A tool".to_string()), false);
let json = serde_json::to_string(&ctx).unwrap();
let parsed: ToolCallbackContext = serde_json::from_str(&json).unwrap();
assert_eq!(ctx.tool_name, parsed.tool_name);
assert_eq!(ctx.tool_description, parsed.tool_description);
}
#[test]
fn test_tool_callback_result_success() {
let result = ToolCallbackResult::success(Duration::from_millis(50));
assert!(result.success);
assert!(!result.blocked);
}
#[test]
fn test_tool_callback_result_failure() {
let result = ToolCallbackResult::failure(Duration::from_millis(20), "File not found");
assert!(!result.success);
assert_eq!(result.error, Some("File not found".to_string()));
}
#[test]
fn test_tool_callback_result_blocked() {
let result =
ToolCallbackResult::blocked(Duration::from_millis(5), "Tool disabled by policy");
assert!(!result.success);
assert!(result.blocked);
assert_eq!(
result.blocked_reason,
Some("Tool disabled by policy".to_string())
);
}
#[test]
fn test_tool_callback_result_with_output_preview() {
let result = ToolCallbackResult::success(Duration::from_millis(50))
.with_output_preview("Result: 42");
assert_eq!(result.output_preview, Some("Result: 42".to_string()));
}
#[test]
fn test_noop_callbacks_compiles() {
let callbacks = NoOpCallbacks;
let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
let agent_result = AgentCallbackResult::success(Duration::from_millis(100));
let model_ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
let model_result = ModelCallbackResult::success(Duration::from_millis(500));
let tool_ctx = ToolCallbackContext::new(ExecutionId::new(), "tool");
let tool_result = ToolCallbackResult::success(Duration::from_millis(50));
callbacks.on_before_agent(&agent_ctx);
callbacks.on_after_agent(&agent_ctx, &agent_result);
callbacks.on_before_model(&model_ctx);
callbacks.on_after_model(&model_ctx, &model_result);
callbacks.on_before_tool(&tool_ctx);
callbacks.on_after_tool(&tool_ctx, &tool_result);
}
struct CountingCallback {
before_agent_count: AtomicU32,
after_agent_count: AtomicU32,
before_model_count: AtomicU32,
after_model_count: AtomicU32,
before_tool_count: AtomicU32,
after_tool_count: AtomicU32,
}
impl CountingCallback {
fn new() -> Self {
Self {
before_agent_count: AtomicU32::new(0),
after_agent_count: AtomicU32::new(0),
before_model_count: AtomicU32::new(0),
after_model_count: AtomicU32::new(0),
before_tool_count: AtomicU32::new(0),
after_tool_count: AtomicU32::new(0),
}
}
}
impl BeforeAgentCallback for CountingCallback {
fn on_before_agent(&self, _ctx: &AgentCallbackContext) {
self.before_agent_count.fetch_add(1, Ordering::SeqCst);
}
}
impl AfterAgentCallback for CountingCallback {
fn on_after_agent(&self, _ctx: &AgentCallbackContext, _result: &AgentCallbackResult) {
self.after_agent_count.fetch_add(1, Ordering::SeqCst);
}
}
impl BeforeModelCallback for CountingCallback {
fn on_before_model(&self, _ctx: &ModelCallbackContext) {
self.before_model_count.fetch_add(1, Ordering::SeqCst);
}
}
impl AfterModelCallback for CountingCallback {
fn on_after_model(&self, _ctx: &ModelCallbackContext, _result: &ModelCallbackResult) {
self.after_model_count.fetch_add(1, Ordering::SeqCst);
}
}
impl BeforeToolCallback for CountingCallback {
fn on_before_tool(&self, _ctx: &ToolCallbackContext) {
self.before_tool_count.fetch_add(1, Ordering::SeqCst);
}
}
impl AfterToolCallback for CountingCallback {
fn on_after_tool(&self, _ctx: &ToolCallbackContext, _result: &ToolCallbackResult) {
self.after_tool_count.fetch_add(1, Ordering::SeqCst);
}
}
#[test]
fn test_callback_registry_new() {
let registry = CallbackRegistry::new();
registry.invoke_before_agent(&AgentCallbackContext::new(ExecutionId::new(), "test"));
}
#[test]
fn test_callback_registry_register_all() {
let callback = Arc::new(CountingCallback::new());
let mut registry = CallbackRegistry::new();
registry.register_all(callback.clone());
let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
let agent_result = AgentCallbackResult::success(Duration::from_millis(100));
let model_ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
let model_result = ModelCallbackResult::success(Duration::from_millis(500));
let tool_ctx = ToolCallbackContext::new(ExecutionId::new(), "tool");
let tool_result = ToolCallbackResult::success(Duration::from_millis(50));
registry.invoke_before_agent(&agent_ctx);
registry.invoke_after_agent(&agent_ctx, &agent_result);
registry.invoke_before_model(&model_ctx);
registry.invoke_after_model(&model_ctx, &model_result);
registry.invoke_before_tool(&tool_ctx);
registry.invoke_after_tool(&tool_ctx, &tool_result);
assert_eq!(callback.before_agent_count.load(Ordering::SeqCst), 1);
assert_eq!(callback.after_agent_count.load(Ordering::SeqCst), 1);
assert_eq!(callback.before_model_count.load(Ordering::SeqCst), 1);
assert_eq!(callback.after_model_count.load(Ordering::SeqCst), 1);
assert_eq!(callback.before_tool_count.load(Ordering::SeqCst), 1);
assert_eq!(callback.after_tool_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_callback_registry_multiple_callbacks() {
let callback1 = Arc::new(CountingCallback::new());
let callback2 = Arc::new(CountingCallback::new());
let mut registry = CallbackRegistry::new();
registry.register_all(callback1.clone());
registry.register_all(callback2.clone());
let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
registry.invoke_before_agent(&agent_ctx);
assert_eq!(callback1.before_agent_count.load(Ordering::SeqCst), 1);
assert_eq!(callback2.before_agent_count.load(Ordering::SeqCst), 1);
}
#[test]
fn test_callback_registry_individual_registration() {
struct SimpleBeforeAgent;
impl BeforeAgentCallback for SimpleBeforeAgent {
fn on_before_agent(&self, _ctx: &AgentCallbackContext) {}
}
let mut registry = CallbackRegistry::new();
registry.on_before_agent(SimpleBeforeAgent);
registry.invoke_before_agent(&AgentCallbackContext::new(ExecutionId::new(), "test"));
}
#[test]
fn test_execution_callbacks_trait() {
fn accept_execution_callbacks<T: ExecutionCallbacks>(_: &T) {}
let noop = NoOpCallbacks;
accept_execution_callbacks(&noop);
}
#[test]
fn test_counting_callback_implements_execution_callbacks() {
fn accept_execution_callbacks<T: ExecutionCallbacks>(_: &T) {}
let counting = CountingCallback::new();
accept_execution_callbacks(&counting);
}
}