rmcp_openapi/tool/
tool_collection.rs

1use super::Tool;
2use crate::config::Authorization;
3use crate::error::{Error, ToolCallError, ToolCallValidationError};
4use crate::transformer::ResponseTransformer;
5use rmcp::model::{CallToolResult, Tool as McpTool};
6use serde_json::Value;
7use std::sync::Arc;
8use tracing::debug_span;
9
10/// Collection of tools with built-in validation and lookup capabilities
11///
12/// This struct encapsulates all tool management logic in the library layer,
13/// providing a clean API for the binary to delegate tool operations to.
14#[derive(Clone, Default)]
15pub struct ToolCollection {
16    tools: Vec<Tool>,
17}
18
19impl ToolCollection {
20    /// Create a new empty tool collection
21    pub fn new() -> Self {
22        Self { tools: Vec::new() }
23    }
24
25    /// Create a tool collection from a vector of tools
26    pub fn from_tools(tools: Vec<Tool>) -> Self {
27        Self { tools }
28    }
29
30    /// Add a tool to the collection
31    pub fn add_tool(&mut self, tool: Tool) {
32        self.tools.push(tool);
33    }
34
35    /// Get the number of tools in the collection
36    pub fn len(&self) -> usize {
37        self.tools.len()
38    }
39
40    /// Check if the collection is empty
41    pub fn is_empty(&self) -> bool {
42        self.tools.is_empty()
43    }
44
45    /// Get all tool names
46    pub fn get_tool_names(&self) -> Vec<String> {
47        self.tools
48            .iter()
49            .map(|tool| tool.metadata.name.clone())
50            .collect()
51    }
52
53    /// Check if a specific tool exists
54    pub fn has_tool(&self, name: &str) -> bool {
55        self.tools.iter().any(|tool| tool.metadata.name == name)
56    }
57
58    /// Get a tool by name
59    pub fn get_tool(&self, name: &str) -> Option<&Tool> {
60        self.tools.iter().find(|tool| tool.metadata.name == name)
61    }
62
63    /// Convert all tools to MCP Tool format for list_tools response
64    pub fn to_mcp_tools(&self) -> Vec<McpTool> {
65        self.tools.iter().map(McpTool::from).collect()
66    }
67
68    /// Set a response transformer for a specific tool, overriding the global one.
69    ///
70    /// The transformer's `transform_schema` method is immediately applied to the tool's
71    /// output schema. The `transform_response` method will be applied to responses
72    /// when the tool is called.
73    ///
74    /// # Errors
75    ///
76    /// Returns an error if the tool is not found
77    pub fn set_tool_transformer(
78        &mut self,
79        tool_name: &str,
80        transformer: Arc<dyn ResponseTransformer>,
81    ) -> Result<(), Error> {
82        let tool = self
83            .tools
84            .iter_mut()
85            .find(|t| t.metadata.name == tool_name)
86            .ok_or_else(|| Error::ToolNotFound(tool_name.to_string()))?;
87
88        // Transform the existing output schema
89        if let Some(schema) = tool.metadata.output_schema.take() {
90            tool.metadata.output_schema = Some(transformer.transform_schema(schema));
91        }
92
93        tool.response_transformer = Some(transformer);
94        Ok(())
95    }
96
97    /// Call a tool by name with validation
98    ///
99    /// This method encapsulates all tool validation logic:
100    /// - Tool not found errors with suggestions
101    /// - Parameter validation
102    /// - Tool execution
103    ///
104    /// # Arguments
105    ///
106    /// * `tool_name` - The name of the tool to call
107    /// * `arguments` - The tool call arguments
108    /// * `authorization` - Authorization configuration
109    /// * `server_transformer` - Optional server-level response transformer
110    pub async fn call_tool(
111        &self,
112        tool_name: &str,
113        arguments: &Value,
114        authorization: Authorization,
115        server_transformer: Option<&dyn ResponseTransformer>,
116    ) -> Result<CallToolResult, ToolCallError> {
117        let span = debug_span!(
118            "tool_execution",
119            tool_name = %tool_name,
120            total_tools = self.tools.len()
121        );
122        let _enter = span.enter();
123
124        // First validate that the tool exists
125        if let Some(tool) = self.get_tool(tool_name) {
126            // Tool exists, delegate to the tool's call method
127            tool.call(arguments, authorization, server_transformer)
128                .await
129        } else {
130            // Tool not found - generate suggestions and return validation error
131            let tool_names: Vec<&str> = self
132                .tools
133                .iter()
134                .map(|tool| tool.metadata.name.as_str())
135                .collect();
136
137            Err(ToolCallError::Validation(
138                ToolCallValidationError::tool_not_found(tool_name.to_string(), &tool_names),
139            ))
140        }
141    }
142
143    /// Get basic statistics about the tool collection
144    pub fn get_stats(&self) -> String {
145        format!("Total tools: {}", self.tools.len())
146    }
147
148    /// Get an iterator over the tools
149    pub fn iter(&self) -> impl Iterator<Item = &Tool> {
150        self.tools.iter()
151    }
152}
153
154impl From<Vec<Tool>> for ToolCollection {
155    fn from(tools: Vec<Tool>) -> Self {
156        Self::from_tools(tools)
157    }
158}
159
160impl IntoIterator for ToolCollection {
161    type Item = Tool;
162    type IntoIter = std::vec::IntoIter<Tool>;
163
164    fn into_iter(self) -> Self::IntoIter {
165        self.tools.into_iter()
166    }
167}
168
169impl<'a> IntoIterator for &'a ToolCollection {
170    type Item = &'a Tool;
171    type IntoIter = std::slice::Iter<'a, Tool>;
172
173    fn into_iter(self) -> Self::IntoIter {
174        self.tools.iter()
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use crate::{HttpClient, tool::ToolMetadata};
182    use serde_json::json;
183
184    fn create_test_tool(name: &str, description: &str) -> Tool {
185        let metadata = ToolMetadata {
186            name: name.to_string(),
187            title: Some(name.to_string()),
188            description: Some(description.to_string()),
189            parameters: json!({
190                "type": "object",
191                "properties": {
192                    "id": {"type": "integer"}
193                },
194                "required": ["id"]
195            }),
196            output_schema: None,
197            method: "GET".to_string(),
198            path: format!("/{}", name),
199            security: None,
200            parameter_mappings: std::collections::HashMap::new(),
201        };
202        Tool::new(metadata, HttpClient::new()).unwrap()
203    }
204
205    #[test]
206    fn test_tool_collection_creation() {
207        let collection = ToolCollection::new();
208        assert_eq!(collection.len(), 0);
209        assert!(collection.is_empty());
210    }
211
212    #[test]
213    fn test_tool_collection_from_tools() {
214        let tool1 = create_test_tool("test1", "Test tool 1");
215        let tool2 = create_test_tool("test2", "Test tool 2");
216        let tools = vec![tool1, tool2];
217
218        let collection = ToolCollection::from_tools(tools);
219        assert_eq!(collection.len(), 2);
220        assert!(!collection.is_empty());
221        assert!(collection.has_tool("test1"));
222        assert!(collection.has_tool("test2"));
223        assert!(!collection.has_tool("test3"));
224    }
225
226    #[test]
227    fn test_add_tool() {
228        let mut collection = ToolCollection::new();
229        let tool = create_test_tool("test", "Test tool");
230
231        collection.add_tool(tool);
232        assert_eq!(collection.len(), 1);
233        assert!(collection.has_tool("test"));
234    }
235
236    #[test]
237    fn test_get_tool_names() {
238        let tool1 = create_test_tool("getPetById", "Get pet by ID");
239        let tool2 = create_test_tool("getPetsByStatus", "Get pets by status");
240        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
241
242        let names = collection.get_tool_names();
243        assert_eq!(names, vec!["getPetById", "getPetsByStatus"]);
244    }
245
246    #[test]
247    fn test_get_tool() {
248        let tool = create_test_tool("test", "Test tool");
249        let collection = ToolCollection::from_tools(vec![tool]);
250
251        assert!(collection.get_tool("test").is_some());
252        assert!(collection.get_tool("nonexistent").is_none());
253    }
254
255    #[test]
256    fn test_to_mcp_tools() {
257        let tool1 = create_test_tool("test1", "Test tool 1");
258        let tool2 = create_test_tool("test2", "Test tool 2");
259        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
260
261        let mcp_tools = collection.to_mcp_tools();
262        assert_eq!(mcp_tools.len(), 2);
263        assert_eq!(mcp_tools[0].name, "test1");
264        assert_eq!(mcp_tools[1].name, "test2");
265    }
266
267    #[actix_web::test]
268    async fn test_call_tool_not_found_with_suggestions() {
269        let tool1 = create_test_tool("getPetById", "Get pet by ID");
270        let tool2 = create_test_tool("getPetsByStatus", "Get pets by status");
271        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
272
273        let result = collection
274            .call_tool("getPetByID", &json!({}), Authorization::default(), None)
275            .await;
276        assert!(result.is_err());
277
278        if let Err(ToolCallError::Validation(ToolCallValidationError::ToolNotFound {
279            tool_name,
280            suggestions,
281        })) = result
282        {
283            assert_eq!(tool_name, "getPetByID");
284            // The algorithm finds multiple similar matches
285            assert!(suggestions.contains(&"getPetById".to_string()));
286            assert!(!suggestions.is_empty());
287        } else {
288            panic!("Expected ToolNotFound error with suggestions");
289        }
290    }
291
292    #[actix_web::test]
293    async fn test_call_tool_not_found_no_suggestions() {
294        let tool = create_test_tool("getPetById", "Get pet by ID");
295        let collection = ToolCollection::from_tools(vec![tool]);
296
297        let result = collection
298            .call_tool(
299                "completelyDifferentName",
300                &json!({}),
301                Authorization::default(),
302                None,
303            )
304            .await;
305        assert!(result.is_err());
306
307        if let Err(ToolCallError::Validation(ToolCallValidationError::ToolNotFound {
308            tool_name,
309            suggestions,
310        })) = result
311        {
312            assert_eq!(tool_name, "completelyDifferentName");
313            assert!(suggestions.is_empty());
314        } else {
315            panic!("Expected ToolNotFound error with no suggestions");
316        }
317    }
318
319    #[test]
320    fn test_iterators() {
321        let tool1 = create_test_tool("test1", "Test tool 1");
322        let tool2 = create_test_tool("test2", "Test tool 2");
323        let collection = ToolCollection::from_tools(vec![tool1, tool2]);
324
325        // Test iter()
326        let names: Vec<String> = collection
327            .iter()
328            .map(|tool| tool.metadata.name.clone())
329            .collect();
330        assert_eq!(names, vec!["test1", "test2"]);
331
332        // Test IntoIterator for &collection
333        let names: Vec<String> = (&collection)
334            .into_iter()
335            .map(|tool| tool.metadata.name.clone())
336            .collect();
337        assert_eq!(names, vec!["test1", "test2"]);
338
339        // Test IntoIterator for collection (consumes it)
340        let names: Vec<String> = collection
341            .into_iter()
342            .map(|tool| tool.metadata.name.clone())
343            .collect();
344        assert_eq!(names, vec!["test1", "test2"]);
345    }
346
347    #[test]
348    fn test_from_vec() {
349        let tool1 = create_test_tool("test1", "Test tool 1");
350        let tool2 = create_test_tool("test2", "Test tool 2");
351        let tools = vec![tool1, tool2];
352
353        let collection: ToolCollection = tools.into();
354        assert_eq!(collection.len(), 2);
355    }
356}