ai_sdk_core/tool/
mod.rs

1mod tool_output;
2
3pub use tool_output::ToolOutput;
4
5use crate::error::ToolError;
6use ai_sdk_provider::language_model::{
7    FunctionTool, Message, ToolCallPart, ToolResultOutput, ToolResultPart,
8};
9use ai_sdk_provider::JsonValue;
10use async_trait::async_trait;
11use serde_json::Value;
12use std::sync::Arc;
13
14/// Context information provided to tools during execution.
15///
16/// This structure contains metadata and conversation history that tools may need
17/// during execution. It provides tools with access to the tool call identifier and
18/// the full conversation context that led to the tool invocation.
19pub struct ToolContext {
20    /// Unique identifier for the tool call being executed.
21    /// This ID is used to correlate tool results back to the original invocation.
22    pub tool_call_id: String,
23    /// Conversation messages leading up to this tool call.
24    /// Includes all previous user, assistant, and tool result messages
25    /// that provide context for the current execution.
26    pub messages: Vec<Message>,
27}
28
29/// Trait that tools must implement to be available for language model invocation.
30///
31/// Tools are callable units of functionality that language models can invoke to
32/// perform actions, retrieve information, or integrate with external systems.
33/// Implementations must provide metadata (name, description, schema) and an
34/// execution method that processes the tool input and returns results.
35#[async_trait]
36pub trait Tool: Send + Sync {
37    /// Returns the name of the tool as recognized by the language model.
38    ///
39    /// This name must be unique among all available tools and is used by the
40    /// language model to reference this specific tool when making invocations.
41    /// Names should be alphanumeric with underscores and clearly describe the tool's purpose.
42    fn name(&self) -> &str;
43
44    /// Returns a human-readable description of what the tool does.
45    ///
46    /// This description is provided to the language model to help it understand
47    /// the tool's purpose and when to use it. It should be clear, concise, and
48    /// explain the primary function and use cases.
49    fn description(&self) -> &str;
50
51    /// Returns the JSON Schema that defines the structure of the tool's input parameters.
52    ///
53    /// The schema should be a valid JSON Schema document that describes all required
54    /// and optional input parameters. This schema is used by the language model to
55    /// understand what inputs the tool expects and by the runtime to validate inputs.
56    fn input_schema(&self) -> Value;
57
58    /// Executes the tool with the given input and context.
59    ///
60    /// This method is responsible for performing the actual work of the tool.
61    /// It may return either a single value or a stream of preliminary results
62    /// (useful for long-running operations that can produce incremental output).
63    ///
64    /// # Arguments
65    ///
66    /// * `input` - The parsed JSON input to the tool, must conform to the schema
67    ///   returned by `input_schema()`.
68    /// * `context` - Additional execution context including the tool call ID and
69    ///   conversation history.
70    ///
71    /// # Returns
72    ///
73    /// * `Ok(ToolOutput)` - Either a single result value or a stream of preliminary results.
74    /// * `Err(ToolError)` - An error that occurred during execution.
75    async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolOutput, ToolError>;
76
77    /// Determines whether this tool requires explicit user approval before execution.
78    ///
79    /// Implement this method to require approval for potentially sensitive operations.
80    /// The runtime will check this method before invoking the tool and will deny
81    /// execution if approval is required.
82    ///
83    /// # Arguments
84    ///
85    /// * `_input` - The input parameters that would be passed to the tool. Can be
86    ///   examined to make approval decisions based on the specific operation.
87    ///
88    /// # Returns
89    ///
90    /// * `true` if approval is required before execution.
91    /// * `false` if the tool can be executed automatically (default).
92    fn needs_approval(&self, _input: &Value) -> bool {
93        false
94    }
95
96    /// Converts the tool's raw output into a structured format for the language model.
97    ///
98    /// This method allows customization of how tool output is formatted and presented
99    /// back to the language model. The default implementation converts strings to text
100    /// output and other values to JSON output.
101    ///
102    /// # Arguments
103    ///
104    /// * `output` - The raw output value from `execute()`.
105    ///
106    /// # Returns
107    ///
108    /// A `ToolResultOutput` variant representing the formatted output that will be
109    /// returned to the language model.
110    fn to_model_output(&self, output: JsonValue) -> ToolResultOutput {
111        // Default implementation
112        match output {
113            JsonValue::String(s) => ToolResultOutput::Text {
114                value: s,
115                provider_metadata: None,
116            },
117            other => ToolResultOutput::Json {
118                value: other,
119                provider_metadata: None,
120            },
121        }
122    }
123}
124
125/// Manages the execution of tools invoked by language models.
126///
127/// The `ToolExecutor` is responsible for:
128/// - Maintaining a registry of available tools
129/// - Executing tool calls invoked by the language model
130/// - Handling approval checks for sensitive operations
131/// - Managing both single and streaming tool execution
132/// - Converting tool results to structured output formats
133///
134/// It supports both parallel execution of multiple tool calls and sequential
135/// streaming execution with preliminary result callbacks.
136pub struct ToolExecutor {
137    tools: Vec<Arc<dyn Tool>>,
138}
139
140impl ToolExecutor {
141    /// Creates a new `ToolExecutor` with the provided set of tools.
142    ///
143    /// # Arguments
144    ///
145    /// * `tools` - A vector of tool implementations to be managed by this executor.
146    ///   Tools are wrapped in `Arc` for thread-safe shared ownership.
147    pub fn new(tools: Vec<Arc<dyn Tool>>) -> Self {
148        Self { tools }
149    }
150
151    /// Retrieves the tool definitions in a format suitable for language model APIs.
152    ///
153    /// This converts the internal tool representations into `FunctionTool` definitions
154    /// that can be provided to language models. These definitions include the tool's
155    /// name, description, and input schema.
156    ///
157    /// # Returns
158    ///
159    /// A vector of `FunctionTool` definitions, one for each registered tool.
160    pub fn tool_definitions(&self) -> Vec<FunctionTool> {
161        self.tools
162            .iter()
163            .map(|tool| FunctionTool {
164                name: tool.name().to_string(),
165                description: Some(tool.description().to_string()),
166                input_schema: tool.input_schema(),
167                provider_options: None,
168            })
169            .collect()
170    }
171
172    /// Executes multiple tool calls in parallel and returns their results.
173    ///
174    /// This method processes all provided tool calls concurrently, allowing for
175    /// efficient execution of multiple independent tools. Each tool call is handled
176    /// independently, with its own error handling and result conversion.
177    ///
178    /// The execution includes:
179    /// - Tool lookup by name
180    /// - Input validation and parsing
181    /// - Approval checks (tools may require user approval)
182    /// - Actual tool execution
183    /// - Result conversion to structured format
184    ///
185    /// # Arguments
186    ///
187    /// * `tool_calls` - A vector of tool calls to execute, each specifying the tool
188    ///   name, call ID, and JSON-encoded input parameters.
189    ///
190    /// # Returns
191    ///
192    /// A vector of `ToolResultPart` structures, one for each input tool call.
193    /// Results include either successful output or error information for each tool.
194    pub async fn execute_tools(&self, tool_calls: Vec<ToolCallPart>) -> Vec<ToolResultPart> {
195        let mut futures = Vec::new();
196
197        for tool_call in tool_calls {
198            let tool_opt = self.find_tool(&tool_call.tool_name);
199            let tool_call_id = tool_call.tool_call_id.clone();
200            let tool_name = tool_call.tool_name.clone();
201            let input_str = tool_call.input.clone();
202
203            let future = async move {
204                // Handle tool not found
205                let tool = match tool_opt {
206                    Some(t) => t,
207                    None => {
208                        return ToolResultPart {
209                            tool_call_id,
210                            tool_name: tool_name.clone(),
211                            output: ToolResultOutput::ErrorText {
212                                value: format!("Tool '{}' not found", tool_name),
213                                provider_metadata: None,
214                            },
215                            preliminary: None,
216                            provider_metadata: None,
217                        };
218                    }
219                };
220
221                let context = ToolContext {
222                    tool_call_id: tool_call_id.clone(),
223                    messages: vec![], // TODO: pass actual messages
224                };
225
226                // Parse input
227                let input: Value = match serde_json::from_str(&input_str) {
228                    Ok(v) => v,
229                    Err(e) => {
230                        // Return error result instead of propagating
231                        return ToolResultPart {
232                            tool_call_id,
233                            tool_name,
234                            output: ToolResultOutput::ErrorText {
235                                value: format!("Invalid input: {}", e),
236                                provider_metadata: None,
237                            },
238                            preliminary: None,
239                            provider_metadata: None,
240                        };
241                    }
242                };
243
244                // Check approval
245                if tool.needs_approval(&input) {
246                    // Return denial result
247                    return ToolResultPart {
248                        tool_call_id,
249                        tool_name,
250                        output: ToolResultOutput::ExecutionDenied {
251                            reason: Some("Execution denied by user".to_string()),
252                            provider_metadata: None,
253                        },
254                        preliminary: None,
255                        provider_metadata: None,
256                    };
257                }
258
259                // Execute tool and convert to structured output
260                let output = match tool.execute(input, &context).await {
261                    Ok(tool_output) => {
262                        // Handle both Value and Stream outputs
263                        match tool_output {
264                            ToolOutput::Value(raw_output) => {
265                                // Success - convert to structured output
266                                tool.to_model_output(raw_output)
267                            }
268                            ToolOutput::Stream(mut stream) => {
269                                // For non-streaming execute_tools, just get the last value
270                                use futures::stream::StreamExt;
271                                let mut last_output = None;
272                                while let Some(item) = stream.next().await {
273                                    match item {
274                                        Ok(output) => last_output = Some(output),
275                                        Err(e) => {
276                                            return ToolResultPart {
277                                                tool_call_id,
278                                                tool_name,
279                                                output: ToolResultOutput::ErrorText {
280                                                    value: e.to_string(),
281                                                    provider_metadata: None,
282                                                },
283                                                preliminary: None,
284                                                provider_metadata: None,
285                                            };
286                                        }
287                                    }
288                                }
289                                // Convert final output
290                                let final_value = last_output.unwrap_or(JsonValue::Null);
291                                tool.to_model_output(final_value)
292                            }
293                        }
294                    }
295                    Err(error) => {
296                        // Execution error - return error output
297                        ToolResultOutput::ErrorText {
298                            value: error.to_string(),
299                            provider_metadata: None,
300                        }
301                    }
302                };
303
304                ToolResultPart {
305                    tool_call_id,
306                    tool_name,
307                    output,
308                    preliminary: None,
309                    provider_metadata: None,
310                }
311            };
312
313            futures.push(future);
314        }
315
316        // Execute all tools in parallel
317        futures::future::join_all(futures).await
318    }
319
320    /// Executes a single tool call with streaming support and preliminary result callbacks.
321    ///
322    /// This method is specialized for tools that produce streaming output. It handles
323    /// the execution of a single tool call and invokes a callback for each preliminary
324    /// result as they arrive. This is useful for long-running operations that can
325    /// produce incremental output.
326    ///
327    /// The method:
328    /// - Finds and validates the tool
329    /// - Parses and validates input
330    /// - Checks approval requirements
331    /// - Executes the tool
332    /// - For streaming results, invokes the callback for each intermediate value
333    /// - Returns the final result after all streaming is complete
334    ///
335    /// # Arguments
336    ///
337    /// * `tool_call` - The tool call to execute, specifying the tool name, call ID,
338    ///   and JSON-encoded input parameters.
339    /// * `on_preliminary` - A callback function that is invoked for each preliminary
340    ///   result as they stream in. Useful for real-time processing or UI updates.
341    ///   The callback is not invoked for the final result.
342    ///
343    /// # Returns
344    ///
345    /// A `ToolResultPart` containing the final tool result, or an error if the
346    /// tool execution or streaming fails. The final result does not have the
347    /// `preliminary` flag set.
348    pub async fn execute_tool_with_stream<F>(
349        &self,
350        tool_call: ToolCallPart,
351        on_preliminary: F,
352    ) -> ToolResultPart
353    where
354        F: Fn(ToolResultPart) + Send,
355    {
356        let tool_call_id = tool_call.tool_call_id.clone();
357        let tool_name = tool_call.tool_name.clone();
358
359        // Find tool
360        let tool = match self.find_tool(&tool_call.tool_name) {
361            Some(t) => t,
362            None => {
363                return ToolResultPart {
364                    tool_call_id,
365                    tool_name: tool_name.clone(),
366                    output: ToolResultOutput::ErrorText {
367                        value: format!("Tool '{}' not found", tool_name),
368                        provider_metadata: None,
369                    },
370                    preliminary: None,
371                    provider_metadata: None,
372                };
373            }
374        };
375
376        let context = ToolContext {
377            tool_call_id: tool_call_id.clone(),
378            messages: vec![], // TODO: pass actual messages
379        };
380
381        // Parse input
382        let input: Value = match serde_json::from_str(&tool_call.input) {
383            Ok(v) => v,
384            Err(e) => {
385                return ToolResultPart {
386                    tool_call_id,
387                    tool_name,
388                    output: ToolResultOutput::ErrorText {
389                        value: format!("Invalid input: {}", e),
390                        provider_metadata: None,
391                    },
392                    preliminary: None,
393                    provider_metadata: None,
394                };
395            }
396        };
397
398        // Check approval
399        if tool.needs_approval(&input) {
400            return ToolResultPart {
401                tool_call_id,
402                tool_name,
403                output: ToolResultOutput::ExecutionDenied {
404                    reason: Some("Execution denied by user".to_string()),
405                    provider_metadata: None,
406                },
407                preliminary: None,
408                provider_metadata: None,
409            };
410        }
411
412        // Execute tool
413        match tool.execute(input, &context).await {
414            Ok(ToolOutput::Value(value)) => {
415                // Simple case: single result
416                let output = tool.to_model_output(value);
417                ToolResultPart {
418                    tool_call_id,
419                    tool_name,
420                    output,
421                    preliminary: None,
422                    provider_metadata: None,
423                }
424            }
425            Ok(ToolOutput::Stream(mut stream)) => {
426                use futures::stream::StreamExt;
427
428                let mut last_output = None;
429
430                // Process all stream items
431                while let Some(item) = stream.next().await {
432                    match item {
433                        Ok(output) => {
434                            // Emit preliminary result
435                            let structured = tool.to_model_output(output.clone());
436                            let preliminary_result = ToolResultPart {
437                                tool_call_id: tool_call_id.clone(),
438                                tool_name: tool_name.clone(),
439                                output: structured,
440                                preliminary: Some(true),
441                                provider_metadata: None,
442                            };
443
444                            on_preliminary(preliminary_result);
445                            last_output = Some(output);
446                        }
447                        Err(e) => {
448                            return ToolResultPart {
449                                tool_call_id,
450                                tool_name,
451                                output: ToolResultOutput::ErrorText {
452                                    value: e.to_string(),
453                                    provider_metadata: None,
454                                },
455                                preliminary: None,
456                                provider_metadata: None,
457                            };
458                        }
459                    }
460                }
461
462                // Return final result (last output without preliminary flag)
463                let final_value = last_output.unwrap_or(JsonValue::Null);
464                let final_output = tool.to_model_output(final_value);
465
466                ToolResultPart {
467                    tool_call_id,
468                    tool_name,
469                    output: final_output,
470                    preliminary: None, // Final result
471                    provider_metadata: None,
472                }
473            }
474            Err(error) => ToolResultPart {
475                tool_call_id,
476                tool_name,
477                output: ToolResultOutput::ErrorText {
478                    value: error.to_string(),
479                    provider_metadata: None,
480                },
481                preliminary: None,
482                provider_metadata: None,
483            },
484        }
485    }
486
487    /// Searches for a tool by name in the executor's tool registry.
488    ///
489    /// # Arguments
490    ///
491    /// * `name` - The name of the tool to find. Must match the value returned
492    ///   by the tool's `name()` method exactly.
493    ///
494    /// # Returns
495    ///
496    /// * `Some(Arc<dyn Tool>)` if a tool with the specified name is found.
497    /// * `None` if no tool with the specified name is available.
498    fn find_tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
499        self.tools.iter().find(|t| t.name() == name).cloned()
500    }
501
502    /// Returns a reference to the list of all available tools in this executor.
503    ///
504    /// This provides access to the complete registry of tools that can be invoked
505    /// by the language model. Useful for introspection and tool enumeration.
506    ///
507    /// # Returns
508    ///
509    /// A slice of all registered tools, allowing iteration over the available tools.
510    pub fn tools(&self) -> &[Arc<dyn Tool>] {
511        &self.tools
512    }
513}
514
515#[cfg(test)]
516mod tests {
517    use super::*;
518
519    struct TestTool {
520        name: String,
521        result: String,
522    }
523
524    #[async_trait]
525    impl Tool for TestTool {
526        fn name(&self) -> &str {
527            &self.name
528        }
529
530        fn description(&self) -> &str {
531            "A test tool"
532        }
533
534        fn input_schema(&self) -> Value {
535            serde_json::json!({
536                "type": "object",
537                "properties": {}
538            })
539        }
540
541        async fn execute(
542            &self,
543            _input: Value,
544            _context: &ToolContext,
545        ) -> Result<ToolOutput, ToolError> {
546            Ok(ToolOutput::Value(JsonValue::String(self.result.clone())))
547        }
548    }
549
550    #[tokio::test]
551    async fn test_tool_executor_find_tool() {
552        let tool = Arc::new(TestTool {
553            name: "test".to_string(),
554            result: "success".to_string(),
555        });
556
557        let executor = ToolExecutor::new(vec![tool]);
558        assert!(executor.find_tool("test").is_some());
559        assert!(executor.find_tool("nonexistent").is_none());
560    }
561
562    #[tokio::test]
563    async fn test_tool_executor_execute() {
564        let tool = Arc::new(TestTool {
565            name: "test".to_string(),
566            result: "success".to_string(),
567        });
568
569        let executor = ToolExecutor::new(vec![tool]);
570
571        let tool_call = ToolCallPart {
572            tool_call_id: "call_123".to_string(),
573            tool_name: "test".to_string(),
574            input: "{}".to_string(),
575            provider_executed: None,
576            dynamic: None,
577            provider_metadata: None,
578        };
579
580        let results = executor.execute_tools(vec![tool_call]).await;
581        assert_eq!(results.len(), 1);
582        assert_eq!(results[0].tool_call_id, "call_123");
583        assert_eq!(results[0].tool_name, "test");
584
585        // Check that output is Text variant
586        match &results[0].output {
587            ToolResultOutput::Text { value, .. } => {
588                assert_eq!(value, "success");
589            }
590            _ => panic!("Expected Text output variant"),
591        }
592    }
593
594    #[tokio::test]
595    async fn test_tool_executor_tool_not_found() {
596        let executor = ToolExecutor::new(vec![]);
597
598        let tool_call = ToolCallPart {
599            tool_call_id: "call_123".to_string(),
600            tool_name: "nonexistent".to_string(),
601            input: "{}".to_string(),
602            provider_executed: None,
603            dynamic: None,
604            provider_metadata: None,
605        };
606
607        let results = executor.execute_tools(vec![tool_call]).await;
608        assert_eq!(results.len(), 1);
609
610        // Check that output is ErrorText variant
611        match &results[0].output {
612            ToolResultOutput::ErrorText { value, .. } => {
613                assert!(value.contains("not found"));
614            }
615            _ => panic!("Expected ErrorText output variant"),
616        }
617    }
618}