1use std::sync::Arc;
45
46use adk_core::{CallbackContext, LlmRequest, LlmResponse, Result, Tool};
47use serde_json::Value;
48use tracing::{debug, warn};
49
50use crate::context::PluginContext;
51use crate::enhanced_plugin::EnhancedPlugin;
52use crate::hook_result::{
53 AfterModelCallResult, AfterToolCallResult, BeforeModelCallResult, BeforeToolCallResult,
54};
55use crate::manager::PluginManagerConfig;
56
57pub struct EnhancedPluginManager {
68 plugins: Vec<Arc<dyn EnhancedPlugin>>,
70 context: Arc<PluginContext>,
72 config: PluginManagerConfig,
74}
75
76impl EnhancedPluginManager {
77 pub fn new(mut plugins: Vec<Arc<dyn EnhancedPlugin>>) -> Self {
94 plugins.sort_by_key(|p| p.priority());
95 Self {
96 plugins,
97 context: Arc::new(PluginContext::new()),
98 config: PluginManagerConfig::default(),
99 }
100 }
101
102 pub fn with_config(
104 mut plugins: Vec<Arc<dyn EnhancedPlugin>>,
105 config: PluginManagerConfig,
106 ) -> Self {
107 plugins.sort_by_key(|p| p.priority());
108 Self { plugins, context: Arc::new(PluginContext::new()), config }
109 }
110
111 pub fn add_plugin(&mut self, plugin: Arc<dyn EnhancedPlugin>) {
116 self.plugins.push(plugin);
117 self.plugins.sort_by_key(|p| p.priority());
118 }
119
120 pub fn context(&self) -> &Arc<PluginContext> {
125 &self.context
126 }
127
128 pub fn plugin_count(&self) -> usize {
130 self.plugins.len()
131 }
132
133 pub fn plugin_names(&self) -> Vec<&str> {
135 self.plugins.iter().map(|p| p.name()).collect()
136 }
137
138 pub async fn run_before_tool_call(
157 &self,
158 tool: Arc<dyn Tool>,
159 args: Value,
160 ctx: Arc<dyn CallbackContext>,
161 ) -> Result<BeforeToolCallResult> {
162 let mut current_args = args;
163
164 for plugin in &self.plugins {
165 debug!(plugin = plugin.name(), "running before_tool_call");
166 match plugin
167 .before_tool_call(tool.clone(), current_args, ctx.clone(), &self.context)
168 .await?
169 {
170 BeforeToolCallResult::Continue(modified_args) => {
171 current_args = modified_args;
172 }
173 BeforeToolCallResult::ShortCircuit(result) => {
174 debug!(plugin = plugin.name(), "before_tool_call short-circuited");
175 return Ok(BeforeToolCallResult::ShortCircuit(result));
176 }
177 }
178 }
179
180 Ok(BeforeToolCallResult::Continue(current_args))
181 }
182
183 pub async fn run_after_tool_call(
201 &self,
202 tool: Arc<dyn Tool>,
203 args: &Value,
204 result: Value,
205 ctx: Arc<dyn CallbackContext>,
206 ) -> Result<AfterToolCallResult> {
207 let mut current_result = result;
208
209 for plugin in &self.plugins {
210 debug!(plugin = plugin.name(), "running after_tool_call");
211 match plugin
212 .after_tool_call(tool.clone(), args, current_result, ctx.clone(), &self.context)
213 .await?
214 {
215 AfterToolCallResult::Continue(modified_result) => {
216 current_result = modified_result;
217 }
218 }
219 }
220
221 Ok(AfterToolCallResult::Continue(current_result))
222 }
223
224 pub async fn run_before_model_call(
242 &self,
243 request: LlmRequest,
244 ctx: Arc<dyn CallbackContext>,
245 ) -> Result<BeforeModelCallResult> {
246 let mut current_request = request;
247
248 for plugin in &self.plugins {
249 debug!(plugin = plugin.name(), "running before_model_call");
250 match plugin.before_model_call(current_request, ctx.clone(), &self.context).await? {
251 BeforeModelCallResult::Continue(modified_request) => {
252 current_request = modified_request;
253 }
254 BeforeModelCallResult::ShortCircuit(response) => {
255 debug!(plugin = plugin.name(), "before_model_call short-circuited");
256 return Ok(BeforeModelCallResult::ShortCircuit(response));
257 }
258 }
259 }
260
261 Ok(BeforeModelCallResult::Continue(current_request))
262 }
263
264 pub async fn run_after_model_call(
279 &self,
280 response: LlmResponse,
281 ctx: Arc<dyn CallbackContext>,
282 ) -> Result<AfterModelCallResult> {
283 let mut current_response = response;
284
285 for plugin in &self.plugins {
286 debug!(plugin = plugin.name(), "running after_model_call");
287 match plugin.after_model_call(current_response, ctx.clone(), &self.context).await? {
288 AfterModelCallResult::Continue(modified_response) => {
289 current_response = modified_response;
290 }
291 }
292 }
293
294 Ok(AfterModelCallResult::Continue(current_response))
295 }
296
297 pub async fn close(&self) {
302 debug!("closing {} enhanced plugins", self.plugins.len());
303
304 for plugin in &self.plugins {
305 let close_future = plugin.close();
306 match tokio::time::timeout(self.config.close_timeout, close_future).await {
307 Ok(()) => {
308 debug!(plugin = plugin.name(), "enhanced plugin closed successfully");
309 }
310 Err(_) => {
311 warn!(plugin = plugin.name(), "enhanced plugin close timed out");
312 }
313 }
314 }
315 }
316}
317
318impl std::fmt::Debug for EnhancedPluginManager {
319 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320 f.debug_struct("EnhancedPluginManager")
321 .field("plugin_count", &self.plugins.len())
322 .field("plugin_names", &self.plugin_names())
323 .field("close_timeout", &self.config.close_timeout)
324 .finish()
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331 use adk_core::Content as AdkContent;
332 use adk_core::{AdkError, LlmRequest, LlmResponse, async_trait};
333 use serde_json::json;
334 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
335
336 struct NoOpPlugin {
340 name: String,
341 priority: i32,
342 }
343
344 impl NoOpPlugin {
345 fn new(name: &str, priority: i32) -> Self {
346 Self { name: name.to_string(), priority }
347 }
348 }
349
350 #[async_trait]
351 impl EnhancedPlugin for NoOpPlugin {
352 fn name(&self) -> &str {
353 &self.name
354 }
355
356 fn priority(&self) -> i32 {
357 self.priority
358 }
359 }
360
361 struct ArgModifierPlugin {
363 name: String,
364 priority: i32,
365 key: String,
366 value: Value,
367 }
368
369 #[async_trait]
370 impl EnhancedPlugin for ArgModifierPlugin {
371 fn name(&self) -> &str {
372 &self.name
373 }
374
375 fn priority(&self) -> i32 {
376 self.priority
377 }
378
379 async fn before_tool_call(
380 &self,
381 _tool: Arc<dyn Tool>,
382 args: Value,
383 _ctx: Arc<dyn CallbackContext>,
384 _plugin_ctx: &PluginContext,
385 ) -> Result<BeforeToolCallResult> {
386 let mut modified = args;
387 if let Value::Object(ref mut map) = modified {
388 map.insert(self.key.clone(), self.value.clone());
389 }
390 Ok(BeforeToolCallResult::Continue(modified))
391 }
392 }
393
394 struct ResultModifierPlugin {
396 name: String,
397 priority: i32,
398 key: String,
399 value: Value,
400 }
401
402 #[async_trait]
403 impl EnhancedPlugin for ResultModifierPlugin {
404 fn name(&self) -> &str {
405 &self.name
406 }
407
408 fn priority(&self) -> i32 {
409 self.priority
410 }
411
412 async fn after_tool_call(
413 &self,
414 _tool: Arc<dyn Tool>,
415 _args: &Value,
416 result: Value,
417 _ctx: Arc<dyn CallbackContext>,
418 _plugin_ctx: &PluginContext,
419 ) -> Result<AfterToolCallResult> {
420 let mut modified = result;
421 if let Value::Object(ref mut map) = modified {
422 map.insert(self.key.clone(), self.value.clone());
423 }
424 Ok(AfterToolCallResult::Continue(modified))
425 }
426 }
427
428 struct ShortCircuitPlugin {
430 name: String,
431 priority: i32,
432 result: Value,
433 }
434
435 #[async_trait]
436 impl EnhancedPlugin for ShortCircuitPlugin {
437 fn name(&self) -> &str {
438 &self.name
439 }
440
441 fn priority(&self) -> i32 {
442 self.priority
443 }
444
445 async fn before_tool_call(
446 &self,
447 _tool: Arc<dyn Tool>,
448 _args: Value,
449 _ctx: Arc<dyn CallbackContext>,
450 _plugin_ctx: &PluginContext,
451 ) -> Result<BeforeToolCallResult> {
452 Ok(BeforeToolCallResult::ShortCircuit(self.result.clone()))
453 }
454 }
455
456 struct ErrorPlugin {
458 name: String,
459 priority: i32,
460 }
461
462 #[async_trait]
463 impl EnhancedPlugin for ErrorPlugin {
464 fn name(&self) -> &str {
465 &self.name
466 }
467
468 fn priority(&self) -> i32 {
469 self.priority
470 }
471
472 async fn before_tool_call(
473 &self,
474 _tool: Arc<dyn Tool>,
475 _args: Value,
476 _ctx: Arc<dyn CallbackContext>,
477 _plugin_ctx: &PluginContext,
478 ) -> Result<BeforeToolCallResult> {
479 Err(AdkError::agent("test error from plugin"))
480 }
481
482 async fn after_tool_call(
483 &self,
484 _tool: Arc<dyn Tool>,
485 _args: &Value,
486 _result: Value,
487 _ctx: Arc<dyn CallbackContext>,
488 _plugin_ctx: &PluginContext,
489 ) -> Result<AfterToolCallResult> {
490 Err(AdkError::agent("test error from after_tool"))
491 }
492
493 async fn before_model_call(
494 &self,
495 _request: LlmRequest,
496 _ctx: Arc<dyn CallbackContext>,
497 _plugin_ctx: &PluginContext,
498 ) -> Result<BeforeModelCallResult> {
499 Err(AdkError::agent("test error from before_model"))
500 }
501
502 async fn after_model_call(
503 &self,
504 _response: LlmResponse,
505 _ctx: Arc<dyn CallbackContext>,
506 _plugin_ctx: &PluginContext,
507 ) -> Result<AfterModelCallResult> {
508 Err(AdkError::agent("test error from after_model"))
509 }
510 }
511
512 struct TrackingPlugin {
514 name: String,
515 priority: i32,
516 before_tool_called: AtomicBool,
517 after_tool_called: AtomicBool,
518 before_model_called: AtomicBool,
519 after_model_called: AtomicBool,
520 }
521
522 impl TrackingPlugin {
523 fn new(name: &str, priority: i32) -> Self {
524 Self {
525 name: name.to_string(),
526 priority,
527 before_tool_called: AtomicBool::new(false),
528 after_tool_called: AtomicBool::new(false),
529 before_model_called: AtomicBool::new(false),
530 after_model_called: AtomicBool::new(false),
531 }
532 }
533 }
534
535 #[async_trait]
536 impl EnhancedPlugin for TrackingPlugin {
537 fn name(&self) -> &str {
538 &self.name
539 }
540
541 fn priority(&self) -> i32 {
542 self.priority
543 }
544
545 async fn before_tool_call(
546 &self,
547 _tool: Arc<dyn Tool>,
548 args: Value,
549 _ctx: Arc<dyn CallbackContext>,
550 _plugin_ctx: &PluginContext,
551 ) -> Result<BeforeToolCallResult> {
552 self.before_tool_called.store(true, Ordering::SeqCst);
553 Ok(BeforeToolCallResult::Continue(args))
554 }
555
556 async fn after_tool_call(
557 &self,
558 _tool: Arc<dyn Tool>,
559 _args: &Value,
560 result: Value,
561 _ctx: Arc<dyn CallbackContext>,
562 _plugin_ctx: &PluginContext,
563 ) -> Result<AfterToolCallResult> {
564 self.after_tool_called.store(true, Ordering::SeqCst);
565 Ok(AfterToolCallResult::Continue(result))
566 }
567
568 async fn before_model_call(
569 &self,
570 request: LlmRequest,
571 _ctx: Arc<dyn CallbackContext>,
572 _plugin_ctx: &PluginContext,
573 ) -> Result<BeforeModelCallResult> {
574 self.before_model_called.store(true, Ordering::SeqCst);
575 Ok(BeforeModelCallResult::Continue(request))
576 }
577
578 async fn after_model_call(
579 &self,
580 response: LlmResponse,
581 _ctx: Arc<dyn CallbackContext>,
582 _plugin_ctx: &PluginContext,
583 ) -> Result<AfterModelCallResult> {
584 self.after_model_called.store(true, Ordering::SeqCst);
585 Ok(AfterModelCallResult::Continue(response))
586 }
587 }
588
589 struct ModelShortCircuitPlugin {
591 name: String,
592 priority: i32,
593 }
594
595 #[async_trait]
596 impl EnhancedPlugin for ModelShortCircuitPlugin {
597 fn name(&self) -> &str {
598 &self.name
599 }
600
601 fn priority(&self) -> i32 {
602 self.priority
603 }
604
605 async fn before_model_call(
606 &self,
607 _request: LlmRequest,
608 _ctx: Arc<dyn CallbackContext>,
609 _plugin_ctx: &PluginContext,
610 ) -> Result<BeforeModelCallResult> {
611 Ok(BeforeModelCallResult::ShortCircuit(LlmResponse::default()))
612 }
613 }
614
615 struct OrderTrackingPlugin {
617 name: String,
618 priority: i32,
619 order_counter: Arc<AtomicUsize>,
620 recorded_order: AtomicUsize,
621 }
622
623 impl OrderTrackingPlugin {
624 fn new(name: &str, priority: i32, counter: Arc<AtomicUsize>) -> Self {
625 Self {
626 name: name.to_string(),
627 priority,
628 order_counter: counter,
629 recorded_order: AtomicUsize::new(0),
630 }
631 }
632
633 fn execution_order(&self) -> usize {
634 self.recorded_order.load(Ordering::SeqCst)
635 }
636 }
637
638 #[async_trait]
639 impl EnhancedPlugin for OrderTrackingPlugin {
640 fn name(&self) -> &str {
641 &self.name
642 }
643
644 fn priority(&self) -> i32 {
645 self.priority
646 }
647
648 async fn before_tool_call(
649 &self,
650 _tool: Arc<dyn Tool>,
651 args: Value,
652 _ctx: Arc<dyn CallbackContext>,
653 _plugin_ctx: &PluginContext,
654 ) -> Result<BeforeToolCallResult> {
655 let order = self.order_counter.fetch_add(1, Ordering::SeqCst);
656 self.recorded_order.store(order, Ordering::SeqCst);
657 Ok(BeforeToolCallResult::Continue(args))
658 }
659 }
660
661 struct MockTool;
664
665 #[async_trait]
666 impl Tool for MockTool {
667 fn name(&self) -> &str {
668 "mock_tool"
669 }
670
671 fn description(&self) -> &str {
672 "A mock tool for testing"
673 }
674
675 async fn execute(
676 &self,
677 _ctx: Arc<dyn adk_core::ToolContext>,
678 _args: Value,
679 ) -> Result<Value> {
680 Ok(json!({"result": "mock"}))
681 }
682 }
683
684 struct MockCallbackContext {
685 content: AdkContent,
686 }
687
688 impl MockCallbackContext {
689 fn new() -> Self {
690 Self { content: AdkContent::new("user") }
691 }
692 }
693
694 impl adk_core::ReadonlyContext for MockCallbackContext {
695 fn invocation_id(&self) -> &str {
696 "test-invocation"
697 }
698
699 fn agent_name(&self) -> &str {
700 "test-agent"
701 }
702
703 fn user_id(&self) -> &str {
704 "test-user"
705 }
706
707 fn app_name(&self) -> &str {
708 "test-app"
709 }
710
711 fn session_id(&self) -> &str {
712 "test-session"
713 }
714
715 fn branch(&self) -> &str {
716 ""
717 }
718
719 fn user_content(&self) -> &AdkContent {
720 &self.content
721 }
722 }
723
724 #[async_trait]
725 impl CallbackContext for MockCallbackContext {
726 fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
727 None
728 }
729
730 fn tool_name(&self) -> Option<&str> {
731 Some("mock_tool")
732 }
733 }
734
735 fn mock_tool() -> Arc<dyn Tool> {
736 Arc::new(MockTool)
737 }
738
739 fn mock_ctx() -> Arc<dyn CallbackContext> {
740 Arc::new(MockCallbackContext::new())
741 }
742
743 fn mock_request() -> LlmRequest {
744 LlmRequest::new("test-model", vec![])
745 }
746
747 #[test]
750 fn test_new_sorts_by_priority() {
751 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
752 Arc::new(NoOpPlugin::new("c", 100)),
753 Arc::new(NoOpPlugin::new("a", 10)),
754 Arc::new(NoOpPlugin::new("b", 50)),
755 ];
756
757 let manager = EnhancedPluginManager::new(plugins);
758 assert_eq!(manager.plugin_names(), vec!["a", "b", "c"]);
759 }
760
761 #[test]
762 fn test_stable_sort_preserves_registration_order() {
763 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
764 Arc::new(NoOpPlugin::new("first", 100)),
765 Arc::new(NoOpPlugin::new("second", 100)),
766 Arc::new(NoOpPlugin::new("third", 100)),
767 ];
768
769 let manager = EnhancedPluginManager::new(plugins);
770 assert_eq!(manager.plugin_names(), vec!["first", "second", "third"]);
771 }
772
773 #[test]
774 fn test_add_plugin_resorts() {
775 let plugins: Vec<Arc<dyn EnhancedPlugin>> =
776 vec![Arc::new(NoOpPlugin::new("b", 50)), Arc::new(NoOpPlugin::new("c", 100))];
777
778 let mut manager = EnhancedPluginManager::new(plugins);
779 manager.add_plugin(Arc::new(NoOpPlugin::new("a", 10)));
780
781 assert_eq!(manager.plugin_names(), vec!["a", "b", "c"]);
782 assert_eq!(manager.plugin_count(), 3);
783 }
784
785 #[test]
786 fn test_context_accessor() {
787 let manager = EnhancedPluginManager::new(vec![]);
788 let ctx = manager.context();
789 assert!(Arc::strong_count(ctx) >= 1);
791 }
792
793 #[test]
794 fn test_empty_manager() {
795 let manager = EnhancedPluginManager::new(vec![]);
796 assert_eq!(manager.plugin_count(), 0);
797 assert!(manager.plugin_names().is_empty());
798 }
799
800 #[tokio::test]
801 async fn test_before_tool_call_pipeline_propagation() {
802 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
803 Arc::new(ArgModifierPlugin {
804 name: "plugin1".to_string(),
805 priority: 10,
806 key: "added_by_1".to_string(),
807 value: json!(true),
808 }),
809 Arc::new(ArgModifierPlugin {
810 name: "plugin2".to_string(),
811 priority: 20,
812 key: "added_by_2".to_string(),
813 value: json!("hello"),
814 }),
815 ];
816
817 let manager = EnhancedPluginManager::new(plugins);
818 let result = manager
819 .run_before_tool_call(mock_tool(), json!({"original": "value"}), mock_ctx())
820 .await
821 .unwrap();
822
823 match result {
824 BeforeToolCallResult::Continue(args) => {
825 assert_eq!(args["original"], "value");
826 assert_eq!(args["added_by_1"], true);
827 assert_eq!(args["added_by_2"], "hello");
828 }
829 BeforeToolCallResult::ShortCircuit(_) => panic!("expected Continue"),
830 }
831 }
832
833 #[tokio::test]
834 async fn test_before_tool_call_short_circuit() {
835 let tracking = Arc::new(TrackingPlugin::new("after_short_circuit", 50));
836 let tracking_clone = tracking.clone();
837
838 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
839 Arc::new(ShortCircuitPlugin {
840 name: "short_circuit".to_string(),
841 priority: 10,
842 result: json!({"cached": true}),
843 }),
844 tracking_clone,
845 ];
846
847 let manager = EnhancedPluginManager::new(plugins);
848 let result =
849 manager.run_before_tool_call(mock_tool(), json!({}), mock_ctx()).await.unwrap();
850
851 match result {
852 BeforeToolCallResult::ShortCircuit(value) => {
853 assert_eq!(value, json!({"cached": true}));
854 }
855 BeforeToolCallResult::Continue(_) => panic!("expected ShortCircuit"),
856 }
857
858 assert!(!tracking.before_tool_called.load(Ordering::SeqCst));
860 }
861
862 #[tokio::test]
863 async fn test_before_tool_call_error_propagation() {
864 let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
865 let tracking_clone = tracking.clone();
866
867 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
868 Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
869 tracking_clone,
870 ];
871
872 let manager = EnhancedPluginManager::new(plugins);
873 let result = manager.run_before_tool_call(mock_tool(), json!({}), mock_ctx()).await;
874
875 assert!(result.is_err());
876 assert!(!tracking.before_tool_called.load(Ordering::SeqCst));
878 }
879
880 #[tokio::test]
881 async fn test_after_tool_call_pipeline_propagation() {
882 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
883 Arc::new(ResultModifierPlugin {
884 name: "plugin1".to_string(),
885 priority: 10,
886 key: "enriched_by_1".to_string(),
887 value: json!(true),
888 }),
889 Arc::new(ResultModifierPlugin {
890 name: "plugin2".to_string(),
891 priority: 20,
892 key: "enriched_by_2".to_string(),
893 value: json!(42),
894 }),
895 ];
896
897 let manager = EnhancedPluginManager::new(plugins);
898 let args = json!({"tool_arg": "test"});
899 let result = manager
900 .run_after_tool_call(mock_tool(), &args, json!({"status": "ok"}), mock_ctx())
901 .await
902 .unwrap();
903
904 match result {
905 AfterToolCallResult::Continue(value) => {
906 assert_eq!(value["status"], "ok");
907 assert_eq!(value["enriched_by_1"], true);
908 assert_eq!(value["enriched_by_2"], 42);
909 }
910 }
911 }
912
913 #[tokio::test]
914 async fn test_after_tool_call_error_propagation() {
915 let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
916 let tracking_clone = tracking.clone();
917
918 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
919 Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
920 tracking_clone,
921 ];
922
923 let manager = EnhancedPluginManager::new(plugins);
924 let result =
925 manager.run_after_tool_call(mock_tool(), &json!({}), json!({}), mock_ctx()).await;
926
927 assert!(result.is_err());
928 assert!(!tracking.after_tool_called.load(Ordering::SeqCst));
929 }
930
931 #[tokio::test]
932 async fn test_before_model_call_pipeline_propagation() {
933 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
935 Arc::new(NoOpPlugin::new("plugin1", 10)),
936 Arc::new(NoOpPlugin::new("plugin2", 20)),
937 ];
938
939 let manager = EnhancedPluginManager::new(plugins);
940 let request = mock_request();
941 let result = manager.run_before_model_call(request, mock_ctx()).await.unwrap();
942
943 match result {
944 BeforeModelCallResult::Continue(_) => { }
945 BeforeModelCallResult::ShortCircuit(_) => panic!("expected Continue"),
946 }
947 }
948
949 #[tokio::test]
950 async fn test_before_model_call_short_circuit() {
951 let tracking = Arc::new(TrackingPlugin::new("after_short_circuit", 50));
952 let tracking_clone = tracking.clone();
953
954 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
955 Arc::new(ModelShortCircuitPlugin {
956 name: "model_short_circuit".to_string(),
957 priority: 10,
958 }),
959 tracking_clone,
960 ];
961
962 let manager = EnhancedPluginManager::new(plugins);
963 let result = manager.run_before_model_call(mock_request(), mock_ctx()).await.unwrap();
964
965 match result {
966 BeforeModelCallResult::ShortCircuit(_) => { }
967 BeforeModelCallResult::Continue(_) => panic!("expected ShortCircuit"),
968 }
969
970 assert!(!tracking.before_model_called.load(Ordering::SeqCst));
971 }
972
973 #[tokio::test]
974 async fn test_before_model_call_error_propagation() {
975 let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
976 let tracking_clone = tracking.clone();
977
978 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
979 Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
980 tracking_clone,
981 ];
982
983 let manager = EnhancedPluginManager::new(plugins);
984 let result = manager.run_before_model_call(mock_request(), mock_ctx()).await;
985
986 assert!(result.is_err());
987 assert!(!tracking.before_model_called.load(Ordering::SeqCst));
988 }
989
990 #[tokio::test]
991 async fn test_after_model_call_pipeline_propagation() {
992 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
993 Arc::new(NoOpPlugin::new("plugin1", 10)),
994 Arc::new(NoOpPlugin::new("plugin2", 20)),
995 ];
996
997 let manager = EnhancedPluginManager::new(plugins);
998 let result =
999 manager.run_after_model_call(LlmResponse::default(), mock_ctx()).await.unwrap();
1000
1001 match result {
1002 AfterModelCallResult::Continue(_) => { }
1003 }
1004 }
1005
1006 #[tokio::test]
1007 async fn test_after_model_call_error_propagation() {
1008 let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
1009 let tracking_clone = tracking.clone();
1010
1011 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
1012 Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
1013 tracking_clone,
1014 ];
1015
1016 let manager = EnhancedPluginManager::new(plugins);
1017 let result = manager.run_after_model_call(LlmResponse::default(), mock_ctx()).await;
1018
1019 assert!(result.is_err());
1020 assert!(!tracking.after_model_called.load(Ordering::SeqCst));
1021 }
1022
1023 #[tokio::test]
1024 async fn test_empty_plugin_list_before_tool_call() {
1025 let manager = EnhancedPluginManager::new(vec![]);
1026 let result = manager
1027 .run_before_tool_call(mock_tool(), json!({"key": "value"}), mock_ctx())
1028 .await
1029 .unwrap();
1030
1031 match result {
1032 BeforeToolCallResult::Continue(args) => {
1033 assert_eq!(args, json!({"key": "value"}));
1034 }
1035 BeforeToolCallResult::ShortCircuit(_) => panic!("expected Continue"),
1036 }
1037 }
1038
1039 #[tokio::test]
1040 async fn test_empty_plugin_list_after_tool_call() {
1041 let manager = EnhancedPluginManager::new(vec![]);
1042 let result = manager
1043 .run_after_tool_call(mock_tool(), &json!({}), json!({"result": 42}), mock_ctx())
1044 .await
1045 .unwrap();
1046
1047 match result {
1048 AfterToolCallResult::Continue(value) => {
1049 assert_eq!(value, json!({"result": 42}));
1050 }
1051 }
1052 }
1053
1054 #[tokio::test]
1055 async fn test_empty_plugin_list_before_model_call() {
1056 let manager = EnhancedPluginManager::new(vec![]);
1057 let request = mock_request();
1058 let result = manager.run_before_model_call(request, mock_ctx()).await.unwrap();
1059
1060 match result {
1061 BeforeModelCallResult::Continue(_) => { }
1062 BeforeModelCallResult::ShortCircuit(_) => panic!("expected Continue"),
1063 }
1064 }
1065
1066 #[tokio::test]
1067 async fn test_empty_plugin_list_after_model_call() {
1068 let manager = EnhancedPluginManager::new(vec![]);
1069 let result =
1070 manager.run_after_model_call(LlmResponse::default(), mock_ctx()).await.unwrap();
1071
1072 match result {
1073 AfterModelCallResult::Continue(_) => { }
1074 }
1075 }
1076
1077 #[tokio::test]
1078 async fn test_priority_ordering_execution() {
1079 let counter = Arc::new(AtomicUsize::new(0));
1080
1081 let p1 = Arc::new(OrderTrackingPlugin::new("high_priority", 10, counter.clone()));
1082 let p2 = Arc::new(OrderTrackingPlugin::new("medium_priority", 50, counter.clone()));
1083 let p3 = Arc::new(OrderTrackingPlugin::new("low_priority", 100, counter.clone()));
1084
1085 let p1_clone = p1.clone();
1086 let p2_clone = p2.clone();
1087 let p3_clone = p3.clone();
1088
1089 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![p3_clone, p1_clone, p2_clone];
1090
1091 let manager = EnhancedPluginManager::new(plugins);
1092 manager.run_before_tool_call(mock_tool(), json!({}), mock_ctx()).await.unwrap();
1093
1094 assert_eq!(p1.execution_order(), 0);
1096 assert_eq!(p2.execution_order(), 1);
1097 assert_eq!(p3.execution_order(), 2);
1098 }
1099
1100 #[tokio::test]
1101 async fn test_close_calls_all_plugins() {
1102 let closed = Arc::new(AtomicUsize::new(0));
1103
1104 struct CloseTrackingPlugin {
1105 name: String,
1106 closed: Arc<AtomicUsize>,
1107 }
1108
1109 #[async_trait]
1110 impl EnhancedPlugin for CloseTrackingPlugin {
1111 fn name(&self) -> &str {
1112 &self.name
1113 }
1114
1115 async fn close(&self) {
1116 self.closed.fetch_add(1, Ordering::SeqCst);
1117 }
1118 }
1119
1120 let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
1121 Arc::new(CloseTrackingPlugin { name: "p1".to_string(), closed: closed.clone() }),
1122 Arc::new(CloseTrackingPlugin { name: "p2".to_string(), closed: closed.clone() }),
1123 Arc::new(CloseTrackingPlugin { name: "p3".to_string(), closed: closed.clone() }),
1124 ];
1125
1126 let manager = EnhancedPluginManager::new(plugins);
1127 manager.close().await;
1128
1129 assert_eq!(closed.load(Ordering::SeqCst), 3);
1130 }
1131
1132 #[tokio::test]
1133 async fn test_debug_impl() {
1134 let plugins: Vec<Arc<dyn EnhancedPlugin>> =
1135 vec![Arc::new(NoOpPlugin::new("alpha", 10)), Arc::new(NoOpPlugin::new("beta", 20))];
1136
1137 let manager = EnhancedPluginManager::new(plugins);
1138 let debug_str = format!("{manager:?}");
1139 assert!(debug_str.contains("EnhancedPluginManager"));
1140 assert!(debug_str.contains("plugin_count: 2"));
1141 }
1142}