Skip to main content

enact_core/runner/
callbacks.rs

1//! Runner Callbacks - Before/after hooks for agents, models, and tools
2//!
3//! This module defines callback traits for observability and instrumentation
4//! of the runner execution. Callbacks are invoked before and after agent,
5//! model, and tool executions, providing hooks for:
6//!
7//! - **Telemetry**: Emit spans/metrics to OpenTelemetry
8//! - **Logging**: Structured logging of execution events
9//! - **Audit**: Record decisions and actions for compliance
10//! - **Monitoring**: Track latencies, success rates, token usage
11//!
12//! ## Architecture
13//!
14//! ```text
15//! Runner
16//!   │
17//!   ├── BeforeAgentCallback::on_before_agent()
18//!   │     └── Agent Execution
19//!   ├── AfterAgentCallback::on_after_agent()
20//!   │
21//!   ├── BeforeModelCallback::on_before_model()
22//!   │     └── Model Call
23//!   ├── AfterModelCallback::on_after_model()
24//!   │
25//!   ├── BeforeToolCallback::on_before_tool()
26//!   │     └── Tool Execution
27//!   └── AfterToolCallback::on_after_tool()
28//! ```
29//!
30//! ## Usage
31//!
32//! ```rust,ignore
33//! use enact_core::runner::callbacks::*;
34//!
35//! struct MyTelemetryCallback;
36//!
37//! impl BeforeAgentCallback for MyTelemetryCallback {
38//!     fn on_before_agent(&self, ctx: &AgentCallbackContext) {
39//!         tracing::info!(
40//!             execution_id = %ctx.execution_id,
41//!             agent_name = %ctx.agent_name,
42//!             "Agent starting"
43//!         );
44//!     }
45//! }
46//! ```
47//!
48//! @see docs/TECHNICAL/01-EXECUTION-TELEMETRY.md
49//! @see crate::telemetry::spans - Span attribute types for OpenTelemetry integration
50
51use crate::kernel::{ExecutionId, StepId};
52use serde::{Deserialize, Serialize};
53use std::sync::Arc;
54use std::time::Duration;
55
56// =============================================================================
57// Callback Context Types
58// =============================================================================
59
60/// Context passed to agent callbacks
61///
62/// Contains all information needed to identify and trace an agent execution.
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AgentCallbackContext {
65    /// Execution ID for this agent run
66    pub execution_id: ExecutionId,
67    /// Step ID (if running within a graph)
68    pub step_id: Option<StepId>,
69    /// Agent name
70    pub agent_name: String,
71    /// Agent description (if available)
72    pub agent_description: Option<String>,
73    /// Input provided to the agent (may be truncated for large inputs)
74    pub input_preview: Option<String>,
75    /// Tenant ID (if multi-tenant)
76    pub tenant_id: Option<String>,
77    /// User ID (if authenticated)
78    pub user_id: Option<String>,
79    /// Trace ID for distributed tracing
80    pub trace_id: Option<String>,
81    /// Parent span ID
82    pub parent_span_id: Option<String>,
83}
84
85impl AgentCallbackContext {
86    /// Create a new agent callback context
87    pub fn new(execution_id: ExecutionId, agent_name: impl Into<String>) -> Self {
88        Self {
89            execution_id,
90            step_id: None,
91            agent_name: agent_name.into(),
92            agent_description: None,
93            input_preview: None,
94            tenant_id: None,
95            user_id: None,
96            trace_id: None,
97            parent_span_id: None,
98        }
99    }
100
101    /// Add step context
102    pub fn with_step(mut self, step_id: StepId) -> Self {
103        self.step_id = Some(step_id);
104        self
105    }
106
107    /// Add agent description
108    pub fn with_description(mut self, description: impl Into<String>) -> Self {
109        self.agent_description = Some(description.into());
110        self
111    }
112
113    /// Add input preview (truncated if needed)
114    pub fn with_input_preview(mut self, input: impl Into<String>) -> Self {
115        let input = input.into();
116        // Truncate long inputs to prevent callback overhead
117        self.input_preview = Some(if input.len() > 500 {
118            format!("{}...", &input[..497])
119        } else {
120            input
121        });
122        self
123    }
124
125    /// Add tenant context
126    pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
127        self.tenant_id = Some(tenant_id.into());
128        self
129    }
130
131    /// Add user context
132    pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
133        self.user_id = Some(user_id.into());
134        self
135    }
136
137    /// Add trace context
138    pub fn with_trace(
139        mut self,
140        trace_id: impl Into<String>,
141        parent_span_id: Option<String>,
142    ) -> Self {
143        self.trace_id = Some(trace_id.into());
144        self.parent_span_id = parent_span_id;
145        self
146    }
147}
148
149/// Result information for agent completion callback
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct AgentCallbackResult {
152    /// Whether the agent succeeded
153    pub success: bool,
154    /// Duration of agent execution
155    pub duration: Duration,
156    /// Output preview (truncated if needed)
157    pub output_preview: Option<String>,
158    /// Error message if failed
159    pub error: Option<String>,
160    /// Number of steps executed
161    pub steps_executed: Option<u32>,
162    /// Number of tool calls made
163    pub tool_calls: Option<u32>,
164    /// Number of model calls made
165    pub model_calls: Option<u32>,
166}
167
168impl AgentCallbackResult {
169    /// Create a successful result
170    pub fn success(duration: Duration) -> Self {
171        Self {
172            success: true,
173            duration,
174            output_preview: None,
175            error: None,
176            steps_executed: None,
177            tool_calls: None,
178            model_calls: None,
179        }
180    }
181
182    /// Create a failed result
183    pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
184        Self {
185            success: false,
186            duration,
187            output_preview: None,
188            error: Some(error.into()),
189            steps_executed: None,
190            tool_calls: None,
191            model_calls: None,
192        }
193    }
194
195    /// Add output preview
196    pub fn with_output_preview(mut self, output: impl Into<String>) -> Self {
197        let output = output.into();
198        self.output_preview = Some(if output.len() > 500 {
199            format!("{}...", &output[..497])
200        } else {
201            output
202        });
203        self
204    }
205
206    /// Add execution stats
207    pub fn with_stats(mut self, steps: u32, tool_calls: u32, model_calls: u32) -> Self {
208        self.steps_executed = Some(steps);
209        self.tool_calls = Some(tool_calls);
210        self.model_calls = Some(model_calls);
211        self
212    }
213}
214
215/// Context passed to model callbacks
216///
217/// Contains information about an LLM model call.
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct ModelCallbackContext {
220    /// Execution ID
221    pub execution_id: ExecutionId,
222    /// Step ID
223    pub step_id: Option<StepId>,
224    /// Provider name (e.g., "openai", "anthropic")
225    pub provider: String,
226    /// Model name (e.g., "gpt-4", "claude-3-opus")
227    pub model: String,
228    /// Temperature setting
229    pub temperature: Option<f32>,
230    /// Max tokens setting
231    pub max_tokens: Option<u32>,
232    /// Number of messages in the request
233    pub message_count: usize,
234    /// Whether this is a streaming request
235    pub streaming: bool,
236    /// Whether tools are enabled
237    pub tools_enabled: bool,
238    /// Trace context
239    pub trace_id: Option<String>,
240}
241
242impl ModelCallbackContext {
243    /// Create a new model callback context
244    pub fn new(
245        execution_id: ExecutionId,
246        provider: impl Into<String>,
247        model: impl Into<String>,
248    ) -> Self {
249        Self {
250            execution_id,
251            step_id: None,
252            provider: provider.into(),
253            model: model.into(),
254            temperature: None,
255            max_tokens: None,
256            message_count: 0,
257            streaming: false,
258            tools_enabled: false,
259            trace_id: None,
260        }
261    }
262
263    /// Add step context
264    pub fn with_step(mut self, step_id: StepId) -> Self {
265        self.step_id = Some(step_id);
266        self
267    }
268
269    /// Add model parameters
270    pub fn with_params(mut self, temperature: Option<f32>, max_tokens: Option<u32>) -> Self {
271        self.temperature = temperature;
272        self.max_tokens = max_tokens;
273        self
274    }
275
276    /// Add request info
277    pub fn with_request_info(
278        mut self,
279        message_count: usize,
280        streaming: bool,
281        tools_enabled: bool,
282    ) -> Self {
283        self.message_count = message_count;
284        self.streaming = streaming;
285        self.tools_enabled = tools_enabled;
286        self
287    }
288
289    /// Add trace context
290    pub fn with_trace(mut self, trace_id: impl Into<String>) -> Self {
291        self.trace_id = Some(trace_id.into());
292        self
293    }
294}
295
296/// Result information for model completion callback
297#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ModelCallbackResult {
299    /// Whether the model call succeeded
300    pub success: bool,
301    /// Duration of model call
302    pub duration: Duration,
303    /// Input tokens used
304    pub input_tokens: Option<u32>,
305    /// Output tokens generated
306    pub output_tokens: Option<u32>,
307    /// Total tokens
308    pub total_tokens: Option<u32>,
309    /// Finish reason (e.g., "stop", "tool_use", "length")
310    pub finish_reason: Option<String>,
311    /// Number of tool calls in response
312    pub tool_calls_count: Option<u32>,
313    /// Error message if failed
314    pub error: Option<String>,
315    /// Whether response was cached
316    pub cached: bool,
317}
318
319impl ModelCallbackResult {
320    /// Create a successful result
321    pub fn success(duration: Duration) -> Self {
322        Self {
323            success: true,
324            duration,
325            input_tokens: None,
326            output_tokens: None,
327            total_tokens: None,
328            finish_reason: None,
329            tool_calls_count: None,
330            error: None,
331            cached: false,
332        }
333    }
334
335    /// Create a failed result
336    pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
337        Self {
338            success: false,
339            duration,
340            input_tokens: None,
341            output_tokens: None,
342            total_tokens: None,
343            finish_reason: None,
344            tool_calls_count: None,
345            error: Some(error.into()),
346            cached: false,
347        }
348    }
349
350    /// Add token usage
351    pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
352        self.input_tokens = Some(input);
353        self.output_tokens = Some(output);
354        self.total_tokens = Some(input + output);
355        self
356    }
357
358    /// Add finish info
359    pub fn with_finish_info(mut self, reason: impl Into<String>, tool_calls: u32) -> Self {
360        self.finish_reason = Some(reason.into());
361        self.tool_calls_count = Some(tool_calls);
362        self
363    }
364
365    /// Mark as cached
366    pub fn cached(mut self) -> Self {
367        self.cached = true;
368        self
369    }
370}
371
372/// Context passed to tool callbacks
373///
374/// Contains information about a tool execution.
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct ToolCallbackContext {
377    /// Execution ID
378    pub execution_id: ExecutionId,
379    /// Step ID
380    pub step_id: Option<StepId>,
381    /// Tool name
382    pub tool_name: String,
383    /// Tool description
384    pub tool_description: Option<String>,
385    /// Whether tool requires network access
386    pub requires_network: bool,
387    /// Arguments preview (truncated if needed)
388    pub args_preview: Option<String>,
389    /// Trace context
390    pub trace_id: Option<String>,
391}
392
393impl ToolCallbackContext {
394    /// Create a new tool callback context
395    pub fn new(execution_id: ExecutionId, tool_name: impl Into<String>) -> Self {
396        Self {
397            execution_id,
398            step_id: None,
399            tool_name: tool_name.into(),
400            tool_description: None,
401            requires_network: true,
402            args_preview: None,
403            trace_id: None,
404        }
405    }
406
407    /// Add step context
408    pub fn with_step(mut self, step_id: StepId) -> Self {
409        self.step_id = Some(step_id);
410        self
411    }
412
413    /// Add tool info
414    pub fn with_tool_info(mut self, description: Option<String>, requires_network: bool) -> Self {
415        self.tool_description = description;
416        self.requires_network = requires_network;
417        self
418    }
419
420    /// Add arguments preview
421    pub fn with_args_preview(mut self, args: impl Into<String>) -> Self {
422        let args = args.into();
423        self.args_preview = Some(if args.len() > 500 {
424            format!("{}...", &args[..497])
425        } else {
426            args
427        });
428        self
429    }
430
431    /// Add trace context
432    pub fn with_trace(mut self, trace_id: impl Into<String>) -> Self {
433        self.trace_id = Some(trace_id.into());
434        self
435    }
436}
437
438/// Result information for tool completion callback
439#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct ToolCallbackResult {
441    /// Whether the tool succeeded
442    pub success: bool,
443    /// Duration of tool execution
444    pub duration: Duration,
445    /// Output preview (truncated if needed)
446    pub output_preview: Option<String>,
447    /// Error message if failed
448    pub error: Option<String>,
449    /// Whether the tool was blocked by policy
450    pub blocked: bool,
451    /// Blocking reason if blocked
452    pub blocked_reason: Option<String>,
453}
454
455impl ToolCallbackResult {
456    /// Create a successful result
457    pub fn success(duration: Duration) -> Self {
458        Self {
459            success: true,
460            duration,
461            output_preview: None,
462            error: None,
463            blocked: false,
464            blocked_reason: None,
465        }
466    }
467
468    /// Create a failed result
469    pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
470        Self {
471            success: false,
472            duration,
473            output_preview: None,
474            error: Some(error.into()),
475            blocked: false,
476            blocked_reason: None,
477        }
478    }
479
480    /// Create a blocked result
481    pub fn blocked(duration: Duration, reason: impl Into<String>) -> Self {
482        Self {
483            success: false,
484            duration,
485            output_preview: None,
486            error: None,
487            blocked: true,
488            blocked_reason: Some(reason.into()),
489        }
490    }
491
492    /// Add output preview
493    pub fn with_output_preview(mut self, output: impl Into<String>) -> Self {
494        let output = output.into();
495        self.output_preview = Some(if output.len() > 500 {
496            format!("{}...", &output[..497])
497        } else {
498            output
499        });
500        self
501    }
502}
503
504// =============================================================================
505// Callback Traits
506// =============================================================================
507
508/// Callback invoked before agent execution starts
509///
510/// Use this for:
511/// - Starting spans/traces
512/// - Logging agent invocation
513/// - Recording audit events
514pub trait BeforeAgentCallback: Send + Sync {
515    /// Called before an agent starts executing
516    fn on_before_agent(&self, ctx: &AgentCallbackContext);
517}
518
519/// Callback invoked after agent execution completes
520///
521/// Use this for:
522/// - Ending spans/traces with status
523/// - Logging completion/failure
524/// - Recording metrics (latency, token usage)
525pub trait AfterAgentCallback: Send + Sync {
526    /// Called after an agent finishes (success or failure)
527    fn on_after_agent(&self, ctx: &AgentCallbackContext, result: &AgentCallbackResult);
528}
529
530/// Callback invoked before a model call
531///
532/// Use this for:
533/// - Starting LLM spans
534/// - Logging model requests
535/// - Request auditing
536pub trait BeforeModelCallback: Send + Sync {
537    /// Called before a model call starts
538    fn on_before_model(&self, ctx: &ModelCallbackContext);
539}
540
541/// Callback invoked after a model call completes
542///
543/// Use this for:
544/// - Ending LLM spans with token counts
545/// - Recording token usage metrics
546/// - Cost tracking
547pub trait AfterModelCallback: Send + Sync {
548    /// Called after a model call finishes
549    fn on_after_model(&self, ctx: &ModelCallbackContext, result: &ModelCallbackResult);
550}
551
552/// Callback invoked before tool execution
553///
554/// Use this for:
555/// - Starting tool spans
556/// - Logging tool invocations
557/// - Security auditing
558pub trait BeforeToolCallback: Send + Sync {
559    /// Called before a tool executes
560    fn on_before_tool(&self, ctx: &ToolCallbackContext);
561}
562
563/// Callback invoked after tool execution completes
564///
565/// Use this for:
566/// - Ending tool spans
567/// - Recording tool metrics
568/// - Audit logging
569pub trait AfterToolCallback: Send + Sync {
570    /// Called after a tool finishes
571    fn on_after_tool(&self, ctx: &ToolCallbackContext, result: &ToolCallbackResult);
572}
573
574// =============================================================================
575// Combined Callback Trait
576// =============================================================================
577
578/// Combined callback trait for all execution events
579///
580/// Implement this trait to handle all callback types with a single implementation.
581/// This is useful for telemetry backends that need to track all execution events.
582pub trait ExecutionCallbacks:
583    BeforeAgentCallback
584    + AfterAgentCallback
585    + BeforeModelCallback
586    + AfterModelCallback
587    + BeforeToolCallback
588    + AfterToolCallback
589{
590}
591
592// Blanket implementation for any type that implements all callback traits
593impl<T> ExecutionCallbacks for T where
594    T: BeforeAgentCallback
595        + AfterAgentCallback
596        + BeforeModelCallback
597        + AfterModelCallback
598        + BeforeToolCallback
599        + AfterToolCallback
600{
601}
602
603// =============================================================================
604// Callback Registry
605// =============================================================================
606
607/// Registry for callback handlers
608///
609/// Allows registering multiple callbacks that will all be invoked.
610#[derive(Default)]
611pub struct CallbackRegistry {
612    before_agent: Vec<Arc<dyn BeforeAgentCallback>>,
613    after_agent: Vec<Arc<dyn AfterAgentCallback>>,
614    before_model: Vec<Arc<dyn BeforeModelCallback>>,
615    after_model: Vec<Arc<dyn AfterModelCallback>>,
616    before_tool: Vec<Arc<dyn BeforeToolCallback>>,
617    after_tool: Vec<Arc<dyn AfterToolCallback>>,
618}
619
620impl CallbackRegistry {
621    /// Create a new empty registry
622    pub fn new() -> Self {
623        Self::default()
624    }
625
626    /// Register a before-agent callback
627    pub fn on_before_agent<C: BeforeAgentCallback + 'static>(&mut self, callback: C) -> &mut Self {
628        self.before_agent.push(Arc::new(callback));
629        self
630    }
631
632    /// Register an after-agent callback
633    pub fn on_after_agent<C: AfterAgentCallback + 'static>(&mut self, callback: C) -> &mut Self {
634        self.after_agent.push(Arc::new(callback));
635        self
636    }
637
638    /// Register a before-model callback
639    pub fn on_before_model<C: BeforeModelCallback + 'static>(&mut self, callback: C) -> &mut Self {
640        self.before_model.push(Arc::new(callback));
641        self
642    }
643
644    /// Register an after-model callback
645    pub fn on_after_model<C: AfterModelCallback + 'static>(&mut self, callback: C) -> &mut Self {
646        self.after_model.push(Arc::new(callback));
647        self
648    }
649
650    /// Register a before-tool callback
651    pub fn on_before_tool<C: BeforeToolCallback + 'static>(&mut self, callback: C) -> &mut Self {
652        self.before_tool.push(Arc::new(callback));
653        self
654    }
655
656    /// Register an after-tool callback
657    pub fn on_after_tool<C: AfterToolCallback + 'static>(&mut self, callback: C) -> &mut Self {
658        self.after_tool.push(Arc::new(callback));
659        self
660    }
661
662    /// Register a combined callback handler for all events
663    pub fn register_all<C>(&mut self, callback: Arc<C>) -> &mut Self
664    where
665        C: ExecutionCallbacks + 'static,
666    {
667        self.before_agent.push(callback.clone());
668        self.after_agent.push(callback.clone());
669        self.before_model.push(callback.clone());
670        self.after_model.push(callback.clone());
671        self.before_tool.push(callback.clone());
672        self.after_tool.push(callback);
673        self
674    }
675
676    /// Invoke all before-agent callbacks
677    pub fn invoke_before_agent(&self, ctx: &AgentCallbackContext) {
678        for callback in &self.before_agent {
679            callback.on_before_agent(ctx);
680        }
681    }
682
683    /// Invoke all after-agent callbacks
684    pub fn invoke_after_agent(&self, ctx: &AgentCallbackContext, result: &AgentCallbackResult) {
685        for callback in &self.after_agent {
686            callback.on_after_agent(ctx, result);
687        }
688    }
689
690    /// Invoke all before-model callbacks
691    pub fn invoke_before_model(&self, ctx: &ModelCallbackContext) {
692        for callback in &self.before_model {
693            callback.on_before_model(ctx);
694        }
695    }
696
697    /// Invoke all after-model callbacks
698    pub fn invoke_after_model(&self, ctx: &ModelCallbackContext, result: &ModelCallbackResult) {
699        for callback in &self.after_model {
700            callback.on_after_model(ctx, result);
701        }
702    }
703
704    /// Invoke all before-tool callbacks
705    pub fn invoke_before_tool(&self, ctx: &ToolCallbackContext) {
706        for callback in &self.before_tool {
707            callback.on_before_tool(ctx);
708        }
709    }
710
711    /// Invoke all after-tool callbacks
712    pub fn invoke_after_tool(&self, ctx: &ToolCallbackContext, result: &ToolCallbackResult) {
713        for callback in &self.after_tool {
714            callback.on_after_tool(ctx, result);
715        }
716    }
717}
718
719// =============================================================================
720// No-op Implementation
721// =============================================================================
722
723/// A no-op callback implementation that does nothing
724///
725/// Useful as a default or for testing.
726#[derive(Debug, Clone, Copy, Default)]
727pub struct NoOpCallbacks;
728
729impl BeforeAgentCallback for NoOpCallbacks {
730    fn on_before_agent(&self, _ctx: &AgentCallbackContext) {}
731}
732
733impl AfterAgentCallback for NoOpCallbacks {
734    fn on_after_agent(&self, _ctx: &AgentCallbackContext, _result: &AgentCallbackResult) {}
735}
736
737impl BeforeModelCallback for NoOpCallbacks {
738    fn on_before_model(&self, _ctx: &ModelCallbackContext) {}
739}
740
741impl AfterModelCallback for NoOpCallbacks {
742    fn on_after_model(&self, _ctx: &ModelCallbackContext, _result: &ModelCallbackResult) {}
743}
744
745impl BeforeToolCallback for NoOpCallbacks {
746    fn on_before_tool(&self, _ctx: &ToolCallbackContext) {}
747}
748
749impl AfterToolCallback for NoOpCallbacks {
750    fn on_after_tool(&self, _ctx: &ToolCallbackContext, _result: &ToolCallbackResult) {}
751}
752
753// =============================================================================
754// Tests
755// =============================================================================
756
757#[cfg(test)]
758mod tests {
759    use super::*;
760    use std::sync::atomic::{AtomicU32, Ordering};
761    use std::time::Duration;
762
763    // =========================================================================
764    // AgentCallbackContext Tests
765    // =========================================================================
766
767    #[test]
768    fn test_agent_callback_context_new() {
769        let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent");
770        assert_eq!(ctx.agent_name, "test_agent");
771        assert!(ctx.step_id.is_none());
772        assert!(ctx.agent_description.is_none());
773        assert!(ctx.input_preview.is_none());
774    }
775
776    #[test]
777    fn test_agent_callback_context_with_step() {
778        let step_id = StepId::new();
779        let ctx =
780            AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_step(step_id.clone());
781        assert!(ctx.step_id.is_some());
782        assert_eq!(ctx.step_id.unwrap().as_str(), step_id.as_str());
783    }
784
785    #[test]
786    fn test_agent_callback_context_with_description() {
787        let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
788            .with_description("A test agent");
789        assert_eq!(ctx.agent_description, Some("A test agent".to_string()));
790    }
791
792    #[test]
793    fn test_agent_callback_context_input_preview_truncation() {
794        let long_input = "x".repeat(1000);
795        let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
796            .with_input_preview(&long_input);
797        let preview = ctx.input_preview.unwrap();
798        assert!(preview.len() <= 500);
799        assert!(preview.ends_with("..."));
800    }
801
802    #[test]
803    fn test_agent_callback_context_short_input_not_truncated() {
804        let short_input = "hello world";
805        let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
806            .with_input_preview(short_input);
807        assert_eq!(ctx.input_preview, Some("hello world".to_string()));
808    }
809
810    #[test]
811    fn test_agent_callback_context_with_tenant() {
812        let ctx =
813            AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_tenant("tenant_123");
814        assert_eq!(ctx.tenant_id, Some("tenant_123".to_string()));
815    }
816
817    #[test]
818    fn test_agent_callback_context_with_user() {
819        let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_user("user_456");
820        assert_eq!(ctx.user_id, Some("user_456".to_string()));
821    }
822
823    #[test]
824    fn test_agent_callback_context_with_trace() {
825        let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
826            .with_trace("trace_abc", Some("span_xyz".to_string()));
827        assert_eq!(ctx.trace_id, Some("trace_abc".to_string()));
828        assert_eq!(ctx.parent_span_id, Some("span_xyz".to_string()));
829    }
830
831    #[test]
832    fn test_agent_callback_context_builder_chain() {
833        let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
834            .with_step(StepId::new())
835            .with_description("Description")
836            .with_input_preview("Input")
837            .with_tenant("tenant")
838            .with_user("user")
839            .with_trace("trace", None);
840
841        assert!(ctx.step_id.is_some());
842        assert!(ctx.agent_description.is_some());
843        assert!(ctx.input_preview.is_some());
844        assert!(ctx.tenant_id.is_some());
845        assert!(ctx.user_id.is_some());
846        assert!(ctx.trace_id.is_some());
847    }
848
849    #[test]
850    fn test_agent_callback_context_serde() {
851        let ctx = AgentCallbackContext::new(ExecutionId::from_string("exec_test"), "test_agent")
852            .with_description("Test description");
853        let json = serde_json::to_string(&ctx).unwrap();
854        let parsed: AgentCallbackContext = serde_json::from_str(&json).unwrap();
855        assert_eq!(ctx.agent_name, parsed.agent_name);
856        assert_eq!(ctx.agent_description, parsed.agent_description);
857    }
858
859    // =========================================================================
860    // AgentCallbackResult Tests
861    // =========================================================================
862
863    #[test]
864    fn test_agent_callback_result_success() {
865        let result = AgentCallbackResult::success(Duration::from_millis(100));
866        assert!(result.success);
867        assert_eq!(result.duration, Duration::from_millis(100));
868        assert!(result.error.is_none());
869    }
870
871    #[test]
872    fn test_agent_callback_result_failure() {
873        let result =
874            AgentCallbackResult::failure(Duration::from_millis(50), "Something went wrong");
875        assert!(!result.success);
876        assert_eq!(result.error, Some("Something went wrong".to_string()));
877    }
878
879    #[test]
880    fn test_agent_callback_result_with_output_preview() {
881        let result = AgentCallbackResult::success(Duration::from_millis(100))
882            .with_output_preview("Output here");
883        assert_eq!(result.output_preview, Some("Output here".to_string()));
884    }
885
886    #[test]
887    fn test_agent_callback_result_output_truncation() {
888        let long_output = "y".repeat(1000);
889        let result = AgentCallbackResult::success(Duration::from_millis(100))
890            .with_output_preview(&long_output);
891        let preview = result.output_preview.unwrap();
892        assert!(preview.len() <= 500);
893        assert!(preview.ends_with("..."));
894    }
895
896    #[test]
897    fn test_agent_callback_result_with_stats() {
898        let result = AgentCallbackResult::success(Duration::from_millis(100)).with_stats(5, 3, 2);
899        assert_eq!(result.steps_executed, Some(5));
900        assert_eq!(result.tool_calls, Some(3));
901        assert_eq!(result.model_calls, Some(2));
902    }
903
904    // =========================================================================
905    // ModelCallbackContext Tests
906    // =========================================================================
907
908    #[test]
909    fn test_model_callback_context_new() {
910        let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
911        assert_eq!(ctx.provider, "openai");
912        assert_eq!(ctx.model, "gpt-4");
913        assert!(ctx.step_id.is_none());
914    }
915
916    #[test]
917    fn test_model_callback_context_with_params() {
918        let ctx = ModelCallbackContext::new(ExecutionId::new(), "anthropic", "claude-3-opus")
919            .with_params(Some(0.7), Some(4096));
920        assert_eq!(ctx.temperature, Some(0.7));
921        assert_eq!(ctx.max_tokens, Some(4096));
922    }
923
924    #[test]
925    fn test_model_callback_context_with_request_info() {
926        let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4")
927            .with_request_info(5, true, true);
928        assert_eq!(ctx.message_count, 5);
929        assert!(ctx.streaming);
930        assert!(ctx.tools_enabled);
931    }
932
933    #[test]
934    fn test_model_callback_context_with_step() {
935        let step_id = StepId::new();
936        let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4")
937            .with_step(step_id.clone());
938        assert!(ctx.step_id.is_some());
939    }
940
941    #[test]
942    fn test_model_callback_context_serde() {
943        let ctx =
944            ModelCallbackContext::new(ExecutionId::from_string("exec_test"), "openai", "gpt-4")
945                .with_params(Some(0.5), Some(1000));
946        let json = serde_json::to_string(&ctx).unwrap();
947        let parsed: ModelCallbackContext = serde_json::from_str(&json).unwrap();
948        assert_eq!(ctx.provider, parsed.provider);
949        assert_eq!(ctx.model, parsed.model);
950        assert_eq!(ctx.temperature, parsed.temperature);
951    }
952
953    // =========================================================================
954    // ModelCallbackResult Tests
955    // =========================================================================
956
957    #[test]
958    fn test_model_callback_result_success() {
959        let result = ModelCallbackResult::success(Duration::from_millis(500));
960        assert!(result.success);
961        assert!(result.error.is_none());
962    }
963
964    #[test]
965    fn test_model_callback_result_failure() {
966        let result =
967            ModelCallbackResult::failure(Duration::from_millis(100), "Rate limit exceeded");
968        assert!(!result.success);
969        assert_eq!(result.error, Some("Rate limit exceeded".to_string()));
970    }
971
972    #[test]
973    fn test_model_callback_result_with_tokens() {
974        let result =
975            ModelCallbackResult::success(Duration::from_millis(500)).with_tokens(1000, 500);
976        assert_eq!(result.input_tokens, Some(1000));
977        assert_eq!(result.output_tokens, Some(500));
978        assert_eq!(result.total_tokens, Some(1500));
979    }
980
981    #[test]
982    fn test_model_callback_result_with_finish_info() {
983        let result = ModelCallbackResult::success(Duration::from_millis(500))
984            .with_finish_info("tool_use", 2);
985        assert_eq!(result.finish_reason, Some("tool_use".to_string()));
986        assert_eq!(result.tool_calls_count, Some(2));
987    }
988
989    #[test]
990    fn test_model_callback_result_cached() {
991        let result = ModelCallbackResult::success(Duration::from_millis(10)).cached();
992        assert!(result.cached);
993    }
994
995    // =========================================================================
996    // ToolCallbackContext Tests
997    // =========================================================================
998
999    #[test]
1000    fn test_tool_callback_context_new() {
1001        let ctx = ToolCallbackContext::new(ExecutionId::new(), "read_file");
1002        assert_eq!(ctx.tool_name, "read_file");
1003        assert!(ctx.requires_network); // Default is true
1004    }
1005
1006    #[test]
1007    fn test_tool_callback_context_with_tool_info() {
1008        let ctx = ToolCallbackContext::new(ExecutionId::new(), "calculator")
1009            .with_tool_info(Some("Performs calculations".to_string()), false);
1010        assert_eq!(
1011            ctx.tool_description,
1012            Some("Performs calculations".to_string())
1013        );
1014        assert!(!ctx.requires_network);
1015    }
1016
1017    #[test]
1018    fn test_tool_callback_context_with_args_preview() {
1019        let ctx = ToolCallbackContext::new(ExecutionId::new(), "search")
1020            .with_args_preview(r#"{"query": "rust programming"}"#);
1021        assert!(ctx.args_preview.is_some());
1022    }
1023
1024    #[test]
1025    fn test_tool_callback_context_args_truncation() {
1026        let long_args = "z".repeat(1000);
1027        let ctx =
1028            ToolCallbackContext::new(ExecutionId::new(), "tool").with_args_preview(&long_args);
1029        let preview = ctx.args_preview.unwrap();
1030        assert!(preview.len() <= 500);
1031        assert!(preview.ends_with("..."));
1032    }
1033
1034    #[test]
1035    fn test_tool_callback_context_serde() {
1036        let ctx = ToolCallbackContext::new(ExecutionId::from_string("exec_test"), "my_tool")
1037            .with_tool_info(Some("A tool".to_string()), false);
1038        let json = serde_json::to_string(&ctx).unwrap();
1039        let parsed: ToolCallbackContext = serde_json::from_str(&json).unwrap();
1040        assert_eq!(ctx.tool_name, parsed.tool_name);
1041        assert_eq!(ctx.tool_description, parsed.tool_description);
1042    }
1043
1044    // =========================================================================
1045    // ToolCallbackResult Tests
1046    // =========================================================================
1047
1048    #[test]
1049    fn test_tool_callback_result_success() {
1050        let result = ToolCallbackResult::success(Duration::from_millis(50));
1051        assert!(result.success);
1052        assert!(!result.blocked);
1053    }
1054
1055    #[test]
1056    fn test_tool_callback_result_failure() {
1057        let result = ToolCallbackResult::failure(Duration::from_millis(20), "File not found");
1058        assert!(!result.success);
1059        assert_eq!(result.error, Some("File not found".to_string()));
1060    }
1061
1062    #[test]
1063    fn test_tool_callback_result_blocked() {
1064        let result =
1065            ToolCallbackResult::blocked(Duration::from_millis(5), "Tool disabled by policy");
1066        assert!(!result.success);
1067        assert!(result.blocked);
1068        assert_eq!(
1069            result.blocked_reason,
1070            Some("Tool disabled by policy".to_string())
1071        );
1072    }
1073
1074    #[test]
1075    fn test_tool_callback_result_with_output_preview() {
1076        let result = ToolCallbackResult::success(Duration::from_millis(50))
1077            .with_output_preview("Result: 42");
1078        assert_eq!(result.output_preview, Some("Result: 42".to_string()));
1079    }
1080
1081    // =========================================================================
1082    // NoOpCallbacks Tests
1083    // =========================================================================
1084
1085    #[test]
1086    fn test_noop_callbacks_compiles() {
1087        let callbacks = NoOpCallbacks;
1088        let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
1089        let agent_result = AgentCallbackResult::success(Duration::from_millis(100));
1090        let model_ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
1091        let model_result = ModelCallbackResult::success(Duration::from_millis(500));
1092        let tool_ctx = ToolCallbackContext::new(ExecutionId::new(), "tool");
1093        let tool_result = ToolCallbackResult::success(Duration::from_millis(50));
1094
1095        // These should compile and not panic
1096        callbacks.on_before_agent(&agent_ctx);
1097        callbacks.on_after_agent(&agent_ctx, &agent_result);
1098        callbacks.on_before_model(&model_ctx);
1099        callbacks.on_after_model(&model_ctx, &model_result);
1100        callbacks.on_before_tool(&tool_ctx);
1101        callbacks.on_after_tool(&tool_ctx, &tool_result);
1102    }
1103
1104    // =========================================================================
1105    // CallbackRegistry Tests
1106    // =========================================================================
1107
1108    struct CountingCallback {
1109        before_agent_count: AtomicU32,
1110        after_agent_count: AtomicU32,
1111        before_model_count: AtomicU32,
1112        after_model_count: AtomicU32,
1113        before_tool_count: AtomicU32,
1114        after_tool_count: AtomicU32,
1115    }
1116
1117    impl CountingCallback {
1118        fn new() -> Self {
1119            Self {
1120                before_agent_count: AtomicU32::new(0),
1121                after_agent_count: AtomicU32::new(0),
1122                before_model_count: AtomicU32::new(0),
1123                after_model_count: AtomicU32::new(0),
1124                before_tool_count: AtomicU32::new(0),
1125                after_tool_count: AtomicU32::new(0),
1126            }
1127        }
1128    }
1129
1130    impl BeforeAgentCallback for CountingCallback {
1131        fn on_before_agent(&self, _ctx: &AgentCallbackContext) {
1132            self.before_agent_count.fetch_add(1, Ordering::SeqCst);
1133        }
1134    }
1135
1136    impl AfterAgentCallback for CountingCallback {
1137        fn on_after_agent(&self, _ctx: &AgentCallbackContext, _result: &AgentCallbackResult) {
1138            self.after_agent_count.fetch_add(1, Ordering::SeqCst);
1139        }
1140    }
1141
1142    impl BeforeModelCallback for CountingCallback {
1143        fn on_before_model(&self, _ctx: &ModelCallbackContext) {
1144            self.before_model_count.fetch_add(1, Ordering::SeqCst);
1145        }
1146    }
1147
1148    impl AfterModelCallback for CountingCallback {
1149        fn on_after_model(&self, _ctx: &ModelCallbackContext, _result: &ModelCallbackResult) {
1150            self.after_model_count.fetch_add(1, Ordering::SeqCst);
1151        }
1152    }
1153
1154    impl BeforeToolCallback for CountingCallback {
1155        fn on_before_tool(&self, _ctx: &ToolCallbackContext) {
1156            self.before_tool_count.fetch_add(1, Ordering::SeqCst);
1157        }
1158    }
1159
1160    impl AfterToolCallback for CountingCallback {
1161        fn on_after_tool(&self, _ctx: &ToolCallbackContext, _result: &ToolCallbackResult) {
1162            self.after_tool_count.fetch_add(1, Ordering::SeqCst);
1163        }
1164    }
1165
1166    #[test]
1167    fn test_callback_registry_new() {
1168        let registry = CallbackRegistry::new();
1169        // Should not panic
1170        registry.invoke_before_agent(&AgentCallbackContext::new(ExecutionId::new(), "test"));
1171    }
1172
1173    #[test]
1174    fn test_callback_registry_register_all() {
1175        let callback = Arc::new(CountingCallback::new());
1176        let mut registry = CallbackRegistry::new();
1177        registry.register_all(callback.clone());
1178
1179        let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
1180        let agent_result = AgentCallbackResult::success(Duration::from_millis(100));
1181        let model_ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
1182        let model_result = ModelCallbackResult::success(Duration::from_millis(500));
1183        let tool_ctx = ToolCallbackContext::new(ExecutionId::new(), "tool");
1184        let tool_result = ToolCallbackResult::success(Duration::from_millis(50));
1185
1186        registry.invoke_before_agent(&agent_ctx);
1187        registry.invoke_after_agent(&agent_ctx, &agent_result);
1188        registry.invoke_before_model(&model_ctx);
1189        registry.invoke_after_model(&model_ctx, &model_result);
1190        registry.invoke_before_tool(&tool_ctx);
1191        registry.invoke_after_tool(&tool_ctx, &tool_result);
1192
1193        assert_eq!(callback.before_agent_count.load(Ordering::SeqCst), 1);
1194        assert_eq!(callback.after_agent_count.load(Ordering::SeqCst), 1);
1195        assert_eq!(callback.before_model_count.load(Ordering::SeqCst), 1);
1196        assert_eq!(callback.after_model_count.load(Ordering::SeqCst), 1);
1197        assert_eq!(callback.before_tool_count.load(Ordering::SeqCst), 1);
1198        assert_eq!(callback.after_tool_count.load(Ordering::SeqCst), 1);
1199    }
1200
1201    #[test]
1202    fn test_callback_registry_multiple_callbacks() {
1203        let callback1 = Arc::new(CountingCallback::new());
1204        let callback2 = Arc::new(CountingCallback::new());
1205        let mut registry = CallbackRegistry::new();
1206        registry.register_all(callback1.clone());
1207        registry.register_all(callback2.clone());
1208
1209        let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
1210        registry.invoke_before_agent(&agent_ctx);
1211
1212        // Both callbacks should have been invoked
1213        assert_eq!(callback1.before_agent_count.load(Ordering::SeqCst), 1);
1214        assert_eq!(callback2.before_agent_count.load(Ordering::SeqCst), 1);
1215    }
1216
1217    #[test]
1218    fn test_callback_registry_individual_registration() {
1219        struct SimpleBeforeAgent;
1220        impl BeforeAgentCallback for SimpleBeforeAgent {
1221            fn on_before_agent(&self, _ctx: &AgentCallbackContext) {}
1222        }
1223
1224        let mut registry = CallbackRegistry::new();
1225        registry.on_before_agent(SimpleBeforeAgent);
1226
1227        // Should not panic
1228        registry.invoke_before_agent(&AgentCallbackContext::new(ExecutionId::new(), "test"));
1229    }
1230
1231    // =========================================================================
1232    // Combined Trait Tests
1233    // =========================================================================
1234
1235    #[test]
1236    fn test_execution_callbacks_trait() {
1237        // NoOpCallbacks should implement ExecutionCallbacks
1238        fn accept_execution_callbacks<T: ExecutionCallbacks>(_: &T) {}
1239        let noop = NoOpCallbacks;
1240        accept_execution_callbacks(&noop);
1241    }
1242
1243    #[test]
1244    fn test_counting_callback_implements_execution_callbacks() {
1245        // CountingCallback should implement ExecutionCallbacks
1246        fn accept_execution_callbacks<T: ExecutionCallbacks>(_: &T) {}
1247        let counting = CountingCallback::new();
1248        accept_execution_callbacks(&counting);
1249    }
1250}