ricecoder_mcp/
executor.rs

1//! Custom Tool Executor component
2
3use crate::config::CustomToolConfig;
4use crate::error::{Error, Result};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11/// Custom Tool Executor for executing custom tools defined in configuration
12#[derive(Debug, Clone)]
13pub struct CustomToolExecutor {
14    tools: Arc<RwLock<HashMap<String, CustomToolConfig>>>,
15}
16
17impl CustomToolExecutor {
18    /// Creates a new custom tool executor
19    pub fn new() -> Self {
20        Self {
21            tools: Arc::new(RwLock::new(HashMap::new())),
22        }
23    }
24
25    /// Registers a custom tool from configuration
26    ///
27    /// # Arguments
28    /// * `tool` - The custom tool configuration to register
29    pub async fn register_tool(&self, tool: CustomToolConfig) -> Result<()> {
30        debug!("Registering custom tool: {}", tool.id);
31
32        // Validate tool configuration
33        if tool.id.is_empty() {
34            return Err(Error::ValidationError(
35                "Custom tool ID cannot be empty".to_string(),
36            ));
37        }
38
39        if tool.handler.is_empty() {
40            return Err(Error::ValidationError(format!(
41                "Custom tool '{}' has no handler",
42                tool.id
43            )));
44        }
45
46        let mut tools = self.tools.write().await;
47        tools.insert(tool.id.clone(), tool.clone());
48
49        info!("Custom tool registered: {}", tool.id);
50        Ok(())
51    }
52
53    /// Registers multiple custom tools
54    pub async fn register_tools(&self, tools: Vec<CustomToolConfig>) -> Result<()> {
55        for tool in tools {
56            self.register_tool(tool).await?;
57        }
58        Ok(())
59    }
60
61    /// Unregisters a custom tool
62    pub async fn unregister_tool(&self, tool_id: &str) -> Result<()> {
63        debug!("Unregistering custom tool: {}", tool_id);
64
65        let mut tools = self.tools.write().await;
66        tools.remove(tool_id);
67
68        info!("Custom tool unregistered: {}", tool_id);
69        Ok(())
70    }
71
72    /// Executes a custom tool with parameter validation
73    ///
74    /// # Arguments
75    /// * `tool_id` - The ID of the tool to execute
76    /// * `parameters` - The parameters to pass to the tool
77    ///
78    /// # Returns
79    /// The result of the tool execution
80    pub async fn execute_tool(&self, tool_id: &str, parameters: Value) -> Result<Value> {
81        debug!("Executing custom tool: {}", tool_id);
82
83        let tools = self.tools.read().await;
84        let tool = tools
85            .get(tool_id)
86            .ok_or_else(|| Error::ToolNotFound(format!("Custom tool not found: {}", tool_id)))?
87            .clone();
88        drop(tools);
89
90        // Validate parameters
91        self.validate_parameters(&tool, &parameters)?;
92
93        // Execute the tool
94        let result = self.execute_handler(&tool, parameters).await?;
95
96        // Validate output
97        self.validate_output(&tool, &result)?;
98
99        info!("Custom tool executed successfully: {}", tool_id);
100        Ok(result)
101    }
102
103    /// Validates tool parameters before execution
104    fn validate_parameters(&self, tool: &CustomToolConfig, parameters: &Value) -> Result<()> {
105        debug!("Validating parameters for tool: {}", tool.id);
106
107        let params_obj = parameters.as_object().ok_or_else(|| {
108            Error::ParameterValidationError("Parameters must be a JSON object".to_string())
109        })?;
110
111        // Check required parameters
112        for param in &tool.parameters {
113            if param.required && !params_obj.contains_key(&param.name) {
114                return Err(Error::ParameterValidationError(format!(
115                    "Required parameter '{}' is missing",
116                    param.name
117                )));
118            }
119
120            // Validate parameter types if present
121            if let Some(value) = params_obj.get(&param.name) {
122                self.validate_parameter_type(&param.name, &param.type_, value)?;
123            }
124        }
125
126        Ok(())
127    }
128
129    /// Validates a single parameter type
130    fn validate_parameter_type(&self, name: &str, expected_type: &str, value: &Value) -> Result<()> {
131        let type_matches = match expected_type {
132            "string" => value.is_string(),
133            "number" => value.is_number(),
134            "integer" => value.is_i64() || value.is_u64(),
135            "boolean" => value.is_boolean(),
136            "array" => value.is_array(),
137            "object" => value.is_object(),
138            _ => true, // Unknown types are allowed
139        };
140
141        if !type_matches {
142            return Err(Error::ParameterValidationError(format!(
143                "Parameter '{}' has invalid type. Expected: {}, Got: {}",
144                name,
145                expected_type,
146                value.type_str()
147            )));
148        }
149
150        Ok(())
151    }
152
153    /// Validates tool output after execution
154    fn validate_output(&self, tool: &CustomToolConfig, output: &Value) -> Result<()> {
155        debug!("Validating output for tool: {}", tool.id);
156
157        // Validate output type
158        self.validate_parameter_type("output", &tool.return_type, output)?;
159
160        Ok(())
161    }
162
163    /// Executes the tool handler
164    ///
165    /// In a real implementation, this would invoke the actual handler function.
166    /// For now, we simulate successful execution.
167    async fn execute_handler(&self, tool: &CustomToolConfig, parameters: Value) -> Result<Value> {
168        debug!("Executing handler for tool: {}", tool.id);
169
170        // Simulate tool execution
171        // In a real implementation, this would:
172        // 1. Look up the handler function
173        // 2. Call it with the parameters
174        // 3. Return the result
175
176        // For now, return a success response
177        Ok(json!({
178            "success": true,
179            "tool_id": tool.id,
180            "message": format!("Tool '{}' executed successfully", tool.id),
181            "parameters": parameters
182        }))
183    }
184
185    /// Gets a registered custom tool
186    pub async fn get_tool(&self, tool_id: &str) -> Result<CustomToolConfig> {
187        let tools = self.tools.read().await;
188        tools
189            .get(tool_id)
190            .cloned()
191            .ok_or_else(|| Error::ToolNotFound(format!("Custom tool not found: {}", tool_id)))
192    }
193
194    /// Lists all registered custom tools
195    pub async fn list_tools(&self) -> Vec<CustomToolConfig> {
196        let tools = self.tools.read().await;
197        tools.values().cloned().collect()
198    }
199
200    /// Gets the count of registered custom tools
201    pub async fn tool_count(&self) -> usize {
202        let tools = self.tools.read().await;
203        tools.len()
204    }
205
206    /// Checks if a tool is registered
207    pub async fn has_tool(&self, tool_id: &str) -> bool {
208        let tools = self.tools.read().await;
209        tools.contains_key(tool_id)
210    }
211
212    /// Clears all registered custom tools
213    pub async fn clear_tools(&self) {
214        let mut tools = self.tools.write().await;
215        tools.clear();
216        info!("All custom tools cleared");
217    }
218}
219
220impl Default for CustomToolExecutor {
221    fn default() -> Self {
222        Self::new()
223    }
224}
225
226/// Extension trait for Value to get type string
227trait ValueTypeStr {
228    fn type_str(&self) -> &'static str;
229}
230
231impl ValueTypeStr for Value {
232    fn type_str(&self) -> &'static str {
233        match self {
234            Value::Null => "null",
235            Value::Bool(_) => "boolean",
236            Value::Number(_) => "number",
237            Value::String(_) => "string",
238            Value::Array(_) => "array",
239            Value::Object(_) => "object",
240        }
241    }
242}
243
244#[cfg(test)]
245mod tests {
246    use super::*;
247    use crate::config::ParameterConfig;
248
249    fn create_test_tool(id: &str) -> CustomToolConfig {
250        CustomToolConfig {
251            id: id.to_string(),
252            name: format!("Test Tool {}", id),
253            description: "A test tool".to_string(),
254            category: "test".to_string(),
255            parameters: vec![
256                ParameterConfig {
257                    name: "input".to_string(),
258                    type_: "string".to_string(),
259                    description: "Input parameter".to_string(),
260                    required: true,
261                    default: None,
262                },
263                ParameterConfig {
264                    name: "count".to_string(),
265                    type_: "integer".to_string(),
266                    description: "Count parameter".to_string(),
267                    required: false,
268                    default: Some(json!(1)),
269                },
270            ],
271            return_type: "object".to_string(),
272            handler: "test::handler".to_string(),
273        }
274    }
275
276    #[tokio::test]
277    async fn test_create_executor() {
278        let executor = CustomToolExecutor::new();
279        assert_eq!(executor.tool_count().await, 0);
280    }
281
282    #[tokio::test]
283    async fn test_register_tool() {
284        let executor = CustomToolExecutor::new();
285        let tool = create_test_tool("tool1");
286
287        let result = executor.register_tool(tool).await;
288        assert!(result.is_ok());
289        assert_eq!(executor.tool_count().await, 1);
290    }
291
292    #[tokio::test]
293    async fn test_register_tool_empty_id() {
294        let executor = CustomToolExecutor::new();
295        let mut tool = create_test_tool("tool1");
296        tool.id = "".to_string();
297
298        let result = executor.register_tool(tool).await;
299        assert!(result.is_err());
300    }
301
302    #[tokio::test]
303    async fn test_register_tool_empty_handler() {
304        let executor = CustomToolExecutor::new();
305        let mut tool = create_test_tool("tool1");
306        tool.handler = "".to_string();
307
308        let result = executor.register_tool(tool).await;
309        assert!(result.is_err());
310    }
311
312    #[tokio::test]
313    async fn test_unregister_tool() {
314        let executor = CustomToolExecutor::new();
315        let tool = create_test_tool("tool1");
316
317        executor.register_tool(tool).await.unwrap();
318        assert_eq!(executor.tool_count().await, 1);
319
320        executor.unregister_tool("tool1").await.unwrap();
321        assert_eq!(executor.tool_count().await, 0);
322    }
323
324    #[tokio::test]
325    async fn test_get_tool() {
326        let executor = CustomToolExecutor::new();
327        let tool = create_test_tool("tool1");
328
329        executor.register_tool(tool.clone()).await.unwrap();
330        let retrieved = executor.get_tool("tool1").await.unwrap();
331        assert_eq!(retrieved.id, tool.id);
332    }
333
334    #[tokio::test]
335    async fn test_get_tool_not_found() {
336        let executor = CustomToolExecutor::new();
337        let result = executor.get_tool("nonexistent").await;
338        assert!(result.is_err());
339    }
340
341    #[tokio::test]
342    async fn test_list_tools() {
343        let executor = CustomToolExecutor::new();
344        executor.register_tool(create_test_tool("tool1")).await.unwrap();
345        executor.register_tool(create_test_tool("tool2")).await.unwrap();
346
347        let tools = executor.list_tools().await;
348        assert_eq!(tools.len(), 2);
349    }
350
351    #[tokio::test]
352    async fn test_has_tool() {
353        let executor = CustomToolExecutor::new();
354        executor.register_tool(create_test_tool("tool1")).await.unwrap();
355
356        assert!(executor.has_tool("tool1").await);
357        assert!(!executor.has_tool("tool2").await);
358    }
359
360    #[tokio::test]
361    async fn test_clear_tools() {
362        let executor = CustomToolExecutor::new();
363        executor.register_tool(create_test_tool("tool1")).await.unwrap();
364        executor.register_tool(create_test_tool("tool2")).await.unwrap();
365
366        assert_eq!(executor.tool_count().await, 2);
367        executor.clear_tools().await;
368        assert_eq!(executor.tool_count().await, 0);
369    }
370
371    #[tokio::test]
372    async fn test_execute_tool_valid_parameters() {
373        let executor = CustomToolExecutor::new();
374        let tool = create_test_tool("tool1");
375        executor.register_tool(tool).await.unwrap();
376
377        let params = json!({
378            "input": "test",
379            "count": 5
380        });
381
382        let result = executor.execute_tool("tool1", params).await;
383        assert!(result.is_ok());
384    }
385
386    #[tokio::test]
387    async fn test_execute_tool_missing_required_parameter() {
388        let executor = CustomToolExecutor::new();
389        let tool = create_test_tool("tool1");
390        executor.register_tool(tool).await.unwrap();
391
392        let params = json!({
393            "count": 5
394        });
395
396        let result = executor.execute_tool("tool1", params).await;
397        assert!(result.is_err());
398    }
399
400    #[tokio::test]
401    async fn test_execute_tool_invalid_parameter_type() {
402        let executor = CustomToolExecutor::new();
403        let tool = create_test_tool("tool1");
404        executor.register_tool(tool).await.unwrap();
405
406        let params = json!({
407            "input": 123,
408            "count": 5
409        });
410
411        let result = executor.execute_tool("tool1", params).await;
412        assert!(result.is_err());
413    }
414
415    #[tokio::test]
416    async fn test_execute_tool_not_found() {
417        let executor = CustomToolExecutor::new();
418        let params = json!({
419            "input": "test"
420        });
421
422        let result = executor.execute_tool("nonexistent", params).await;
423        assert!(result.is_err());
424    }
425
426    #[tokio::test]
427    async fn test_register_multiple_tools() {
428        let executor = CustomToolExecutor::new();
429        let tools = vec![
430            create_test_tool("tool1"),
431            create_test_tool("tool2"),
432            create_test_tool("tool3"),
433        ];
434
435        let result = executor.register_tools(tools).await;
436        assert!(result.is_ok());
437        assert_eq!(executor.tool_count().await, 3);
438    }
439
440    #[tokio::test]
441    async fn test_validate_parameter_types() {
442        let executor = CustomToolExecutor::new();
443
444        // Test string validation
445        assert!(executor
446            .validate_parameter_type("test", "string", &json!("hello"))
447            .is_ok());
448        assert!(executor
449            .validate_parameter_type("test", "string", &json!(123))
450            .is_err());
451
452        // Test number validation
453        assert!(executor
454            .validate_parameter_type("test", "number", &json!(123.45))
455            .is_ok());
456        assert!(executor
457            .validate_parameter_type("test", "number", &json!("hello"))
458            .is_err());
459
460        // Test boolean validation
461        assert!(executor
462            .validate_parameter_type("test", "boolean", &json!(true))
463            .is_ok());
464        assert!(executor
465            .validate_parameter_type("test", "boolean", &json!("hello"))
466            .is_err());
467
468        // Test array validation
469        assert!(executor
470            .validate_parameter_type("test", "array", &json!([1, 2, 3]))
471            .is_ok());
472        assert!(executor
473            .validate_parameter_type("test", "array", &json!("hello"))
474            .is_err());
475
476        // Test object validation
477        assert!(executor
478            .validate_parameter_type("test", "object", &json!({}))
479            .is_ok());
480        assert!(executor
481            .validate_parameter_type("test", "object", &json!("hello"))
482            .is_err());
483    }
484}