1use crate::error::Result;
7use crate::registry::ToolRegistry;
8use crate::metadata::ToolMetadata;
9use serde_json::Value;
10use std::collections::HashMap;
11use std::sync::Arc;
12
13pub trait ToolInvoker: Send + Sync {
18 fn invoke_tool(&self, tool_id: &str, parameters: HashMap<String, Value>) -> Result<Value>;
29}
30
31pub trait ToolDiscovery: Send + Sync {
35 fn get_all_tools(&self) -> Vec<ToolMetadata>;
37
38 fn get_tools_by_category(&self, category: &str) -> Vec<ToolMetadata>;
40
41 fn get_tools_by_server(&self, server_id: &str) -> Vec<ToolMetadata>;
43
44 fn get_tool(&self, tool_id: &str) -> Option<ToolMetadata>;
46
47 fn search_tools(&self, query: &str) -> Vec<ToolMetadata>;
49}
50
51pub struct AgentToolCapabilities {
56 registry: Arc<ToolRegistry>,
57 invoker: Arc<dyn ToolInvoker>,
58}
59
60impl AgentToolCapabilities {
61 pub fn new(registry: Arc<ToolRegistry>, invoker: Arc<dyn ToolInvoker>) -> Self {
68 Self { registry, invoker }
69 }
70
71 pub fn invoke_tool(&self, tool_id: &str, parameters: HashMap<String, Value>) -> Result<Value> {
82 self.invoker.invoke_tool(tool_id, parameters)
83 }
84
85 pub fn get_all_tools(&self) -> Vec<ToolMetadata> {
87 self.registry
88 .list_tools()
89 .into_iter()
90 .cloned()
91 .collect()
92 }
93
94 pub fn get_tools_by_category(&self, category: &str) -> Vec<ToolMetadata> {
96 self.registry
97 .list_tools_by_category(category)
98 .into_iter()
99 .cloned()
100 .collect()
101 }
102
103 pub fn get_tools_by_server(&self, server_id: &str) -> Vec<ToolMetadata> {
105 self.registry
106 .list_tools_by_server(server_id)
107 .into_iter()
108 .cloned()
109 .collect()
110 }
111
112 pub fn get_tool(&self, tool_id: &str) -> Option<ToolMetadata> {
114 self.registry.get_tool(tool_id).cloned()
115 }
116
117 pub fn search_tools(&self, query: &str) -> Vec<ToolMetadata> {
119 let query_lower = query.to_lowercase();
120 self.registry
121 .list_tools()
122 .into_iter()
123 .filter(|tool| {
124 tool.name.to_lowercase().contains(&query_lower)
125 || tool.description.to_lowercase().contains(&query_lower)
126 })
127 .cloned()
128 .collect()
129 }
130
131 pub fn tool_count(&self) -> usize {
133 self.registry.tool_count()
134 }
135
136 pub fn get_tool_documentation(&self, tool_id: &str) -> Option<String> {
138 self.registry.get_tool(tool_id).map(|tool| tool.get_documentation())
139 }
140}
141
142pub struct ToolExecutionContext {
146 pub agent_id: String,
148 pub task_id: String,
150 pub metadata: HashMap<String, Value>,
152}
153
154impl ToolExecutionContext {
155 pub fn new(agent_id: String, task_id: String) -> Self {
157 Self {
158 agent_id,
159 task_id,
160 metadata: HashMap::new(),
161 }
162 }
163
164 pub fn with_metadata(mut self, key: String, value: Value) -> Self {
166 self.metadata.insert(key, value);
167 self
168 }
169}
170
171#[derive(Debug, Clone)]
175pub struct ToolExecutionResult {
176 pub tool_id: String,
178 pub success: bool,
180 pub output: Value,
182 pub error: Option<String>,
184 pub duration_ms: u64,
186}
187
188impl ToolExecutionResult {
189 pub fn success(tool_id: String, output: Value, duration_ms: u64) -> Self {
191 Self {
192 tool_id,
193 success: true,
194 output,
195 error: None,
196 duration_ms,
197 }
198 }
199
200 pub fn failure(tool_id: String, error: String, duration_ms: u64) -> Self {
202 Self {
203 tool_id,
204 success: false,
205 output: Value::Null,
206 error: Some(error),
207 duration_ms,
208 }
209 }
210}
211
212pub struct ToolWorkflowIntegration {
217 capabilities: Arc<AgentToolCapabilities>,
218}
219
220impl ToolWorkflowIntegration {
221 pub fn new(capabilities: Arc<AgentToolCapabilities>) -> Self {
223 Self { capabilities }
224 }
225
226 pub async fn execute_tool(
228 &self,
229 context: &ToolExecutionContext,
230 tool_id: &str,
231 parameters: HashMap<String, Value>,
232 ) -> Result<ToolExecutionResult> {
233 let _ = context; let start = std::time::Instant::now();
235
236 match self.capabilities.invoke_tool(tool_id, parameters) {
237 Ok(output) => {
238 let duration_ms = start.elapsed().as_millis() as u64;
239 Ok(ToolExecutionResult::success(
240 tool_id.to_string(),
241 output,
242 duration_ms,
243 ))
244 }
245 Err(e) => {
246 let duration_ms = start.elapsed().as_millis() as u64;
247 Ok(ToolExecutionResult::failure(
248 tool_id.to_string(),
249 e.to_string(),
250 duration_ms,
251 ))
252 }
253 }
254 }
255
256 pub async fn execute_tools_sequential(
258 &self,
259 context: &ToolExecutionContext,
260 tools: Vec<(String, HashMap<String, Value>)>,
261 ) -> Result<Vec<ToolExecutionResult>> {
262 let mut results = Vec::new();
263
264 for (tool_id, parameters) in tools {
265 let result = self.execute_tool(context, &tool_id, parameters).await?;
266 results.push(result);
267 }
268
269 Ok(results)
270 }
271
272 pub async fn execute_tools_parallel(
274 &self,
275 _context: &ToolExecutionContext,
276 tools: Vec<(String, HashMap<String, Value>)>,
277 ) -> Result<Vec<ToolExecutionResult>> {
278 let mut handles = Vec::new();
279
280 for (tool_id, parameters) in tools {
281 let capabilities = self.capabilities.clone();
282
283 let handle = tokio::spawn(async move {
284 let start = std::time::Instant::now();
285
286 match capabilities.invoke_tool(&tool_id, parameters) {
287 Ok(output) => {
288 let duration_ms = start.elapsed().as_millis() as u64;
289 ToolExecutionResult::success(tool_id, output, duration_ms)
290 }
291 Err(e) => {
292 let duration_ms = start.elapsed().as_millis() as u64;
293 ToolExecutionResult::failure(tool_id, e.to_string(), duration_ms)
294 }
295 }
296 });
297
298 handles.push(handle);
299 }
300
301 let mut results = Vec::new();
302 for handle in handles {
303 match handle.await {
304 Ok(result) => results.push(result),
305 Err(e) => {
306 return Err(crate::error::Error::ExecutionError(format!(
307 "Tool execution task failed: {}",
308 e
309 )))
310 }
311 }
312 }
313
314 Ok(results)
315 }
316
317 pub fn get_available_tools(&self) -> Vec<ToolMetadata> {
319 self.capabilities.get_all_tools()
320 }
321
322 pub fn get_tool_documentation(&self, tool_id: &str) -> Option<String> {
324 self.capabilities.get_tool_documentation(tool_id)
325 }
326}
327
328#[cfg(test)]
329mod tests {
330 use super::*;
331
332 struct MockToolInvoker;
333
334 impl ToolInvoker for MockToolInvoker {
335 fn invoke_tool(&self, tool_id: &str, _parameters: HashMap<String, Value>) -> Result<Value> {
336 Ok(serde_json::json!({
337 "tool_id": tool_id,
338 "result": "success"
339 }))
340 }
341 }
342
343 #[test]
344 fn test_agent_tool_capabilities_creation() {
345 let registry = Arc::new(ToolRegistry::new());
346 let invoker: Arc<dyn ToolInvoker> = Arc::new(MockToolInvoker);
347 let capabilities = AgentToolCapabilities::new(registry, invoker);
348
349 assert_eq!(capabilities.tool_count(), 0);
350 }
351
352 #[test]
353 fn test_tool_execution_context_creation() {
354 let context = ToolExecutionContext::new("agent-1".to_string(), "task-1".to_string());
355
356 assert_eq!(context.agent_id, "agent-1");
357 assert_eq!(context.task_id, "task-1");
358 assert!(context.metadata.is_empty());
359 }
360
361 #[test]
362 fn test_tool_execution_context_with_metadata() {
363 let context = ToolExecutionContext::new("agent-1".to_string(), "task-1".to_string())
364 .with_metadata("key1".to_string(), serde_json::json!("value1"));
365
366 assert_eq!(context.metadata.len(), 1);
367 assert_eq!(context.metadata.get("key1"), Some(&serde_json::json!("value1")));
368 }
369
370 #[test]
371 fn test_tool_execution_result_success() {
372 let result = ToolExecutionResult::success(
373 "tool-1".to_string(),
374 serde_json::json!({"result": "success"}),
375 100,
376 );
377
378 assert_eq!(result.tool_id, "tool-1");
379 assert!(result.success);
380 assert_eq!(result.duration_ms, 100);
381 assert!(result.error.is_none());
382 }
383
384 #[test]
385 fn test_tool_execution_result_failure() {
386 let result = ToolExecutionResult::failure(
387 "tool-1".to_string(),
388 "Tool execution failed".to_string(),
389 100,
390 );
391
392 assert_eq!(result.tool_id, "tool-1");
393 assert!(!result.success);
394 assert_eq!(result.duration_ms, 100);
395 assert_eq!(result.error, Some("Tool execution failed".to_string()));
396 }
397
398 #[tokio::test]
399 async fn test_tool_workflow_integration_execute_tool() {
400 let registry = Arc::new(ToolRegistry::new());
401 let invoker: Arc<dyn ToolInvoker> = Arc::new(MockToolInvoker);
402 let capabilities = Arc::new(AgentToolCapabilities::new(registry, invoker));
403 let workflow = ToolWorkflowIntegration::new(capabilities);
404
405 let context = ToolExecutionContext::new("agent-1".to_string(), "task-1".to_string());
406 let mut params = HashMap::new();
407 params.insert("param1".to_string(), serde_json::json!("value1"));
408
409 let result = workflow
410 .execute_tool(&context, "tool-1", params)
411 .await
412 .unwrap();
413
414 assert_eq!(result.tool_id, "tool-1");
415 assert!(result.success);
416 }
417
418 #[tokio::test]
419 async fn test_tool_workflow_integration_execute_tools_sequential() {
420 let registry = Arc::new(ToolRegistry::new());
421 let invoker: Arc<dyn ToolInvoker> = Arc::new(MockToolInvoker);
422 let capabilities = Arc::new(AgentToolCapabilities::new(registry, invoker));
423 let workflow = ToolWorkflowIntegration::new(capabilities);
424
425 let context = ToolExecutionContext::new("agent-1".to_string(), "task-1".to_string());
426 let tools = vec![
427 ("tool-1".to_string(), HashMap::new()),
428 ("tool-2".to_string(), HashMap::new()),
429 ];
430
431 let results = workflow
432 .execute_tools_sequential(&context, tools)
433 .await
434 .unwrap();
435
436 assert_eq!(results.len(), 2);
437 assert!(results[0].success);
438 assert!(results[1].success);
439 }
440
441 #[tokio::test]
442 async fn test_tool_workflow_integration_execute_tools_parallel() {
443 let registry = Arc::new(ToolRegistry::new());
444 let invoker: Arc<dyn ToolInvoker> = Arc::new(MockToolInvoker);
445 let capabilities = Arc::new(AgentToolCapabilities::new(registry, invoker));
446 let workflow = ToolWorkflowIntegration::new(capabilities);
447
448 let context = ToolExecutionContext::new("agent-1".to_string(), "task-1".to_string());
449 let tools = vec![
450 ("tool-1".to_string(), HashMap::new()),
451 ("tool-2".to_string(), HashMap::new()),
452 ];
453
454 let results = workflow
455 .execute_tools_parallel(&context, tools)
456 .await
457 .unwrap();
458
459 assert_eq!(results.len(), 2);
460 assert!(results.iter().all(|r| r.success));
461 }
462}