1use crate::config::CustomToolConfig;
4use crate::error::{Error, Result};
5use serde_json::{json, Value};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use tracing::{debug, info};
10
11#[derive(Debug, Clone)]
13pub struct CustomToolExecutor {
14 tools: Arc<RwLock<HashMap<String, CustomToolConfig>>>,
15}
16
17impl CustomToolExecutor {
18 pub fn new() -> Self {
20 Self {
21 tools: Arc::new(RwLock::new(HashMap::new())),
22 }
23 }
24
25 pub async fn register_tool(&self, tool: CustomToolConfig) -> Result<()> {
30 debug!("Registering custom tool: {}", tool.id);
31
32 if tool.id.is_empty() {
34 return Err(Error::ValidationError(
35 "Custom tool ID cannot be empty".to_string(),
36 ));
37 }
38
39 if tool.handler.is_empty() {
40 return Err(Error::ValidationError(format!(
41 "Custom tool '{}' has no handler",
42 tool.id
43 )));
44 }
45
46 let mut tools = self.tools.write().await;
47 tools.insert(tool.id.clone(), tool.clone());
48
49 info!("Custom tool registered: {}", tool.id);
50 Ok(())
51 }
52
53 pub async fn register_tools(&self, tools: Vec<CustomToolConfig>) -> Result<()> {
55 for tool in tools {
56 self.register_tool(tool).await?;
57 }
58 Ok(())
59 }
60
61 pub async fn unregister_tool(&self, tool_id: &str) -> Result<()> {
63 debug!("Unregistering custom tool: {}", tool_id);
64
65 let mut tools = self.tools.write().await;
66 tools.remove(tool_id);
67
68 info!("Custom tool unregistered: {}", tool_id);
69 Ok(())
70 }
71
72 pub async fn execute_tool(&self, tool_id: &str, parameters: Value) -> Result<Value> {
81 debug!("Executing custom tool: {}", tool_id);
82
83 let tools = self.tools.read().await;
84 let tool = tools
85 .get(tool_id)
86 .ok_or_else(|| Error::ToolNotFound(format!("Custom tool not found: {}", tool_id)))?
87 .clone();
88 drop(tools);
89
90 self.validate_parameters(&tool, ¶meters)?;
92
93 let result = self.execute_handler(&tool, parameters).await?;
95
96 self.validate_output(&tool, &result)?;
98
99 info!("Custom tool executed successfully: {}", tool_id);
100 Ok(result)
101 }
102
103 fn validate_parameters(&self, tool: &CustomToolConfig, parameters: &Value) -> Result<()> {
105 debug!("Validating parameters for tool: {}", tool.id);
106
107 let params_obj = parameters.as_object().ok_or_else(|| {
108 Error::ParameterValidationError("Parameters must be a JSON object".to_string())
109 })?;
110
111 for param in &tool.parameters {
113 if param.required && !params_obj.contains_key(¶m.name) {
114 return Err(Error::ParameterValidationError(format!(
115 "Required parameter '{}' is missing",
116 param.name
117 )));
118 }
119
120 if let Some(value) = params_obj.get(¶m.name) {
122 self.validate_parameter_type(¶m.name, ¶m.type_, value)?;
123 }
124 }
125
126 Ok(())
127 }
128
129 fn validate_parameter_type(&self, name: &str, expected_type: &str, value: &Value) -> Result<()> {
131 let type_matches = match expected_type {
132 "string" => value.is_string(),
133 "number" => value.is_number(),
134 "integer" => value.is_i64() || value.is_u64(),
135 "boolean" => value.is_boolean(),
136 "array" => value.is_array(),
137 "object" => value.is_object(),
138 _ => true, };
140
141 if !type_matches {
142 return Err(Error::ParameterValidationError(format!(
143 "Parameter '{}' has invalid type. Expected: {}, Got: {}",
144 name,
145 expected_type,
146 value.type_str()
147 )));
148 }
149
150 Ok(())
151 }
152
153 fn validate_output(&self, tool: &CustomToolConfig, output: &Value) -> Result<()> {
155 debug!("Validating output for tool: {}", tool.id);
156
157 self.validate_parameter_type("output", &tool.return_type, output)?;
159
160 Ok(())
161 }
162
163 async fn execute_handler(&self, tool: &CustomToolConfig, parameters: Value) -> Result<Value> {
168 debug!("Executing handler for tool: {}", tool.id);
169
170 Ok(json!({
178 "success": true,
179 "tool_id": tool.id,
180 "message": format!("Tool '{}' executed successfully", tool.id),
181 "parameters": parameters
182 }))
183 }
184
185 pub async fn get_tool(&self, tool_id: &str) -> Result<CustomToolConfig> {
187 let tools = self.tools.read().await;
188 tools
189 .get(tool_id)
190 .cloned()
191 .ok_or_else(|| Error::ToolNotFound(format!("Custom tool not found: {}", tool_id)))
192 }
193
194 pub async fn list_tools(&self) -> Vec<CustomToolConfig> {
196 let tools = self.tools.read().await;
197 tools.values().cloned().collect()
198 }
199
200 pub async fn tool_count(&self) -> usize {
202 let tools = self.tools.read().await;
203 tools.len()
204 }
205
206 pub async fn has_tool(&self, tool_id: &str) -> bool {
208 let tools = self.tools.read().await;
209 tools.contains_key(tool_id)
210 }
211
212 pub async fn clear_tools(&self) {
214 let mut tools = self.tools.write().await;
215 tools.clear();
216 info!("All custom tools cleared");
217 }
218}
219
220impl Default for CustomToolExecutor {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226trait ValueTypeStr {
228 fn type_str(&self) -> &'static str;
229}
230
231impl ValueTypeStr for Value {
232 fn type_str(&self) -> &'static str {
233 match self {
234 Value::Null => "null",
235 Value::Bool(_) => "boolean",
236 Value::Number(_) => "number",
237 Value::String(_) => "string",
238 Value::Array(_) => "array",
239 Value::Object(_) => "object",
240 }
241 }
242}
243
244#[cfg(test)]
245mod tests {
246 use super::*;
247 use crate::config::ParameterConfig;
248
249 fn create_test_tool(id: &str) -> CustomToolConfig {
250 CustomToolConfig {
251 id: id.to_string(),
252 name: format!("Test Tool {}", id),
253 description: "A test tool".to_string(),
254 category: "test".to_string(),
255 parameters: vec![
256 ParameterConfig {
257 name: "input".to_string(),
258 type_: "string".to_string(),
259 description: "Input parameter".to_string(),
260 required: true,
261 default: None,
262 },
263 ParameterConfig {
264 name: "count".to_string(),
265 type_: "integer".to_string(),
266 description: "Count parameter".to_string(),
267 required: false,
268 default: Some(json!(1)),
269 },
270 ],
271 return_type: "object".to_string(),
272 handler: "test::handler".to_string(),
273 }
274 }
275
276 #[tokio::test]
277 async fn test_create_executor() {
278 let executor = CustomToolExecutor::new();
279 assert_eq!(executor.tool_count().await, 0);
280 }
281
282 #[tokio::test]
283 async fn test_register_tool() {
284 let executor = CustomToolExecutor::new();
285 let tool = create_test_tool("tool1");
286
287 let result = executor.register_tool(tool).await;
288 assert!(result.is_ok());
289 assert_eq!(executor.tool_count().await, 1);
290 }
291
292 #[tokio::test]
293 async fn test_register_tool_empty_id() {
294 let executor = CustomToolExecutor::new();
295 let mut tool = create_test_tool("tool1");
296 tool.id = "".to_string();
297
298 let result = executor.register_tool(tool).await;
299 assert!(result.is_err());
300 }
301
302 #[tokio::test]
303 async fn test_register_tool_empty_handler() {
304 let executor = CustomToolExecutor::new();
305 let mut tool = create_test_tool("tool1");
306 tool.handler = "".to_string();
307
308 let result = executor.register_tool(tool).await;
309 assert!(result.is_err());
310 }
311
312 #[tokio::test]
313 async fn test_unregister_tool() {
314 let executor = CustomToolExecutor::new();
315 let tool = create_test_tool("tool1");
316
317 executor.register_tool(tool).await.unwrap();
318 assert_eq!(executor.tool_count().await, 1);
319
320 executor.unregister_tool("tool1").await.unwrap();
321 assert_eq!(executor.tool_count().await, 0);
322 }
323
324 #[tokio::test]
325 async fn test_get_tool() {
326 let executor = CustomToolExecutor::new();
327 let tool = create_test_tool("tool1");
328
329 executor.register_tool(tool.clone()).await.unwrap();
330 let retrieved = executor.get_tool("tool1").await.unwrap();
331 assert_eq!(retrieved.id, tool.id);
332 }
333
334 #[tokio::test]
335 async fn test_get_tool_not_found() {
336 let executor = CustomToolExecutor::new();
337 let result = executor.get_tool("nonexistent").await;
338 assert!(result.is_err());
339 }
340
341 #[tokio::test]
342 async fn test_list_tools() {
343 let executor = CustomToolExecutor::new();
344 executor.register_tool(create_test_tool("tool1")).await.unwrap();
345 executor.register_tool(create_test_tool("tool2")).await.unwrap();
346
347 let tools = executor.list_tools().await;
348 assert_eq!(tools.len(), 2);
349 }
350
351 #[tokio::test]
352 async fn test_has_tool() {
353 let executor = CustomToolExecutor::new();
354 executor.register_tool(create_test_tool("tool1")).await.unwrap();
355
356 assert!(executor.has_tool("tool1").await);
357 assert!(!executor.has_tool("tool2").await);
358 }
359
360 #[tokio::test]
361 async fn test_clear_tools() {
362 let executor = CustomToolExecutor::new();
363 executor.register_tool(create_test_tool("tool1")).await.unwrap();
364 executor.register_tool(create_test_tool("tool2")).await.unwrap();
365
366 assert_eq!(executor.tool_count().await, 2);
367 executor.clear_tools().await;
368 assert_eq!(executor.tool_count().await, 0);
369 }
370
371 #[tokio::test]
372 async fn test_execute_tool_valid_parameters() {
373 let executor = CustomToolExecutor::new();
374 let tool = create_test_tool("tool1");
375 executor.register_tool(tool).await.unwrap();
376
377 let params = json!({
378 "input": "test",
379 "count": 5
380 });
381
382 let result = executor.execute_tool("tool1", params).await;
383 assert!(result.is_ok());
384 }
385
386 #[tokio::test]
387 async fn test_execute_tool_missing_required_parameter() {
388 let executor = CustomToolExecutor::new();
389 let tool = create_test_tool("tool1");
390 executor.register_tool(tool).await.unwrap();
391
392 let params = json!({
393 "count": 5
394 });
395
396 let result = executor.execute_tool("tool1", params).await;
397 assert!(result.is_err());
398 }
399
400 #[tokio::test]
401 async fn test_execute_tool_invalid_parameter_type() {
402 let executor = CustomToolExecutor::new();
403 let tool = create_test_tool("tool1");
404 executor.register_tool(tool).await.unwrap();
405
406 let params = json!({
407 "input": 123,
408 "count": 5
409 });
410
411 let result = executor.execute_tool("tool1", params).await;
412 assert!(result.is_err());
413 }
414
415 #[tokio::test]
416 async fn test_execute_tool_not_found() {
417 let executor = CustomToolExecutor::new();
418 let params = json!({
419 "input": "test"
420 });
421
422 let result = executor.execute_tool("nonexistent", params).await;
423 assert!(result.is_err());
424 }
425
426 #[tokio::test]
427 async fn test_register_multiple_tools() {
428 let executor = CustomToolExecutor::new();
429 let tools = vec![
430 create_test_tool("tool1"),
431 create_test_tool("tool2"),
432 create_test_tool("tool3"),
433 ];
434
435 let result = executor.register_tools(tools).await;
436 assert!(result.is_ok());
437 assert_eq!(executor.tool_count().await, 3);
438 }
439
440 #[tokio::test]
441 async fn test_validate_parameter_types() {
442 let executor = CustomToolExecutor::new();
443
444 assert!(executor
446 .validate_parameter_type("test", "string", &json!("hello"))
447 .is_ok());
448 assert!(executor
449 .validate_parameter_type("test", "string", &json!(123))
450 .is_err());
451
452 assert!(executor
454 .validate_parameter_type("test", "number", &json!(123.45))
455 .is_ok());
456 assert!(executor
457 .validate_parameter_type("test", "number", &json!("hello"))
458 .is_err());
459
460 assert!(executor
462 .validate_parameter_type("test", "boolean", &json!(true))
463 .is_ok());
464 assert!(executor
465 .validate_parameter_type("test", "boolean", &json!("hello"))
466 .is_err());
467
468 assert!(executor
470 .validate_parameter_type("test", "array", &json!([1, 2, 3]))
471 .is_ok());
472 assert!(executor
473 .validate_parameter_type("test", "array", &json!("hello"))
474 .is_err());
475
476 assert!(executor
478 .validate_parameter_type("test", "object", &json!({}))
479 .is_ok());
480 assert!(executor
481 .validate_parameter_type("test", "object", &json!("hello"))
482 .is_err());
483 }
484}