mcp_host/registry/
tools.rs

1//! Tool registry for MCP servers
2//!
3//! Provides registration and execution of MCP tools with circuit breaker protection
4
5use std::collections::HashSet;
6use std::sync::{Arc, RwLock};
7
8use async_trait::async_trait;
9use breaker_machines::{CircuitBreaker, Config as BreakerConfig};
10use dashmap::DashMap;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use state_machines::state_machine;
14use thiserror::Error;
15use tokio::sync::mpsc;
16
17use crate::content::types::{Content, ImageContent, TextContent};
18use crate::server::multiplexer::ClientRequester;
19use crate::server::session::Session;
20use crate::server::visibility::{ExecutionContext, VisibilityContext};
21use crate::transport::traits::JsonRpcNotification;
22
23/// Tool execution errors
24#[derive(Debug, Error)]
25pub enum ToolError {
26    /// Tool not found
27    #[error("Tool not found: {0}")]
28    NotFound(String),
29
30    /// Invalid arguments
31    #[error("Invalid arguments: {0}")]
32    InvalidArguments(String),
33
34    /// Execution error
35    #[error("Execution error: {0}")]
36    Execution(String),
37
38    /// Internal error
39    #[error("Internal error: {0}")]
40    Internal(String),
41
42    /// Circuit breaker is open
43    #[error("Circuit breaker open for tool '{tool}': {message}")]
44    CircuitOpen { tool: String, message: String },
45}
46
47// =============================================================================
48// TOOL LIFECYCLE - Enabled/Disabled/Defective per tool
49// =============================================================================
50
51/// Runtime lifecycle state for a tool.
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum ToolLifecycleState {
54    /// Tool is enabled and can be listed/called.
55    Enabled,
56    /// Tool is disabled (hidden and not callable).
57    Disabled,
58    /// Tool is defective (disabled by circuit breaker).
59    Defective,
60}
61
62impl ToolLifecycleState {
63    fn from_state_name(name: &str) -> Option<Self> {
64        match name {
65            "Enabled" => Some(Self::Enabled),
66            "Disabled" => Some(Self::Disabled),
67            "Defective" => Some(Self::Defective),
68            _ => None,
69        }
70    }
71}
72
73/// Policy configuration for a tool's lifecycle.
74#[derive(Debug, Clone)]
75pub struct ToolPolicy {
76    /// Desired initial state when the tool is registered.
77    pub initial_state: ToolLifecycleState,
78    /// Exclusive group name (only one tool in the group can be enabled).
79    pub exclusive_group: Option<String>,
80    /// Tools to activate immediately when this tool becomes enabled.
81    pub activates: Vec<String>,
82}
83
84impl Default for ToolPolicy {
85    fn default() -> Self {
86        Self {
87            initial_state: ToolLifecycleState::Enabled,
88            exclusive_group: None,
89            activates: Vec::new(),
90        }
91    }
92}
93
94impl ToolPolicy {
95    /// Create a new policy with defaults (enabled, no group, no activations).
96    pub fn new() -> Self {
97        Self::default()
98    }
99
100    /// Set the initial lifecycle state.
101    pub fn initial_state(mut self, state: ToolLifecycleState) -> Self {
102        self.initial_state = state;
103        self
104    }
105
106    /// Set an exclusive group for this tool.
107    pub fn exclusive_group(mut self, group: impl Into<String>) -> Self {
108        self.exclusive_group = Some(group.into());
109        self
110    }
111
112    /// Set tools to activate when this tool becomes enabled.
113    pub fn activates<I, S>(mut self, tools: I) -> Self
114    where
115        I: IntoIterator<Item = S>,
116        S: Into<String>,
117    {
118        self.activates = tools.into_iter().map(Into::into).collect();
119        self
120    }
121}
122
123#[derive(Debug, Default)]
124struct ToolLifecycleContext;
125
126impl ToolLifecycleContext {
127    fn new() -> Self {
128        Self
129    }
130}
131
132state_machine! {
133    name: ToolLifecycle,
134    context: ToolLifecycleContext,
135    dynamic: true,
136    initial: Enabled,
137    states: [Enabled, Disabled, Defective],
138    events {
139        enable {
140            transition: { from: [Disabled, Defective], to: Enabled }
141        }
142        disable {
143            transition: { from: [Enabled, Defective], to: Disabled }
144        }
145        mark_defective {
146            transition: { from: [Enabled, Disabled], to: Defective }
147        }
148        recover {
149            transition: { from: Defective, to: Enabled }
150        }
151    }
152}
153
154// =============================================================================
155// TOOL OUTPUT - Exclusive content OR structured output
156// =============================================================================
157
158/// Tool execution output - either content OR structured, never both.
159///
160/// MCP spec allows tools to return either:
161/// - `content`: Array of content items (text, images, etc.) for display
162/// - `structuredContent`: JSON object conforming to the tool's `outputSchema`
163///
164/// This enum enforces exclusivity at the type level.
165pub enum ToolOutput {
166    /// Traditional content array (text, images, etc.)
167    Content(Vec<Box<dyn Content>>),
168    /// Structured JSON output (must conform to outputSchema if defined)
169    Structured(Value),
170}
171
172impl std::fmt::Debug for ToolOutput {
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        match self {
175            Self::Content(items) => f
176                .debug_struct("Content")
177                .field("count", &items.len())
178                .finish(),
179            Self::Structured(value) => f.debug_tuple("Structured").field(value).finish(),
180        }
181    }
182}
183
184impl ToolOutput {
185    /// Create content output with a single text item
186    pub fn text(s: impl Into<String>) -> Self {
187        Self::Content(vec![Box::new(TextContent::new(s))])
188    }
189
190    /// Create content output with multiple text items
191    pub fn texts(items: &[&str]) -> Self {
192        Self::Content(
193            items
194                .iter()
195                .map(|s| Box::new(TextContent::new(*s)) as Box<dyn Content>)
196                .collect(),
197        )
198    }
199
200    /// Create content output from a vec of content items
201    pub fn content(items: Vec<Box<dyn Content>>) -> Self {
202        Self::Content(items)
203    }
204
205    /// Create structured output from a serializable value
206    pub fn structured<T: Serialize>(value: T) -> Result<Self, serde_json::Error> {
207        Ok(Self::Structured(serde_json::to_value(value)?))
208    }
209
210    /// Create structured output from a raw JSON value
211    pub fn json(value: Value) -> Self {
212        Self::Structured(value)
213    }
214
215    /// Check if this is content output
216    pub fn is_content(&self) -> bool {
217        matches!(self, Self::Content(_))
218    }
219
220    /// Check if this is structured output
221    pub fn is_structured(&self) -> bool {
222        matches!(self, Self::Structured(_))
223    }
224}
225
226/// Tool metadata for listing
227#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct ToolInfo {
229    /// Tool name (programmatic identifier)
230    pub name: String,
231
232    /// Human-readable title for UI display
233    #[serde(skip_serializing_if = "Option::is_none")]
234    pub title: Option<String>,
235
236    /// Tool description
237    #[serde(skip_serializing_if = "Option::is_none")]
238    pub description: Option<String>,
239
240    /// Input schema (JSON Schema)
241    #[serde(rename = "inputSchema")]
242    pub input_schema: Value,
243
244    /// Output schema (JSON Schema) for structured output
245    #[serde(rename = "outputSchema", skip_serializing_if = "Option::is_none")]
246    pub output_schema: Option<Value>,
247
248    /// Execution metadata (task support, etc.)
249    #[serde(skip_serializing_if = "Option::is_none")]
250    pub execution: Option<crate::protocol::types::ToolExecution>,
251
252    /// Tool annotations (behavioral hints for LLMs)
253    #[serde(skip_serializing_if = "Option::is_none")]
254    pub annotations: Option<crate::protocol::types::ToolAnnotations>,
255}
256
257/// Tool trait for implementing MCP tools
258#[async_trait]
259pub trait Tool: Send + Sync {
260    /// Get tool name (programmatic identifier)
261    fn name(&self) -> &str;
262
263    /// Get human-readable title for UI display
264    ///
265    /// If not provided, clients should use `name` as fallback.
266    fn title(&self) -> Option<&str> {
267        None
268    }
269
270    /// Get tool description
271    fn description(&self) -> Option<&str> {
272        None
273    }
274
275    /// Get input schema (JSON Schema)
276    fn input_schema(&self) -> Value;
277
278    /// Get output schema (JSON Schema) for structured output
279    ///
280    /// When this returns Some, the tool MUST return `ToolOutput::Structured`
281    /// that conforms to this schema. When None, the tool returns `ToolOutput::Content`.
282    fn output_schema(&self) -> Option<Value> {
283        None
284    }
285
286    /// Get execution metadata (task support, etc.)
287    ///
288    /// Override this to declare task support for this tool.
289    /// Returns None by default (forbidden task execution).
290    fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
291        None
292    }
293
294    /// Get tool annotations (behavioral hints for LLMs)
295    ///
296    /// Override this to provide hints about tool behavior (read-only, destructive, etc.)
297    fn annotations(&self) -> Option<crate::protocol::types::ToolAnnotations> {
298        None
299    }
300
301    /// Check if this tool should be visible in the given context
302    ///
303    /// Override this to implement contextual visibility. The default implementation
304    /// always returns true (always visible).
305    ///
306    /// # Example
307    ///
308    /// ```rust,ignore
309    /// fn is_visible(&self, ctx: &VisibilityContext) -> bool {
310    ///     // Only show git commit tool if repo has uncommitted changes
311    ///     ctx.environment
312    ///         .map(|e| e.has_git_repo() && !e.git_is_clean())
313    ///         .unwrap_or(false)
314    /// }
315    /// ```
316    fn is_visible(&self, _ctx: &VisibilityContext) -> bool {
317        true
318    }
319
320    /// Execute the tool with execution context
321    ///
322    /// The execution context provides access to:
323    /// - `ctx.params`: Input parameters from the tool call
324    /// - `ctx.session`: Current session (roles, state, client info)
325    /// - `ctx.environment`: Optional environment state
326    ///
327    /// # Returns
328    ///
329    /// - `ToolOutput::Content` for traditional content (text, images)
330    /// - `ToolOutput::Structured` for JSON output (requires `output_schema`)
331    ///
332    /// # Example
333    ///
334    /// ```rust,ignore
335    /// // Content output
336    /// async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<ToolOutput, ToolError> {
337    ///     Ok(ToolOutput::text("Hello, world!"))
338    /// }
339    ///
340    /// // Structured output
341    /// async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<ToolOutput, ToolError> {
342    ///     ToolOutput::structured(MyOutput { value: 42 })
343    ///         .map_err(|e| ToolError::Internal(e.to_string()))
344    /// }
345    /// ```
346    async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<ToolOutput, ToolError>;
347}
348
349/// Helper methods for tool implementations
350///
351/// Provides ergonomic shortcuts for creating common content types.
352/// Auto-implemented for all Tool trait objects.
353pub trait ToolHelpers {
354    /// Create text content
355    fn text(&self, content: &str) -> Box<dyn Content> {
356        Box::new(TextContent::new(content))
357    }
358
359    /// Create image content
360    fn image(&self, data: &str, mime_type: &str) -> Box<dyn Content> {
361        Box::new(ImageContent::new(data, mime_type))
362    }
363}
364
365// Auto-implement ToolHelpers for all Tool trait objects
366impl<T: Tool + ?Sized> ToolHelpers for T {}
367
368/// Circuit breaker configuration for tools
369#[derive(Debug, Clone)]
370pub struct ToolBreakerConfig {
371    /// Number of failures before opening circuit
372    pub failure_threshold: usize,
373    /// Time window in seconds for counting failures
374    pub failure_window_secs: f64,
375    /// Timeout before transitioning to half-open
376    pub half_open_timeout_secs: f64,
377    /// Successes needed to close circuit
378    pub success_threshold: usize,
379}
380
381impl Default for ToolBreakerConfig {
382    fn default() -> Self {
383        Self {
384            failure_threshold: 5,
385            failure_window_secs: 60.0,
386            half_open_timeout_secs: 30.0,
387            success_threshold: 2,
388        }
389    }
390}
391
392impl From<ToolBreakerConfig> for BreakerConfig {
393    fn from(cfg: ToolBreakerConfig) -> Self {
394        BreakerConfig {
395            failure_threshold: Some(cfg.failure_threshold),
396            failure_rate_threshold: None,
397            minimum_calls: 1,
398            failure_window_secs: cfg.failure_window_secs,
399            half_open_timeout_secs: cfg.half_open_timeout_secs,
400            success_threshold: cfg.success_threshold,
401            jitter_factor: 0.1,
402        }
403    }
404}
405
406/// Tool registry for managing available tools with circuit breaker protection
407#[derive(Clone)]
408pub struct ToolRegistry {
409    tools: Arc<DashMap<String, Arc<dyn Tool>>>,
410    breakers: Arc<DashMap<String, Arc<RwLock<CircuitBreaker>>>>,
411    breaker_config: Arc<RwLock<ToolBreakerConfig>>,
412    notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
413    tool_states: Arc<DashMap<String, Arc<RwLock<DynamicToolLifecycle>>>>,
414    tool_policies: Arc<DashMap<String, ToolPolicy>>,
415}
416
417impl ToolRegistry {
418    /// Create new tool registry
419    pub fn new() -> Self {
420        Self {
421            tools: Arc::new(DashMap::new()),
422            breakers: Arc::new(DashMap::new()),
423            breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
424            notification_tx: None,
425            tool_states: Arc::new(DashMap::new()),
426            tool_policies: Arc::new(DashMap::new()),
427        }
428    }
429
430    /// Create tool registry with notification channel
431    pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
432        Self {
433            tools: Arc::new(DashMap::new()),
434            breakers: Arc::new(DashMap::new()),
435            breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
436            notification_tx: Some(notification_tx),
437            tool_states: Arc::new(DashMap::new()),
438            tool_policies: Arc::new(DashMap::new()),
439        }
440    }
441
442    /// Set notification channel
443    pub fn set_notification_tx(&mut self, tx: mpsc::UnboundedSender<JsonRpcNotification>) {
444        self.notification_tx = Some(tx);
445    }
446
447    /// Configure circuit breakers for all tools
448    pub fn set_breaker_config(&self, config: ToolBreakerConfig) {
449        if let Ok(mut cfg) = self.breaker_config.write() {
450            *cfg = config;
451        }
452    }
453
454    /// Configure lifecycle policy for a tool.
455    ///
456    /// This applies the policy immediately, including the desired state.
457    pub fn set_tool_policy(&self, name: impl Into<String>, policy: ToolPolicy) {
458        let name = name.into();
459        self.tool_policies.insert(name.clone(), policy.clone());
460        self.ensure_tool_state(&name);
461
462        let mut visited = HashSet::new();
463        let changed = self.set_tool_state_internal(&name, policy.initial_state, &mut visited);
464        if !changed && self.tool_state(&name) == Some(ToolLifecycleState::Enabled) {
465            visited.insert(name.clone());
466            self.apply_enable_policy(&name, &policy, &mut visited);
467        }
468    }
469
470    /// Get lifecycle state for a tool if tracked.
471    pub fn tool_state(&self, name: &str) -> Option<ToolLifecycleState> {
472        self.tool_states.get(name).and_then(|entry| {
473            entry
474                .read()
475                .ok()
476                .and_then(|machine| ToolLifecycleState::from_state_name(machine.current_state()))
477        })
478    }
479
480    /// Check if a tool is enabled.
481    ///
482    /// If the tool has no lifecycle tracking yet, it is treated as enabled.
483    pub fn is_tool_enabled(&self, name: &str) -> bool {
484        match self.tool_state(name) {
485            Some(ToolLifecycleState::Enabled) => true,
486            Some(_) => false,
487            None => true,
488        }
489    }
490
491    /// Enable a tool (and apply activation/exclusive policies).
492    pub fn enable_tool(&self, name: &str) -> bool {
493        let mut visited = HashSet::new();
494        self.set_tool_state_internal(name, ToolLifecycleState::Enabled, &mut visited)
495    }
496
497    /// Disable a tool.
498    pub fn disable_tool(&self, name: &str) -> bool {
499        let mut visited = HashSet::new();
500        self.set_tool_state_internal(name, ToolLifecycleState::Disabled, &mut visited)
501    }
502
503    /// Mark a tool as defective (breaker-open state).
504    pub fn mark_tool_defective(&self, name: &str) -> bool {
505        let mut visited = HashSet::new();
506        self.set_tool_state_internal(name, ToolLifecycleState::Defective, &mut visited)
507    }
508
509    /// Recover a tool from defective state (no-op if not defective).
510    pub fn recover_tool(&self, name: &str) -> bool {
511        if self.tool_state(name) != Some(ToolLifecycleState::Defective) {
512            return false;
513        }
514        let mut visited = HashSet::new();
515        self.set_tool_state_internal(name, ToolLifecycleState::Enabled, &mut visited)
516    }
517
518    /// Register a tool with circuit breaker protection
519    pub fn register<T: Tool + 'static>(&self, tool: T) {
520        self.register_boxed(Arc::new(tool));
521    }
522
523    /// Register a tool with lifecycle policy.
524    pub fn register_with_policy<T: Tool + 'static>(&self, tool: T, policy: ToolPolicy) {
525        self.register_boxed_with_policy(Arc::new(tool), policy);
526    }
527
528    /// Register a boxed tool with circuit breaker
529    pub fn register_boxed(&self, tool: Arc<dyn Tool>) {
530        let name = tool.name().to_string();
531        let policy = self
532            .tool_policies
533            .get(&name)
534            .map(|p| p.clone())
535            .unwrap_or_default();
536        self.register_boxed_with_policy(tool, policy);
537    }
538
539    /// Register a boxed tool with lifecycle policy.
540    pub fn register_boxed_with_policy(&self, tool: Arc<dyn Tool>, policy: ToolPolicy) {
541        let name = tool.name().to_string();
542
543        // Create circuit breaker for this tool
544        let breaker_config = self
545            .breaker_config
546            .read()
547            .map(|c| c.clone())
548            .unwrap_or_default();
549
550        let breaker = CircuitBreaker::builder(&name)
551            .failure_threshold(breaker_config.failure_threshold)
552            .failure_window_secs(breaker_config.failure_window_secs)
553            .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
554            .success_threshold(breaker_config.success_threshold)
555            .build();
556
557        self.breakers
558            .insert(name.clone(), Arc::new(RwLock::new(breaker)));
559        self.tools.insert(name.clone(), tool);
560
561        // Store policy (overrides any existing policy)
562        self.tool_policies.insert(name.clone(), policy.clone());
563
564        // Initialize lifecycle state only if we just created it.
565        let (_, created) = self.ensure_tool_state(&name);
566        if created && policy.initial_state != ToolLifecycleState::Enabled {
567            self.set_tool_state_initial(&name, policy.initial_state);
568        }
569        if self.tool_state(&name) == Some(ToolLifecycleState::Enabled) {
570            let mut visited = HashSet::new();
571            visited.insert(name.clone());
572            self.apply_enable_policy(&name, &policy, &mut visited);
573        }
574    }
575
576    /// Get a tool by name
577    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
578        self.tools.get(name).map(|t| Arc::clone(&t))
579    }
580
581    /// Check if a tool's circuit breaker is open
582    pub fn is_circuit_open(&self, name: &str) -> bool {
583        self.breakers
584            .get(name)
585            .and_then(|b| b.read().ok().map(|breaker| breaker.is_open()))
586            .unwrap_or(false)
587    }
588
589    /// List all registered tools (excluding those with open circuits if filtered)
590    pub fn list(&self) -> Vec<ToolInfo> {
591        self.tools
592            .iter()
593            .map(|entry| {
594                let tool = entry.value();
595                ToolInfo {
596                    name: tool.name().to_string(),
597                    title: tool.title().map(|s| s.to_string()),
598                    description: tool.description().map(|s| s.to_string()),
599                    input_schema: tool.input_schema(),
600                    output_schema: tool.output_schema(),
601                    execution: tool.execution(),
602                    annotations: tool.annotations(),
603                }
604            })
605            .collect()
606    }
607
608    /// List only available tools (circuit closed or half-open)
609    pub fn list_available(&self) -> Vec<ToolInfo> {
610        self.tools
611            .iter()
612            .filter(|entry| {
613                !self.is_circuit_open(entry.key()) && self.is_tool_enabled(entry.key())
614            })
615            .map(|entry| {
616                let tool = entry.value();
617                ToolInfo {
618                    name: tool.name().to_string(),
619                    title: tool.title().map(|s| s.to_string()),
620                    description: tool.description().map(|s| s.to_string()),
621                    input_schema: tool.input_schema(),
622                    output_schema: tool.output_schema(),
623                    execution: tool.execution(),
624                    annotations: tool.annotations(),
625                }
626            })
627            .collect()
628    }
629
630    /// Send notification about circuit state change
631    fn send_notification(&self, method: &str, params: Option<Value>) {
632        if let Some(tx) = &self.notification_tx {
633            let notification = JsonRpcNotification::new(method, params);
634            let _ = tx.send(notification);
635        }
636    }
637
638    /// Send tools list changed notification
639    fn notify_tools_changed(&self) {
640        self.send_notification("notifications/tools/list_changed", None);
641    }
642
643    /// Send logging message notification (visible to LLM)
644    fn notify_message(&self, level: &str, logger: &str, message: &str) {
645        self.send_notification(
646            "notifications/message",
647            Some(serde_json::json!({
648                "level": level,
649                "logger": logger,
650                "data": message
651            })),
652        );
653    }
654
655    fn ensure_tool_state(&self, name: &str) -> (Arc<RwLock<DynamicToolLifecycle>>, bool) {
656        if let Some(entry) = self.tool_states.get(name) {
657            return (Arc::clone(entry.value()), false);
658        }
659
660        if !self.tool_policies.contains_key(name) {
661            self.tool_policies
662                .insert(name.to_string(), ToolPolicy::default());
663        }
664
665        let machine = DynamicToolLifecycle::new(ToolLifecycleContext::new());
666        let entry = Arc::new(RwLock::new(machine));
667        self.tool_states.insert(name.to_string(), Arc::clone(&entry));
668        (entry, true)
669    }
670
671    fn set_tool_state_initial(&self, name: &str, target: ToolLifecycleState) {
672        let event = match target {
673            ToolLifecycleState::Enabled => return,
674            ToolLifecycleState::Disabled => ToolLifecycleEvent::Disable,
675            ToolLifecycleState::Defective => ToolLifecycleEvent::MarkDefective,
676        };
677        let _ = self.transition_tool_state(name, event, false);
678    }
679
680    fn transition_tool_state(
681        &self,
682        name: &str,
683        event: ToolLifecycleEvent,
684        notify: bool,
685    ) -> Option<(ToolLifecycleState, ToolLifecycleState)> {
686        let (entry, _) = self.ensure_tool_state(name);
687        let mut guard = entry.write().ok()?;
688        let before = ToolLifecycleState::from_state_name(guard.current_state())?;
689        let _ = guard.handle(event);
690        let after = ToolLifecycleState::from_state_name(guard.current_state())?;
691        drop(guard);
692
693        if notify
694            && before != after
695            && (before == ToolLifecycleState::Enabled
696                || after == ToolLifecycleState::Enabled)
697        {
698            self.notify_tools_changed();
699        }
700
701        Some((before, after))
702    }
703
704    fn set_tool_state_internal(
705        &self,
706        name: &str,
707        target: ToolLifecycleState,
708        visited: &mut HashSet<String>,
709    ) -> bool {
710        let _ = self.ensure_tool_state(name);
711        let current = self
712            .tool_state(name)
713            .unwrap_or(ToolLifecycleState::Enabled);
714
715        if current == target {
716            return false;
717        }
718
719        let event = match (current, target) {
720            (ToolLifecycleState::Enabled, ToolLifecycleState::Disabled) => {
721                ToolLifecycleEvent::Disable
722            }
723            (ToolLifecycleState::Enabled, ToolLifecycleState::Defective) => {
724                ToolLifecycleEvent::MarkDefective
725            }
726            (ToolLifecycleState::Disabled, ToolLifecycleState::Enabled) => {
727                ToolLifecycleEvent::Enable
728            }
729            (ToolLifecycleState::Disabled, ToolLifecycleState::Defective) => {
730                ToolLifecycleEvent::MarkDefective
731            }
732            (ToolLifecycleState::Defective, ToolLifecycleState::Enabled) => {
733                ToolLifecycleEvent::Recover
734            }
735            (ToolLifecycleState::Defective, ToolLifecycleState::Disabled) => {
736                ToolLifecycleEvent::Disable
737            }
738            _ => return false,
739        };
740
741        let change = self.transition_tool_state(name, event, true);
742        let changed = matches!(change, Some((prev, next)) if prev != next);
743        let became_enabled = matches!(
744            change,
745            Some((prev, ToolLifecycleState::Enabled))
746                if prev != ToolLifecycleState::Enabled
747        );
748
749        if became_enabled {
750            if !visited.insert(name.to_string()) {
751                return true;
752            }
753            if let Some(policy) = self.tool_policies.get(name).map(|p| p.clone()) {
754                self.apply_enable_policy(name, &policy, visited);
755            }
756        }
757
758        changed
759    }
760
761    fn apply_enable_policy(
762        &self,
763        name: &str,
764        policy: &ToolPolicy,
765        visited: &mut HashSet<String>,
766    ) {
767        if let Some(group) = policy.exclusive_group.as_deref() {
768            let others: Vec<String> = self
769                .tool_policies
770                .iter()
771                .filter(|entry| {
772                    entry.key() != name
773                        && entry.value().exclusive_group.as_deref() == Some(group)
774                })
775                .map(|entry| entry.key().clone())
776                .collect();
777
778            for other in others {
779                if self.tool_state(&other) == Some(ToolLifecycleState::Enabled) {
780                    self.set_tool_state_internal(&other, ToolLifecycleState::Disabled, visited);
781                }
782            }
783        }
784
785        for target in &policy.activates {
786            self.set_tool_state_internal(target, ToolLifecycleState::Enabled, visited);
787        }
788    }
789
790    /// Call a tool by name with circuit breaker protection
791    ///
792    /// # Arguments
793    ///
794    /// * `name` - Tool name to call
795    /// * `params` - Tool parameters
796    /// * `session` - Current session
797    /// * `client_requester` - Optional client requester for server→client requests
798    pub async fn call(
799        &self,
800        name: &str,
801        params: Value,
802        session: &Session,
803        logger: &crate::logging::McpLogger,
804        client_requester: Option<ClientRequester>,
805    ) -> Result<ToolOutput, ToolError> {
806        let tool = self
807            .get(name)
808            .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
809
810        if !self.is_tool_enabled(name) {
811            return Err(ToolError::NotFound(name.to_string()));
812        }
813
814        let breaker = self.breakers.get(name).ok_or_else(|| {
815            ToolError::Internal(format!("No circuit breaker for tool '{}'", name))
816        })?;
817
818        // Check circuit state before execution
819        let was_open = {
820            let breaker_guard = breaker
821                .read()
822                .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
823            breaker_guard.is_open()
824        };
825
826        if was_open {
827            return Err(ToolError::CircuitOpen {
828                tool: name.to_string(),
829                message: "Too many recent failures. Service temporarily unavailable.".to_string(),
830            });
831        }
832
833        // Create execution context with optional client requester
834        let ctx = match client_requester {
835            Some(cr) => ExecutionContext::new(params, session, logger).with_client_requester(cr),
836            None => ExecutionContext::new(params, session, logger),
837        };
838
839        // Execute the tool with timing
840        let start = std::time::Instant::now();
841        let result = tool.execute(ctx).await;
842        let duration_secs = start.elapsed().as_secs_f64();
843
844        // Update circuit breaker based on result
845        let breaker_guard = breaker
846            .write()
847            .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
848
849        let was_closed_before = !breaker_guard.is_open();
850
851        match &result {
852            Ok(_) => {
853                breaker_guard.record_success(duration_secs);
854                // Check if we recovered from open state
855                if was_open && !breaker_guard.is_open() {
856                    let _ = self.recover_tool(name);
857                    self.notify_message(
858                        "info",
859                        "breaker-machines",
860                        &format!("Tool '{}' recovered and available", name),
861                    );
862                }
863            }
864            Err(_) => {
865                breaker_guard.record_failure(duration_secs);
866                // Check if we just opened the circuit
867                if was_closed_before && breaker_guard.is_open() {
868                    let _ = self.mark_tool_defective(name);
869                    self.notify_message(
870                        "warning",
871                        "breaker-machines",
872                        &format!(
873                            "Tool '{}' disabled: circuit breaker open after failures",
874                            name
875                        ),
876                    );
877                }
878            }
879        }
880
881        result
882    }
883
884    // ==================== Session-Aware Methods ====================
885
886    /// List tools visible to a specific session
887    ///
888    /// Resolution order:
889    /// 1. Session overrides (replace global tools with session-specific implementations)
890    /// 2. Session extras (additional tools added to session)
891    /// 3. Global tools (filtered by hidden list and visibility predicate)
892    pub fn list_for_session(
893        &self,
894        session: &Session,
895        ctx: &VisibilityContext<'_>,
896    ) -> Vec<ToolInfo> {
897        let mut tools = std::collections::HashMap::new();
898
899        // 1. Add global tools (filtered by hidden and visibility)
900        for entry in self.tools.iter() {
901            let name = entry.key().clone();
902            if !session.is_tool_hidden(&name)
903                && !self.is_circuit_open(&name)
904                && self.is_tool_enabled(&name)
905            {
906                let tool = entry.value();
907                if tool.is_visible(ctx) {
908                    tools.insert(
909                        name,
910                        ToolInfo {
911                            name: tool.name().to_string(),
912                            title: tool.title().map(|s| s.to_string()),
913                            description: tool.description().map(|s| s.to_string()),
914                            input_schema: tool.input_schema(),
915                            output_schema: tool.output_schema(),
916                            execution: tool.execution(),
917                            annotations: tool.annotations(),
918                        },
919                    );
920                }
921            }
922        }
923
924        // 2. Add session extras (can add new tools or override global)
925        for entry in session.tool_extras().iter() {
926            let name = entry.key().clone();
927            let tool = entry.value();
928            if tool.is_visible(ctx) && self.is_tool_enabled(&name) {
929                tools.insert(
930                    name,
931                    ToolInfo {
932                        name: tool.name().to_string(),
933                        title: tool.title().map(|s| s.to_string()),
934                        description: tool.description().map(|s| s.to_string()),
935                        input_schema: tool.input_schema(),
936                        output_schema: tool.output_schema(),
937                        execution: tool.execution(),
938                        annotations: tool.annotations(),
939                    },
940                );
941            }
942        }
943
944        // 3. Apply session overrides (replace implementations)
945        for entry in session.tool_overrides().iter() {
946            let name = entry.key().clone();
947            let tool = entry.value();
948            if tool.is_visible(ctx) && self.is_tool_enabled(&name) {
949                tools.insert(
950                    name,
951                    ToolInfo {
952                        name: tool.name().to_string(),
953                        title: tool.title().map(|s| s.to_string()),
954                        description: tool.description().map(|s| s.to_string()),
955                        input_schema: tool.input_schema(),
956                        output_schema: tool.output_schema(),
957                        execution: tool.execution(),
958                        annotations: tool.annotations(),
959                    },
960                );
961            }
962        }
963
964        tools.into_values().collect()
965    }
966
967    /// Call a tool with session context
968    ///
969    /// Resolution order:
970    /// 1. Resolve alias to actual name
971    /// 2. Check session overrides (polymorphic tools)
972    /// 3. Check session extras (injected tools)
973    /// 4. Check session hidden (filtered out)
974    /// 5. Check visibility predicate (contextual)
975    /// 6. Fall back to global registry with circuit breaker
976    pub async fn call_for_session(
977        &self,
978        name: &str,
979        params: Value,
980        session: &Session,
981        logger: &crate::logging::McpLogger,
982        visibility_ctx: &VisibilityContext<'_>,
983        client_requester: Option<ClientRequester>,
984    ) -> Result<ToolOutput, ToolError> {
985        // 1. Resolve alias
986        let resolved_name = session.resolve_tool_alias(name);
987        let resolved = resolved_name.as_ref();
988
989        if !self.is_tool_enabled(resolved) {
990            return Err(ToolError::NotFound(name.to_string()));
991        }
992
993        // Create execution context (reuse environment from visibility context)
994        // Clone client_requester so we can use it both for exec_ctx and for nested call()
995        let exec_ctx = match (visibility_ctx.environment, client_requester.as_ref()) {
996            (Some(env), Some(cr)) => {
997                ExecutionContext::with_environment(params.clone(), session, logger, env)
998                    .with_client_requester(cr.clone())
999            }
1000            (Some(env), None) => {
1001                ExecutionContext::with_environment(params.clone(), session, logger, env)
1002            }
1003            (None, Some(cr)) => ExecutionContext::new(params.clone(), session, logger)
1004                .with_client_requester(cr.clone()),
1005            (None, None) => ExecutionContext::new(params.clone(), session, logger),
1006        };
1007
1008        // 2. Check session override first
1009        if let Some(tool) = session.get_tool_override(resolved) {
1010            if !tool.is_visible(visibility_ctx) {
1011                return Err(ToolError::NotFound(name.to_string()));
1012            }
1013            return tool.execute(exec_ctx).await;
1014        }
1015
1016        // 3. Check session extras
1017        if let Some(tool) = session.get_tool_extra(resolved) {
1018            if !tool.is_visible(visibility_ctx) {
1019                return Err(ToolError::NotFound(name.to_string()));
1020            }
1021            return tool.execute(exec_ctx).await;
1022        }
1023
1024        // 4. Check if hidden in session
1025        if session.is_tool_hidden(resolved) {
1026            return Err(ToolError::NotFound(name.to_string()));
1027        }
1028
1029        // 5. Check global registry with visibility check
1030        let tool = self
1031            .get(resolved)
1032            .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
1033
1034        if !tool.is_visible(visibility_ctx) {
1035            return Err(ToolError::NotFound(name.to_string()));
1036        }
1037
1038        // 6. Execute with circuit breaker (use resolved name for breaker tracking)
1039        self.call(resolved, params, session, logger, client_requester)
1040            .await
1041    }
1042
1043    /// Get number of registered tools
1044    pub fn len(&self) -> usize {
1045        self.tools.len()
1046    }
1047
1048    /// Check if registry is empty
1049    pub fn is_empty(&self) -> bool {
1050        self.tools.is_empty()
1051    }
1052}
1053
1054impl Default for ToolRegistry {
1055    fn default() -> Self {
1056        Self::new()
1057    }
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062    use super::*;
1063
1064    // Example tool for testing
1065    struct EchoTool;
1066
1067    #[async_trait]
1068    impl Tool for EchoTool {
1069        fn name(&self) -> &str {
1070            "echo"
1071        }
1072
1073        fn description(&self) -> Option<&str> {
1074            Some("Echoes back the input message")
1075        }
1076
1077        fn input_schema(&self) -> Value {
1078            serde_json::json!({
1079                "type": "object",
1080                "properties": {
1081                    "message": {
1082                        "type": "string",
1083                        "description": "Message to echo"
1084                    }
1085                },
1086                "required": ["message"]
1087            })
1088        }
1089
1090        async fn execute(
1091            &self,
1092            ctx: ExecutionContext<'_>,
1093        ) -> Result<ToolOutput, ToolError> {
1094            let message = ctx
1095                .params
1096                .get("message")
1097                .and_then(|v| v.as_str())
1098                .ok_or_else(|| {
1099                    ToolError::InvalidArguments("Missing 'message' field".to_string())
1100                })?;
1101
1102            Ok(ToolOutput::text(format!("Echo: {}", message)))
1103        }
1104    }
1105
1106    struct ToolA;
1107    struct ToolB;
1108
1109    #[async_trait]
1110    impl Tool for ToolA {
1111        fn name(&self) -> &str {
1112            "tool_a"
1113        }
1114
1115        fn input_schema(&self) -> Value {
1116            serde_json::json!({ "type": "object" })
1117        }
1118
1119        async fn execute(
1120            &self,
1121            _ctx: ExecutionContext<'_>,
1122        ) -> Result<ToolOutput, ToolError> {
1123            Ok(ToolOutput::text("A"))
1124        }
1125    }
1126
1127    #[async_trait]
1128    impl Tool for ToolB {
1129        fn name(&self) -> &str {
1130            "tool_b"
1131        }
1132
1133        fn input_schema(&self) -> Value {
1134            serde_json::json!({ "type": "object" })
1135        }
1136
1137        async fn execute(
1138            &self,
1139            _ctx: ExecutionContext<'_>,
1140        ) -> Result<ToolOutput, ToolError> {
1141            Ok(ToolOutput::text("B"))
1142        }
1143    }
1144
1145    #[test]
1146    fn test_registry_creation() {
1147        let registry = ToolRegistry::new();
1148        assert!(registry.is_empty());
1149    }
1150
1151    #[test]
1152    fn test_tool_registration() {
1153        let registry = ToolRegistry::new();
1154        registry.register(EchoTool);
1155
1156        assert_eq!(registry.len(), 1);
1157        assert!(!registry.is_empty());
1158    }
1159
1160    #[test]
1161    fn test_get_tool() {
1162        let registry = ToolRegistry::new();
1163        registry.register(EchoTool);
1164
1165        let tool = registry.get("echo");
1166        assert!(tool.is_some());
1167        assert_eq!(tool.unwrap().name(), "echo");
1168
1169        let missing = registry.get("nonexistent");
1170        assert!(missing.is_none());
1171    }
1172
1173    #[test]
1174    fn test_list_tools() {
1175        let registry = ToolRegistry::new();
1176        registry.register(EchoTool);
1177
1178        let tools = registry.list();
1179        assert_eq!(tools.len(), 1);
1180        assert_eq!(tools[0].name, "echo");
1181        assert_eq!(
1182            tools[0].description,
1183            Some("Echoes back the input message".to_string())
1184        );
1185    }
1186
1187    #[tokio::test]
1188    async fn test_call_tool() {
1189        let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1190        let logger = crate::logging::McpLogger::new(_tx, "test");
1191        let registry = ToolRegistry::new();
1192        registry.register(EchoTool);
1193        let session = Session::new();
1194
1195        let params = serde_json::json!({
1196            "message": "Hello, world!"
1197        });
1198
1199        let result = registry
1200            .call("echo", params, &session, &logger, None)
1201            .await
1202            .unwrap();
1203        assert!(result.is_content());
1204    }
1205
1206    #[tokio::test]
1207    async fn test_call_missing_tool() {
1208        let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1209        let logger = crate::logging::McpLogger::new(_tx, "test");
1210        let registry = ToolRegistry::new();
1211        let session = Session::new();
1212
1213        let params = serde_json::json!({});
1214        let result = registry
1215            .call("nonexistent", params, &session, &logger, None)
1216            .await;
1217
1218        assert!(matches!(result, Err(ToolError::NotFound(_))));
1219    }
1220
1221    #[tokio::test]
1222    async fn test_tool_invalid_arguments() {
1223        let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1224        let logger = crate::logging::McpLogger::new(_tx, "test");
1225        let registry = ToolRegistry::new();
1226        registry.register(EchoTool);
1227        let session = Session::new();
1228
1229        let params = serde_json::json!({}); // Missing required 'message' field
1230
1231        let result = registry.call("echo", params, &session, &logger, None).await;
1232        assert!(matches!(result, Err(ToolError::InvalidArguments(_))));
1233    }
1234
1235    #[tokio::test]
1236    async fn test_disabled_tool_not_callable() {
1237        let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1238        let logger = crate::logging::McpLogger::new(_tx, "test");
1239        let registry = ToolRegistry::new();
1240        registry.register(EchoTool);
1241        registry.disable_tool("echo");
1242        let session = Session::new();
1243
1244        let params = serde_json::json!({
1245            "message": "Hello, world!"
1246        });
1247
1248        let result = registry.call("echo", params, &session, &logger, None).await;
1249        assert!(matches!(result, Err(ToolError::NotFound(_))));
1250    }
1251
1252    #[test]
1253    fn test_exclusive_group_disables_other() {
1254        let registry = ToolRegistry::new();
1255        registry.register_with_policy(
1256            ToolA,
1257            ToolPolicy::new().exclusive_group("exclusive"),
1258        );
1259        registry.register_with_policy(
1260            ToolB,
1261            ToolPolicy::new().exclusive_group("exclusive"),
1262        );
1263
1264        registry.enable_tool("tool_a");
1265        registry.enable_tool("tool_b");
1266
1267        assert_eq!(
1268            registry.tool_state("tool_a"),
1269            Some(ToolLifecycleState::Disabled)
1270        );
1271        assert_eq!(
1272            registry.tool_state("tool_b"),
1273            Some(ToolLifecycleState::Enabled)
1274        );
1275    }
1276}