ai_sdk_core/tool/
tool_output.rs

1use crate::error::ToolError;
2use ai_sdk_provider::language_model::ToolResultOutput;
3use ai_sdk_provider::JsonValue;
4use futures::stream::Stream;
5use std::pin::Pin;
6
7/// Tool execution output - either a single value or a stream of values
8pub enum ToolOutput {
9    /// Single final result
10    Value(JsonValue),
11
12    /// Stream of preliminary results, ending with final result
13    /// Each item in the stream represents a preliminary update
14    Stream(Pin<Box<dyn Stream<Item = Result<JsonValue, ToolError>> + Send>>),
15}
16
17/// Error mode for tool output conversion
18#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19#[allow(dead_code)]
20pub enum ErrorMode {
21    /// No error, normal output
22    None,
23    /// Convert output to error text
24    Text,
25    /// Convert output to error JSON
26    Json,
27}
28
29/// Convert raw tool output to structured ToolResultOutput
30///
31/// # Arguments
32/// * `output` - The raw JSON output from tool execution
33/// * `error_mode` - How to handle the output (normal, error text, or error JSON)
34/// * `custom_converter` - Optional custom conversion function
35///
36/// # Returns
37/// A structured ToolResultOutput enum variant
38#[allow(dead_code)]
39pub fn create_tool_output(
40    output: JsonValue,
41    error_mode: ErrorMode,
42    custom_converter: Option<&dyn Fn(JsonValue) -> ToolResultOutput>,
43) -> ToolResultOutput {
44    // Handle errors first
45    match error_mode {
46        ErrorMode::Text => {
47            return ToolResultOutput::ErrorText {
48                value: match output {
49                    JsonValue::String(s) => s,
50                    other => serde_json::to_string(&other)
51                        .unwrap_or_else(|_| "Error serializing value".to_string()),
52                },
53                provider_metadata: None,
54            };
55        }
56        ErrorMode::Json => {
57            return ToolResultOutput::ErrorJson {
58                value: output,
59                provider_metadata: None,
60            };
61        }
62        ErrorMode::None => {}
63    }
64
65    // Custom conversion via hook
66    if let Some(converter) = custom_converter {
67        return converter(output);
68    }
69
70    // Default conversion
71    match output {
72        JsonValue::String(s) => ToolResultOutput::Text {
73            value: s,
74            provider_metadata: None,
75        },
76        other => ToolResultOutput::Json {
77            value: other,
78            provider_metadata: None,
79        },
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    #[test]
88    fn test_create_tool_output_string_to_text() {
89        let output = JsonValue::String("Hello, world!".to_string());
90        let result = create_tool_output(output, ErrorMode::None, None);
91
92        match result {
93            ToolResultOutput::Text { value, .. } => {
94                assert_eq!(value, "Hello, world!");
95            }
96            _ => panic!("Expected Text variant"),
97        }
98    }
99
100    #[test]
101    fn test_create_tool_output_object_to_json() {
102        use std::collections::HashMap;
103        let mut map = HashMap::new();
104        map.insert(
105            "result".to_string(),
106            JsonValue::String("success".to_string()),
107        );
108        map.insert(
109            "count".to_string(),
110            JsonValue::Number(serde_json::Number::from(42)),
111        );
112        let output = JsonValue::Object(map);
113
114        let result = create_tool_output(output, ErrorMode::None, None);
115
116        match result {
117            ToolResultOutput::Json { value, .. } => {
118                // Verify it's the same object
119                if let JsonValue::Object(obj) = value {
120                    assert!(obj.contains_key("result"));
121                    assert!(obj.contains_key("count"));
122                } else {
123                    panic!("Expected Object");
124                }
125            }
126            _ => panic!("Expected Json variant"),
127        }
128    }
129
130    #[test]
131    fn test_create_tool_output_error_text() {
132        let output = JsonValue::String("An error occurred".to_string());
133        let result = create_tool_output(output, ErrorMode::Text, None);
134
135        match result {
136            ToolResultOutput::ErrorText { value, .. } => {
137                assert_eq!(value, "An error occurred");
138            }
139            _ => panic!("Expected ErrorText variant"),
140        }
141    }
142
143    #[test]
144    fn test_create_tool_output_error_json() {
145        use std::collections::HashMap;
146        let mut map = HashMap::new();
147        map.insert(
148            "error".to_string(),
149            JsonValue::String("Not found".to_string()),
150        );
151        map.insert(
152            "code".to_string(),
153            JsonValue::Number(serde_json::Number::from(404)),
154        );
155        let output = JsonValue::Object(map);
156
157        let result = create_tool_output(output, ErrorMode::Json, None);
158
159        match result {
160            ToolResultOutput::ErrorJson { value, .. } => {
161                if let JsonValue::Object(obj) = value {
162                    assert!(obj.contains_key("error"));
163                    assert!(obj.contains_key("code"));
164                } else {
165                    panic!("Expected Object");
166                }
167            }
168            _ => panic!("Expected ErrorJson variant"),
169        }
170    }
171
172    #[test]
173    fn test_create_tool_output_custom_converter() {
174        let output = JsonValue::Null;
175
176        let custom_converter = |_: JsonValue| ToolResultOutput::Text {
177            value: "Custom conversion".to_string(),
178            provider_metadata: None,
179        };
180
181        let result = create_tool_output(output, ErrorMode::None, Some(&custom_converter));
182
183        match result {
184            ToolResultOutput::Text { value, .. } => {
185                assert_eq!(value, "Custom conversion");
186            }
187            _ => panic!("Expected Text variant from custom converter"),
188        }
189    }
190
191    #[test]
192    fn test_error_mode_takes_precedence_over_custom() {
193        let output = JsonValue::String("test".to_string());
194
195        let custom_converter = |_: JsonValue| ToolResultOutput::Text {
196            value: "Should not be used".to_string(),
197            provider_metadata: None,
198        };
199
200        let result = create_tool_output(output, ErrorMode::Text, Some(&custom_converter));
201
202        match result {
203            ToolResultOutput::ErrorText { value, .. } => {
204                assert_eq!(value, "test");
205            }
206            _ => panic!("Expected ErrorText variant - error mode should take precedence"),
207        }
208    }
209}