1use crate::kernel::{ExecutionId, StepId};
52use serde::{Deserialize, Serialize};
53use std::sync::Arc;
54use std::time::Duration;
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct AgentCallbackContext {
65 pub execution_id: ExecutionId,
67 pub step_id: Option<StepId>,
69 pub agent_name: String,
71 pub agent_description: Option<String>,
73 pub input_preview: Option<String>,
75 pub tenant_id: Option<String>,
77 pub user_id: Option<String>,
79 pub trace_id: Option<String>,
81 pub parent_span_id: Option<String>,
83}
84
85impl AgentCallbackContext {
86 pub fn new(execution_id: ExecutionId, agent_name: impl Into<String>) -> Self {
88 Self {
89 execution_id,
90 step_id: None,
91 agent_name: agent_name.into(),
92 agent_description: None,
93 input_preview: None,
94 tenant_id: None,
95 user_id: None,
96 trace_id: None,
97 parent_span_id: None,
98 }
99 }
100
101 pub fn with_step(mut self, step_id: StepId) -> Self {
103 self.step_id = Some(step_id);
104 self
105 }
106
107 pub fn with_description(mut self, description: impl Into<String>) -> Self {
109 self.agent_description = Some(description.into());
110 self
111 }
112
113 pub fn with_input_preview(mut self, input: impl Into<String>) -> Self {
115 let input = input.into();
116 self.input_preview = Some(if input.len() > 500 {
118 format!("{}...", &input[..497])
119 } else {
120 input
121 });
122 self
123 }
124
125 pub fn with_tenant(mut self, tenant_id: impl Into<String>) -> Self {
127 self.tenant_id = Some(tenant_id.into());
128 self
129 }
130
131 pub fn with_user(mut self, user_id: impl Into<String>) -> Self {
133 self.user_id = Some(user_id.into());
134 self
135 }
136
137 pub fn with_trace(
139 mut self,
140 trace_id: impl Into<String>,
141 parent_span_id: Option<String>,
142 ) -> Self {
143 self.trace_id = Some(trace_id.into());
144 self.parent_span_id = parent_span_id;
145 self
146 }
147}
148
149#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct AgentCallbackResult {
152 pub success: bool,
154 pub duration: Duration,
156 pub output_preview: Option<String>,
158 pub error: Option<String>,
160 pub steps_executed: Option<u32>,
162 pub tool_calls: Option<u32>,
164 pub model_calls: Option<u32>,
166}
167
168impl AgentCallbackResult {
169 pub fn success(duration: Duration) -> Self {
171 Self {
172 success: true,
173 duration,
174 output_preview: None,
175 error: None,
176 steps_executed: None,
177 tool_calls: None,
178 model_calls: None,
179 }
180 }
181
182 pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
184 Self {
185 success: false,
186 duration,
187 output_preview: None,
188 error: Some(error.into()),
189 steps_executed: None,
190 tool_calls: None,
191 model_calls: None,
192 }
193 }
194
195 pub fn with_output_preview(mut self, output: impl Into<String>) -> Self {
197 let output = output.into();
198 self.output_preview = Some(if output.len() > 500 {
199 format!("{}...", &output[..497])
200 } else {
201 output
202 });
203 self
204 }
205
206 pub fn with_stats(mut self, steps: u32, tool_calls: u32, model_calls: u32) -> Self {
208 self.steps_executed = Some(steps);
209 self.tool_calls = Some(tool_calls);
210 self.model_calls = Some(model_calls);
211 self
212 }
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct ModelCallbackContext {
220 pub execution_id: ExecutionId,
222 pub step_id: Option<StepId>,
224 pub provider: String,
226 pub model: String,
228 pub temperature: Option<f32>,
230 pub max_tokens: Option<u32>,
232 pub message_count: usize,
234 pub streaming: bool,
236 pub tools_enabled: bool,
238 pub trace_id: Option<String>,
240}
241
242impl ModelCallbackContext {
243 pub fn new(
245 execution_id: ExecutionId,
246 provider: impl Into<String>,
247 model: impl Into<String>,
248 ) -> Self {
249 Self {
250 execution_id,
251 step_id: None,
252 provider: provider.into(),
253 model: model.into(),
254 temperature: None,
255 max_tokens: None,
256 message_count: 0,
257 streaming: false,
258 tools_enabled: false,
259 trace_id: None,
260 }
261 }
262
263 pub fn with_step(mut self, step_id: StepId) -> Self {
265 self.step_id = Some(step_id);
266 self
267 }
268
269 pub fn with_params(mut self, temperature: Option<f32>, max_tokens: Option<u32>) -> Self {
271 self.temperature = temperature;
272 self.max_tokens = max_tokens;
273 self
274 }
275
276 pub fn with_request_info(
278 mut self,
279 message_count: usize,
280 streaming: bool,
281 tools_enabled: bool,
282 ) -> Self {
283 self.message_count = message_count;
284 self.streaming = streaming;
285 self.tools_enabled = tools_enabled;
286 self
287 }
288
289 pub fn with_trace(mut self, trace_id: impl Into<String>) -> Self {
291 self.trace_id = Some(trace_id.into());
292 self
293 }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
298pub struct ModelCallbackResult {
299 pub success: bool,
301 pub duration: Duration,
303 pub input_tokens: Option<u32>,
305 pub output_tokens: Option<u32>,
307 pub total_tokens: Option<u32>,
309 pub finish_reason: Option<String>,
311 pub tool_calls_count: Option<u32>,
313 pub error: Option<String>,
315 pub cached: bool,
317}
318
319impl ModelCallbackResult {
320 pub fn success(duration: Duration) -> Self {
322 Self {
323 success: true,
324 duration,
325 input_tokens: None,
326 output_tokens: None,
327 total_tokens: None,
328 finish_reason: None,
329 tool_calls_count: None,
330 error: None,
331 cached: false,
332 }
333 }
334
335 pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
337 Self {
338 success: false,
339 duration,
340 input_tokens: None,
341 output_tokens: None,
342 total_tokens: None,
343 finish_reason: None,
344 tool_calls_count: None,
345 error: Some(error.into()),
346 cached: false,
347 }
348 }
349
350 pub fn with_tokens(mut self, input: u32, output: u32) -> Self {
352 self.input_tokens = Some(input);
353 self.output_tokens = Some(output);
354 self.total_tokens = Some(input + output);
355 self
356 }
357
358 pub fn with_finish_info(mut self, reason: impl Into<String>, tool_calls: u32) -> Self {
360 self.finish_reason = Some(reason.into());
361 self.tool_calls_count = Some(tool_calls);
362 self
363 }
364
365 pub fn cached(mut self) -> Self {
367 self.cached = true;
368 self
369 }
370}
371
372#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct ToolCallbackContext {
377 pub execution_id: ExecutionId,
379 pub step_id: Option<StepId>,
381 pub tool_name: String,
383 pub tool_description: Option<String>,
385 pub requires_network: bool,
387 pub args_preview: Option<String>,
389 pub trace_id: Option<String>,
391}
392
393impl ToolCallbackContext {
394 pub fn new(execution_id: ExecutionId, tool_name: impl Into<String>) -> Self {
396 Self {
397 execution_id,
398 step_id: None,
399 tool_name: tool_name.into(),
400 tool_description: None,
401 requires_network: true,
402 args_preview: None,
403 trace_id: None,
404 }
405 }
406
407 pub fn with_step(mut self, step_id: StepId) -> Self {
409 self.step_id = Some(step_id);
410 self
411 }
412
413 pub fn with_tool_info(mut self, description: Option<String>, requires_network: bool) -> Self {
415 self.tool_description = description;
416 self.requires_network = requires_network;
417 self
418 }
419
420 pub fn with_args_preview(mut self, args: impl Into<String>) -> Self {
422 let args = args.into();
423 self.args_preview = Some(if args.len() > 500 {
424 format!("{}...", &args[..497])
425 } else {
426 args
427 });
428 self
429 }
430
431 pub fn with_trace(mut self, trace_id: impl Into<String>) -> Self {
433 self.trace_id = Some(trace_id.into());
434 self
435 }
436}
437
438#[derive(Debug, Clone, Serialize, Deserialize)]
440pub struct ToolCallbackResult {
441 pub success: bool,
443 pub duration: Duration,
445 pub output_preview: Option<String>,
447 pub error: Option<String>,
449 pub blocked: bool,
451 pub blocked_reason: Option<String>,
453}
454
455impl ToolCallbackResult {
456 pub fn success(duration: Duration) -> Self {
458 Self {
459 success: true,
460 duration,
461 output_preview: None,
462 error: None,
463 blocked: false,
464 blocked_reason: None,
465 }
466 }
467
468 pub fn failure(duration: Duration, error: impl Into<String>) -> Self {
470 Self {
471 success: false,
472 duration,
473 output_preview: None,
474 error: Some(error.into()),
475 blocked: false,
476 blocked_reason: None,
477 }
478 }
479
480 pub fn blocked(duration: Duration, reason: impl Into<String>) -> Self {
482 Self {
483 success: false,
484 duration,
485 output_preview: None,
486 error: None,
487 blocked: true,
488 blocked_reason: Some(reason.into()),
489 }
490 }
491
492 pub fn with_output_preview(mut self, output: impl Into<String>) -> Self {
494 let output = output.into();
495 self.output_preview = Some(if output.len() > 500 {
496 format!("{}...", &output[..497])
497 } else {
498 output
499 });
500 self
501 }
502}
503
504pub trait BeforeAgentCallback: Send + Sync {
515 fn on_before_agent(&self, ctx: &AgentCallbackContext);
517}
518
519pub trait AfterAgentCallback: Send + Sync {
526 fn on_after_agent(&self, ctx: &AgentCallbackContext, result: &AgentCallbackResult);
528}
529
530pub trait BeforeModelCallback: Send + Sync {
537 fn on_before_model(&self, ctx: &ModelCallbackContext);
539}
540
541pub trait AfterModelCallback: Send + Sync {
548 fn on_after_model(&self, ctx: &ModelCallbackContext, result: &ModelCallbackResult);
550}
551
552pub trait BeforeToolCallback: Send + Sync {
559 fn on_before_tool(&self, ctx: &ToolCallbackContext);
561}
562
563pub trait AfterToolCallback: Send + Sync {
570 fn on_after_tool(&self, ctx: &ToolCallbackContext, result: &ToolCallbackResult);
572}
573
574pub trait ExecutionCallbacks:
583 BeforeAgentCallback
584 + AfterAgentCallback
585 + BeforeModelCallback
586 + AfterModelCallback
587 + BeforeToolCallback
588 + AfterToolCallback
589{
590}
591
592impl<T> ExecutionCallbacks for T where
594 T: BeforeAgentCallback
595 + AfterAgentCallback
596 + BeforeModelCallback
597 + AfterModelCallback
598 + BeforeToolCallback
599 + AfterToolCallback
600{
601}
602
603#[derive(Default)]
611pub struct CallbackRegistry {
612 before_agent: Vec<Arc<dyn BeforeAgentCallback>>,
613 after_agent: Vec<Arc<dyn AfterAgentCallback>>,
614 before_model: Vec<Arc<dyn BeforeModelCallback>>,
615 after_model: Vec<Arc<dyn AfterModelCallback>>,
616 before_tool: Vec<Arc<dyn BeforeToolCallback>>,
617 after_tool: Vec<Arc<dyn AfterToolCallback>>,
618}
619
620impl CallbackRegistry {
621 pub fn new() -> Self {
623 Self::default()
624 }
625
626 pub fn on_before_agent<C: BeforeAgentCallback + 'static>(&mut self, callback: C) -> &mut Self {
628 self.before_agent.push(Arc::new(callback));
629 self
630 }
631
632 pub fn on_after_agent<C: AfterAgentCallback + 'static>(&mut self, callback: C) -> &mut Self {
634 self.after_agent.push(Arc::new(callback));
635 self
636 }
637
638 pub fn on_before_model<C: BeforeModelCallback + 'static>(&mut self, callback: C) -> &mut Self {
640 self.before_model.push(Arc::new(callback));
641 self
642 }
643
644 pub fn on_after_model<C: AfterModelCallback + 'static>(&mut self, callback: C) -> &mut Self {
646 self.after_model.push(Arc::new(callback));
647 self
648 }
649
650 pub fn on_before_tool<C: BeforeToolCallback + 'static>(&mut self, callback: C) -> &mut Self {
652 self.before_tool.push(Arc::new(callback));
653 self
654 }
655
656 pub fn on_after_tool<C: AfterToolCallback + 'static>(&mut self, callback: C) -> &mut Self {
658 self.after_tool.push(Arc::new(callback));
659 self
660 }
661
662 pub fn register_all<C>(&mut self, callback: Arc<C>) -> &mut Self
664 where
665 C: ExecutionCallbacks + 'static,
666 {
667 self.before_agent.push(callback.clone());
668 self.after_agent.push(callback.clone());
669 self.before_model.push(callback.clone());
670 self.after_model.push(callback.clone());
671 self.before_tool.push(callback.clone());
672 self.after_tool.push(callback);
673 self
674 }
675
676 pub fn invoke_before_agent(&self, ctx: &AgentCallbackContext) {
678 for callback in &self.before_agent {
679 callback.on_before_agent(ctx);
680 }
681 }
682
683 pub fn invoke_after_agent(&self, ctx: &AgentCallbackContext, result: &AgentCallbackResult) {
685 for callback in &self.after_agent {
686 callback.on_after_agent(ctx, result);
687 }
688 }
689
690 pub fn invoke_before_model(&self, ctx: &ModelCallbackContext) {
692 for callback in &self.before_model {
693 callback.on_before_model(ctx);
694 }
695 }
696
697 pub fn invoke_after_model(&self, ctx: &ModelCallbackContext, result: &ModelCallbackResult) {
699 for callback in &self.after_model {
700 callback.on_after_model(ctx, result);
701 }
702 }
703
704 pub fn invoke_before_tool(&self, ctx: &ToolCallbackContext) {
706 for callback in &self.before_tool {
707 callback.on_before_tool(ctx);
708 }
709 }
710
711 pub fn invoke_after_tool(&self, ctx: &ToolCallbackContext, result: &ToolCallbackResult) {
713 for callback in &self.after_tool {
714 callback.on_after_tool(ctx, result);
715 }
716 }
717}
718
719#[derive(Debug, Clone, Copy, Default)]
727pub struct NoOpCallbacks;
728
729impl BeforeAgentCallback for NoOpCallbacks {
730 fn on_before_agent(&self, _ctx: &AgentCallbackContext) {}
731}
732
733impl AfterAgentCallback for NoOpCallbacks {
734 fn on_after_agent(&self, _ctx: &AgentCallbackContext, _result: &AgentCallbackResult) {}
735}
736
737impl BeforeModelCallback for NoOpCallbacks {
738 fn on_before_model(&self, _ctx: &ModelCallbackContext) {}
739}
740
741impl AfterModelCallback for NoOpCallbacks {
742 fn on_after_model(&self, _ctx: &ModelCallbackContext, _result: &ModelCallbackResult) {}
743}
744
745impl BeforeToolCallback for NoOpCallbacks {
746 fn on_before_tool(&self, _ctx: &ToolCallbackContext) {}
747}
748
749impl AfterToolCallback for NoOpCallbacks {
750 fn on_after_tool(&self, _ctx: &ToolCallbackContext, _result: &ToolCallbackResult) {}
751}
752
753#[cfg(test)]
758mod tests {
759 use super::*;
760 use std::sync::atomic::{AtomicU32, Ordering};
761 use std::time::Duration;
762
763 #[test]
768 fn test_agent_callback_context_new() {
769 let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent");
770 assert_eq!(ctx.agent_name, "test_agent");
771 assert!(ctx.step_id.is_none());
772 assert!(ctx.agent_description.is_none());
773 assert!(ctx.input_preview.is_none());
774 }
775
776 #[test]
777 fn test_agent_callback_context_with_step() {
778 let step_id = StepId::new();
779 let ctx =
780 AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_step(step_id.clone());
781 assert!(ctx.step_id.is_some());
782 assert_eq!(ctx.step_id.unwrap().as_str(), step_id.as_str());
783 }
784
785 #[test]
786 fn test_agent_callback_context_with_description() {
787 let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
788 .with_description("A test agent");
789 assert_eq!(ctx.agent_description, Some("A test agent".to_string()));
790 }
791
792 #[test]
793 fn test_agent_callback_context_input_preview_truncation() {
794 let long_input = "x".repeat(1000);
795 let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
796 .with_input_preview(&long_input);
797 let preview = ctx.input_preview.unwrap();
798 assert!(preview.len() <= 500);
799 assert!(preview.ends_with("..."));
800 }
801
802 #[test]
803 fn test_agent_callback_context_short_input_not_truncated() {
804 let short_input = "hello world";
805 let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
806 .with_input_preview(short_input);
807 assert_eq!(ctx.input_preview, Some("hello world".to_string()));
808 }
809
810 #[test]
811 fn test_agent_callback_context_with_tenant() {
812 let ctx =
813 AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_tenant("tenant_123");
814 assert_eq!(ctx.tenant_id, Some("tenant_123".to_string()));
815 }
816
817 #[test]
818 fn test_agent_callback_context_with_user() {
819 let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent").with_user("user_456");
820 assert_eq!(ctx.user_id, Some("user_456".to_string()));
821 }
822
823 #[test]
824 fn test_agent_callback_context_with_trace() {
825 let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
826 .with_trace("trace_abc", Some("span_xyz".to_string()));
827 assert_eq!(ctx.trace_id, Some("trace_abc".to_string()));
828 assert_eq!(ctx.parent_span_id, Some("span_xyz".to_string()));
829 }
830
831 #[test]
832 fn test_agent_callback_context_builder_chain() {
833 let ctx = AgentCallbackContext::new(ExecutionId::new(), "test_agent")
834 .with_step(StepId::new())
835 .with_description("Description")
836 .with_input_preview("Input")
837 .with_tenant("tenant")
838 .with_user("user")
839 .with_trace("trace", None);
840
841 assert!(ctx.step_id.is_some());
842 assert!(ctx.agent_description.is_some());
843 assert!(ctx.input_preview.is_some());
844 assert!(ctx.tenant_id.is_some());
845 assert!(ctx.user_id.is_some());
846 assert!(ctx.trace_id.is_some());
847 }
848
849 #[test]
850 fn test_agent_callback_context_serde() {
851 let ctx = AgentCallbackContext::new(ExecutionId::from_string("exec_test"), "test_agent")
852 .with_description("Test description");
853 let json = serde_json::to_string(&ctx).unwrap();
854 let parsed: AgentCallbackContext = serde_json::from_str(&json).unwrap();
855 assert_eq!(ctx.agent_name, parsed.agent_name);
856 assert_eq!(ctx.agent_description, parsed.agent_description);
857 }
858
859 #[test]
864 fn test_agent_callback_result_success() {
865 let result = AgentCallbackResult::success(Duration::from_millis(100));
866 assert!(result.success);
867 assert_eq!(result.duration, Duration::from_millis(100));
868 assert!(result.error.is_none());
869 }
870
871 #[test]
872 fn test_agent_callback_result_failure() {
873 let result =
874 AgentCallbackResult::failure(Duration::from_millis(50), "Something went wrong");
875 assert!(!result.success);
876 assert_eq!(result.error, Some("Something went wrong".to_string()));
877 }
878
879 #[test]
880 fn test_agent_callback_result_with_output_preview() {
881 let result = AgentCallbackResult::success(Duration::from_millis(100))
882 .with_output_preview("Output here");
883 assert_eq!(result.output_preview, Some("Output here".to_string()));
884 }
885
886 #[test]
887 fn test_agent_callback_result_output_truncation() {
888 let long_output = "y".repeat(1000);
889 let result = AgentCallbackResult::success(Duration::from_millis(100))
890 .with_output_preview(&long_output);
891 let preview = result.output_preview.unwrap();
892 assert!(preview.len() <= 500);
893 assert!(preview.ends_with("..."));
894 }
895
896 #[test]
897 fn test_agent_callback_result_with_stats() {
898 let result = AgentCallbackResult::success(Duration::from_millis(100)).with_stats(5, 3, 2);
899 assert_eq!(result.steps_executed, Some(5));
900 assert_eq!(result.tool_calls, Some(3));
901 assert_eq!(result.model_calls, Some(2));
902 }
903
904 #[test]
909 fn test_model_callback_context_new() {
910 let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
911 assert_eq!(ctx.provider, "openai");
912 assert_eq!(ctx.model, "gpt-4");
913 assert!(ctx.step_id.is_none());
914 }
915
916 #[test]
917 fn test_model_callback_context_with_params() {
918 let ctx = ModelCallbackContext::new(ExecutionId::new(), "anthropic", "claude-3-opus")
919 .with_params(Some(0.7), Some(4096));
920 assert_eq!(ctx.temperature, Some(0.7));
921 assert_eq!(ctx.max_tokens, Some(4096));
922 }
923
924 #[test]
925 fn test_model_callback_context_with_request_info() {
926 let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4")
927 .with_request_info(5, true, true);
928 assert_eq!(ctx.message_count, 5);
929 assert!(ctx.streaming);
930 assert!(ctx.tools_enabled);
931 }
932
933 #[test]
934 fn test_model_callback_context_with_step() {
935 let step_id = StepId::new();
936 let ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4")
937 .with_step(step_id.clone());
938 assert!(ctx.step_id.is_some());
939 }
940
941 #[test]
942 fn test_model_callback_context_serde() {
943 let ctx =
944 ModelCallbackContext::new(ExecutionId::from_string("exec_test"), "openai", "gpt-4")
945 .with_params(Some(0.5), Some(1000));
946 let json = serde_json::to_string(&ctx).unwrap();
947 let parsed: ModelCallbackContext = serde_json::from_str(&json).unwrap();
948 assert_eq!(ctx.provider, parsed.provider);
949 assert_eq!(ctx.model, parsed.model);
950 assert_eq!(ctx.temperature, parsed.temperature);
951 }
952
953 #[test]
958 fn test_model_callback_result_success() {
959 let result = ModelCallbackResult::success(Duration::from_millis(500));
960 assert!(result.success);
961 assert!(result.error.is_none());
962 }
963
964 #[test]
965 fn test_model_callback_result_failure() {
966 let result =
967 ModelCallbackResult::failure(Duration::from_millis(100), "Rate limit exceeded");
968 assert!(!result.success);
969 assert_eq!(result.error, Some("Rate limit exceeded".to_string()));
970 }
971
972 #[test]
973 fn test_model_callback_result_with_tokens() {
974 let result =
975 ModelCallbackResult::success(Duration::from_millis(500)).with_tokens(1000, 500);
976 assert_eq!(result.input_tokens, Some(1000));
977 assert_eq!(result.output_tokens, Some(500));
978 assert_eq!(result.total_tokens, Some(1500));
979 }
980
981 #[test]
982 fn test_model_callback_result_with_finish_info() {
983 let result = ModelCallbackResult::success(Duration::from_millis(500))
984 .with_finish_info("tool_use", 2);
985 assert_eq!(result.finish_reason, Some("tool_use".to_string()));
986 assert_eq!(result.tool_calls_count, Some(2));
987 }
988
989 #[test]
990 fn test_model_callback_result_cached() {
991 let result = ModelCallbackResult::success(Duration::from_millis(10)).cached();
992 assert!(result.cached);
993 }
994
995 #[test]
1000 fn test_tool_callback_context_new() {
1001 let ctx = ToolCallbackContext::new(ExecutionId::new(), "read_file");
1002 assert_eq!(ctx.tool_name, "read_file");
1003 assert!(ctx.requires_network); }
1005
1006 #[test]
1007 fn test_tool_callback_context_with_tool_info() {
1008 let ctx = ToolCallbackContext::new(ExecutionId::new(), "calculator")
1009 .with_tool_info(Some("Performs calculations".to_string()), false);
1010 assert_eq!(
1011 ctx.tool_description,
1012 Some("Performs calculations".to_string())
1013 );
1014 assert!(!ctx.requires_network);
1015 }
1016
1017 #[test]
1018 fn test_tool_callback_context_with_args_preview() {
1019 let ctx = ToolCallbackContext::new(ExecutionId::new(), "search")
1020 .with_args_preview(r#"{"query": "rust programming"}"#);
1021 assert!(ctx.args_preview.is_some());
1022 }
1023
1024 #[test]
1025 fn test_tool_callback_context_args_truncation() {
1026 let long_args = "z".repeat(1000);
1027 let ctx =
1028 ToolCallbackContext::new(ExecutionId::new(), "tool").with_args_preview(&long_args);
1029 let preview = ctx.args_preview.unwrap();
1030 assert!(preview.len() <= 500);
1031 assert!(preview.ends_with("..."));
1032 }
1033
1034 #[test]
1035 fn test_tool_callback_context_serde() {
1036 let ctx = ToolCallbackContext::new(ExecutionId::from_string("exec_test"), "my_tool")
1037 .with_tool_info(Some("A tool".to_string()), false);
1038 let json = serde_json::to_string(&ctx).unwrap();
1039 let parsed: ToolCallbackContext = serde_json::from_str(&json).unwrap();
1040 assert_eq!(ctx.tool_name, parsed.tool_name);
1041 assert_eq!(ctx.tool_description, parsed.tool_description);
1042 }
1043
1044 #[test]
1049 fn test_tool_callback_result_success() {
1050 let result = ToolCallbackResult::success(Duration::from_millis(50));
1051 assert!(result.success);
1052 assert!(!result.blocked);
1053 }
1054
1055 #[test]
1056 fn test_tool_callback_result_failure() {
1057 let result = ToolCallbackResult::failure(Duration::from_millis(20), "File not found");
1058 assert!(!result.success);
1059 assert_eq!(result.error, Some("File not found".to_string()));
1060 }
1061
1062 #[test]
1063 fn test_tool_callback_result_blocked() {
1064 let result =
1065 ToolCallbackResult::blocked(Duration::from_millis(5), "Tool disabled by policy");
1066 assert!(!result.success);
1067 assert!(result.blocked);
1068 assert_eq!(
1069 result.blocked_reason,
1070 Some("Tool disabled by policy".to_string())
1071 );
1072 }
1073
1074 #[test]
1075 fn test_tool_callback_result_with_output_preview() {
1076 let result = ToolCallbackResult::success(Duration::from_millis(50))
1077 .with_output_preview("Result: 42");
1078 assert_eq!(result.output_preview, Some("Result: 42".to_string()));
1079 }
1080
1081 #[test]
1086 fn test_noop_callbacks_compiles() {
1087 let callbacks = NoOpCallbacks;
1088 let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
1089 let agent_result = AgentCallbackResult::success(Duration::from_millis(100));
1090 let model_ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
1091 let model_result = ModelCallbackResult::success(Duration::from_millis(500));
1092 let tool_ctx = ToolCallbackContext::new(ExecutionId::new(), "tool");
1093 let tool_result = ToolCallbackResult::success(Duration::from_millis(50));
1094
1095 callbacks.on_before_agent(&agent_ctx);
1097 callbacks.on_after_agent(&agent_ctx, &agent_result);
1098 callbacks.on_before_model(&model_ctx);
1099 callbacks.on_after_model(&model_ctx, &model_result);
1100 callbacks.on_before_tool(&tool_ctx);
1101 callbacks.on_after_tool(&tool_ctx, &tool_result);
1102 }
1103
1104 struct CountingCallback {
1109 before_agent_count: AtomicU32,
1110 after_agent_count: AtomicU32,
1111 before_model_count: AtomicU32,
1112 after_model_count: AtomicU32,
1113 before_tool_count: AtomicU32,
1114 after_tool_count: AtomicU32,
1115 }
1116
1117 impl CountingCallback {
1118 fn new() -> Self {
1119 Self {
1120 before_agent_count: AtomicU32::new(0),
1121 after_agent_count: AtomicU32::new(0),
1122 before_model_count: AtomicU32::new(0),
1123 after_model_count: AtomicU32::new(0),
1124 before_tool_count: AtomicU32::new(0),
1125 after_tool_count: AtomicU32::new(0),
1126 }
1127 }
1128 }
1129
1130 impl BeforeAgentCallback for CountingCallback {
1131 fn on_before_agent(&self, _ctx: &AgentCallbackContext) {
1132 self.before_agent_count.fetch_add(1, Ordering::SeqCst);
1133 }
1134 }
1135
1136 impl AfterAgentCallback for CountingCallback {
1137 fn on_after_agent(&self, _ctx: &AgentCallbackContext, _result: &AgentCallbackResult) {
1138 self.after_agent_count.fetch_add(1, Ordering::SeqCst);
1139 }
1140 }
1141
1142 impl BeforeModelCallback for CountingCallback {
1143 fn on_before_model(&self, _ctx: &ModelCallbackContext) {
1144 self.before_model_count.fetch_add(1, Ordering::SeqCst);
1145 }
1146 }
1147
1148 impl AfterModelCallback for CountingCallback {
1149 fn on_after_model(&self, _ctx: &ModelCallbackContext, _result: &ModelCallbackResult) {
1150 self.after_model_count.fetch_add(1, Ordering::SeqCst);
1151 }
1152 }
1153
1154 impl BeforeToolCallback for CountingCallback {
1155 fn on_before_tool(&self, _ctx: &ToolCallbackContext) {
1156 self.before_tool_count.fetch_add(1, Ordering::SeqCst);
1157 }
1158 }
1159
1160 impl AfterToolCallback for CountingCallback {
1161 fn on_after_tool(&self, _ctx: &ToolCallbackContext, _result: &ToolCallbackResult) {
1162 self.after_tool_count.fetch_add(1, Ordering::SeqCst);
1163 }
1164 }
1165
1166 #[test]
1167 fn test_callback_registry_new() {
1168 let registry = CallbackRegistry::new();
1169 registry.invoke_before_agent(&AgentCallbackContext::new(ExecutionId::new(), "test"));
1171 }
1172
1173 #[test]
1174 fn test_callback_registry_register_all() {
1175 let callback = Arc::new(CountingCallback::new());
1176 let mut registry = CallbackRegistry::new();
1177 registry.register_all(callback.clone());
1178
1179 let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
1180 let agent_result = AgentCallbackResult::success(Duration::from_millis(100));
1181 let model_ctx = ModelCallbackContext::new(ExecutionId::new(), "openai", "gpt-4");
1182 let model_result = ModelCallbackResult::success(Duration::from_millis(500));
1183 let tool_ctx = ToolCallbackContext::new(ExecutionId::new(), "tool");
1184 let tool_result = ToolCallbackResult::success(Duration::from_millis(50));
1185
1186 registry.invoke_before_agent(&agent_ctx);
1187 registry.invoke_after_agent(&agent_ctx, &agent_result);
1188 registry.invoke_before_model(&model_ctx);
1189 registry.invoke_after_model(&model_ctx, &model_result);
1190 registry.invoke_before_tool(&tool_ctx);
1191 registry.invoke_after_tool(&tool_ctx, &tool_result);
1192
1193 assert_eq!(callback.before_agent_count.load(Ordering::SeqCst), 1);
1194 assert_eq!(callback.after_agent_count.load(Ordering::SeqCst), 1);
1195 assert_eq!(callback.before_model_count.load(Ordering::SeqCst), 1);
1196 assert_eq!(callback.after_model_count.load(Ordering::SeqCst), 1);
1197 assert_eq!(callback.before_tool_count.load(Ordering::SeqCst), 1);
1198 assert_eq!(callback.after_tool_count.load(Ordering::SeqCst), 1);
1199 }
1200
1201 #[test]
1202 fn test_callback_registry_multiple_callbacks() {
1203 let callback1 = Arc::new(CountingCallback::new());
1204 let callback2 = Arc::new(CountingCallback::new());
1205 let mut registry = CallbackRegistry::new();
1206 registry.register_all(callback1.clone());
1207 registry.register_all(callback2.clone());
1208
1209 let agent_ctx = AgentCallbackContext::new(ExecutionId::new(), "test");
1210 registry.invoke_before_agent(&agent_ctx);
1211
1212 assert_eq!(callback1.before_agent_count.load(Ordering::SeqCst), 1);
1214 assert_eq!(callback2.before_agent_count.load(Ordering::SeqCst), 1);
1215 }
1216
1217 #[test]
1218 fn test_callback_registry_individual_registration() {
1219 struct SimpleBeforeAgent;
1220 impl BeforeAgentCallback for SimpleBeforeAgent {
1221 fn on_before_agent(&self, _ctx: &AgentCallbackContext) {}
1222 }
1223
1224 let mut registry = CallbackRegistry::new();
1225 registry.on_before_agent(SimpleBeforeAgent);
1226
1227 registry.invoke_before_agent(&AgentCallbackContext::new(ExecutionId::new(), "test"));
1229 }
1230
1231 #[test]
1236 fn test_execution_callbacks_trait() {
1237 fn accept_execution_callbacks<T: ExecutionCallbacks>(_: &T) {}
1239 let noop = NoOpCallbacks;
1240 accept_execution_callbacks(&noop);
1241 }
1242
1243 #[test]
1244 fn test_counting_callback_implements_execution_callbacks() {
1245 fn accept_execution_callbacks<T: ExecutionCallbacks>(_: &T) {}
1247 let counting = CountingCallback::new();
1248 accept_execution_callbacks(&counting);
1249 }
1250}