1use std::collections::HashSet;
6use std::sync::{Arc, RwLock};
7
8use async_trait::async_trait;
9use breaker_machines::{CircuitBreaker, Config as BreakerConfig};
10use dashmap::DashMap;
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use state_machines::state_machine;
14use thiserror::Error;
15use tokio::sync::mpsc;
16
17use crate::content::types::{Content, ImageContent, TextContent};
18use crate::server::multiplexer::ClientRequester;
19use crate::server::session::Session;
20use crate::server::visibility::{ExecutionContext, VisibilityContext};
21use crate::transport::traits::JsonRpcNotification;
22
23#[derive(Debug, Error)]
25pub enum ToolError {
26 #[error("Tool not found: {0}")]
28 NotFound(String),
29
30 #[error("Invalid arguments: {0}")]
32 InvalidArguments(String),
33
34 #[error("Execution error: {0}")]
36 Execution(String),
37
38 #[error("Internal error: {0}")]
40 Internal(String),
41
42 #[error("Circuit breaker open for tool '{tool}': {message}")]
44 CircuitOpen { tool: String, message: String },
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum ToolLifecycleState {
54 Enabled,
56 Disabled,
58 Defective,
60}
61
62impl ToolLifecycleState {
63 fn from_state_name(name: &str) -> Option<Self> {
64 match name {
65 "Enabled" => Some(Self::Enabled),
66 "Disabled" => Some(Self::Disabled),
67 "Defective" => Some(Self::Defective),
68 _ => None,
69 }
70 }
71}
72
73#[derive(Debug, Clone)]
75pub struct ToolPolicy {
76 pub initial_state: ToolLifecycleState,
78 pub exclusive_group: Option<String>,
80 pub activates: Vec<String>,
82}
83
84impl Default for ToolPolicy {
85 fn default() -> Self {
86 Self {
87 initial_state: ToolLifecycleState::Enabled,
88 exclusive_group: None,
89 activates: Vec::new(),
90 }
91 }
92}
93
94impl ToolPolicy {
95 pub fn new() -> Self {
97 Self::default()
98 }
99
100 pub fn initial_state(mut self, state: ToolLifecycleState) -> Self {
102 self.initial_state = state;
103 self
104 }
105
106 pub fn exclusive_group(mut self, group: impl Into<String>) -> Self {
108 self.exclusive_group = Some(group.into());
109 self
110 }
111
112 pub fn activates<I, S>(mut self, tools: I) -> Self
114 where
115 I: IntoIterator<Item = S>,
116 S: Into<String>,
117 {
118 self.activates = tools.into_iter().map(Into::into).collect();
119 self
120 }
121}
122
123#[derive(Debug, Default)]
124struct ToolLifecycleContext;
125
126impl ToolLifecycleContext {
127 fn new() -> Self {
128 Self
129 }
130}
131
132state_machine! {
133 name: ToolLifecycle,
134 context: ToolLifecycleContext,
135 dynamic: true,
136 initial: Enabled,
137 states: [Enabled, Disabled, Defective],
138 events {
139 enable {
140 transition: { from: [Disabled, Defective], to: Enabled }
141 }
142 disable {
143 transition: { from: [Enabled, Defective], to: Disabled }
144 }
145 mark_defective {
146 transition: { from: [Enabled, Disabled], to: Defective }
147 }
148 recover {
149 transition: { from: Defective, to: Enabled }
150 }
151 }
152}
153
154pub enum ToolOutput {
166 Content(Vec<Box<dyn Content>>),
168 Structured(Value),
170}
171
172impl std::fmt::Debug for ToolOutput {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 match self {
175 Self::Content(items) => f
176 .debug_struct("Content")
177 .field("count", &items.len())
178 .finish(),
179 Self::Structured(value) => f.debug_tuple("Structured").field(value).finish(),
180 }
181 }
182}
183
184impl ToolOutput {
185 pub fn text(s: impl Into<String>) -> Self {
187 Self::Content(vec![Box::new(TextContent::new(s))])
188 }
189
190 pub fn texts(items: &[&str]) -> Self {
192 Self::Content(
193 items
194 .iter()
195 .map(|s| Box::new(TextContent::new(*s)) as Box<dyn Content>)
196 .collect(),
197 )
198 }
199
200 pub fn content(items: Vec<Box<dyn Content>>) -> Self {
202 Self::Content(items)
203 }
204
205 pub fn structured<T: Serialize>(value: T) -> Result<Self, serde_json::Error> {
207 Ok(Self::Structured(serde_json::to_value(value)?))
208 }
209
210 pub fn json(value: Value) -> Self {
212 Self::Structured(value)
213 }
214
215 pub fn is_content(&self) -> bool {
217 matches!(self, Self::Content(_))
218 }
219
220 pub fn is_structured(&self) -> bool {
222 matches!(self, Self::Structured(_))
223 }
224}
225
226#[derive(Debug, Clone, Serialize, Deserialize)]
228pub struct ToolInfo {
229 pub name: String,
231
232 #[serde(skip_serializing_if = "Option::is_none")]
234 pub title: Option<String>,
235
236 #[serde(skip_serializing_if = "Option::is_none")]
238 pub description: Option<String>,
239
240 #[serde(rename = "inputSchema")]
242 pub input_schema: Value,
243
244 #[serde(rename = "outputSchema", skip_serializing_if = "Option::is_none")]
246 pub output_schema: Option<Value>,
247
248 #[serde(skip_serializing_if = "Option::is_none")]
250 pub execution: Option<crate::protocol::types::ToolExecution>,
251
252 #[serde(skip_serializing_if = "Option::is_none")]
254 pub annotations: Option<crate::protocol::types::ToolAnnotations>,
255}
256
257#[async_trait]
259pub trait Tool: Send + Sync {
260 fn name(&self) -> &str;
262
263 fn title(&self) -> Option<&str> {
267 None
268 }
269
270 fn description(&self) -> Option<&str> {
272 None
273 }
274
275 fn input_schema(&self) -> Value;
277
278 fn output_schema(&self) -> Option<Value> {
283 None
284 }
285
286 fn execution(&self) -> Option<crate::protocol::types::ToolExecution> {
291 None
292 }
293
294 fn annotations(&self) -> Option<crate::protocol::types::ToolAnnotations> {
298 None
299 }
300
301 fn is_visible(&self, _ctx: &VisibilityContext) -> bool {
317 true
318 }
319
320 async fn execute(&self, ctx: ExecutionContext<'_>) -> Result<ToolOutput, ToolError>;
347}
348
349pub trait ToolHelpers {
354 fn text(&self, content: &str) -> Box<dyn Content> {
356 Box::new(TextContent::new(content))
357 }
358
359 fn image(&self, data: &str, mime_type: &str) -> Box<dyn Content> {
361 Box::new(ImageContent::new(data, mime_type))
362 }
363}
364
365impl<T: Tool + ?Sized> ToolHelpers for T {}
367
368#[derive(Debug, Clone)]
370pub struct ToolBreakerConfig {
371 pub failure_threshold: usize,
373 pub failure_window_secs: f64,
375 pub half_open_timeout_secs: f64,
377 pub success_threshold: usize,
379}
380
381impl Default for ToolBreakerConfig {
382 fn default() -> Self {
383 Self {
384 failure_threshold: 5,
385 failure_window_secs: 60.0,
386 half_open_timeout_secs: 30.0,
387 success_threshold: 2,
388 }
389 }
390}
391
392impl From<ToolBreakerConfig> for BreakerConfig {
393 fn from(cfg: ToolBreakerConfig) -> Self {
394 BreakerConfig {
395 failure_threshold: Some(cfg.failure_threshold),
396 failure_rate_threshold: None,
397 minimum_calls: 1,
398 failure_window_secs: cfg.failure_window_secs,
399 half_open_timeout_secs: cfg.half_open_timeout_secs,
400 success_threshold: cfg.success_threshold,
401 jitter_factor: 0.1,
402 }
403 }
404}
405
406#[derive(Clone)]
408pub struct ToolRegistry {
409 tools: Arc<DashMap<String, Arc<dyn Tool>>>,
410 breakers: Arc<DashMap<String, Arc<RwLock<CircuitBreaker>>>>,
411 breaker_config: Arc<RwLock<ToolBreakerConfig>>,
412 notification_tx: Option<mpsc::UnboundedSender<JsonRpcNotification>>,
413 tool_states: Arc<DashMap<String, Arc<RwLock<DynamicToolLifecycle>>>>,
414 tool_policies: Arc<DashMap<String, ToolPolicy>>,
415}
416
417impl ToolRegistry {
418 pub fn new() -> Self {
420 Self {
421 tools: Arc::new(DashMap::new()),
422 breakers: Arc::new(DashMap::new()),
423 breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
424 notification_tx: None,
425 tool_states: Arc::new(DashMap::new()),
426 tool_policies: Arc::new(DashMap::new()),
427 }
428 }
429
430 pub fn with_notifications(notification_tx: mpsc::UnboundedSender<JsonRpcNotification>) -> Self {
432 Self {
433 tools: Arc::new(DashMap::new()),
434 breakers: Arc::new(DashMap::new()),
435 breaker_config: Arc::new(RwLock::new(ToolBreakerConfig::default())),
436 notification_tx: Some(notification_tx),
437 tool_states: Arc::new(DashMap::new()),
438 tool_policies: Arc::new(DashMap::new()),
439 }
440 }
441
442 pub fn set_notification_tx(&mut self, tx: mpsc::UnboundedSender<JsonRpcNotification>) {
444 self.notification_tx = Some(tx);
445 }
446
447 pub fn set_breaker_config(&self, config: ToolBreakerConfig) {
449 if let Ok(mut cfg) = self.breaker_config.write() {
450 *cfg = config;
451 }
452 }
453
454 pub fn set_tool_policy(&self, name: impl Into<String>, policy: ToolPolicy) {
458 let name = name.into();
459 self.tool_policies.insert(name.clone(), policy.clone());
460 self.ensure_tool_state(&name);
461
462 let mut visited = HashSet::new();
463 let changed = self.set_tool_state_internal(&name, policy.initial_state, &mut visited);
464 if !changed && self.tool_state(&name) == Some(ToolLifecycleState::Enabled) {
465 visited.insert(name.clone());
466 self.apply_enable_policy(&name, &policy, &mut visited);
467 }
468 }
469
470 pub fn tool_state(&self, name: &str) -> Option<ToolLifecycleState> {
472 self.tool_states.get(name).and_then(|entry| {
473 entry
474 .read()
475 .ok()
476 .and_then(|machine| ToolLifecycleState::from_state_name(machine.current_state()))
477 })
478 }
479
480 pub fn is_tool_enabled(&self, name: &str) -> bool {
484 match self.tool_state(name) {
485 Some(ToolLifecycleState::Enabled) => true,
486 Some(_) => false,
487 None => true,
488 }
489 }
490
491 pub fn enable_tool(&self, name: &str) -> bool {
493 let mut visited = HashSet::new();
494 self.set_tool_state_internal(name, ToolLifecycleState::Enabled, &mut visited)
495 }
496
497 pub fn disable_tool(&self, name: &str) -> bool {
499 let mut visited = HashSet::new();
500 self.set_tool_state_internal(name, ToolLifecycleState::Disabled, &mut visited)
501 }
502
503 pub fn mark_tool_defective(&self, name: &str) -> bool {
505 let mut visited = HashSet::new();
506 self.set_tool_state_internal(name, ToolLifecycleState::Defective, &mut visited)
507 }
508
509 pub fn recover_tool(&self, name: &str) -> bool {
511 if self.tool_state(name) != Some(ToolLifecycleState::Defective) {
512 return false;
513 }
514 let mut visited = HashSet::new();
515 self.set_tool_state_internal(name, ToolLifecycleState::Enabled, &mut visited)
516 }
517
518 pub fn register<T: Tool + 'static>(&self, tool: T) {
520 self.register_boxed(Arc::new(tool));
521 }
522
523 pub fn register_with_policy<T: Tool + 'static>(&self, tool: T, policy: ToolPolicy) {
525 self.register_boxed_with_policy(Arc::new(tool), policy);
526 }
527
528 pub fn register_boxed(&self, tool: Arc<dyn Tool>) {
530 let name = tool.name().to_string();
531 let policy = self
532 .tool_policies
533 .get(&name)
534 .map(|p| p.clone())
535 .unwrap_or_default();
536 self.register_boxed_with_policy(tool, policy);
537 }
538
539 pub fn register_boxed_with_policy(&self, tool: Arc<dyn Tool>, policy: ToolPolicy) {
541 let name = tool.name().to_string();
542
543 let breaker_config = self
545 .breaker_config
546 .read()
547 .map(|c| c.clone())
548 .unwrap_or_default();
549
550 let breaker = CircuitBreaker::builder(&name)
551 .failure_threshold(breaker_config.failure_threshold)
552 .failure_window_secs(breaker_config.failure_window_secs)
553 .half_open_timeout_secs(breaker_config.half_open_timeout_secs)
554 .success_threshold(breaker_config.success_threshold)
555 .build();
556
557 self.breakers
558 .insert(name.clone(), Arc::new(RwLock::new(breaker)));
559 self.tools.insert(name.clone(), tool);
560
561 self.tool_policies.insert(name.clone(), policy.clone());
563
564 let (_, created) = self.ensure_tool_state(&name);
566 if created && policy.initial_state != ToolLifecycleState::Enabled {
567 self.set_tool_state_initial(&name, policy.initial_state);
568 }
569 if self.tool_state(&name) == Some(ToolLifecycleState::Enabled) {
570 let mut visited = HashSet::new();
571 visited.insert(name.clone());
572 self.apply_enable_policy(&name, &policy, &mut visited);
573 }
574 }
575
576 pub fn get(&self, name: &str) -> Option<Arc<dyn Tool>> {
578 self.tools.get(name).map(|t| Arc::clone(&t))
579 }
580
581 pub fn is_circuit_open(&self, name: &str) -> bool {
583 self.breakers
584 .get(name)
585 .and_then(|b| b.read().ok().map(|breaker| breaker.is_open()))
586 .unwrap_or(false)
587 }
588
589 pub fn list(&self) -> Vec<ToolInfo> {
591 self.tools
592 .iter()
593 .map(|entry| {
594 let tool = entry.value();
595 ToolInfo {
596 name: tool.name().to_string(),
597 title: tool.title().map(|s| s.to_string()),
598 description: tool.description().map(|s| s.to_string()),
599 input_schema: tool.input_schema(),
600 output_schema: tool.output_schema(),
601 execution: tool.execution(),
602 annotations: tool.annotations(),
603 }
604 })
605 .collect()
606 }
607
608 pub fn list_available(&self) -> Vec<ToolInfo> {
610 self.tools
611 .iter()
612 .filter(|entry| {
613 !self.is_circuit_open(entry.key()) && self.is_tool_enabled(entry.key())
614 })
615 .map(|entry| {
616 let tool = entry.value();
617 ToolInfo {
618 name: tool.name().to_string(),
619 title: tool.title().map(|s| s.to_string()),
620 description: tool.description().map(|s| s.to_string()),
621 input_schema: tool.input_schema(),
622 output_schema: tool.output_schema(),
623 execution: tool.execution(),
624 annotations: tool.annotations(),
625 }
626 })
627 .collect()
628 }
629
630 fn send_notification(&self, method: &str, params: Option<Value>) {
632 if let Some(tx) = &self.notification_tx {
633 let notification = JsonRpcNotification::new(method, params);
634 let _ = tx.send(notification);
635 }
636 }
637
638 fn notify_tools_changed(&self) {
640 self.send_notification("notifications/tools/list_changed", None);
641 }
642
643 fn notify_message(&self, level: &str, logger: &str, message: &str) {
645 self.send_notification(
646 "notifications/message",
647 Some(serde_json::json!({
648 "level": level,
649 "logger": logger,
650 "data": message
651 })),
652 );
653 }
654
655 fn ensure_tool_state(&self, name: &str) -> (Arc<RwLock<DynamicToolLifecycle>>, bool) {
656 if let Some(entry) = self.tool_states.get(name) {
657 return (Arc::clone(entry.value()), false);
658 }
659
660 if !self.tool_policies.contains_key(name) {
661 self.tool_policies
662 .insert(name.to_string(), ToolPolicy::default());
663 }
664
665 let machine = DynamicToolLifecycle::new(ToolLifecycleContext::new());
666 let entry = Arc::new(RwLock::new(machine));
667 self.tool_states.insert(name.to_string(), Arc::clone(&entry));
668 (entry, true)
669 }
670
671 fn set_tool_state_initial(&self, name: &str, target: ToolLifecycleState) {
672 let event = match target {
673 ToolLifecycleState::Enabled => return,
674 ToolLifecycleState::Disabled => ToolLifecycleEvent::Disable,
675 ToolLifecycleState::Defective => ToolLifecycleEvent::MarkDefective,
676 };
677 let _ = self.transition_tool_state(name, event, false);
678 }
679
680 fn transition_tool_state(
681 &self,
682 name: &str,
683 event: ToolLifecycleEvent,
684 notify: bool,
685 ) -> Option<(ToolLifecycleState, ToolLifecycleState)> {
686 let (entry, _) = self.ensure_tool_state(name);
687 let mut guard = entry.write().ok()?;
688 let before = ToolLifecycleState::from_state_name(guard.current_state())?;
689 let _ = guard.handle(event);
690 let after = ToolLifecycleState::from_state_name(guard.current_state())?;
691 drop(guard);
692
693 if notify
694 && before != after
695 && (before == ToolLifecycleState::Enabled
696 || after == ToolLifecycleState::Enabled)
697 {
698 self.notify_tools_changed();
699 }
700
701 Some((before, after))
702 }
703
704 fn set_tool_state_internal(
705 &self,
706 name: &str,
707 target: ToolLifecycleState,
708 visited: &mut HashSet<String>,
709 ) -> bool {
710 let _ = self.ensure_tool_state(name);
711 let current = self
712 .tool_state(name)
713 .unwrap_or(ToolLifecycleState::Enabled);
714
715 if current == target {
716 return false;
717 }
718
719 let event = match (current, target) {
720 (ToolLifecycleState::Enabled, ToolLifecycleState::Disabled) => {
721 ToolLifecycleEvent::Disable
722 }
723 (ToolLifecycleState::Enabled, ToolLifecycleState::Defective) => {
724 ToolLifecycleEvent::MarkDefective
725 }
726 (ToolLifecycleState::Disabled, ToolLifecycleState::Enabled) => {
727 ToolLifecycleEvent::Enable
728 }
729 (ToolLifecycleState::Disabled, ToolLifecycleState::Defective) => {
730 ToolLifecycleEvent::MarkDefective
731 }
732 (ToolLifecycleState::Defective, ToolLifecycleState::Enabled) => {
733 ToolLifecycleEvent::Recover
734 }
735 (ToolLifecycleState::Defective, ToolLifecycleState::Disabled) => {
736 ToolLifecycleEvent::Disable
737 }
738 _ => return false,
739 };
740
741 let change = self.transition_tool_state(name, event, true);
742 let changed = matches!(change, Some((prev, next)) if prev != next);
743 let became_enabled = matches!(
744 change,
745 Some((prev, ToolLifecycleState::Enabled))
746 if prev != ToolLifecycleState::Enabled
747 );
748
749 if became_enabled {
750 if !visited.insert(name.to_string()) {
751 return true;
752 }
753 if let Some(policy) = self.tool_policies.get(name).map(|p| p.clone()) {
754 self.apply_enable_policy(name, &policy, visited);
755 }
756 }
757
758 changed
759 }
760
761 fn apply_enable_policy(
762 &self,
763 name: &str,
764 policy: &ToolPolicy,
765 visited: &mut HashSet<String>,
766 ) {
767 if let Some(group) = policy.exclusive_group.as_deref() {
768 let others: Vec<String> = self
769 .tool_policies
770 .iter()
771 .filter(|entry| {
772 entry.key() != name
773 && entry.value().exclusive_group.as_deref() == Some(group)
774 })
775 .map(|entry| entry.key().clone())
776 .collect();
777
778 for other in others {
779 if self.tool_state(&other) == Some(ToolLifecycleState::Enabled) {
780 self.set_tool_state_internal(&other, ToolLifecycleState::Disabled, visited);
781 }
782 }
783 }
784
785 for target in &policy.activates {
786 self.set_tool_state_internal(target, ToolLifecycleState::Enabled, visited);
787 }
788 }
789
790 pub async fn call(
799 &self,
800 name: &str,
801 params: Value,
802 session: &Session,
803 logger: &crate::logging::McpLogger,
804 client_requester: Option<ClientRequester>,
805 ) -> Result<ToolOutput, ToolError> {
806 let tool = self
807 .get(name)
808 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
809
810 if !self.is_tool_enabled(name) {
811 return Err(ToolError::NotFound(name.to_string()));
812 }
813
814 let breaker = self.breakers.get(name).ok_or_else(|| {
815 ToolError::Internal(format!("No circuit breaker for tool '{}'", name))
816 })?;
817
818 let was_open = {
820 let breaker_guard = breaker
821 .read()
822 .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
823 breaker_guard.is_open()
824 };
825
826 if was_open {
827 return Err(ToolError::CircuitOpen {
828 tool: name.to_string(),
829 message: "Too many recent failures. Service temporarily unavailable.".to_string(),
830 });
831 }
832
833 let ctx = match client_requester {
835 Some(cr) => ExecutionContext::new(params, session, logger).with_client_requester(cr),
836 None => ExecutionContext::new(params, session, logger),
837 };
838
839 let start = std::time::Instant::now();
841 let result = tool.execute(ctx).await;
842 let duration_secs = start.elapsed().as_secs_f64();
843
844 let breaker_guard = breaker
846 .write()
847 .map_err(|e| ToolError::Internal(format!("Breaker lock error: {}", e)))?;
848
849 let was_closed_before = !breaker_guard.is_open();
850
851 match &result {
852 Ok(_) => {
853 breaker_guard.record_success(duration_secs);
854 if was_open && !breaker_guard.is_open() {
856 let _ = self.recover_tool(name);
857 self.notify_message(
858 "info",
859 "breaker-machines",
860 &format!("Tool '{}' recovered and available", name),
861 );
862 }
863 }
864 Err(_) => {
865 breaker_guard.record_failure(duration_secs);
866 if was_closed_before && breaker_guard.is_open() {
868 let _ = self.mark_tool_defective(name);
869 self.notify_message(
870 "warning",
871 "breaker-machines",
872 &format!(
873 "Tool '{}' disabled: circuit breaker open after failures",
874 name
875 ),
876 );
877 }
878 }
879 }
880
881 result
882 }
883
884 pub fn list_for_session(
893 &self,
894 session: &Session,
895 ctx: &VisibilityContext<'_>,
896 ) -> Vec<ToolInfo> {
897 let mut tools = std::collections::HashMap::new();
898
899 for entry in self.tools.iter() {
901 let name = entry.key().clone();
902 if !session.is_tool_hidden(&name)
903 && !self.is_circuit_open(&name)
904 && self.is_tool_enabled(&name)
905 {
906 let tool = entry.value();
907 if tool.is_visible(ctx) {
908 tools.insert(
909 name,
910 ToolInfo {
911 name: tool.name().to_string(),
912 title: tool.title().map(|s| s.to_string()),
913 description: tool.description().map(|s| s.to_string()),
914 input_schema: tool.input_schema(),
915 output_schema: tool.output_schema(),
916 execution: tool.execution(),
917 annotations: tool.annotations(),
918 },
919 );
920 }
921 }
922 }
923
924 for entry in session.tool_extras().iter() {
926 let name = entry.key().clone();
927 let tool = entry.value();
928 if tool.is_visible(ctx) && self.is_tool_enabled(&name) {
929 tools.insert(
930 name,
931 ToolInfo {
932 name: tool.name().to_string(),
933 title: tool.title().map(|s| s.to_string()),
934 description: tool.description().map(|s| s.to_string()),
935 input_schema: tool.input_schema(),
936 output_schema: tool.output_schema(),
937 execution: tool.execution(),
938 annotations: tool.annotations(),
939 },
940 );
941 }
942 }
943
944 for entry in session.tool_overrides().iter() {
946 let name = entry.key().clone();
947 let tool = entry.value();
948 if tool.is_visible(ctx) && self.is_tool_enabled(&name) {
949 tools.insert(
950 name,
951 ToolInfo {
952 name: tool.name().to_string(),
953 title: tool.title().map(|s| s.to_string()),
954 description: tool.description().map(|s| s.to_string()),
955 input_schema: tool.input_schema(),
956 output_schema: tool.output_schema(),
957 execution: tool.execution(),
958 annotations: tool.annotations(),
959 },
960 );
961 }
962 }
963
964 tools.into_values().collect()
965 }
966
967 pub async fn call_for_session(
977 &self,
978 name: &str,
979 params: Value,
980 session: &Session,
981 logger: &crate::logging::McpLogger,
982 visibility_ctx: &VisibilityContext<'_>,
983 client_requester: Option<ClientRequester>,
984 ) -> Result<ToolOutput, ToolError> {
985 let resolved_name = session.resolve_tool_alias(name);
987 let resolved = resolved_name.as_ref();
988
989 if !self.is_tool_enabled(resolved) {
990 return Err(ToolError::NotFound(name.to_string()));
991 }
992
993 let exec_ctx = match (visibility_ctx.environment, client_requester.as_ref()) {
996 (Some(env), Some(cr)) => {
997 ExecutionContext::with_environment(params.clone(), session, logger, env)
998 .with_client_requester(cr.clone())
999 }
1000 (Some(env), None) => {
1001 ExecutionContext::with_environment(params.clone(), session, logger, env)
1002 }
1003 (None, Some(cr)) => ExecutionContext::new(params.clone(), session, logger)
1004 .with_client_requester(cr.clone()),
1005 (None, None) => ExecutionContext::new(params.clone(), session, logger),
1006 };
1007
1008 if let Some(tool) = session.get_tool_override(resolved) {
1010 if !tool.is_visible(visibility_ctx) {
1011 return Err(ToolError::NotFound(name.to_string()));
1012 }
1013 return tool.execute(exec_ctx).await;
1014 }
1015
1016 if let Some(tool) = session.get_tool_extra(resolved) {
1018 if !tool.is_visible(visibility_ctx) {
1019 return Err(ToolError::NotFound(name.to_string()));
1020 }
1021 return tool.execute(exec_ctx).await;
1022 }
1023
1024 if session.is_tool_hidden(resolved) {
1026 return Err(ToolError::NotFound(name.to_string()));
1027 }
1028
1029 let tool = self
1031 .get(resolved)
1032 .ok_or_else(|| ToolError::NotFound(name.to_string()))?;
1033
1034 if !tool.is_visible(visibility_ctx) {
1035 return Err(ToolError::NotFound(name.to_string()));
1036 }
1037
1038 self.call(resolved, params, session, logger, client_requester)
1040 .await
1041 }
1042
1043 pub fn len(&self) -> usize {
1045 self.tools.len()
1046 }
1047
1048 pub fn is_empty(&self) -> bool {
1050 self.tools.is_empty()
1051 }
1052}
1053
1054impl Default for ToolRegistry {
1055 fn default() -> Self {
1056 Self::new()
1057 }
1058}
1059
1060#[cfg(test)]
1061mod tests {
1062 use super::*;
1063
1064 struct EchoTool;
1066
1067 #[async_trait]
1068 impl Tool for EchoTool {
1069 fn name(&self) -> &str {
1070 "echo"
1071 }
1072
1073 fn description(&self) -> Option<&str> {
1074 Some("Echoes back the input message")
1075 }
1076
1077 fn input_schema(&self) -> Value {
1078 serde_json::json!({
1079 "type": "object",
1080 "properties": {
1081 "message": {
1082 "type": "string",
1083 "description": "Message to echo"
1084 }
1085 },
1086 "required": ["message"]
1087 })
1088 }
1089
1090 async fn execute(
1091 &self,
1092 ctx: ExecutionContext<'_>,
1093 ) -> Result<ToolOutput, ToolError> {
1094 let message = ctx
1095 .params
1096 .get("message")
1097 .and_then(|v| v.as_str())
1098 .ok_or_else(|| {
1099 ToolError::InvalidArguments("Missing 'message' field".to_string())
1100 })?;
1101
1102 Ok(ToolOutput::text(format!("Echo: {}", message)))
1103 }
1104 }
1105
1106 struct ToolA;
1107 struct ToolB;
1108
1109 #[async_trait]
1110 impl Tool for ToolA {
1111 fn name(&self) -> &str {
1112 "tool_a"
1113 }
1114
1115 fn input_schema(&self) -> Value {
1116 serde_json::json!({ "type": "object" })
1117 }
1118
1119 async fn execute(
1120 &self,
1121 _ctx: ExecutionContext<'_>,
1122 ) -> Result<ToolOutput, ToolError> {
1123 Ok(ToolOutput::text("A"))
1124 }
1125 }
1126
1127 #[async_trait]
1128 impl Tool for ToolB {
1129 fn name(&self) -> &str {
1130 "tool_b"
1131 }
1132
1133 fn input_schema(&self) -> Value {
1134 serde_json::json!({ "type": "object" })
1135 }
1136
1137 async fn execute(
1138 &self,
1139 _ctx: ExecutionContext<'_>,
1140 ) -> Result<ToolOutput, ToolError> {
1141 Ok(ToolOutput::text("B"))
1142 }
1143 }
1144
1145 #[test]
1146 fn test_registry_creation() {
1147 let registry = ToolRegistry::new();
1148 assert!(registry.is_empty());
1149 }
1150
1151 #[test]
1152 fn test_tool_registration() {
1153 let registry = ToolRegistry::new();
1154 registry.register(EchoTool);
1155
1156 assert_eq!(registry.len(), 1);
1157 assert!(!registry.is_empty());
1158 }
1159
1160 #[test]
1161 fn test_get_tool() {
1162 let registry = ToolRegistry::new();
1163 registry.register(EchoTool);
1164
1165 let tool = registry.get("echo");
1166 assert!(tool.is_some());
1167 assert_eq!(tool.unwrap().name(), "echo");
1168
1169 let missing = registry.get("nonexistent");
1170 assert!(missing.is_none());
1171 }
1172
1173 #[test]
1174 fn test_list_tools() {
1175 let registry = ToolRegistry::new();
1176 registry.register(EchoTool);
1177
1178 let tools = registry.list();
1179 assert_eq!(tools.len(), 1);
1180 assert_eq!(tools[0].name, "echo");
1181 assert_eq!(
1182 tools[0].description,
1183 Some("Echoes back the input message".to_string())
1184 );
1185 }
1186
1187 #[tokio::test]
1188 async fn test_call_tool() {
1189 let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1190 let logger = crate::logging::McpLogger::new(_tx, "test");
1191 let registry = ToolRegistry::new();
1192 registry.register(EchoTool);
1193 let session = Session::new();
1194
1195 let params = serde_json::json!({
1196 "message": "Hello, world!"
1197 });
1198
1199 let result = registry
1200 .call("echo", params, &session, &logger, None)
1201 .await
1202 .unwrap();
1203 assert!(result.is_content());
1204 }
1205
1206 #[tokio::test]
1207 async fn test_call_missing_tool() {
1208 let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1209 let logger = crate::logging::McpLogger::new(_tx, "test");
1210 let registry = ToolRegistry::new();
1211 let session = Session::new();
1212
1213 let params = serde_json::json!({});
1214 let result = registry
1215 .call("nonexistent", params, &session, &logger, None)
1216 .await;
1217
1218 assert!(matches!(result, Err(ToolError::NotFound(_))));
1219 }
1220
1221 #[tokio::test]
1222 async fn test_tool_invalid_arguments() {
1223 let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1224 let logger = crate::logging::McpLogger::new(_tx, "test");
1225 let registry = ToolRegistry::new();
1226 registry.register(EchoTool);
1227 let session = Session::new();
1228
1229 let params = serde_json::json!({}); let result = registry.call("echo", params, &session, &logger, None).await;
1232 assert!(matches!(result, Err(ToolError::InvalidArguments(_))));
1233 }
1234
1235 #[tokio::test]
1236 async fn test_disabled_tool_not_callable() {
1237 let (_tx, _rx) = tokio::sync::mpsc::unbounded_channel();
1238 let logger = crate::logging::McpLogger::new(_tx, "test");
1239 let registry = ToolRegistry::new();
1240 registry.register(EchoTool);
1241 registry.disable_tool("echo");
1242 let session = Session::new();
1243
1244 let params = serde_json::json!({
1245 "message": "Hello, world!"
1246 });
1247
1248 let result = registry.call("echo", params, &session, &logger, None).await;
1249 assert!(matches!(result, Err(ToolError::NotFound(_))));
1250 }
1251
1252 #[test]
1253 fn test_exclusive_group_disables_other() {
1254 let registry = ToolRegistry::new();
1255 registry.register_with_policy(
1256 ToolA,
1257 ToolPolicy::new().exclusive_group("exclusive"),
1258 );
1259 registry.register_with_policy(
1260 ToolB,
1261 ToolPolicy::new().exclusive_group("exclusive"),
1262 );
1263
1264 registry.enable_tool("tool_a");
1265 registry.enable_tool("tool_b");
1266
1267 assert_eq!(
1268 registry.tool_state("tool_a"),
1269 Some(ToolLifecycleState::Disabled)
1270 );
1271 assert_eq!(
1272 registry.tool_state("tool_b"),
1273 Some(ToolLifecycleState::Enabled)
1274 );
1275 }
1276}