rmcp_openapi/tool/
tool_collection.rs1use super::Tool;
2use crate::error::{ToolCallError, ToolCallValidationError};
3use rmcp::model::{CallToolResult, Tool as McpTool};
4use serde_json::Value;
5use tracing::debug_span;
6
7#[derive(Clone, Default)]
12pub struct ToolCollection {
13    tools: Vec<Tool>,
14}
15
16impl ToolCollection {
17    pub fn new() -> Self {
19        Self { tools: Vec::new() }
20    }
21
22    pub fn from_tools(tools: Vec<Tool>) -> Self {
24        Self { tools }
25    }
26
27    pub fn add_tool(&mut self, tool: Tool) {
29        self.tools.push(tool);
30    }
31
32    pub fn len(&self) -> usize {
34        self.tools.len()
35    }
36
37    pub fn is_empty(&self) -> bool {
39        self.tools.is_empty()
40    }
41
42    pub fn get_tool_names(&self) -> Vec<String> {
44        self.tools
45            .iter()
46            .map(|tool| tool.metadata.name.clone())
47            .collect()
48    }
49
50    pub fn has_tool(&self, name: &str) -> bool {
52        self.tools.iter().any(|tool| tool.metadata.name == name)
53    }
54
55    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
57        self.tools.iter().find(|tool| tool.metadata.name == name)
58    }
59
60    pub fn to_mcp_tools(&self) -> Vec<McpTool> {
62        self.tools.iter().map(McpTool::from).collect()
63    }
64
65    pub async fn call_tool(
72        &self,
73        tool_name: &str,
74        arguments: &Value,
75    ) -> Result<CallToolResult, ToolCallError> {
76        let span = debug_span!(
77            "tool_execution",
78            tool_name = %tool_name,
79            total_tools = self.tools.len()
80        );
81        let _enter = span.enter();
82
83        if let Some(tool) = self.get_tool(tool_name) {
85            tool.call(arguments).await
87        } else {
88            let tool_names: Vec<&str> = self
90                .tools
91                .iter()
92                .map(|tool| tool.metadata.name.as_str())
93                .collect();
94
95            Err(ToolCallError::Validation(
96                ToolCallValidationError::tool_not_found(tool_name.to_string(), &tool_names),
97            ))
98        }
99    }
100
101    pub fn get_stats(&self) -> String {
103        format!("Total tools: {}", self.tools.len())
104    }
105
106    pub fn iter(&self) -> impl Iterator<Item = &Tool> {
108        self.tools.iter()
109    }
110}
111
112impl From<Vec<Tool>> for ToolCollection {
113    fn from(tools: Vec<Tool>) -> Self {
114        Self::from_tools(tools)
115    }
116}
117
118impl IntoIterator for ToolCollection {
119    type Item = Tool;
120    type IntoIter = std::vec::IntoIter<Tool>;
121
122    fn into_iter(self) -> Self::IntoIter {
123        self.tools.into_iter()
124    }
125}
126
127impl<'a> IntoIterator for &'a ToolCollection {
128    type Item = &'a Tool;
129    type IntoIter = std::slice::Iter<'a, Tool>;
130
131    fn into_iter(self) -> Self::IntoIter {
132        self.tools.iter()
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::tool::ToolMetadata;
140    use serde_json::json;
141
142    fn create_test_tool(name: &str, description: &str) -> Tool {
143        let metadata = ToolMetadata {
144            name: name.to_string(),
145            title: Some(name.to_string()),
146            description: description.to_string(),
147            parameters: json!({
148                "type": "object",
149                "properties": {
150                    "id": {"type": "integer"}
151                },
152                "required": ["id"]
153            }),
154            output_schema: None,
155            method: "GET".to_string(),
156            path: format!("/{}", name),
157        };
158        Tool::new(metadata, None, None).unwrap()
159    }
160
161    #[test]
162    fn test_tool_collection_creation() {
163        let collection = ToolCollection::new();
164        assert_eq!(collection.len(), 0);
165        assert!(collection.is_empty());
166    }
167
168    #[test]
169    fn test_tool_collection_from_tools() {
170        let tool1 = create_test_tool("test1", "Test tool 1");
171        let tool2 = create_test_tool("test2", "Test tool 2");
172        let tools = vec![tool1, tool2];
173
174        let collection = ToolCollection::from_tools(tools);
175        assert_eq!(collection.len(), 2);
176        assert!(!collection.is_empty());
177        assert!(collection.has_tool("test1"));
178        assert!(collection.has_tool("test2"));
179        assert!(!collection.has_tool("test3"));
180    }
181
182    #[test]
183    fn test_add_tool() {
184        let mut collection = ToolCollection::new();
185        let tool = create_test_tool("test", "Test tool");
186
187        collection.add_tool(tool);
188        assert_eq!(collection.len(), 1);
189        assert!(collection.has_tool("test"));
190    }
191
192    #[test]
193    fn test_get_tool_names() {
194        let tool1 = create_test_tool("getPetById", "Get pet by ID");
195        let tool2 = create_test_tool("getPetsByStatus", "Get pets by status");
196        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
197
198        let names = collection.get_tool_names();
199        assert_eq!(names, vec!["getPetById", "getPetsByStatus"]);
200    }
201
202    #[test]
203    fn test_get_tool() {
204        let tool = create_test_tool("test", "Test tool");
205        let collection = ToolCollection::from_tools(vec![tool]);
206
207        assert!(collection.get_tool("test").is_some());
208        assert!(collection.get_tool("nonexistent").is_none());
209    }
210
211    #[test]
212    fn test_to_mcp_tools() {
213        let tool1 = create_test_tool("test1", "Test tool 1");
214        let tool2 = create_test_tool("test2", "Test tool 2");
215        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
216
217        let mcp_tools = collection.to_mcp_tools();
218        assert_eq!(mcp_tools.len(), 2);
219        assert_eq!(mcp_tools[0].name, "test1");
220        assert_eq!(mcp_tools[1].name, "test2");
221    }
222
223    #[tokio::test]
224    async fn test_call_tool_not_found_with_suggestions() {
225        let tool1 = create_test_tool("getPetById", "Get pet by ID");
226        let tool2 = create_test_tool("getPetsByStatus", "Get pets by status");
227        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
228
229        let result = collection.call_tool("getPetByID", &json!({})).await;
230        assert!(result.is_err());
231
232        if let Err(ToolCallError::Validation(ToolCallValidationError::ToolNotFound {
233            tool_name,
234            suggestions,
235        })) = result
236        {
237            assert_eq!(tool_name, "getPetByID");
238            assert!(suggestions.contains(&"getPetById".to_string()));
240            assert!(!suggestions.is_empty());
241        } else {
242            panic!("Expected ToolNotFound error with suggestions");
243        }
244    }
245
246    #[tokio::test]
247    async fn test_call_tool_not_found_no_suggestions() {
248        let tool = create_test_tool("getPetById", "Get pet by ID");
249        let collection = ToolCollection::from_tools(vec![tool]);
250
251        let result = collection
252            .call_tool("completelyDifferentName", &json!({}))
253            .await;
254        assert!(result.is_err());
255
256        if let Err(ToolCallError::Validation(ToolCallValidationError::ToolNotFound {
257            tool_name,
258            suggestions,
259        })) = result
260        {
261            assert_eq!(tool_name, "completelyDifferentName");
262            assert!(suggestions.is_empty());
263        } else {
264            panic!("Expected ToolNotFound error with no suggestions");
265        }
266    }
267
268    #[test]
269    fn test_iterators() {
270        let tool1 = create_test_tool("test1", "Test tool 1");
271        let tool2 = create_test_tool("test2", "Test tool 2");
272        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
273
274        let names: Vec<String> = collection
276            .iter()
277            .map(|tool| tool.metadata.name.clone())
278            .collect();
279        assert_eq!(names, vec!["test1", "test2"]);
280
281        let names: Vec<String> = (&collection)
283            .into_iter()
284            .map(|tool| tool.metadata.name.clone())
285            .collect();
286        assert_eq!(names, vec!["test1", "test2"]);
287
288        let names: Vec<String> = collection
290            .into_iter()
291            .map(|tool| tool.metadata.name.clone())
292            .collect();
293        assert_eq!(names, vec!["test1", "test2"]);
294    }
295
296    #[test]
297    fn test_from_vec() {
298        let tool1 = create_test_tool("test1", "Test tool 1");
299        let tool2 = create_test_tool("test2", "Test tool 2");
300        let tools = vec![tool1, tool2];
301
302        let collection: ToolCollection = tools.into();
303        assert_eq!(collection.len(), 2);
304    }
305}