ai_sdk_core/tool/mod.rs
1mod tool_output;
2
3pub use tool_output::ToolOutput;
4
5use crate::error::ToolError;
6use ai_sdk_provider::language_model::{
7 FunctionTool, Message, ToolCallPart, ToolResultOutput, ToolResultPart,
8};
9use ai_sdk_provider::JsonValue;
10use async_trait::async_trait;
11use serde_json::Value;
12use std::sync::Arc;
13
14/// Context information provided to tools during execution.
15///
16/// This structure contains metadata and conversation history that tools may need
17/// during execution. It provides tools with access to the tool call identifier and
18/// the full conversation context that led to the tool invocation.
19pub struct ToolContext {
20 /// Unique identifier for the tool call being executed.
21 /// This ID is used to correlate tool results back to the original invocation.
22 pub tool_call_id: String,
23 /// Conversation messages leading up to this tool call.
24 /// Includes all previous user, assistant, and tool result messages
25 /// that provide context for the current execution.
26 pub messages: Vec<Message>,
27}
28
29/// Trait that tools must implement to be available for language model invocation.
30///
31/// Tools are callable units of functionality that language models can invoke to
32/// perform actions, retrieve information, or integrate with external systems.
33/// Implementations must provide metadata (name, description, schema) and an
34/// execution method that processes the tool input and returns results.
35#[async_trait]
36pub trait Tool: Send + Sync {
37 /// Returns the name of the tool as recognized by the language model.
38 ///
39 /// This name must be unique among all available tools and is used by the
40 /// language model to reference this specific tool when making invocations.
41 /// Names should be alphanumeric with underscores and clearly describe the tool's purpose.
42 fn name(&self) -> &str;
43
44 /// Returns a human-readable description of what the tool does.
45 ///
46 /// This description is provided to the language model to help it understand
47 /// the tool's purpose and when to use it. It should be clear, concise, and
48 /// explain the primary function and use cases.
49 fn description(&self) -> &str;
50
51 /// Returns the JSON Schema that defines the structure of the tool's input parameters.
52 ///
53 /// The schema should be a valid JSON Schema document that describes all required
54 /// and optional input parameters. This schema is used by the language model to
55 /// understand what inputs the tool expects and by the runtime to validate inputs.
56 fn input_schema(&self) -> Value;
57
58 /// Executes the tool with the given input and context.
59 ///
60 /// This method is responsible for performing the actual work of the tool.
61 /// It may return either a single value or a stream of preliminary results
62 /// (useful for long-running operations that can produce incremental output).
63 ///
64 /// # Arguments
65 ///
66 /// * `input` - The parsed JSON input to the tool, must conform to the schema
67 /// returned by `input_schema()`.
68 /// * `context` - Additional execution context including the tool call ID and
69 /// conversation history.
70 ///
71 /// # Returns
72 ///
73 /// * `Ok(ToolOutput)` - Either a single result value or a stream of preliminary results.
74 /// * `Err(ToolError)` - An error that occurred during execution.
75 async fn execute(&self, input: Value, context: &ToolContext) -> Result<ToolOutput, ToolError>;
76
77 /// Determines whether this tool requires explicit user approval before execution.
78 ///
79 /// Implement this method to require approval for potentially sensitive operations.
80 /// The runtime will check this method before invoking the tool and will deny
81 /// execution if approval is required.
82 ///
83 /// # Arguments
84 ///
85 /// * `_input` - The input parameters that would be passed to the tool. Can be
86 /// examined to make approval decisions based on the specific operation.
87 ///
88 /// # Returns
89 ///
90 /// * `true` if approval is required before execution.
91 /// * `false` if the tool can be executed automatically (default).
92 fn needs_approval(&self, _input: &Value) -> bool {
93 false
94 }
95
96 /// Converts the tool's raw output into a structured format for the language model.
97 ///
98 /// This method allows customization of how tool output is formatted and presented
99 /// back to the language model. The default implementation converts strings to text
100 /// output and other values to JSON output.
101 ///
102 /// # Arguments
103 ///
104 /// * `output` - The raw output value from `execute()`.
105 ///
106 /// # Returns
107 ///
108 /// A `ToolResultOutput` variant representing the formatted output that will be
109 /// returned to the language model.
110 fn to_model_output(&self, output: JsonValue) -> ToolResultOutput {
111 // Default implementation
112 match output {
113 JsonValue::String(s) => ToolResultOutput::Text {
114 value: s,
115 provider_metadata: None,
116 },
117 other => ToolResultOutput::Json {
118 value: other,
119 provider_metadata: None,
120 },
121 }
122 }
123}
124
125/// Manages the execution of tools invoked by language models.
126///
127/// The `ToolExecutor` is responsible for:
128/// - Maintaining a registry of available tools
129/// - Executing tool calls invoked by the language model
130/// - Handling approval checks for sensitive operations
131/// - Managing both single and streaming tool execution
132/// - Converting tool results to structured output formats
133///
134/// It supports both parallel execution of multiple tool calls and sequential
135/// streaming execution with preliminary result callbacks.
136pub struct ToolExecutor {
137 tools: Vec<Arc<dyn Tool>>,
138}
139
140impl ToolExecutor {
141 /// Creates a new `ToolExecutor` with the provided set of tools.
142 ///
143 /// # Arguments
144 ///
145 /// * `tools` - A vector of tool implementations to be managed by this executor.
146 /// Tools are wrapped in `Arc` for thread-safe shared ownership.
147 pub fn new(tools: Vec<Arc<dyn Tool>>) -> Self {
148 Self { tools }
149 }
150
151 /// Retrieves the tool definitions in a format suitable for language model APIs.
152 ///
153 /// This converts the internal tool representations into `FunctionTool` definitions
154 /// that can be provided to language models. These definitions include the tool's
155 /// name, description, and input schema.
156 ///
157 /// # Returns
158 ///
159 /// A vector of `FunctionTool` definitions, one for each registered tool.
160 pub fn tool_definitions(&self) -> Vec<FunctionTool> {
161 self.tools
162 .iter()
163 .map(|tool| FunctionTool {
164 name: tool.name().to_string(),
165 description: Some(tool.description().to_string()),
166 input_schema: tool.input_schema(),
167 provider_options: None,
168 })
169 .collect()
170 }
171
172 /// Executes multiple tool calls in parallel and returns their results.
173 ///
174 /// This method processes all provided tool calls concurrently, allowing for
175 /// efficient execution of multiple independent tools. Each tool call is handled
176 /// independently, with its own error handling and result conversion.
177 ///
178 /// The execution includes:
179 /// - Tool lookup by name
180 /// - Input validation and parsing
181 /// - Approval checks (tools may require user approval)
182 /// - Actual tool execution
183 /// - Result conversion to structured format
184 ///
185 /// # Arguments
186 ///
187 /// * `tool_calls` - A vector of tool calls to execute, each specifying the tool
188 /// name, call ID, and JSON-encoded input parameters.
189 ///
190 /// # Returns
191 ///
192 /// A vector of `ToolResultPart` structures, one for each input tool call.
193 /// Results include either successful output or error information for each tool.
194 pub async fn execute_tools(&self, tool_calls: Vec<ToolCallPart>) -> Vec<ToolResultPart> {
195 let mut futures = Vec::new();
196
197 for tool_call in tool_calls {
198 let tool_opt = self.find_tool(&tool_call.tool_name);
199 let tool_call_id = tool_call.tool_call_id.clone();
200 let tool_name = tool_call.tool_name.clone();
201 let input_str = tool_call.input.clone();
202
203 let future = async move {
204 // Handle tool not found
205 let tool = match tool_opt {
206 Some(t) => t,
207 None => {
208 return ToolResultPart {
209 tool_call_id,
210 tool_name: tool_name.clone(),
211 output: ToolResultOutput::ErrorText {
212 value: format!("Tool '{}' not found", tool_name),
213 provider_metadata: None,
214 },
215 preliminary: None,
216 provider_metadata: None,
217 };
218 }
219 };
220
221 let context = ToolContext {
222 tool_call_id: tool_call_id.clone(),
223 messages: vec![], // TODO: pass actual messages
224 };
225
226 // Parse input
227 let input: Value = match serde_json::from_str(&input_str) {
228 Ok(v) => v,
229 Err(e) => {
230 // Return error result instead of propagating
231 return ToolResultPart {
232 tool_call_id,
233 tool_name,
234 output: ToolResultOutput::ErrorText {
235 value: format!("Invalid input: {}", e),
236 provider_metadata: None,
237 },
238 preliminary: None,
239 provider_metadata: None,
240 };
241 }
242 };
243
244 // Check approval
245 if tool.needs_approval(&input) {
246 // Return denial result
247 return ToolResultPart {
248 tool_call_id,
249 tool_name,
250 output: ToolResultOutput::ExecutionDenied {
251 reason: Some("Execution denied by user".to_string()),
252 provider_metadata: None,
253 },
254 preliminary: None,
255 provider_metadata: None,
256 };
257 }
258
259 // Execute tool and convert to structured output
260 let output = match tool.execute(input, &context).await {
261 Ok(tool_output) => {
262 // Handle both Value and Stream outputs
263 match tool_output {
264 ToolOutput::Value(raw_output) => {
265 // Success - convert to structured output
266 tool.to_model_output(raw_output)
267 }
268 ToolOutput::Stream(mut stream) => {
269 // For non-streaming execute_tools, just get the last value
270 use futures::stream::StreamExt;
271 let mut last_output = None;
272 while let Some(item) = stream.next().await {
273 match item {
274 Ok(output) => last_output = Some(output),
275 Err(e) => {
276 return ToolResultPart {
277 tool_call_id,
278 tool_name,
279 output: ToolResultOutput::ErrorText {
280 value: e.to_string(),
281 provider_metadata: None,
282 },
283 preliminary: None,
284 provider_metadata: None,
285 };
286 }
287 }
288 }
289 // Convert final output
290 let final_value = last_output.unwrap_or(JsonValue::Null);
291 tool.to_model_output(final_value)
292 }
293 }
294 }
295 Err(error) => {
296 // Execution error - return error output
297 ToolResultOutput::ErrorText {
298 value: error.to_string(),
299 provider_metadata: None,
300 }
301 }
302 };
303
304 ToolResultPart {
305 tool_call_id,
306 tool_name,
307 output,
308 preliminary: None,
309 provider_metadata: None,
310 }
311 };
312
313 futures.push(future);
314 }
315
316 // Execute all tools in parallel
317 futures::future::join_all(futures).await
318 }
319
320 /// Executes a single tool call with streaming support and preliminary result callbacks.
321 ///
322 /// This method is specialized for tools that produce streaming output. It handles
323 /// the execution of a single tool call and invokes a callback for each preliminary
324 /// result as they arrive. This is useful for long-running operations that can
325 /// produce incremental output.
326 ///
327 /// The method:
328 /// - Finds and validates the tool
329 /// - Parses and validates input
330 /// - Checks approval requirements
331 /// - Executes the tool
332 /// - For streaming results, invokes the callback for each intermediate value
333 /// - Returns the final result after all streaming is complete
334 ///
335 /// # Arguments
336 ///
337 /// * `tool_call` - The tool call to execute, specifying the tool name, call ID,
338 /// and JSON-encoded input parameters.
339 /// * `on_preliminary` - A callback function that is invoked for each preliminary
340 /// result as they stream in. Useful for real-time processing or UI updates.
341 /// The callback is not invoked for the final result.
342 ///
343 /// # Returns
344 ///
345 /// A `ToolResultPart` containing the final tool result, or an error if the
346 /// tool execution or streaming fails. The final result does not have the
347 /// `preliminary` flag set.
348 pub async fn execute_tool_with_stream<F>(
349 &self,
350 tool_call: ToolCallPart,
351 on_preliminary: F,
352 ) -> ToolResultPart
353 where
354 F: Fn(ToolResultPart) + Send,
355 {
356 let tool_call_id = tool_call.tool_call_id.clone();
357 let tool_name = tool_call.tool_name.clone();
358
359 // Find tool
360 let tool = match self.find_tool(&tool_call.tool_name) {
361 Some(t) => t,
362 None => {
363 return ToolResultPart {
364 tool_call_id,
365 tool_name: tool_name.clone(),
366 output: ToolResultOutput::ErrorText {
367 value: format!("Tool '{}' not found", tool_name),
368 provider_metadata: None,
369 },
370 preliminary: None,
371 provider_metadata: None,
372 };
373 }
374 };
375
376 let context = ToolContext {
377 tool_call_id: tool_call_id.clone(),
378 messages: vec![], // TODO: pass actual messages
379 };
380
381 // Parse input
382 let input: Value = match serde_json::from_str(&tool_call.input) {
383 Ok(v) => v,
384 Err(e) => {
385 return ToolResultPart {
386 tool_call_id,
387 tool_name,
388 output: ToolResultOutput::ErrorText {
389 value: format!("Invalid input: {}", e),
390 provider_metadata: None,
391 },
392 preliminary: None,
393 provider_metadata: None,
394 };
395 }
396 };
397
398 // Check approval
399 if tool.needs_approval(&input) {
400 return ToolResultPart {
401 tool_call_id,
402 tool_name,
403 output: ToolResultOutput::ExecutionDenied {
404 reason: Some("Execution denied by user".to_string()),
405 provider_metadata: None,
406 },
407 preliminary: None,
408 provider_metadata: None,
409 };
410 }
411
412 // Execute tool
413 match tool.execute(input, &context).await {
414 Ok(ToolOutput::Value(value)) => {
415 // Simple case: single result
416 let output = tool.to_model_output(value);
417 ToolResultPart {
418 tool_call_id,
419 tool_name,
420 output,
421 preliminary: None,
422 provider_metadata: None,
423 }
424 }
425 Ok(ToolOutput::Stream(mut stream)) => {
426 use futures::stream::StreamExt;
427
428 let mut last_output = None;
429
430 // Process all stream items
431 while let Some(item) = stream.next().await {
432 match item {
433 Ok(output) => {
434 // Emit preliminary result
435 let structured = tool.to_model_output(output.clone());
436 let preliminary_result = ToolResultPart {
437 tool_call_id: tool_call_id.clone(),
438 tool_name: tool_name.clone(),
439 output: structured,
440 preliminary: Some(true),
441 provider_metadata: None,
442 };
443
444 on_preliminary(preliminary_result);
445 last_output = Some(output);
446 }
447 Err(e) => {
448 return ToolResultPart {
449 tool_call_id,
450 tool_name,
451 output: ToolResultOutput::ErrorText {
452 value: e.to_string(),
453 provider_metadata: None,
454 },
455 preliminary: None,
456 provider_metadata: None,
457 };
458 }
459 }
460 }
461
462 // Return final result (last output without preliminary flag)
463 let final_value = last_output.unwrap_or(JsonValue::Null);
464 let final_output = tool.to_model_output(final_value);
465
466 ToolResultPart {
467 tool_call_id,
468 tool_name,
469 output: final_output,
470 preliminary: None, // Final result
471 provider_metadata: None,
472 }
473 }
474 Err(error) => ToolResultPart {
475 tool_call_id,
476 tool_name,
477 output: ToolResultOutput::ErrorText {
478 value: error.to_string(),
479 provider_metadata: None,
480 },
481 preliminary: None,
482 provider_metadata: None,
483 },
484 }
485 }
486
487 /// Searches for a tool by name in the executor's tool registry.
488 ///
489 /// # Arguments
490 ///
491 /// * `name` - The name of the tool to find. Must match the value returned
492 /// by the tool's `name()` method exactly.
493 ///
494 /// # Returns
495 ///
496 /// * `Some(Arc<dyn Tool>)` if a tool with the specified name is found.
497 /// * `None` if no tool with the specified name is available.
498 fn find_tool(&self, name: &str) -> Option<Arc<dyn Tool>> {
499 self.tools.iter().find(|t| t.name() == name).cloned()
500 }
501
502 /// Returns a reference to the list of all available tools in this executor.
503 ///
504 /// This provides access to the complete registry of tools that can be invoked
505 /// by the language model. Useful for introspection and tool enumeration.
506 ///
507 /// # Returns
508 ///
509 /// A slice of all registered tools, allowing iteration over the available tools.
510 pub fn tools(&self) -> &[Arc<dyn Tool>] {
511 &self.tools
512 }
513}
514
515#[cfg(test)]
516mod tests {
517 use super::*;
518
519 struct TestTool {
520 name: String,
521 result: String,
522 }
523
524 #[async_trait]
525 impl Tool for TestTool {
526 fn name(&self) -> &str {
527 &self.name
528 }
529
530 fn description(&self) -> &str {
531 "A test tool"
532 }
533
534 fn input_schema(&self) -> Value {
535 serde_json::json!({
536 "type": "object",
537 "properties": {}
538 })
539 }
540
541 async fn execute(
542 &self,
543 _input: Value,
544 _context: &ToolContext,
545 ) -> Result<ToolOutput, ToolError> {
546 Ok(ToolOutput::Value(JsonValue::String(self.result.clone())))
547 }
548 }
549
550 #[tokio::test]
551 async fn test_tool_executor_find_tool() {
552 let tool = Arc::new(TestTool {
553 name: "test".to_string(),
554 result: "success".to_string(),
555 });
556
557 let executor = ToolExecutor::new(vec![tool]);
558 assert!(executor.find_tool("test").is_some());
559 assert!(executor.find_tool("nonexistent").is_none());
560 }
561
562 #[tokio::test]
563 async fn test_tool_executor_execute() {
564 let tool = Arc::new(TestTool {
565 name: "test".to_string(),
566 result: "success".to_string(),
567 });
568
569 let executor = ToolExecutor::new(vec![tool]);
570
571 let tool_call = ToolCallPart {
572 tool_call_id: "call_123".to_string(),
573 tool_name: "test".to_string(),
574 input: "{}".to_string(),
575 provider_executed: None,
576 dynamic: None,
577 provider_metadata: None,
578 };
579
580 let results = executor.execute_tools(vec![tool_call]).await;
581 assert_eq!(results.len(), 1);
582 assert_eq!(results[0].tool_call_id, "call_123");
583 assert_eq!(results[0].tool_name, "test");
584
585 // Check that output is Text variant
586 match &results[0].output {
587 ToolResultOutput::Text { value, .. } => {
588 assert_eq!(value, "success");
589 }
590 _ => panic!("Expected Text output variant"),
591 }
592 }
593
594 #[tokio::test]
595 async fn test_tool_executor_tool_not_found() {
596 let executor = ToolExecutor::new(vec![]);
597
598 let tool_call = ToolCallPart {
599 tool_call_id: "call_123".to_string(),
600 tool_name: "nonexistent".to_string(),
601 input: "{}".to_string(),
602 provider_executed: None,
603 dynamic: None,
604 provider_metadata: None,
605 };
606
607 let results = executor.execute_tools(vec![tool_call]).await;
608 assert_eq!(results.len(), 1);
609
610 // Check that output is ErrorText variant
611 match &results[0].output {
612 ToolResultOutput::ErrorText { value, .. } => {
613 assert!(value.contains("not found"));
614 }
615 _ => panic!("Expected ErrorText output variant"),
616 }
617 }
618}