Skip to main content

adk_plugin/
enhanced_manager.rs

1//! Enhanced Plugin Manager with priority-based pipeline execution.
2//!
3//! [`EnhancedPluginManager`] manages a collection of [`EnhancedPlugin`] instances,
4//! executing their hooks in priority order with pipeline semantics where each
5//! plugin's output feeds the next plugin's input.
6//!
7//! # Overview
8//!
9//! The manager provides four pipeline methods:
10//!
11//! - [`run_before_tool_call`](EnhancedPluginManager::run_before_tool_call) — intercept tool calls before execution
12//! - [`run_after_tool_call`](EnhancedPluginManager::run_after_tool_call) — transform tool results after execution
13//! - [`run_before_model_call`](EnhancedPluginManager::run_before_model_call) — intercept model calls before execution
14//! - [`run_after_model_call`](EnhancedPluginManager::run_after_model_call) — transform model responses after execution
15//!
16//! # Pipeline Semantics
17//!
18//! - **Continue**: The (possibly modified) value is passed to the next plugin in the chain.
19//! - **ShortCircuit** (before-hooks only): Stops the pipeline immediately, skips the
20//!   underlying operation, and returns the synthetic result.
21//! - **Error**: Stops the pipeline immediately and propagates the error to the caller.
22//!
23//! # Priority Ordering
24//!
25//! Plugins execute in ascending priority order (lower values run first).
26//! Plugins with the same priority execute in registration order (stable sort).
27//!
28//! # Examples
29//!
30//! ```rust,ignore
31//! use std::sync::Arc;
32//! use adk_plugin::{EnhancedPluginManager, EnhancedPlugin};
33//!
34//! let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
35//!     Arc::new(SecurityPlugin),   // priority = 10
36//!     Arc::new(CachePlugin),      // priority = 50
37//!     Arc::new(LoggingPlugin),    // priority = 100
38//! ];
39//!
40//! let manager = EnhancedPluginManager::new(plugins);
41//! // Plugins will execute in order: Security → Cache → Logging
42//! ```
43
44use std::sync::Arc;
45
46use adk_core::{CallbackContext, LlmRequest, LlmResponse, Result, Tool};
47use serde_json::Value;
48use tracing::{debug, warn};
49
50use crate::context::PluginContext;
51use crate::enhanced_plugin::EnhancedPlugin;
52use crate::hook_result::{
53    AfterModelCallResult, AfterToolCallResult, BeforeModelCallResult, BeforeToolCallResult,
54};
55use crate::manager::PluginManagerConfig;
56
57/// Manages enhanced plugins with priority-based pipeline execution.
58///
59/// Plugins are stored sorted by priority (ascending). All pipeline methods
60/// iterate plugins in this order, passing each plugin's output as input to
61/// the next plugin in the chain.
62///
63/// # Thread Safety
64///
65/// `EnhancedPluginManager` is `Send + Sync` and can be shared across async tasks
66/// via `Arc<EnhancedPluginManager>`.
67pub struct EnhancedPluginManager {
68    /// Plugins sorted by priority (ascending). Same-priority preserves registration order.
69    plugins: Vec<Arc<dyn EnhancedPlugin>>,
70    /// Shared plugin context for the lifetime of this manager.
71    context: Arc<PluginContext>,
72    /// Configuration for the manager (e.g., close timeout).
73    config: PluginManagerConfig,
74}
75
76impl EnhancedPluginManager {
77    /// Creates a new `EnhancedPluginManager` with the given plugins.
78    ///
79    /// Plugins are sorted by priority in ascending order using a stable sort,
80    /// so plugins with the same priority retain their registration order.
81    ///
82    /// # Examples
83    ///
84    /// ```rust,ignore
85    /// use std::sync::Arc;
86    /// use adk_plugin::{EnhancedPluginManager, EnhancedPlugin};
87    ///
88    /// let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
89    ///     Arc::new(MyPlugin),
90    /// ];
91    /// let manager = EnhancedPluginManager::new(plugins);
92    /// ```
93    pub fn new(mut plugins: Vec<Arc<dyn EnhancedPlugin>>) -> Self {
94        plugins.sort_by_key(|p| p.priority());
95        Self {
96            plugins,
97            context: Arc::new(PluginContext::new()),
98            config: PluginManagerConfig::default(),
99        }
100    }
101
102    /// Creates a new `EnhancedPluginManager` with custom configuration.
103    pub fn with_config(
104        mut plugins: Vec<Arc<dyn EnhancedPlugin>>,
105        config: PluginManagerConfig,
106    ) -> Self {
107        plugins.sort_by_key(|p| p.priority());
108        Self { plugins, context: Arc::new(PluginContext::new()), config }
109    }
110
111    /// Adds a plugin after construction, re-sorting by priority.
112    ///
113    /// The plugin is inserted and the entire list is re-sorted using a stable sort
114    /// to maintain registration order for same-priority plugins.
115    pub fn add_plugin(&mut self, plugin: Arc<dyn EnhancedPlugin>) {
116        self.plugins.push(plugin);
117        self.plugins.sort_by_key(|p| p.priority());
118    }
119
120    /// Returns a reference to the shared plugin context.
121    ///
122    /// The context is shared across all hook invocations and persists for the
123    /// lifetime of this manager.
124    pub fn context(&self) -> &Arc<PluginContext> {
125        &self.context
126    }
127
128    /// Returns the number of registered plugins.
129    pub fn plugin_count(&self) -> usize {
130        self.plugins.len()
131    }
132
133    /// Returns the names of all registered plugins in execution order.
134    pub fn plugin_names(&self) -> Vec<&str> {
135        self.plugins.iter().map(|p| p.name()).collect()
136    }
137
138    /// Executes the `before_tool_call` pipeline across all plugins in priority order.
139    ///
140    /// Each plugin receives the (possibly modified) arguments from the previous plugin.
141    /// If a plugin returns `ShortCircuit`, the pipeline stops and the synthetic result
142    /// is returned. If a plugin returns an error, the pipeline stops and the error
143    /// is propagated.
144    ///
145    /// # Arguments
146    ///
147    /// * `tool` - The tool about to be executed
148    /// * `args` - The initial tool call arguments
149    /// * `ctx` - The callback context for the current invocation
150    ///
151    /// # Returns
152    ///
153    /// - `Ok(BeforeToolCallResult::Continue(args))` — final modified arguments for tool execution
154    /// - `Ok(BeforeToolCallResult::ShortCircuit(result))` — synthetic result, skip tool execution
155    /// - `Err(e)` — pipeline error, skip tool execution
156    pub async fn run_before_tool_call(
157        &self,
158        tool: Arc<dyn Tool>,
159        args: Value,
160        ctx: Arc<dyn CallbackContext>,
161    ) -> Result<BeforeToolCallResult> {
162        let mut current_args = args;
163
164        for plugin in &self.plugins {
165            debug!(plugin = plugin.name(), "running before_tool_call");
166            match plugin
167                .before_tool_call(tool.clone(), current_args, ctx.clone(), &self.context)
168                .await?
169            {
170                BeforeToolCallResult::Continue(modified_args) => {
171                    current_args = modified_args;
172                }
173                BeforeToolCallResult::ShortCircuit(result) => {
174                    debug!(plugin = plugin.name(), "before_tool_call short-circuited");
175                    return Ok(BeforeToolCallResult::ShortCircuit(result));
176                }
177            }
178        }
179
180        Ok(BeforeToolCallResult::Continue(current_args))
181    }
182
183    /// Executes the `after_tool_call` pipeline across all plugins in priority order.
184    ///
185    /// Each plugin receives the (possibly modified) result from the previous plugin.
186    /// The `args` parameter contains the final modified arguments from the before-hook
187    /// pipeline (not the original arguments).
188    ///
189    /// # Arguments
190    ///
191    /// * `tool` - The tool that was executed
192    /// * `args` - The final arguments used for tool execution (after before-hook modifications)
193    /// * `result` - The initial tool execution result
194    /// * `ctx` - The callback context for the current invocation
195    ///
196    /// # Returns
197    ///
198    /// - `Ok(AfterToolCallResult::Continue(result))` — final modified result
199    /// - `Err(e)` — pipeline error
200    pub async fn run_after_tool_call(
201        &self,
202        tool: Arc<dyn Tool>,
203        args: &Value,
204        result: Value,
205        ctx: Arc<dyn CallbackContext>,
206    ) -> Result<AfterToolCallResult> {
207        let mut current_result = result;
208
209        for plugin in &self.plugins {
210            debug!(plugin = plugin.name(), "running after_tool_call");
211            match plugin
212                .after_tool_call(tool.clone(), args, current_result, ctx.clone(), &self.context)
213                .await?
214            {
215                AfterToolCallResult::Continue(modified_result) => {
216                    current_result = modified_result;
217                }
218            }
219        }
220
221        Ok(AfterToolCallResult::Continue(current_result))
222    }
223
224    /// Executes the `before_model_call` pipeline across all plugins in priority order.
225    ///
226    /// Each plugin receives the (possibly modified) request from the previous plugin.
227    /// If a plugin returns `ShortCircuit`, the pipeline stops and the synthetic response
228    /// is returned. If a plugin returns an error, the pipeline stops and the error
229    /// is propagated.
230    ///
231    /// # Arguments
232    ///
233    /// * `request` - The initial LLM request
234    /// * `ctx` - The callback context for the current invocation
235    ///
236    /// # Returns
237    ///
238    /// - `Ok(BeforeModelCallResult::Continue(request))` — final modified request for model call
239    /// - `Ok(BeforeModelCallResult::ShortCircuit(response))` — synthetic response, skip model call
240    /// - `Err(e)` — pipeline error, skip model call
241    pub async fn run_before_model_call(
242        &self,
243        request: LlmRequest,
244        ctx: Arc<dyn CallbackContext>,
245    ) -> Result<BeforeModelCallResult> {
246        let mut current_request = request;
247
248        for plugin in &self.plugins {
249            debug!(plugin = plugin.name(), "running before_model_call");
250            match plugin.before_model_call(current_request, ctx.clone(), &self.context).await? {
251                BeforeModelCallResult::Continue(modified_request) => {
252                    current_request = modified_request;
253                }
254                BeforeModelCallResult::ShortCircuit(response) => {
255                    debug!(plugin = plugin.name(), "before_model_call short-circuited");
256                    return Ok(BeforeModelCallResult::ShortCircuit(response));
257                }
258            }
259        }
260
261        Ok(BeforeModelCallResult::Continue(current_request))
262    }
263
264    /// Executes the `after_model_call` pipeline across all plugins in priority order.
265    ///
266    /// Each plugin receives the (possibly modified) response from the previous plugin.
267    /// If a plugin returns an error, the pipeline stops and the error is propagated.
268    ///
269    /// # Arguments
270    ///
271    /// * `response` - The initial LLM response
272    /// * `ctx` - The callback context for the current invocation
273    ///
274    /// # Returns
275    ///
276    /// - `Ok(AfterModelCallResult::Continue(response))` — final modified response
277    /// - `Err(e)` — pipeline error
278    pub async fn run_after_model_call(
279        &self,
280        response: LlmResponse,
281        ctx: Arc<dyn CallbackContext>,
282    ) -> Result<AfterModelCallResult> {
283        let mut current_response = response;
284
285        for plugin in &self.plugins {
286            debug!(plugin = plugin.name(), "running after_model_call");
287            match plugin.after_model_call(current_response, ctx.clone(), &self.context).await? {
288                AfterModelCallResult::Continue(modified_response) => {
289                    current_response = modified_response;
290                }
291            }
292        }
293
294        Ok(AfterModelCallResult::Continue(current_response))
295    }
296
297    /// Closes all plugins, ignoring individual close errors.
298    ///
299    /// Each plugin's `close()` method is called in sequence. Errors during
300    /// close are logged but do not prevent other plugins from being closed.
301    pub async fn close(&self) {
302        debug!("closing {} enhanced plugins", self.plugins.len());
303
304        for plugin in &self.plugins {
305            let close_future = plugin.close();
306            match tokio::time::timeout(self.config.close_timeout, close_future).await {
307                Ok(()) => {
308                    debug!(plugin = plugin.name(), "enhanced plugin closed successfully");
309                }
310                Err(_) => {
311                    warn!(plugin = plugin.name(), "enhanced plugin close timed out");
312                }
313            }
314        }
315    }
316}
317
318impl std::fmt::Debug for EnhancedPluginManager {
319    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
320        f.debug_struct("EnhancedPluginManager")
321            .field("plugin_count", &self.plugins.len())
322            .field("plugin_names", &self.plugin_names())
323            .field("close_timeout", &self.config.close_timeout)
324            .finish()
325    }
326}
327
328#[cfg(test)]
329mod tests {
330    use super::*;
331    use adk_core::Content as AdkContent;
332    use adk_core::{AdkError, LlmRequest, LlmResponse, async_trait};
333    use serde_json::json;
334    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
335
336    // --- Test helpers ---
337
338    /// A simple plugin that passes through unchanged.
339    struct NoOpPlugin {
340        name: String,
341        priority: i32,
342    }
343
344    impl NoOpPlugin {
345        fn new(name: &str, priority: i32) -> Self {
346            Self { name: name.to_string(), priority }
347        }
348    }
349
350    #[async_trait]
351    impl EnhancedPlugin for NoOpPlugin {
352        fn name(&self) -> &str {
353            &self.name
354        }
355
356        fn priority(&self) -> i32 {
357            self.priority
358        }
359    }
360
361    /// A plugin that appends a field to tool call arguments.
362    struct ArgModifierPlugin {
363        name: String,
364        priority: i32,
365        key: String,
366        value: Value,
367    }
368
369    #[async_trait]
370    impl EnhancedPlugin for ArgModifierPlugin {
371        fn name(&self) -> &str {
372            &self.name
373        }
374
375        fn priority(&self) -> i32 {
376            self.priority
377        }
378
379        async fn before_tool_call(
380            &self,
381            _tool: Arc<dyn Tool>,
382            args: Value,
383            _ctx: Arc<dyn CallbackContext>,
384            _plugin_ctx: &PluginContext,
385        ) -> Result<BeforeToolCallResult> {
386            let mut modified = args;
387            if let Value::Object(ref mut map) = modified {
388                map.insert(self.key.clone(), self.value.clone());
389            }
390            Ok(BeforeToolCallResult::Continue(modified))
391        }
392    }
393
394    /// A plugin that modifies tool results.
395    struct ResultModifierPlugin {
396        name: String,
397        priority: i32,
398        key: String,
399        value: Value,
400    }
401
402    #[async_trait]
403    impl EnhancedPlugin for ResultModifierPlugin {
404        fn name(&self) -> &str {
405            &self.name
406        }
407
408        fn priority(&self) -> i32 {
409            self.priority
410        }
411
412        async fn after_tool_call(
413            &self,
414            _tool: Arc<dyn Tool>,
415            _args: &Value,
416            result: Value,
417            _ctx: Arc<dyn CallbackContext>,
418            _plugin_ctx: &PluginContext,
419        ) -> Result<AfterToolCallResult> {
420            let mut modified = result;
421            if let Value::Object(ref mut map) = modified {
422                map.insert(self.key.clone(), self.value.clone());
423            }
424            Ok(AfterToolCallResult::Continue(modified))
425        }
426    }
427
428    /// A plugin that short-circuits before_tool_call.
429    struct ShortCircuitPlugin {
430        name: String,
431        priority: i32,
432        result: Value,
433    }
434
435    #[async_trait]
436    impl EnhancedPlugin for ShortCircuitPlugin {
437        fn name(&self) -> &str {
438            &self.name
439        }
440
441        fn priority(&self) -> i32 {
442            self.priority
443        }
444
445        async fn before_tool_call(
446            &self,
447            _tool: Arc<dyn Tool>,
448            _args: Value,
449            _ctx: Arc<dyn CallbackContext>,
450            _plugin_ctx: &PluginContext,
451        ) -> Result<BeforeToolCallResult> {
452            Ok(BeforeToolCallResult::ShortCircuit(self.result.clone()))
453        }
454    }
455
456    /// A plugin that returns an error from before_tool_call.
457    struct ErrorPlugin {
458        name: String,
459        priority: i32,
460    }
461
462    #[async_trait]
463    impl EnhancedPlugin for ErrorPlugin {
464        fn name(&self) -> &str {
465            &self.name
466        }
467
468        fn priority(&self) -> i32 {
469            self.priority
470        }
471
472        async fn before_tool_call(
473            &self,
474            _tool: Arc<dyn Tool>,
475            _args: Value,
476            _ctx: Arc<dyn CallbackContext>,
477            _plugin_ctx: &PluginContext,
478        ) -> Result<BeforeToolCallResult> {
479            Err(AdkError::agent("test error from plugin"))
480        }
481
482        async fn after_tool_call(
483            &self,
484            _tool: Arc<dyn Tool>,
485            _args: &Value,
486            _result: Value,
487            _ctx: Arc<dyn CallbackContext>,
488            _plugin_ctx: &PluginContext,
489        ) -> Result<AfterToolCallResult> {
490            Err(AdkError::agent("test error from after_tool"))
491        }
492
493        async fn before_model_call(
494            &self,
495            _request: LlmRequest,
496            _ctx: Arc<dyn CallbackContext>,
497            _plugin_ctx: &PluginContext,
498        ) -> Result<BeforeModelCallResult> {
499            Err(AdkError::agent("test error from before_model"))
500        }
501
502        async fn after_model_call(
503            &self,
504            _response: LlmResponse,
505            _ctx: Arc<dyn CallbackContext>,
506            _plugin_ctx: &PluginContext,
507        ) -> Result<AfterModelCallResult> {
508            Err(AdkError::agent("test error from after_model"))
509        }
510    }
511
512    /// A plugin that tracks whether it was called.
513    struct TrackingPlugin {
514        name: String,
515        priority: i32,
516        before_tool_called: AtomicBool,
517        after_tool_called: AtomicBool,
518        before_model_called: AtomicBool,
519        after_model_called: AtomicBool,
520    }
521
522    impl TrackingPlugin {
523        fn new(name: &str, priority: i32) -> Self {
524            Self {
525                name: name.to_string(),
526                priority,
527                before_tool_called: AtomicBool::new(false),
528                after_tool_called: AtomicBool::new(false),
529                before_model_called: AtomicBool::new(false),
530                after_model_called: AtomicBool::new(false),
531            }
532        }
533    }
534
535    #[async_trait]
536    impl EnhancedPlugin for TrackingPlugin {
537        fn name(&self) -> &str {
538            &self.name
539        }
540
541        fn priority(&self) -> i32 {
542            self.priority
543        }
544
545        async fn before_tool_call(
546            &self,
547            _tool: Arc<dyn Tool>,
548            args: Value,
549            _ctx: Arc<dyn CallbackContext>,
550            _plugin_ctx: &PluginContext,
551        ) -> Result<BeforeToolCallResult> {
552            self.before_tool_called.store(true, Ordering::SeqCst);
553            Ok(BeforeToolCallResult::Continue(args))
554        }
555
556        async fn after_tool_call(
557            &self,
558            _tool: Arc<dyn Tool>,
559            _args: &Value,
560            result: Value,
561            _ctx: Arc<dyn CallbackContext>,
562            _plugin_ctx: &PluginContext,
563        ) -> Result<AfterToolCallResult> {
564            self.after_tool_called.store(true, Ordering::SeqCst);
565            Ok(AfterToolCallResult::Continue(result))
566        }
567
568        async fn before_model_call(
569            &self,
570            request: LlmRequest,
571            _ctx: Arc<dyn CallbackContext>,
572            _plugin_ctx: &PluginContext,
573        ) -> Result<BeforeModelCallResult> {
574            self.before_model_called.store(true, Ordering::SeqCst);
575            Ok(BeforeModelCallResult::Continue(request))
576        }
577
578        async fn after_model_call(
579            &self,
580            response: LlmResponse,
581            _ctx: Arc<dyn CallbackContext>,
582            _plugin_ctx: &PluginContext,
583        ) -> Result<AfterModelCallResult> {
584            self.after_model_called.store(true, Ordering::SeqCst);
585            Ok(AfterModelCallResult::Continue(response))
586        }
587    }
588
589    /// A plugin that short-circuits before_model_call.
590    struct ModelShortCircuitPlugin {
591        name: String,
592        priority: i32,
593    }
594
595    #[async_trait]
596    impl EnhancedPlugin for ModelShortCircuitPlugin {
597        fn name(&self) -> &str {
598            &self.name
599        }
600
601        fn priority(&self) -> i32 {
602            self.priority
603        }
604
605        async fn before_model_call(
606            &self,
607            _request: LlmRequest,
608            _ctx: Arc<dyn CallbackContext>,
609            _plugin_ctx: &PluginContext,
610        ) -> Result<BeforeModelCallResult> {
611            Ok(BeforeModelCallResult::ShortCircuit(LlmResponse::default()))
612        }
613    }
614
615    /// A plugin that records execution order via a shared counter.
616    struct OrderTrackingPlugin {
617        name: String,
618        priority: i32,
619        order_counter: Arc<AtomicUsize>,
620        recorded_order: AtomicUsize,
621    }
622
623    impl OrderTrackingPlugin {
624        fn new(name: &str, priority: i32, counter: Arc<AtomicUsize>) -> Self {
625            Self {
626                name: name.to_string(),
627                priority,
628                order_counter: counter,
629                recorded_order: AtomicUsize::new(0),
630            }
631        }
632
633        fn execution_order(&self) -> usize {
634            self.recorded_order.load(Ordering::SeqCst)
635        }
636    }
637
638    #[async_trait]
639    impl EnhancedPlugin for OrderTrackingPlugin {
640        fn name(&self) -> &str {
641            &self.name
642        }
643
644        fn priority(&self) -> i32 {
645            self.priority
646        }
647
648        async fn before_tool_call(
649            &self,
650            _tool: Arc<dyn Tool>,
651            args: Value,
652            _ctx: Arc<dyn CallbackContext>,
653            _plugin_ctx: &PluginContext,
654        ) -> Result<BeforeToolCallResult> {
655            let order = self.order_counter.fetch_add(1, Ordering::SeqCst);
656            self.recorded_order.store(order, Ordering::SeqCst);
657            Ok(BeforeToolCallResult::Continue(args))
658        }
659    }
660
661    // --- Mock Tool and CallbackContext ---
662
663    struct MockTool;
664
665    #[async_trait]
666    impl Tool for MockTool {
667        fn name(&self) -> &str {
668            "mock_tool"
669        }
670
671        fn description(&self) -> &str {
672            "A mock tool for testing"
673        }
674
675        async fn execute(
676            &self,
677            _ctx: Arc<dyn adk_core::ToolContext>,
678            _args: Value,
679        ) -> Result<Value> {
680            Ok(json!({"result": "mock"}))
681        }
682    }
683
684    struct MockCallbackContext {
685        content: AdkContent,
686    }
687
688    impl MockCallbackContext {
689        fn new() -> Self {
690            Self { content: AdkContent::new("user") }
691        }
692    }
693
694    impl adk_core::ReadonlyContext for MockCallbackContext {
695        fn invocation_id(&self) -> &str {
696            "test-invocation"
697        }
698
699        fn agent_name(&self) -> &str {
700            "test-agent"
701        }
702
703        fn user_id(&self) -> &str {
704            "test-user"
705        }
706
707        fn app_name(&self) -> &str {
708            "test-app"
709        }
710
711        fn session_id(&self) -> &str {
712            "test-session"
713        }
714
715        fn branch(&self) -> &str {
716            ""
717        }
718
719        fn user_content(&self) -> &AdkContent {
720            &self.content
721        }
722    }
723
724    #[async_trait]
725    impl CallbackContext for MockCallbackContext {
726        fn artifacts(&self) -> Option<Arc<dyn adk_core::Artifacts>> {
727            None
728        }
729
730        fn tool_name(&self) -> Option<&str> {
731            Some("mock_tool")
732        }
733    }
734
735    fn mock_tool() -> Arc<dyn Tool> {
736        Arc::new(MockTool)
737    }
738
739    fn mock_ctx() -> Arc<dyn CallbackContext> {
740        Arc::new(MockCallbackContext::new())
741    }
742
743    fn mock_request() -> LlmRequest {
744        LlmRequest::new("test-model", vec![])
745    }
746
747    // --- Tests ---
748
749    #[test]
750    fn test_new_sorts_by_priority() {
751        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
752            Arc::new(NoOpPlugin::new("c", 100)),
753            Arc::new(NoOpPlugin::new("a", 10)),
754            Arc::new(NoOpPlugin::new("b", 50)),
755        ];
756
757        let manager = EnhancedPluginManager::new(plugins);
758        assert_eq!(manager.plugin_names(), vec!["a", "b", "c"]);
759    }
760
761    #[test]
762    fn test_stable_sort_preserves_registration_order() {
763        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
764            Arc::new(NoOpPlugin::new("first", 100)),
765            Arc::new(NoOpPlugin::new("second", 100)),
766            Arc::new(NoOpPlugin::new("third", 100)),
767        ];
768
769        let manager = EnhancedPluginManager::new(plugins);
770        assert_eq!(manager.plugin_names(), vec!["first", "second", "third"]);
771    }
772
773    #[test]
774    fn test_add_plugin_resorts() {
775        let plugins: Vec<Arc<dyn EnhancedPlugin>> =
776            vec![Arc::new(NoOpPlugin::new("b", 50)), Arc::new(NoOpPlugin::new("c", 100))];
777
778        let mut manager = EnhancedPluginManager::new(plugins);
779        manager.add_plugin(Arc::new(NoOpPlugin::new("a", 10)));
780
781        assert_eq!(manager.plugin_names(), vec!["a", "b", "c"]);
782        assert_eq!(manager.plugin_count(), 3);
783    }
784
785    #[test]
786    fn test_context_accessor() {
787        let manager = EnhancedPluginManager::new(vec![]);
788        let ctx = manager.context();
789        // Just verify we can access it without panic
790        assert!(Arc::strong_count(ctx) >= 1);
791    }
792
793    #[test]
794    fn test_empty_manager() {
795        let manager = EnhancedPluginManager::new(vec![]);
796        assert_eq!(manager.plugin_count(), 0);
797        assert!(manager.plugin_names().is_empty());
798    }
799
800    #[tokio::test]
801    async fn test_before_tool_call_pipeline_propagation() {
802        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
803            Arc::new(ArgModifierPlugin {
804                name: "plugin1".to_string(),
805                priority: 10,
806                key: "added_by_1".to_string(),
807                value: json!(true),
808            }),
809            Arc::new(ArgModifierPlugin {
810                name: "plugin2".to_string(),
811                priority: 20,
812                key: "added_by_2".to_string(),
813                value: json!("hello"),
814            }),
815        ];
816
817        let manager = EnhancedPluginManager::new(plugins);
818        let result = manager
819            .run_before_tool_call(mock_tool(), json!({"original": "value"}), mock_ctx())
820            .await
821            .unwrap();
822
823        match result {
824            BeforeToolCallResult::Continue(args) => {
825                assert_eq!(args["original"], "value");
826                assert_eq!(args["added_by_1"], true);
827                assert_eq!(args["added_by_2"], "hello");
828            }
829            BeforeToolCallResult::ShortCircuit(_) => panic!("expected Continue"),
830        }
831    }
832
833    #[tokio::test]
834    async fn test_before_tool_call_short_circuit() {
835        let tracking = Arc::new(TrackingPlugin::new("after_short_circuit", 50));
836        let tracking_clone = tracking.clone();
837
838        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
839            Arc::new(ShortCircuitPlugin {
840                name: "short_circuit".to_string(),
841                priority: 10,
842                result: json!({"cached": true}),
843            }),
844            tracking_clone,
845        ];
846
847        let manager = EnhancedPluginManager::new(plugins);
848        let result =
849            manager.run_before_tool_call(mock_tool(), json!({}), mock_ctx()).await.unwrap();
850
851        match result {
852            BeforeToolCallResult::ShortCircuit(value) => {
853                assert_eq!(value, json!({"cached": true}));
854            }
855            BeforeToolCallResult::Continue(_) => panic!("expected ShortCircuit"),
856        }
857
858        // The tracking plugin should NOT have been called
859        assert!(!tracking.before_tool_called.load(Ordering::SeqCst));
860    }
861
862    #[tokio::test]
863    async fn test_before_tool_call_error_propagation() {
864        let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
865        let tracking_clone = tracking.clone();
866
867        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
868            Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
869            tracking_clone,
870        ];
871
872        let manager = EnhancedPluginManager::new(plugins);
873        let result = manager.run_before_tool_call(mock_tool(), json!({}), mock_ctx()).await;
874
875        assert!(result.is_err());
876        // The tracking plugin should NOT have been called
877        assert!(!tracking.before_tool_called.load(Ordering::SeqCst));
878    }
879
880    #[tokio::test]
881    async fn test_after_tool_call_pipeline_propagation() {
882        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
883            Arc::new(ResultModifierPlugin {
884                name: "plugin1".to_string(),
885                priority: 10,
886                key: "enriched_by_1".to_string(),
887                value: json!(true),
888            }),
889            Arc::new(ResultModifierPlugin {
890                name: "plugin2".to_string(),
891                priority: 20,
892                key: "enriched_by_2".to_string(),
893                value: json!(42),
894            }),
895        ];
896
897        let manager = EnhancedPluginManager::new(plugins);
898        let args = json!({"tool_arg": "test"});
899        let result = manager
900            .run_after_tool_call(mock_tool(), &args, json!({"status": "ok"}), mock_ctx())
901            .await
902            .unwrap();
903
904        match result {
905            AfterToolCallResult::Continue(value) => {
906                assert_eq!(value["status"], "ok");
907                assert_eq!(value["enriched_by_1"], true);
908                assert_eq!(value["enriched_by_2"], 42);
909            }
910        }
911    }
912
913    #[tokio::test]
914    async fn test_after_tool_call_error_propagation() {
915        let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
916        let tracking_clone = tracking.clone();
917
918        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
919            Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
920            tracking_clone,
921        ];
922
923        let manager = EnhancedPluginManager::new(plugins);
924        let result =
925            manager.run_after_tool_call(mock_tool(), &json!({}), json!({}), mock_ctx()).await;
926
927        assert!(result.is_err());
928        assert!(!tracking.after_tool_called.load(Ordering::SeqCst));
929    }
930
931    #[tokio::test]
932    async fn test_before_model_call_pipeline_propagation() {
933        // Use no-op plugins to verify the request passes through
934        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
935            Arc::new(NoOpPlugin::new("plugin1", 10)),
936            Arc::new(NoOpPlugin::new("plugin2", 20)),
937        ];
938
939        let manager = EnhancedPluginManager::new(plugins);
940        let request = mock_request();
941        let result = manager.run_before_model_call(request, mock_ctx()).await.unwrap();
942
943        match result {
944            BeforeModelCallResult::Continue(_) => { /* pass */ }
945            BeforeModelCallResult::ShortCircuit(_) => panic!("expected Continue"),
946        }
947    }
948
949    #[tokio::test]
950    async fn test_before_model_call_short_circuit() {
951        let tracking = Arc::new(TrackingPlugin::new("after_short_circuit", 50));
952        let tracking_clone = tracking.clone();
953
954        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
955            Arc::new(ModelShortCircuitPlugin {
956                name: "model_short_circuit".to_string(),
957                priority: 10,
958            }),
959            tracking_clone,
960        ];
961
962        let manager = EnhancedPluginManager::new(plugins);
963        let result = manager.run_before_model_call(mock_request(), mock_ctx()).await.unwrap();
964
965        match result {
966            BeforeModelCallResult::ShortCircuit(_) => { /* pass */ }
967            BeforeModelCallResult::Continue(_) => panic!("expected ShortCircuit"),
968        }
969
970        assert!(!tracking.before_model_called.load(Ordering::SeqCst));
971    }
972
973    #[tokio::test]
974    async fn test_before_model_call_error_propagation() {
975        let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
976        let tracking_clone = tracking.clone();
977
978        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
979            Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
980            tracking_clone,
981        ];
982
983        let manager = EnhancedPluginManager::new(plugins);
984        let result = manager.run_before_model_call(mock_request(), mock_ctx()).await;
985
986        assert!(result.is_err());
987        assert!(!tracking.before_model_called.load(Ordering::SeqCst));
988    }
989
990    #[tokio::test]
991    async fn test_after_model_call_pipeline_propagation() {
992        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
993            Arc::new(NoOpPlugin::new("plugin1", 10)),
994            Arc::new(NoOpPlugin::new("plugin2", 20)),
995        ];
996
997        let manager = EnhancedPluginManager::new(plugins);
998        let result =
999            manager.run_after_model_call(LlmResponse::default(), mock_ctx()).await.unwrap();
1000
1001        match result {
1002            AfterModelCallResult::Continue(_) => { /* pass */ }
1003        }
1004    }
1005
1006    #[tokio::test]
1007    async fn test_after_model_call_error_propagation() {
1008        let tracking = Arc::new(TrackingPlugin::new("after_error", 50));
1009        let tracking_clone = tracking.clone();
1010
1011        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
1012            Arc::new(ErrorPlugin { name: "error_plugin".to_string(), priority: 10 }),
1013            tracking_clone,
1014        ];
1015
1016        let manager = EnhancedPluginManager::new(plugins);
1017        let result = manager.run_after_model_call(LlmResponse::default(), mock_ctx()).await;
1018
1019        assert!(result.is_err());
1020        assert!(!tracking.after_model_called.load(Ordering::SeqCst));
1021    }
1022
1023    #[tokio::test]
1024    async fn test_empty_plugin_list_before_tool_call() {
1025        let manager = EnhancedPluginManager::new(vec![]);
1026        let result = manager
1027            .run_before_tool_call(mock_tool(), json!({"key": "value"}), mock_ctx())
1028            .await
1029            .unwrap();
1030
1031        match result {
1032            BeforeToolCallResult::Continue(args) => {
1033                assert_eq!(args, json!({"key": "value"}));
1034            }
1035            BeforeToolCallResult::ShortCircuit(_) => panic!("expected Continue"),
1036        }
1037    }
1038
1039    #[tokio::test]
1040    async fn test_empty_plugin_list_after_tool_call() {
1041        let manager = EnhancedPluginManager::new(vec![]);
1042        let result = manager
1043            .run_after_tool_call(mock_tool(), &json!({}), json!({"result": 42}), mock_ctx())
1044            .await
1045            .unwrap();
1046
1047        match result {
1048            AfterToolCallResult::Continue(value) => {
1049                assert_eq!(value, json!({"result": 42}));
1050            }
1051        }
1052    }
1053
1054    #[tokio::test]
1055    async fn test_empty_plugin_list_before_model_call() {
1056        let manager = EnhancedPluginManager::new(vec![]);
1057        let request = mock_request();
1058        let result = manager.run_before_model_call(request, mock_ctx()).await.unwrap();
1059
1060        match result {
1061            BeforeModelCallResult::Continue(_) => { /* pass */ }
1062            BeforeModelCallResult::ShortCircuit(_) => panic!("expected Continue"),
1063        }
1064    }
1065
1066    #[tokio::test]
1067    async fn test_empty_plugin_list_after_model_call() {
1068        let manager = EnhancedPluginManager::new(vec![]);
1069        let result =
1070            manager.run_after_model_call(LlmResponse::default(), mock_ctx()).await.unwrap();
1071
1072        match result {
1073            AfterModelCallResult::Continue(_) => { /* pass */ }
1074        }
1075    }
1076
1077    #[tokio::test]
1078    async fn test_priority_ordering_execution() {
1079        let counter = Arc::new(AtomicUsize::new(0));
1080
1081        let p1 = Arc::new(OrderTrackingPlugin::new("high_priority", 10, counter.clone()));
1082        let p2 = Arc::new(OrderTrackingPlugin::new("medium_priority", 50, counter.clone()));
1083        let p3 = Arc::new(OrderTrackingPlugin::new("low_priority", 100, counter.clone()));
1084
1085        let p1_clone = p1.clone();
1086        let p2_clone = p2.clone();
1087        let p3_clone = p3.clone();
1088
1089        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![p3_clone, p1_clone, p2_clone];
1090
1091        let manager = EnhancedPluginManager::new(plugins);
1092        manager.run_before_tool_call(mock_tool(), json!({}), mock_ctx()).await.unwrap();
1093
1094        // Verify execution order: high (0), medium (1), low (2)
1095        assert_eq!(p1.execution_order(), 0);
1096        assert_eq!(p2.execution_order(), 1);
1097        assert_eq!(p3.execution_order(), 2);
1098    }
1099
1100    #[tokio::test]
1101    async fn test_close_calls_all_plugins() {
1102        let closed = Arc::new(AtomicUsize::new(0));
1103
1104        struct CloseTrackingPlugin {
1105            name: String,
1106            closed: Arc<AtomicUsize>,
1107        }
1108
1109        #[async_trait]
1110        impl EnhancedPlugin for CloseTrackingPlugin {
1111            fn name(&self) -> &str {
1112                &self.name
1113            }
1114
1115            async fn close(&self) {
1116                self.closed.fetch_add(1, Ordering::SeqCst);
1117            }
1118        }
1119
1120        let plugins: Vec<Arc<dyn EnhancedPlugin>> = vec![
1121            Arc::new(CloseTrackingPlugin { name: "p1".to_string(), closed: closed.clone() }),
1122            Arc::new(CloseTrackingPlugin { name: "p2".to_string(), closed: closed.clone() }),
1123            Arc::new(CloseTrackingPlugin { name: "p3".to_string(), closed: closed.clone() }),
1124        ];
1125
1126        let manager = EnhancedPluginManager::new(plugins);
1127        manager.close().await;
1128
1129        assert_eq!(closed.load(Ordering::SeqCst), 3);
1130    }
1131
1132    #[tokio::test]
1133    async fn test_debug_impl() {
1134        let plugins: Vec<Arc<dyn EnhancedPlugin>> =
1135            vec![Arc::new(NoOpPlugin::new("alpha", 10)), Arc::new(NoOpPlugin::new("beta", 20))];
1136
1137        let manager = EnhancedPluginManager::new(plugins);
1138        let debug_str = format!("{manager:?}");
1139        assert!(debug_str.contains("EnhancedPluginManager"));
1140        assert!(debug_str.contains("plugin_count: 2"));
1141    }
1142}