mcp_host/registry/
tools.rs

1//! Tool registry for MCP servers
2//!
3//! Provides registration and execution of MCP tools with circuit breaker protection
4
5use std::sync::{Arc, RwLock};
6
7use async_trait::async_trait;
8use breaker_machines::{CircuitBreaker, Config as BreakerConfig};
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::mpsc;
14
15use crate::content::types::{Content, TextContent, ImageContent};
16use crate::server::multiplexer::ClientRequester;
17use crate::server::session::Session;
18use crate::server::visibility::{ExecutionContext, VisibilityContext};
19use crate::transport::traits::JsonRpcNotification;
20
21/// Tool execution errors
22#[derive(Debug, Error)]
23pub enum ToolError {
24    /// Tool not found
25    #[error("Tool not found: {0}")]
26    NotFound(String),
27
28    /// Invalid arguments
29    #[error("Invalid arguments: {0}")]
30    InvalidArguments(String),
31
32    /// Execution error
33    #[error("Execution error: {0}")]
34    Execution(String),
35
36    /// Internal error
37    #[error("Internal error: {0}")]
38    Internal(String),
39
40    /// Circuit breaker is open
41    #[error("Circuit breaker open for tool '{tool}': {message}")]
42    CircuitOpen { tool: String, message: String },
43}
44
45/// Tool metadata for listing
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolInfo {
48    /// Tool name
49    pub name: String,
50
51    /// Tool description
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub description: Option<String>,
54
55    /// Input schema (JSON Schema)
56    #[serde(rename = "inputSchema")]
57    pub input_schema: Value,
58
59    /// Execution metadata (task support, etc.)
60    #[serde(skip_serializing_if = "Option::is_none")]
61    pub execution: Option<crate::protocol::types::ToolExecution>,
62}
63
64/// Tool trait for implementing MCP tools
65#[async_trait]
66pub trait Tool: Send + Sync {
67    /// Get tool name
68    fn name(&self) -> &str;
69
70    /// Get tool description
71    fn description(&self) -> Option<&str> {
72        None
73    }
74
75    /// Get input schema (JSON Schema)
76    fn input_schema(&self) -> Value;
77
78    /// Get execution metadata (task support, etc.)
79    ///
80    /// Override this to declare task support for this tool.
81    /// Returns None by default (forbidden task execution).
82    fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
83        None
84    }
85
86    /// Check if this tool should be visible in the given context
87    ///
88    /// Override this to implement contextual visibility. The default implementation
89    /// always returns true (always visible).
90    ///
91    /// # Example
92    ///
93    /// ```rust,ignore
94    /// fn is_visible(&self, ctx: &VisibilityContext) -> bool {
95    ///     // Only show git commit tool if repo has uncommitted changes
96    ///     ctx.environment
97    ///         .map(|e| e.has_git_repo() && !e.git_is_clean())
98    ///         .unwrap_or(false)
99    /// }
100    /// ```
101    fn is_visible(&self, _ctx: &VisibilityContext) -> bool {
102        true
103    }
104
105    /// Execute the tool with execution context
106    ///
107    /// The execution context provides access to:
108    /// - `ctx.params`: Input parameters from the tool call
109    /// - `ctx.session`: Current session (roles, state, client info)
110    /// - `ctx.environment`: Optional environment state
111    ///
112    /// # Example
113    ///
114    /// ```rust,ignore
115    /// async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
116    ///     let name = ctx.params.get("name").and_then(|v| v.as_str());
117    ///
118    ///     // Check session roles
119    ///     if ctx.is_admin() {
120    ///         // Admin-only behavior
121    ///     }
122    ///
123    ///     Ok(vec![])
124    /// }
125    /// ```
126    async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError>;
127}
128
129/// Helper methods for tool implementations
130///
131/// Provides ergonomic shortcuts for creating common content types.
132/// Auto-implemented for all Tool trait objects.
133pub trait ToolHelpers {
134    /// Create text content
135    fn text(&self, content: &str) -> Box<dyn Content> {
136        Box::new(TextContent::new(content))
137    }
138
139    /// Create image content
140    fn image(&self, data: &str, mime_type: &str) -> Box<dyn Content> {
141        Box::new(ImageContent::new(data, mime_type))
142    }
143}
144
145// Auto-implement ToolHelpers for all Tool trait objects
146impl<T: Tool + ?Sized> ToolHelpers for T {}
147
148/// Circuit breaker configuration for tools
149#[derive(Debug, Clone)]
150pub struct ToolBreakerConfig {
151    /// Number of failures before opening circuit
152    pub failure_threshold: usize,
153    /// Time window in seconds for counting failures
154    pub failure_window_secs: f64,
155    /// Timeout before transitioning to half-open
156    pub half_open_timeout_secs: f64,
157    /// Successes needed to close circuit
158    pub success_threshold: usize,
159}
160
161impl Default for ToolBreakerConfig {
162    fn default() -> Self {
163        Self {
164            failure_threshold: 5,
165            failure_window_secs: 60.0,
166            half_open_timeout_secs: 30.0,
167            success_threshold: 2,
168        }
169    }
170}
171
172impl From<ToolBreakerConfig> for BreakerConfig {
173    fn from(cfg: ToolBreakerConfig) -> Self {
174        BreakerConfig {
175            failure_threshold: Some(cfg.failure_threshold),
176            failure_rate_threshold: None,
177            minimum_calls: 1,
178            failure_window_secs: cfg.failure_window_secs,
179            half_open_timeout_secs: cfg.half_open_timeout_secs,
180            success_threshold: cfg.success_threshold,
181            jitter_factor: 0.1,
182        }
183    }
184}
185
186/// Tool registry for managing available tools with circuit breaker protection
187#[derive(Clone)]
188pub struct ToolRegistry {
189    tools: Arc<DashMap<String, Arc<dyn Tool>>>,
190    breakers: Arc<DashMap<String, Arc<RwLock<CircuitBreaker>>>>,
191    breaker_config: Arc<RwLock<ToolBreakerConfig>>,
192    notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
193}
194
195impl ToolRegistry {
196    /// Create new tool registry
197    pub fn new() -> Self {
198        Self {
199            tools: Arc::new(DashMap::new()),
200            breakers: Arc::new(DashMap::new()),
201            breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
202            notification_tx: None,
203        }
204    }
205
206    /// Create tool registry with notification channel
207    pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
208        Self {
209            tools: Arc::new(DashMap::new()),
210            breakers: Arc::new(DashMap::new()),
211            breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
212            notification_tx: Some(notification_tx),
213        }
214    }
215
216    /// Set notification channel
217    pub fn set_notification_tx(&mut self, tx: mpsc::UnboundedSender<JsonRpcNotification>) {
218        self.notification_tx = Some(tx);
219    }
220
221    /// Configure circuit breakers for all tools
222    pub fn set_breaker_config(&self, config: ToolBreakerConfig) {
223        if let Ok(mut cfg) = self.breaker_config.write() {
224            *cfg = config;
225        }
226    }
227
228    /// Register a tool with circuit breaker protection
229    pub fn register<T: Tool + 'static>(&self, tool: T) {
230        let name = tool.name().to_string();
231
232        // Create circuit breaker for this tool
233        let breaker_config = self.breaker_config.read()
234            .map(|c| c.clone())
235            .unwrap_or_default();
236
237        let breaker = CircuitBreaker::builder(&name)
238            .failure_threshold(breaker_config.failure_threshold)
239            .failure_window_secs(breaker_config.failure_window_secs)
240            .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
241            .success_threshold(breaker_config.success_threshold)
242            .build();
243
244        self.breakers.insert(name.clone(), Arc::new(RwLock::new(breaker)));
245        self.tools.insert(name, Arc::new(tool));
246    }
247
248    /// Register a boxed tool with circuit breaker
249    pub fn register_boxed(&self, tool: Arc<dyn Tool>) {
250        let name = tool.name().to_string();
251
252        let breaker_config = self.breaker_config.read()
253            .map(|c| c.clone())
254            .unwrap_or_default();
255
256        let breaker = CircuitBreaker::builder(&name)
257            .failure_threshold(breaker_config.failure_threshold)
258            .failure_window_secs(breaker_config.failure_window_secs)
259            .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
260            .success_threshold(breaker_config.success_threshold)
261            .build();
262
263        self.breakers.insert(name.clone(), Arc::new(RwLock::new(breaker)));
264        self.tools.insert(name, tool);
265    }
266
267    /// Get a tool by name
268    pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
269        self.tools.get(name).map(|t| Arc::clone(&t))
270    }
271
272    /// Check if a tool's circuit breaker is open
273    pub fn is_circuit_open(&self, name: &str) -> bool {
274        self.breakers
275            .get(name)
276            .and_then(|b| b.read().ok().map(|breaker| breaker.is_open()))
277            .unwrap_or(false)
278    }
279
280    /// List all registered tools (excluding those with open circuits if filtered)
281    pub fn list(&self) -> Vec<ToolInfo> {
282        self.tools
283            .iter()
284            .map(|entry| {
285                let tool = entry.value();
286                ToolInfo {
287                    name: tool.name().to_string(),
288                    description: tool.description().map(|s| s.to_string()),
289                    input_schema: tool.input_schema(),
290                    execution: tool.execution(),
291                }
292            })
293            .collect()
294    }
295
296    /// List only available tools (circuit closed or half-open)
297    pub fn list_available(&self) -> Vec<ToolInfo> {
298        self.tools
299            .iter()
300            .filter(|entry| !self.is_circuit_open(entry.key()))
301            .map(|entry| {
302                let tool = entry.value();
303                ToolInfo {
304                    name: tool.name().to_string(),
305                    description: tool.description().map(|s| s.to_string()),
306                    input_schema: tool.input_schema(),
307                    execution: tool.execution(),
308                }
309            })
310            .collect()
311    }
312
313    /// Send notification about circuit state change
314    fn send_notification(&self, method: &str, params: Option<Value>) {
315        if let Some(tx) = &self.notification_tx {
316            let notification = JsonRpcNotification::new(method, params);
317            let _ = tx.send(notification);
318        }
319    }
320
321    /// Send tools list changed notification
322    fn notify_tools_changed(&self) {
323        self.send_notification("notifications/tools/list_changed", None);
324    }
325
326    /// Send logging message notification (visible to LLM)
327    fn notify_message(&self, level: &str, logger: &str, message: &str) {
328        self.send_notification(
329            "notifications/message",
330            Some(serde_json::json!({
331                "level": level,
332                "logger": logger,
333                "data": message
334            })),
335        );
336    }
337
338    /// Call a tool by name with circuit breaker protection
339    ///
340    /// # Arguments
341    ///
342    /// * `name` - Tool name to call
343    /// * `params` - Tool parameters
344    /// * `session` - Current session
345    /// * `client_requester` - Optional client requester for server→client requests
346    pub async fn call(
347        &self,
348        name: &str,
349        params: Value,
350        session: &Session,
351        client_requester: Option<ClientRequester>,
352    ) -> Result<Vec<Box<dyn Content>>, ToolError> {
353        let tool = self
354            .get(name)
355            .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
356
357        let breaker = self
358            .breakers
359            .get(name)
360            .ok_or_else(|| ToolError::Internal(format!("No circuit breaker for tool '{}'", name)))?;
361
362        // Check circuit state before execution
363        let was_open = {
364            let breaker_guard = breaker.read()
365                .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
366            breaker_guard.is_open()
367        };
368
369        if was_open {
370            return Err(ToolError::CircuitOpen {
371                tool: name.to_string(),
372                message: "Too many recent failures. Service temporarily unavailable.".to_string(),
373            });
374        }
375
376        // Create execution context with optional client requester
377        let ctx = match client_requester {
378            Some(cr) => ExecutionContext::new(params, session).with_client_requester(cr),
379            None => ExecutionContext::new(params, session),
380        };
381
382        // Execute the tool with timing
383        let start = std::time::Instant::now();
384        let result = tool.execute(ctx).await;
385        let duration_secs = start.elapsed().as_secs_f64();
386
387        // Update circuit breaker based on result
388        let breaker_guard = breaker.write()
389            .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
390
391        let was_closed_before = !breaker_guard.is_open();
392
393        match &result {
394            Ok(_) => {
395                breaker_guard.record_success(duration_secs);
396                // Check if we recovered from open state
397                if was_open && !breaker_guard.is_open() {
398                    self.notify_tools_changed();
399                    self.notify_message(
400                        "info",
401                        "breaker-machines",
402                        &format!("Tool '{}' recovered and available", name),
403                    );
404                }
405            }
406            Err(_) => {
407                breaker_guard.record_failure(duration_secs);
408                // Check if we just opened the circuit
409                if was_closed_before && breaker_guard.is_open() {
410                    self.notify_tools_changed();
411                    self.notify_message(
412                        "warning",
413                        "breaker-machines",
414                        &format!("Tool '{}' disabled: circuit breaker open after failures", name),
415                    );
416                }
417            }
418        }
419
420        result
421    }
422
423    // ==================== Session-Aware Methods ====================
424
425    /// List tools visible to a specific session
426    ///
427    /// Resolution order:
428    /// 1. Session overrides (replace global tools with session-specific implementations)
429    /// 2. Session extras (additional tools added to session)
430    /// 3. Global tools (filtered by hidden list and visibility predicate)
431    pub fn list_for_session(&self, session: &Session, ctx: &VisibilityContext<'_>) -> Vec<ToolInfo> {
432        let mut tools = std::collections::HashMap::new();
433
434        // 1. Add global tools (filtered by hidden and visibility)
435        for entry in self.tools.iter() {
436            let name = entry.key().clone();
437            if !session.is_tool_hidden(&name) && !self.is_circuit_open(&name) {
438                let tool = entry.value();
439                if tool.is_visible(ctx) {
440                    tools.insert(
441                        name,
442                        ToolInfo {
443                            name: tool.name().to_string(),
444                            description: tool.description().map(|s| s.to_string()),
445                            input_schema: tool.input_schema(),
446                            execution: tool.execution(),
447                        },
448                    );
449                }
450            }
451        }
452
453        // 2. Add session extras (can add new tools or override global)
454        for entry in session.tool_extras().iter() {
455            let name = entry.key().clone();
456            let tool = entry.value();
457            if tool.is_visible(ctx) {
458                tools.insert(
459                    name,
460                    ToolInfo {
461                        name: tool.name().to_string(),
462                        description: tool.description().map(|s| s.to_string(),),
463                        input_schema: tool.input_schema(),
464                        execution: tool.execution(),
465                    },
466                );
467            }
468        }
469
470        // 3. Apply session overrides (replace implementations)
471        for entry in session.tool_overrides().iter() {
472            let name = entry.key().clone();
473            let tool = entry.value();
474            if tool.is_visible(ctx) {
475                tools.insert(
476                    name,
477                    ToolInfo {
478                        name: tool.name().to_string(),
479                        description: tool.description().map(|s| s.to_string(),),
480                        input_schema: tool.input_schema(),
481                        execution: tool.execution(),
482                    },
483                );
484            }
485        }
486
487        tools.into_values().collect()
488    }
489
490    /// Call a tool with session context
491    ///
492    /// Resolution order:
493    /// 1. Resolve alias to actual name
494    /// 2. Check session overrides (polymorphic tools)
495    /// 3. Check session extras (injected tools)
496    /// 4. Check session hidden (filtered out)
497    /// 5. Check visibility predicate (contextual)
498    /// 6. Fall back to global registry with circuit breaker
499    pub async fn call_for_session(
500        &self,
501        name: &str,
502        params: Value,
503        session: &Session,
504        visibility_ctx: &VisibilityContext<'_>,
505        client_requester: Option<ClientRequester>,
506    ) -> Result<Vec<Box<dyn Content>>, ToolError> {
507        // 1. Resolve alias
508        let resolved_name = session.resolve_tool_alias(name);
509        let resolved = resolved_name.as_ref();
510
511        // Create execution context (reuse environment from visibility context)
512        // Clone client_requester so we can use it both for exec_ctx and for nested call()
513        let exec_ctx = match (visibility_ctx.environment, client_requester.as_ref()) {
514            (Some(env), Some(cr)) => ExecutionContext::with_environment(params.clone(), session, env)
515                .with_client_requester(cr.clone()),
516            (Some(env), None) => ExecutionContext::with_environment(params.clone(), session, env),
517            (None, Some(cr)) => ExecutionContext::new(params.clone(), session)
518                .with_client_requester(cr.clone()),
519            (None, None) => ExecutionContext::new(params.clone(), session),
520        };
521
522        // 2. Check session override first
523        if let Some(tool) = session.get_tool_override(resolved) {
524            if !tool.is_visible(visibility_ctx) {
525                return Err(ToolError::NotFound(name.to_string()));
526            }
527            return tool.execute(exec_ctx).await;
528        }
529
530        // 3. Check session extras
531        if let Some(tool) = session.get_tool_extra(resolved) {
532            if !tool.is_visible(visibility_ctx) {
533                return Err(ToolError::NotFound(name.to_string()));
534            }
535            return tool.execute(exec_ctx).await;
536        }
537
538        // 4. Check if hidden in session
539        if session.is_tool_hidden(resolved) {
540            return Err(ToolError::NotFound(name.to_string()));
541        }
542
543        // 5. Check global registry with visibility check
544        let tool = self
545            .get(resolved)
546            .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
547
548        if !tool.is_visible(visibility_ctx) {
549            return Err(ToolError::NotFound(name.to_string()));
550        }
551
552        // 6. Execute with circuit breaker (use resolved name for breaker tracking)
553        self.call(resolved, params, session, client_requester).await
554    }
555
556    /// Get number of registered tools
557    pub fn len(&self) -> usize {
558        self.tools.len()
559    }
560
561    /// Check if registry is empty
562    pub fn is_empty(&self) -> bool {
563        self.tools.is_empty()
564    }
565}
566
567impl Default for ToolRegistry {
568    fn default() -> Self {
569        Self::new()
570    }
571}
572
573#[cfg(test)]
574mod tests {
575    use super::*;
576    use crate::content::types::TextContent;
577
578    // Example tool for testing
579    struct EchoTool;
580
581    #[async_trait]
582    impl Tool for EchoTool {
583        fn name(&self) -> &str {
584            "echo"
585        }
586
587        fn description(&self) -> Option<&str> {
588            Some("Echoes back the input message")
589        }
590
591        fn input_schema(&self) -> Value {
592            serde_json::json!({
593                "type": "object",
594                "properties": {
595                    "message": {
596                        "type": "string",
597                        "description": "Message to echo"
598                    }
599                },
600                "required": ["message"]
601            })
602        }
603
604        async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError> {
605            let message = ctx.params
606                .get("message")
607                .and_then(|v| v.as_str())
608                .ok_or_else(|| ToolError::InvalidArguments("Missing 'message' field".to_string()))?;
609
610            let content = TextContent::new(format!("Echo: {}", message));
611            Ok(vec![Box::new(content)])
612        }
613    }
614
615    #[test]
616    fn test_registry_creation() {
617        let registry = ToolRegistry::new();
618        assert!(registry.is_empty());
619    }
620
621    #[test]
622    fn test_tool_registration() {
623        let registry = ToolRegistry::new();
624        registry.register(EchoTool);
625
626        assert_eq!(registry.len(), 1);
627        assert!(!registry.is_empty());
628    }
629
630    #[test]
631    fn test_get_tool() {
632        let registry = ToolRegistry::new();
633        registry.register(EchoTool);
634
635        let tool = registry.get("echo");
636        assert!(tool.is_some());
637        assert_eq!(tool.unwrap().name(), "echo");
638
639        let missing = registry.get("nonexistent");
640        assert!(missing.is_none());
641    }
642
643    #[test]
644    fn test_list_tools() {
645        let registry = ToolRegistry::new();
646        registry.register(EchoTool);
647
648        let tools = registry.list();
649        assert_eq!(tools.len(), 1);
650        assert_eq!(tools[0].name, "echo");
651        assert_eq!(tools[0].description, Some("Echoes back the input message".to_string()));
652    }
653
654    #[tokio::test]
655    async fn test_call_tool() {
656        let registry = ToolRegistry::new();
657        registry.register(EchoTool);
658        let session = Session::new();
659
660        let params = serde_json::json!({
661            "message": "Hello, world!"
662        });
663
664        let result = registry.call("echo", params, &session, None).await.unwrap();
665        assert_eq!(result.len(), 1);
666    }
667
668    #[tokio::test]
669    async fn test_call_missing_tool() {
670        let registry = ToolRegistry::new();
671        let session = Session::new();
672
673        let params = serde_json::json!({});
674        let result = registry.call("nonexistent", params, &session, None).await;
675
676        assert!(matches!(result, Err(ToolError::NotFound(_))));
677    }
678
679    #[tokio::test]
680    async fn test_tool_invalid_arguments() {
681        let registry = ToolRegistry::new();
682        registry.register(EchoTool);
683        let session = Session::new();
684
685        let params = serde_json::json!({}); // Missing required 'message' field
686
687        let result = registry.call("echo", params, &session, None).await;
688        assert!(matches!(result, Err(ToolError::InvalidArguments(_))));
689    }
690}