claude_agent/mcp/
toolset.rs

1//! MCP Toolset configuration for API requests with deferred loading support.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Default, Serialize, Deserialize)]
8pub struct ToolLoadConfig {
9    #[serde(skip_serializing_if = "Option::is_none")]
10    pub defer_loading: Option<bool>,
11}
12
13impl ToolLoadConfig {
14    pub fn deferred() -> Self {
15        Self {
16            defer_loading: Some(true),
17        }
18    }
19
20    pub fn immediate() -> Self {
21        Self {
22            defer_loading: Some(false),
23        }
24    }
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct McpToolset {
29    #[serde(rename = "type")]
30    pub toolset_type: String,
31    pub mcp_server_name: String,
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub default_config: Option<ToolLoadConfig>,
34    #[serde(skip_serializing_if = "Option::is_none")]
35    pub configs: Option<HashMap<String, ToolLoadConfig>>,
36}
37
38impl McpToolset {
39    pub fn new(server_name: impl Into<String>) -> Self {
40        Self {
41            toolset_type: "mcp_toolset".to_string(),
42            mcp_server_name: server_name.into(),
43            default_config: None,
44            configs: None,
45        }
46    }
47
48    pub fn defer_all(mut self) -> Self {
49        self.default_config = Some(ToolLoadConfig::deferred());
50        self
51    }
52
53    pub fn keep_loaded(mut self, tool_names: impl IntoIterator<Item = impl Into<String>>) -> Self {
54        let configs = self.configs.get_or_insert_with(HashMap::new);
55        for name in tool_names {
56            configs.insert(name.into(), ToolLoadConfig::immediate());
57        }
58        self
59    }
60
61    pub fn defer_tools(mut self, tool_names: impl IntoIterator<Item = impl Into<String>>) -> Self {
62        let configs = self.configs.get_or_insert_with(HashMap::new);
63        for name in tool_names {
64            configs.insert(name.into(), ToolLoadConfig::deferred());
65        }
66        self
67    }
68
69    pub fn is_deferred(&self, tool_name: &str) -> bool {
70        if let Some(defer) = self
71            .configs
72            .as_ref()
73            .and_then(|c| c.get(tool_name))
74            .and_then(|c| c.defer_loading)
75        {
76            return defer;
77        }
78        self.default_config
79            .as_ref()
80            .and_then(|c| c.defer_loading)
81            .unwrap_or(false)
82    }
83
84    pub fn server_name(&self) -> &str {
85        &self.mcp_server_name
86    }
87}
88
89#[derive(Debug, Clone, Default)]
90pub struct McpToolsetRegistry {
91    toolsets: HashMap<String, McpToolset>,
92}
93
94impl McpToolsetRegistry {
95    pub fn new() -> Self {
96        Self::default()
97    }
98
99    pub fn register(&mut self, toolset: McpToolset) {
100        self.toolsets
101            .insert(toolset.mcp_server_name.clone(), toolset);
102    }
103
104    pub fn get(&self, server_name: &str) -> Option<&McpToolset> {
105        self.toolsets.get(server_name)
106    }
107
108    pub fn is_deferred(&self, server_name: &str, tool_name: &str) -> bool {
109        self.toolsets
110            .get(server_name)
111            .map(|ts| ts.is_deferred(tool_name))
112            .unwrap_or(false)
113    }
114
115    pub fn iter(&self) -> impl Iterator<Item = &McpToolset> {
116        self.toolsets.values()
117    }
118}
119
120#[cfg(test)]
121mod tests {
122    use super::*;
123
124    #[test]
125    fn test_toolset_defer_all() {
126        let toolset = McpToolset::new("database").defer_all();
127        assert!(toolset.is_deferred("any_tool"));
128    }
129
130    #[test]
131    fn test_toolset_keep_loaded() {
132        let toolset = McpToolset::new("database")
133            .defer_all()
134            .keep_loaded(["search_events"]);
135
136        assert!(!toolset.is_deferred("search_events"));
137        assert!(toolset.is_deferred("other_tool"));
138    }
139
140    #[test]
141    fn test_toolset_serialization() {
142        let toolset = McpToolset::new("database")
143            .defer_all()
144            .keep_loaded(["search"]);
145
146        let json = serde_json::to_string_pretty(&toolset).unwrap();
147        assert!(json.contains("mcp_toolset"));
148        assert!(json.contains("database"));
149    }
150
151    #[test]
152    fn test_registry() {
153        let mut registry = McpToolsetRegistry::new();
154        registry.register(McpToolset::new("server1").defer_all());
155        registry.register(McpToolset::new("server2"));
156
157        assert!(registry.is_deferred("server1", "any_tool"));
158        assert!(!registry.is_deferred("server2", "any_tool"));
159        assert!(!registry.is_deferred("server3", "any_tool"));
160    }
161}