mcp_host/registry/
tools.rs

1//! Tool registry for MCP servers
2//!
3//! Provides registration and execution of MCP tools with circuit breaker protection
4
5use std::sync::{Arc, RwLock};
6
7use async_trait::async_trait;
8use breaker_machines::{CircuitBreaker, Config as BreakerConfig};
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::mpsc;
14
15use crate::content::types::{Content, ImageContent, TextContent};
16use crate::server::multiplexer::ClientRequester;
17use crate::server::session::Session;
18use crate::server::visibility::{ExecutionContext, VisibilityContext};
19use crate::transport::traits::JsonRpcNotification;
20
21/// Tool execution errors
22#[derive(Debug, Error)]
23pub enum ToolError {
24    /// Tool not found
25    #[error("Tool not found: {0}")]
26    NotFound(String),
27
28    /// Invalid arguments
29    #[error("Invalid arguments: {0}")]
30    InvalidArguments(String),
31
32    /// Execution error
33    #[error("Execution error: {0}")]
34    Execution(String),
35
36    /// Internal error
37    #[error("Internal error: {0}")]
38    Internal(String),
39
40    /// Circuit breaker is open
41    #[error("Circuit breaker open for tool '{tool}': {message}")]
42    CircuitOpen { tool: String, message: String },
43}
44
45/// Tool metadata for listing
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolInfo {
48    /// Tool name
49    pub name: String,
50
51    /// Tool description
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub description: Option<String>,
54
55    /// Input schema (JSON Schema)
56    #[serde(rename = "inputSchema")]
57    pub input_schema: Value,
58
59    /// Execution metadata (task support, etc.)
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub execution: Option<crate::protocol::types::ToolExecution>,
62}
63
64/// Tool trait for implementing MCP tools
65#[async_trait]
66pub trait Tool: Send + Sync {
67    /// Get tool name
68    fn name(&self) -> &str;
69
70    /// Get tool description
71    fn description(&self) -> Option<&str> {
72        None
73    }
74
75    /// Get input schema (JSON Schema)
76    fn input_schema(&self) -> Value;
77
78    /// Get execution metadata (task support, etc.)
79    ///
80    /// Override this to declare task support for this tool.
81    /// Returns None by default (forbidden task execution).
82    fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
83        None
84    }
85
86    /// Check if this tool should be visible in the given context
87    ///
88    /// Override this to implement contextual visibility. The default implementation
89    /// always returns true (always visible).
90    ///
91    /// # Example
92    ///
93    /// ```rust,ignore
94    /// fn is_visible(&self, ctx: &VisibilityContext) -> bool {
95    ///     // Only show git commit tool if repo has uncommitted changes
96    ///     ctx.environment
97    ///         .map(|e| e.has_git_repo() && !e.git_is_clean())
98    ///         .unwrap_or(false)
99    /// }
100    /// ```
101    fn is_visible(&self, _ctx: &VisibilityContext) -> bool {
102        true
103    }
104
105    /// Execute the tool with execution context
106    ///
107    /// The execution context provides access to:
108    /// - `ctx.params`: Input parameters from the tool call
109    /// - `ctx.session`: Current session (roles, state, client info)
110    /// - `ctx.environment`: Optional environment state
111    ///
112    /// # Example
113    ///
114    /// ```rust,ignore
115    /// async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
116    ///     let name = ctx.params.get("name").and_then(|v| v.as_str());
117    ///
118    ///     // Check session roles
119    ///     if ctx.is_admin() {
120    ///         // Admin-only behavior
121    ///     }
122    ///
123    ///     Ok(vec![])
124    /// }
125    /// ```
126    async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError>;
127}
128
129/// Helper methods for tool implementations
130///
131/// Provides ergonomic shortcuts for creating common content types.
132/// Auto-implemented for all Tool trait objects.
133pub trait ToolHelpers {
134    /// Create text content
135    fn text(&self, content: &str) -> Box<dyn Content> {
136        Box::new(TextContent::new(content))
137    }
138
139    /// Create image content
140    fn image(&self, data: &str, mime_type: &str) -> Box<dyn Content> {
141        Box::new(ImageContent::new(data, mime_type))
142    }
143}
144
145// Auto-implement ToolHelpers for all Tool trait objects
146impl<T: Tool + ?Sized> ToolHelpers for T {}
147
148/// Circuit breaker configuration for tools
149#[derive(Debug, Clone)]
150pub struct ToolBreakerConfig {
151    /// Number of failures before opening circuit
152    pub failure_threshold: usize,
153    /// Time window in seconds for counting failures
154    pub failure_window_secs: f64,
155    /// Timeout before transitioning to half-open
156    pub half_open_timeout_secs: f64,
157    /// Successes needed to close circuit
158    pub success_threshold: usize,
159}
160
161impl Default for ToolBreakerConfig {
162    fn default() -> Self {
163        Self {
164            failure_threshold: 5,
165            failure_window_secs: 60.0,
166            half_open_timeout_secs: 30.0,
167            success_threshold: 2,
168        }
169    }
170}
171
172impl From<ToolBreakerConfig> for BreakerConfig {
173    fn from(cfg: ToolBreakerConfig) -> Self {
174        BreakerConfig {
175            failure_threshold: Some(cfg.failure_threshold),
176            failure_rate_threshold: None,
177            minimum_calls: 1,
178            failure_window_secs: cfg.failure_window_secs,
179            half_open_timeout_secs: cfg.half_open_timeout_secs,
180            success_threshold: cfg.success_threshold,
181            jitter_factor: 0.1,
182        }
183    }
184}
185
186/// Tool registry for managing available tools with circuit breaker protection
187#[derive(Clone)]
188pub struct ToolRegistry {
189    tools: Arc<DashMap<String, Arc<dyn Tool>>>,
190    breakers: Arc<DashMap<String, Arc<RwLock<CircuitBreaker>>>>,
191    breaker_config: Arc<RwLock<ToolBreakerConfig>>,
192    notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
193}
194
195impl ToolRegistry {
196    /// Create new tool registry
197    pub fn new() -> Self {
198        Self {
199            tools: Arc::new(DashMap::new()),
200            breakers: Arc::new(DashMap::new()),
201            breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
202            notification_tx: None,
203        }
204    }
205
206    /// Create tool registry with notification channel
207    pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
208        Self {
209            tools: Arc::new(DashMap::new()),
210            breakers: Arc::new(DashMap::new()),
211            breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
212            notification_tx: Some(notification_tx),
213        }
214    }
215
216    /// Set notification channel
217    pub fn set_notification_tx(&mut self, tx: mpsc::UnboundedSender<JsonRpcNotification>) {
218        self.notification_tx = Some(tx);
219    }
220
221    /// Configure circuit breakers for all tools
222    pub fn set_breaker_config(&self, config: ToolBreakerConfig) {
223        if let Ok(mut cfg) = self.breaker_config.write() {
224            *cfg = config;
225        }
226    }
227
228    /// Register a tool with circuit breaker protection
229    pub fn register<T: Tool + 'static>(&self, tool: T) {
230        let name = tool.name().to_string();
231
232        // Create circuit breaker for this tool
233        let breaker_config = self
234            .breaker_config
235            .read()
236            .map(|c| c.clone())
237            .unwrap_or_default();
238
239        let breaker = CircuitBreaker::builder(&name)
240            .failure_threshold(breaker_config.failure_threshold)
241            .failure_window_secs(breaker_config.failure_window_secs)
242            .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
243            .success_threshold(breaker_config.success_threshold)
244            .build();
245
246        self.breakers
247            .insert(name.clone(), Arc::new(RwLock::new(breaker)));
248        self.tools.insert(name, Arc::new(tool));
249    }
250
251    /// Register a boxed tool with circuit breaker
252    pub fn register_boxed(&self, tool: Arc<dyn Tool>) {
253        let name = tool.name().to_string();
254
255        let breaker_config = self
256            .breaker_config
257            .read()
258            .map(|c| c.clone())
259            .unwrap_or_default();
260
261        let breaker = CircuitBreaker::builder(&name)
262            .failure_threshold(breaker_config.failure_threshold)
263            .failure_window_secs(breaker_config.failure_window_secs)
264            .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
265            .success_threshold(breaker_config.success_threshold)
266            .build();
267
268        self.breakers
269            .insert(name.clone(), Arc::new(RwLock::new(breaker)));
270        self.tools.insert(name, tool);
271    }
272
273    /// Get a tool by name
274    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
275        self.tools.get(name).map(|t| Arc::clone(&t))
276    }
277
278    /// Check if a tool's circuit breaker is open
279    pub fn is_circuit_open(&self, name: &str) -> bool {
280        self.breakers
281            .get(name)
282            .and_then(|b| b.read().ok().map(|breaker| breaker.is_open()))
283            .unwrap_or(false)
284    }
285
286    /// List all registered tools (excluding those with open circuits if filtered)
287    pub fn list(&self) -> Vec<ToolInfo> {
288        self.tools
289            .iter()
290            .map(|entry| {
291                let tool = entry.value();
292                ToolInfo {
293                    name: tool.name().to_string(),
294                    description: tool.description().map(|s| s.to_string()),
295                    input_schema: tool.input_schema(),
296                    execution: tool.execution(),
297                }
298            })
299            .collect()
300    }
301
302    /// List only available tools (circuit closed or half-open)
303    pub fn list_available(&self) -> Vec<ToolInfo> {
304        self.tools
305            .iter()
306            .filter(|entry| !self.is_circuit_open(entry.key()))
307            .map(|entry| {
308                let tool = entry.value();
309                ToolInfo {
310                    name: tool.name().to_string(),
311                    description: tool.description().map(|s| s.to_string()),
312                    input_schema: tool.input_schema(),
313                    execution: tool.execution(),
314                }
315            })
316            .collect()
317    }
318
319    /// Send notification about circuit state change
320    fn send_notification(&self, method: &str, params: Option<Value>) {
321        if let Some(tx) = &self.notification_tx {
322            let notification = JsonRpcNotification::new(method, params);
323            let _ = tx.send(notification);
324        }
325    }
326
327    /// Send tools list changed notification
328    fn notify_tools_changed(&self) {
329        self.send_notification("notifications/tools/list_changed", None);
330    }
331
332    /// Send logging message notification (visible to LLM)
333    fn notify_message(&self, level: &str, logger: &str, message: &str) {
334        self.send_notification(
335            "notifications/message",
336            Some(serde_json::json!({
337                "level": level,
338                "logger": logger,
339                "data": message
340            })),
341        );
342    }
343
344    /// Call a tool by name with circuit breaker protection
345    ///
346    /// # Arguments
347    ///
348    /// * `name` - Tool name to call
349    /// * `params` - Tool parameters
350    /// * `session` - Current session
351    /// * `client_requester` - Optional client requester for server→client requests
352    pub async fn call(
353        &self,
354        name: &str,
355        params: Value,
356        session: &Session,
357        logger: &crate::logging::McpLogger,
358        client_requester: Option<ClientRequester>,
359    ) -> Result<Vec<Box<dyn Content>>, ToolError> {
360        let tool = self
361            .get(name)
362            .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
363
364        let breaker = self.breakers.get(name).ok_or_else(|| {
365            ToolError::Internal(format!("No circuit breaker for tool '{}'", name))
366        })?;
367
368        // Check circuit state before execution
369        let was_open = {
370            let breaker_guard = breaker
371                .read()
372                .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
373            breaker_guard.is_open()
374        };
375
376        if was_open {
377            return Err(ToolError::CircuitOpen {
378                tool: name.to_string(),
379                message: "Too many recent failures. Service temporarily unavailable.".to_string(),
380            });
381        }
382
383        // Create execution context with optional client requester
384        let ctx = match client_requester {
385            Some(cr) => ExecutionContext::new(params, session, logger).with_client_requester(cr),
386            None => ExecutionContext::new(params, session, logger),
387        };
388
389        // Execute the tool with timing
390        let start = std::time::Instant::now();
391        let result = tool.execute(ctx).await;
392        let duration_secs = start.elapsed().as_secs_f64();
393
394        // Update circuit breaker based on result
395        let breaker_guard = breaker
396            .write()
397            .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
398
399        let was_closed_before = !breaker_guard.is_open();
400
401        match &result {
402            Ok(_) => {
403                breaker_guard.record_success(duration_secs);
404                // Check if we recovered from open state
405                if was_open && !breaker_guard.is_open() {
406                    self.notify_tools_changed();
407                    self.notify_message(
408                        "info",
409                        "breaker-machines",
410                        &format!("Tool '{}' recovered and available", name),
411                    );
412                }
413            }
414            Err(_) => {
415                breaker_guard.record_failure(duration_secs);
416                // Check if we just opened the circuit
417                if was_closed_before && breaker_guard.is_open() {
418                    self.notify_tools_changed();
419                    self.notify_message(
420                        "warning",
421                        "breaker-machines",
422                        &format!(
423                            "Tool '{}' disabled: circuit breaker open after failures",
424                            name
425                        ),
426                    );
427                }
428            }
429        }
430
431        result
432    }
433
434    // ==================== Session-Aware Methods ====================
435
436    /// List tools visible to a specific session
437    ///
438    /// Resolution order:
439    /// 1. Session overrides (replace global tools with session-specific implementations)
440    /// 2. Session extras (additional tools added to session)
441    /// 3. Global tools (filtered by hidden list and visibility predicate)
442    pub fn list_for_session(
443        &self,
444        session: &Session,
445        ctx: &VisibilityContext<'_>,
446    ) -> Vec<ToolInfo> {
447        let mut tools = std::collections::HashMap::new();
448
449        // 1. Add global tools (filtered by hidden and visibility)
450        for entry in self.tools.iter() {
451            let name = entry.key().clone();
452            if !session.is_tool_hidden(&name) && !self.is_circuit_open(&name) {
453                let tool = entry.value();
454                if tool.is_visible(ctx) {
455                    tools.insert(
456                        name,
457                        ToolInfo {
458                            name: tool.name().to_string(),
459                            description: tool.description().map(|s| s.to_string()),
460                            input_schema: tool.input_schema(),
461                            execution: tool.execution(),
462                        },
463                    );
464                }
465            }
466        }
467
468        // 2. Add session extras (can add new tools or override global)
469        for entry in session.tool_extras().iter() {
470            let name = entry.key().clone();
471            let tool = entry.value();
472            if tool.is_visible(ctx) {
473                tools.insert(
474                    name,
475                    ToolInfo {
476                        name: tool.name().to_string(),
477                        description: tool.description().map(|s| s.to_string()),
478                        input_schema: tool.input_schema(),
479                        execution: tool.execution(),
480                    },
481                );
482            }
483        }
484
485        // 3. Apply session overrides (replace implementations)
486        for entry in session.tool_overrides().iter() {
487            let name = entry.key().clone();
488            let tool = entry.value();
489            if tool.is_visible(ctx) {
490                tools.insert(
491                    name,
492                    ToolInfo {
493                        name: tool.name().to_string(),
494                        description: tool.description().map(|s| s.to_string()),
495                        input_schema: tool.input_schema(),
496                        execution: tool.execution(),
497                    },
498                );
499            }
500        }
501
502        tools.into_values().collect()
503    }
504
505    /// Call a tool with session context
506    ///
507    /// Resolution order:
508    /// 1. Resolve alias to actual name
509    /// 2. Check session overrides (polymorphic tools)
510    /// 3. Check session extras (injected tools)
511    /// 4. Check session hidden (filtered out)
512    /// 5. Check visibility predicate (contextual)
513    /// 6. Fall back to global registry with circuit breaker
514    pub async fn call_for_session(
515        &self,
516        name: &str,
517        params: Value,
518        session: &Session,
519        logger: &crate::logging::McpLogger,
520        visibility_ctx: &VisibilityContext<'_>,
521        client_requester: Option<ClientRequester>,
522    ) -> Result<Vec<Box<dyn Content>>, ToolError> {
523        // 1. Resolve alias
524        let resolved_name = session.resolve_tool_alias(name);
525        let resolved = resolved_name.as_ref();
526
527        // Create execution context (reuse environment from visibility context)
528        // Clone client_requester so we can use it both for exec_ctx and for nested call()
529        let exec_ctx = match (visibility_ctx.environment, client_requester.as_ref()) {
530            (Some(env), Some(cr)) => {
531                ExecutionContext::with_environment(params.clone(), session, logger, env)
532                    .with_client_requester(cr.clone())
533            }
534            (Some(env), None) => {
535                ExecutionContext::with_environment(params.clone(), session, logger, env)
536            }
537            (None, Some(cr)) => ExecutionContext::new(params.clone(), session, logger)
538                .with_client_requester(cr.clone()),
539            (None, None) => ExecutionContext::new(params.clone(), session, logger),
540        };
541
542        // 2. Check session override first
543        if let Some(tool) = session.get_tool_override(resolved) {
544            if !tool.is_visible(visibility_ctx) {
545                return Err(ToolError::NotFound(name.to_string()));
546            }
547            return tool.execute(exec_ctx).await;
548        }
549
550        // 3. Check session extras
551        if let Some(tool) = session.get_tool_extra(resolved) {
552            if !tool.is_visible(visibility_ctx) {
553                return Err(ToolError::NotFound(name.to_string()));
554            }
555            return tool.execute(exec_ctx).await;
556        }
557
558        // 4. Check if hidden in session
559        if session.is_tool_hidden(resolved) {
560            return Err(ToolError::NotFound(name.to_string()));
561        }
562
563        // 5. Check global registry with visibility check
564        let tool = self
565            .get(resolved)
566            .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
567
568        if !tool.is_visible(visibility_ctx) {
569            return Err(ToolError::NotFound(name.to_string()));
570        }
571
572        // 6. Execute with circuit breaker (use resolved name for breaker tracking)
573        self.call(resolved, params, session, logger, client_requester)
574            .await
575    }
576
577    /// Get number of registered tools
578    pub fn len(&self) -> usize {
579        self.tools.len()
580    }
581
582    /// Check if registry is empty
583    pub fn is_empty(&self) -> bool {
584        self.tools.is_empty()
585    }
586}
587
588impl Default for ToolRegistry {
589    fn default() -> Self {
590        Self::new()
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use crate::content::types::TextContent;
598
599    // Example tool for testing
600    struct EchoTool;
601
602    #[async_trait]
603    impl Tool for EchoTool {
604        fn name(&self) -> &str {
605            "echo"
606        }
607
608        fn description(&self) -> Option<&str> {
609            Some("Echoes back the input message")
610        }
611
612        fn input_schema(&self) -> Value {
613            serde_json::json!({
614                "type": "object",
615                "properties": {
616                    "message": {
617                        "type": "string",
618                        "description": "Message to echo"
619                    }
620                },
621                "required": ["message"]
622            })
623        }
624
625        async fn execute(
626            &self,
627            ctx: ExecutionContext<'_>,
628        ) -> Result<Vec<Box<dyn Content>>, ToolError> {
629            let message = ctx
630                .params
631                .get("message")
632                .and_then(|v| v.as_str())
633                .ok_or_else(|| {
634                    ToolError::InvalidArguments("Missing 'message' field".to_string())
635                })?;
636
637            let content = TextContent::new(format!("Echo: {}", message));
638            Ok(vec![Box::new(content)])
639        }
640    }
641
642    #[test]
643    fn test_registry_creation() {
644        let registry = ToolRegistry::new();
645        assert!(registry.is_empty());
646    }
647
648    #[test]
649    fn test_tool_registration() {
650        let registry = ToolRegistry::new();
651        registry.register(EchoTool);
652
653        assert_eq!(registry.len(), 1);
654        assert!(!registry.is_empty());
655    }
656
657    #[test]
658    fn test_get_tool() {
659        let registry = ToolRegistry::new();
660        registry.register(EchoTool);
661
662        let tool = registry.get("echo");
663        assert!(tool.is_some());
664        assert_eq!(tool.unwrap().name(), "echo");
665
666        let missing = registry.get("nonexistent");
667        assert!(missing.is_none());
668    }
669
670    #[test]
671    fn test_list_tools() {
672        let registry = ToolRegistry::new();
673        registry.register(EchoTool);
674
675        let tools = registry.list();
676        assert_eq!(tools.len(), 1);
677        assert_eq!(tools[0].name, "echo");
678        assert_eq!(
679            tools[0].description,
680            Some("Echoes back the input message".to_string())
681        );
682    }
683
684    #[tokio::test]
685    async fn test_call_tool() {
686        let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
687        let logger = crate::logging::McpLogger::new(_tx, "test");
688        let registry = ToolRegistry::new();
689        registry.register(EchoTool);
690        let session = Session::new();
691
692        let params = serde_json::json!({
693            "message": "Hello, world!"
694        });
695
696        let result = registry
697            .call("echo", params, &session, &logger, None)
698            .await
699            .unwrap();
700        assert_eq!(result.len(), 1);
701    }
702
703    #[tokio::test]
704    async fn test_call_missing_tool() {
705        let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
706        let logger = crate::logging::McpLogger::new(_tx, "test");
707        let registry = ToolRegistry::new();
708        let session = Session::new();
709
710        let params = serde_json::json!({});
711        let result = registry
712            .call("nonexistent", params, &session, &logger, None)
713            .await;
714
715        assert!(matches!(result, Err(ToolError::NotFound(_))));
716    }
717
718    #[tokio::test]
719    async fn test_tool_invalid_arguments() {
720        let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
721        let logger = crate::logging::McpLogger::new(_tx, "test");
722        let registry = ToolRegistry::new();
723        registry.register(EchoTool);
724        let session = Session::new();
725
726        let params = serde_json::json!({}); // Missing required 'message' field
727
728        let result = registry.call("echo", params, &session, &logger, None).await;
729        assert!(matches!(result, Err(ToolError::InvalidArguments(_))));
730    }
731}