1use std::sync::{Arc, RwLock};
6
7use async_trait::async_trait;
8use breaker_machines::{CircuitBreaker, Config as BreakerConfig};
9use dashmap::DashMap;
10use serde::{Deserialize, Serialize};
11use serde_json::Value;
12use thiserror::Error;
13use tokio::sync::mpsc;
14
15use crate::content::types::{Content, ImageContent, TextContent};
16use crate::server::multiplexer::ClientRequester;
17use crate::server::session::Session;
18use crate::server::visibility::{ExecutionContext, VisibilityContext};
19use crate::transport::traits::JsonRpcNotification;
20
21#[derive(Debug, Error)]
23pub enum ToolError {
24 #[error("Tool not found: {0}")]
26 NotFound(String),
27
28 #[error("Invalid arguments: {0}")]
30 InvalidArguments(String),
31
32 #[error("Execution error: {0}")]
34 Execution(String),
35
36 #[error("Internal error: {0}")]
38 Internal(String),
39
40 #[error("Circuit breaker open for tool '{tool}': {message}")]
42 CircuitOpen { tool: String, message: String },
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct ToolInfo {
48 pub name: String,
50
51 #[serde(skip_serializing_if = "Option::is_none")]
53 pub description: Option<String>,
54
55 #[serde(rename = "inputSchema")]
57 pub input_schema: Value,
58
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub execution: Option<crate::protocol::types::ToolExecution>,
62}
63
64#[async_trait]
66pub trait Tool: Send + Sync {
67 fn name(&self) -> &str;
69
70 fn description(&self) -> Option<&str> {
72 None
73 }
74
75 fn input_schema(&self) -> Value;
77
78 fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
83 None
84 }
85
86 fn is_visible(&self, _ctx: &VisibilityContext) -> bool {
102 true
103 }
104
105 async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<Vec<Box<dyn Content>>, ToolError>;
127}
128
129pub trait ToolHelpers {
134 fn text(&self, content: &str) -> Box<dyn Content> {
136 Box::new(TextContent::new(content))
137 }
138
139 fn image(&self, data: &str, mime_type: &str) -> Box<dyn Content> {
141 Box::new(ImageContent::new(data, mime_type))
142 }
143}
144
145impl<T: Tool + ?Sized> ToolHelpers for T {}
147
148#[derive(Debug, Clone)]
150pub struct ToolBreakerConfig {
151 pub failure_threshold: usize,
153 pub failure_window_secs: f64,
155 pub half_open_timeout_secs: f64,
157 pub success_threshold: usize,
159}
160
161impl Default for ToolBreakerConfig {
162 fn default() -> Self {
163 Self {
164 failure_threshold: 5,
165 failure_window_secs: 60.0,
166 half_open_timeout_secs: 30.0,
167 success_threshold: 2,
168 }
169 }
170}
171
172impl From<ToolBreakerConfig> for BreakerConfig {
173 fn from(cfg: ToolBreakerConfig) -> Self {
174 BreakerConfig {
175 failure_threshold: Some(cfg.failure_threshold),
176 failure_rate_threshold: None,
177 minimum_calls: 1,
178 failure_window_secs: cfg.failure_window_secs,
179 half_open_timeout_secs: cfg.half_open_timeout_secs,
180 success_threshold: cfg.success_threshold,
181 jitter_factor: 0.1,
182 }
183 }
184}
185
186#[derive(Clone)]
188pub struct ToolRegistry {
189 tools: Arc<DashMap<String, Arc<dyn Tool>>>,
190 breakers: Arc<DashMap<String, Arc<RwLock<CircuitBreaker>>>>,
191 breaker_config: Arc<RwLock<ToolBreakerConfig>>,
192 notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
193}
194
195impl ToolRegistry {
196 pub fn new() -> Self {
198 Self {
199 tools: Arc::new(DashMap::new()),
200 breakers: Arc::new(DashMap::new()),
201 breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
202 notification_tx: None,
203 }
204 }
205
206 pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
208 Self {
209 tools: Arc::new(DashMap::new()),
210 breakers: Arc::new(DashMap::new()),
211 breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
212 notification_tx: Some(notification_tx),
213 }
214 }
215
216 pub fn set_notification_tx(&mut self, tx: mpsc::UnboundedSender<JsonRpcNotification>) {
218 self.notification_tx = Some(tx);
219 }
220
221 pub fn set_breaker_config(&self, config: ToolBreakerConfig) {
223 if let Ok(mut cfg) = self.breaker_config.write() {
224 *cfg = config;
225 }
226 }
227
228 pub fn register<T: Tool + 'static>(&self, tool: T) {
230 let name = tool.name().to_string();
231
232 let breaker_config = self
234 .breaker_config
235 .read()
236 .map(|c| c.clone())
237 .unwrap_or_default();
238
239 let breaker = CircuitBreaker::builder(&name)
240 .failure_threshold(breaker_config.failure_threshold)
241 .failure_window_secs(breaker_config.failure_window_secs)
242 .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
243 .success_threshold(breaker_config.success_threshold)
244 .build();
245
246 self.breakers
247 .insert(name.clone(), Arc::new(RwLock::new(breaker)));
248 self.tools.insert(name, Arc::new(tool));
249 }
250
251 pub fn register_boxed(&self, tool: Arc<dyn Tool>) {
253 let name = tool.name().to_string();
254
255 let breaker_config = self
256 .breaker_config
257 .read()
258 .map(|c| c.clone())
259 .unwrap_or_default();
260
261 let breaker = CircuitBreaker::builder(&name)
262 .failure_threshold(breaker_config.failure_threshold)
263 .failure_window_secs(breaker_config.failure_window_secs)
264 .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
265 .success_threshold(breaker_config.success_threshold)
266 .build();
267
268 self.breakers
269 .insert(name.clone(), Arc::new(RwLock::new(breaker)));
270 self.tools.insert(name, tool);
271 }
272
273 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
275 self.tools.get(name).map(|t| Arc::clone(&t))
276 }
277
278 pub fn is_circuit_open(&self, name: &str) -> bool {
280 self.breakers
281 .get(name)
282 .and_then(|b| b.read().ok().map(|breaker| breaker.is_open()))
283 .unwrap_or(false)
284 }
285
286 pub fn list(&self) -> Vec<ToolInfo> {
288 self.tools
289 .iter()
290 .map(|entry| {
291 let tool = entry.value();
292 ToolInfo {
293 name: tool.name().to_string(),
294 description: tool.description().map(|s| s.to_string()),
295 input_schema: tool.input_schema(),
296 execution: tool.execution(),
297 }
298 })
299 .collect()
300 }
301
302 pub fn list_available(&self) -> Vec<ToolInfo> {
304 self.tools
305 .iter()
306 .filter(|entry| !self.is_circuit_open(entry.key()))
307 .map(|entry| {
308 let tool = entry.value();
309 ToolInfo {
310 name: tool.name().to_string(),
311 description: tool.description().map(|s| s.to_string()),
312 input_schema: tool.input_schema(),
313 execution: tool.execution(),
314 }
315 })
316 .collect()
317 }
318
319 fn send_notification(&self, method: &str, params: Option<Value>) {
321 if let Some(tx) = &self.notification_tx {
322 let notification = JsonRpcNotification::new(method, params);
323 let _ = tx.send(notification);
324 }
325 }
326
327 fn notify_tools_changed(&self) {
329 self.send_notification("notifications/tools/list_changed", None);
330 }
331
332 fn notify_message(&self, level: &str, logger: &str, message: &str) {
334 self.send_notification(
335 "notifications/message",
336 Some(serde_json::json!({
337 "level": level,
338 "logger": logger,
339 "data": message
340 })),
341 );
342 }
343
344 pub async fn call(
353 &self,
354 name: &str,
355 params: Value,
356 session: &Session,
357 logger: &crate::logging::McpLogger,
358 client_requester: Option<ClientRequester>,
359 ) -> Result<Vec<Box<dyn Content>>, ToolError> {
360 let tool = self
361 .get(name)
362 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
363
364 let breaker = self.breakers.get(name).ok_or_else(|| {
365 ToolError::Internal(format!("No circuit breaker for tool '{}'", name))
366 })?;
367
368 let was_open = {
370 let breaker_guard = breaker
371 .read()
372 .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
373 breaker_guard.is_open()
374 };
375
376 if was_open {
377 return Err(ToolError::CircuitOpen {
378 tool: name.to_string(),
379 message: "Too many recent failures. Service temporarily unavailable.".to_string(),
380 });
381 }
382
383 let ctx = match client_requester {
385 Some(cr) => ExecutionContext::new(params, session, logger).with_client_requester(cr),
386 None => ExecutionContext::new(params, session, logger),
387 };
388
389 let start = std::time::Instant::now();
391 let result = tool.execute(ctx).await;
392 let duration_secs = start.elapsed().as_secs_f64();
393
394 let breaker_guard = breaker
396 .write()
397 .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
398
399 let was_closed_before = !breaker_guard.is_open();
400
401 match &result {
402 Ok(_) => {
403 breaker_guard.record_success(duration_secs);
404 if was_open && !breaker_guard.is_open() {
406 self.notify_tools_changed();
407 self.notify_message(
408 "info",
409 "breaker-machines",
410 &format!("Tool '{}' recovered and available", name),
411 );
412 }
413 }
414 Err(_) => {
415 breaker_guard.record_failure(duration_secs);
416 if was_closed_before && breaker_guard.is_open() {
418 self.notify_tools_changed();
419 self.notify_message(
420 "warning",
421 "breaker-machines",
422 &format!(
423 "Tool '{}' disabled: circuit breaker open after failures",
424 name
425 ),
426 );
427 }
428 }
429 }
430
431 result
432 }
433
434 pub fn list_for_session(
443 &self,
444 session: &Session,
445 ctx: &VisibilityContext<'_>,
446 ) -> Vec<ToolInfo> {
447 let mut tools = std::collections::HashMap::new();
448
449 for entry in self.tools.iter() {
451 let name = entry.key().clone();
452 if !session.is_tool_hidden(&name) && !self.is_circuit_open(&name) {
453 let tool = entry.value();
454 if tool.is_visible(ctx) {
455 tools.insert(
456 name,
457 ToolInfo {
458 name: tool.name().to_string(),
459 description: tool.description().map(|s| s.to_string()),
460 input_schema: tool.input_schema(),
461 execution: tool.execution(),
462 },
463 );
464 }
465 }
466 }
467
468 for entry in session.tool_extras().iter() {
470 let name = entry.key().clone();
471 let tool = entry.value();
472 if tool.is_visible(ctx) {
473 tools.insert(
474 name,
475 ToolInfo {
476 name: tool.name().to_string(),
477 description: tool.description().map(|s| s.to_string()),
478 input_schema: tool.input_schema(),
479 execution: tool.execution(),
480 },
481 );
482 }
483 }
484
485 for entry in session.tool_overrides().iter() {
487 let name = entry.key().clone();
488 let tool = entry.value();
489 if tool.is_visible(ctx) {
490 tools.insert(
491 name,
492 ToolInfo {
493 name: tool.name().to_string(),
494 description: tool.description().map(|s| s.to_string()),
495 input_schema: tool.input_schema(),
496 execution: tool.execution(),
497 },
498 );
499 }
500 }
501
502 tools.into_values().collect()
503 }
504
505 pub async fn call_for_session(
515 &self,
516 name: &str,
517 params: Value,
518 session: &Session,
519 logger: &crate::logging::McpLogger,
520 visibility_ctx: &VisibilityContext<'_>,
521 client_requester: Option<ClientRequester>,
522 ) -> Result<Vec<Box<dyn Content>>, ToolError> {
523 let resolved_name = session.resolve_tool_alias(name);
525 let resolved = resolved_name.as_ref();
526
527 let exec_ctx = match (visibility_ctx.environment, client_requester.as_ref()) {
530 (Some(env), Some(cr)) => {
531 ExecutionContext::with_environment(params.clone(), session, logger, env)
532 .with_client_requester(cr.clone())
533 }
534 (Some(env), None) => {
535 ExecutionContext::with_environment(params.clone(), session, logger, env)
536 }
537 (None, Some(cr)) => ExecutionContext::new(params.clone(), session, logger)
538 .with_client_requester(cr.clone()),
539 (None, None) => ExecutionContext::new(params.clone(), session, logger),
540 };
541
542 if let Some(tool) = session.get_tool_override(resolved) {
544 if !tool.is_visible(visibility_ctx) {
545 return Err(ToolError::NotFound(name.to_string()));
546 }
547 return tool.execute(exec_ctx).await;
548 }
549
550 if let Some(tool) = session.get_tool_extra(resolved) {
552 if !tool.is_visible(visibility_ctx) {
553 return Err(ToolError::NotFound(name.to_string()));
554 }
555 return tool.execute(exec_ctx).await;
556 }
557
558 if session.is_tool_hidden(resolved) {
560 return Err(ToolError::NotFound(name.to_string()));
561 }
562
563 let tool = self
565 .get(resolved)
566 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
567
568 if !tool.is_visible(visibility_ctx) {
569 return Err(ToolError::NotFound(name.to_string()));
570 }
571
572 self.call(resolved, params, session, logger, client_requester)
574 .await
575 }
576
577 pub fn len(&self) -> usize {
579 self.tools.len()
580 }
581
582 pub fn is_empty(&self) -> bool {
584 self.tools.is_empty()
585 }
586}
587
588impl Default for ToolRegistry {
589 fn default() -> Self {
590 Self::new()
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use crate::content::types::TextContent;
598
599 struct EchoTool;
601
602 #[async_trait]
603 impl Tool for EchoTool {
604 fn name(&self) -> &str {
605 "echo"
606 }
607
608 fn description(&self) -> Option<&str> {
609 Some("Echoes back the input message")
610 }
611
612 fn input_schema(&self) -> Value {
613 serde_json::json!({
614 "type": "object",
615 "properties": {
616 "message": {
617 "type": "string",
618 "description": "Message to echo"
619 }
620 },
621 "required": ["message"]
622 })
623 }
624
625 async fn execute(
626 &self,
627 ctx: ExecutionContext<'_>,
628 ) -> Result<Vec<Box<dyn Content>>, ToolError> {
629 let message = ctx
630 .params
631 .get("message")
632 .and_then(|v| v.as_str())
633 .ok_or_else(|| {
634 ToolError::InvalidArguments("Missing 'message' field".to_string())
635 })?;
636
637 let content = TextContent::new(format!("Echo: {}", message));
638 Ok(vec![Box::new(content)])
639 }
640 }
641
642 #[test]
643 fn test_registry_creation() {
644 let registry = ToolRegistry::new();
645 assert!(registry.is_empty());
646 }
647
648 #[test]
649 fn test_tool_registration() {
650 let registry = ToolRegistry::new();
651 registry.register(EchoTool);
652
653 assert_eq!(registry.len(), 1);
654 assert!(!registry.is_empty());
655 }
656
657 #[test]
658 fn test_get_tool() {
659 let registry = ToolRegistry::new();
660 registry.register(EchoTool);
661
662 let tool = registry.get("echo");
663 assert!(tool.is_some());
664 assert_eq!(tool.unwrap().name(), "echo");
665
666 let missing = registry.get("nonexistent");
667 assert!(missing.is_none());
668 }
669
670 #[test]
671 fn test_list_tools() {
672 let registry = ToolRegistry::new();
673 registry.register(EchoTool);
674
675 let tools = registry.list();
676 assert_eq!(tools.len(), 1);
677 assert_eq!(tools[0].name, "echo");
678 assert_eq!(
679 tools[0].description,
680 Some("Echoes back the input message".to_string())
681 );
682 }
683
684 #[tokio::test]
685 async fn test_call_tool() {
686 let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
687 let logger = crate::logging::McpLogger::new(_tx, "test");
688 let registry = ToolRegistry::new();
689 registry.register(EchoTool);
690 let session = Session::new();
691
692 let params = serde_json::json!({
693 "message": "Hello, world!"
694 });
695
696 let result = registry
697 .call("echo", params, &session, &logger, None)
698 .await
699 .unwrap();
700 assert_eq!(result.len(), 1);
701 }
702
703 #[tokio::test]
704 async fn test_call_missing_tool() {
705 let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
706 let logger = crate::logging::McpLogger::new(_tx, "test");
707 let registry = ToolRegistry::new();
708 let session = Session::new();
709
710 let params = serde_json::json!({});
711 let result = registry
712 .call("nonexistent", params, &session, &logger, None)
713 .await;
714
715 assert!(matches!(result, Err(ToolError::NotFound(_))));
716 }
717
718 #[tokio::test]
719 async fn test_tool_invalid_arguments() {
720 let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
721 let logger = crate::logging::McpLogger::new(_tx, "test");
722 let registry = ToolRegistry::new();
723 registry.register(EchoTool);
724 let session = Session::new();
725
726 let params = serde_json::json!({}); let result = registry.call("echo", params, &session, &logger, None).await;
729 assert!(matches!(result, Err(ToolError::InvalidArguments(_))));
730 }
731}