ai_sdk_core/tool/
tool_output.rs1use crate::error::ToolError;
2use ai_sdk_provider::language_model::ToolResultOutput;
3use ai_sdk_provider::JsonValue;
4use futures::stream::Stream;
5use std::pin::Pin;
6
7pub enum ToolOutput {
9 Value(JsonValue),
11
12 Stream(Pin<Box<dyn Stream<Item = Result<JsonValue, ToolError>> + Send>>),
15}
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq)]
19#[allow(dead_code)]
20pub enum ErrorMode {
21 None,
23 Text,
25 Json,
27}
28
29#[allow(dead_code)]
39pub fn create_tool_output(
40 output: JsonValue,
41 error_mode: ErrorMode,
42 custom_converter: Option<&dyn Fn(JsonValue) -> ToolResultOutput>,
43) -> ToolResultOutput {
44 match error_mode {
46 ErrorMode::Text => {
47 return ToolResultOutput::ErrorText {
48 value: match output {
49 JsonValue::String(s) => s,
50 other => serde_json::to_string(&other)
51 .unwrap_or_else(|_| "Error serializing value".to_string()),
52 },
53 provider_metadata: None,
54 };
55 }
56 ErrorMode::Json => {
57 return ToolResultOutput::ErrorJson {
58 value: output,
59 provider_metadata: None,
60 };
61 }
62 ErrorMode::None => {}
63 }
64
65 if let Some(converter) = custom_converter {
67 return converter(output);
68 }
69
70 match output {
72 JsonValue::String(s) => ToolResultOutput::Text {
73 value: s,
74 provider_metadata: None,
75 },
76 other => ToolResultOutput::Json {
77 value: other,
78 provider_metadata: None,
79 },
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 #[test]
88 fn test_create_tool_output_string_to_text() {
89 let output = JsonValue::String("Hello, world!".to_string());
90 let result = create_tool_output(output, ErrorMode::None, None);
91
92 match result {
93 ToolResultOutput::Text { value, .. } => {
94 assert_eq!(value, "Hello, world!");
95 }
96 _ => panic!("Expected Text variant"),
97 }
98 }
99
100 #[test]
101 fn test_create_tool_output_object_to_json() {
102 use std::collections::HashMap;
103 let mut map = HashMap::new();
104 map.insert(
105 "result".to_string(),
106 JsonValue::String("success".to_string()),
107 );
108 map.insert(
109 "count".to_string(),
110 JsonValue::Number(serde_json::Number::from(42)),
111 );
112 let output = JsonValue::Object(map);
113
114 let result = create_tool_output(output, ErrorMode::None, None);
115
116 match result {
117 ToolResultOutput::Json { value, .. } => {
118 if let JsonValue::Object(obj) = value {
120 assert!(obj.contains_key("result"));
121 assert!(obj.contains_key("count"));
122 } else {
123 panic!("Expected Object");
124 }
125 }
126 _ => panic!("Expected Json variant"),
127 }
128 }
129
130 #[test]
131 fn test_create_tool_output_error_text() {
132 let output = JsonValue::String("An error occurred".to_string());
133 let result = create_tool_output(output, ErrorMode::Text, None);
134
135 match result {
136 ToolResultOutput::ErrorText { value, .. } => {
137 assert_eq!(value, "An error occurred");
138 }
139 _ => panic!("Expected ErrorText variant"),
140 }
141 }
142
143 #[test]
144 fn test_create_tool_output_error_json() {
145 use std::collections::HashMap;
146 let mut map = HashMap::new();
147 map.insert(
148 "error".to_string(),
149 JsonValue::String("Not found".to_string()),
150 );
151 map.insert(
152 "code".to_string(),
153 JsonValue::Number(serde_json::Number::from(404)),
154 );
155 let output = JsonValue::Object(map);
156
157 let result = create_tool_output(output, ErrorMode::Json, None);
158
159 match result {
160 ToolResultOutput::ErrorJson { value, .. } => {
161 if let JsonValue::Object(obj) = value {
162 assert!(obj.contains_key("error"));
163 assert!(obj.contains_key("code"));
164 } else {
165 panic!("Expected Object");
166 }
167 }
168 _ => panic!("Expected ErrorJson variant"),
169 }
170 }
171
172 #[test]
173 fn test_create_tool_output_custom_converter() {
174 let output = JsonValue::Null;
175
176 let custom_converter = |_: JsonValue| ToolResultOutput::Text {
177 value: "Custom conversion".to_string(),
178 provider_metadata: None,
179 };
180
181 let result = create_tool_output(output, ErrorMode::None, Some(&custom_converter));
182
183 match result {
184 ToolResultOutput::Text { value, .. } => {
185 assert_eq!(value, "Custom conversion");
186 }
187 _ => panic!("Expected Text variant from custom converter"),
188 }
189 }
190
191 #[test]
192 fn test_error_mode_takes_precedence_over_custom() {
193 let output = JsonValue::String("test".to_string());
194
195 let custom_converter = |_: JsonValue| ToolResultOutput::Text {
196 value: "Should not be used".to_string(),
197 provider_metadata: None,
198 };
199
200 let result = create_tool_output(output, ErrorMode::Text, Some(&custom_converter));
201
202 match result {
203 ToolResultOutput::ErrorText { value, .. } => {
204 assert_eq!(value, "test");
205 }
206 _ => panic!("Expected ErrorText variant - error mode should take precedence"),
207 }
208 }
209}