ai_agent/services/
tool_execution.rs1use crate::tools::ToolDefinition;
7use crate::types::Message;
8
9pub const HOOK_TIMING_DISPLAY_THRESHOLD_MS: u64 = 500;
11
12pub const SLOW_PHASE_LOG_THRESHOLD_MS: u64 = 2000;
15
16pub fn classify_tool_error(error: &(dyn std::error::Error + 'static)) -> String {
26 let error_name = std::any::type_name_of_val(error);
30
31 if let Some(downcast) = error.downcast_ref::<std::io::Error>() {
33 let errno = downcast.raw_os_error();
34 if let Some(code) = errno {
35 return format!("Error:{}", code);
36 }
37 }
38
39 let name_len = error_name.len();
41 if name_len > 3 && !error_name.contains("std::io::Error") {
42 let short_name = error_name
44 .rsplit("::")
45 .next()
46 .unwrap_or(error_name)
47 .chars()
48 .take(60)
49 .collect::<String>();
50 return short_name;
51 }
52
53 "Error".to_string()
54}
55
56pub fn classify_tool_error_from_message(message: &str) -> String {
58 let lower = message.to_lowercase();
59
60 if lower.contains("enoent") || lower.contains("file not found") {
62 return "Error:ENOENT".to_string();
63 }
64 if lower.contains("eacces") || lower.contains("permission denied") {
65 return "Error:EACCES".to_string();
66 }
67 if lower.contains("timeout") {
68 return "Error:ETIMEDOUT".to_string();
69 }
70
71 "Error".to_string()
73}
74
75pub fn build_schema_not_sent_hint(
79 tool_name: &str,
80 messages: &[Message],
81 tools: &[ToolDefinition],
82) -> Option<String> {
83 let tool_available = tools.iter().any(|t| t.name == tool_name);
85 if tool_available {
86 return None;
87 }
88
89 let discovered_in_messages = messages.iter().any(|m| m.content.contains(tool_name));
91
92 if discovered_in_messages {
93 return Some(format!(
94 "\n\nThis tool's schema was not sent to the API — it was not in the discovered-tool set derived from message history. \
95 Without the schema in your prompt, typed parameters (arrays, numbers, booleans) get emitted as strings and the client-side parser rejects them. \
96 Load the tool first: call tool_search with query \"select:{}\", then retry this call.",
97 tool_name
98 ));
99 }
100
101 None
102}
103
104#[derive(Debug, Clone)]
106pub struct MessageUpdateLazy {
107 pub message: Message,
108 pub context_modifier: Option<ContextModifier>,
109}
110
111#[derive(Debug, Clone)]
113pub struct ContextModifier {
114 pub tool_use_id: String,
115}
116
117#[derive(Debug, Clone)]
119pub struct ToolProgress {
120 pub tool_use_id: String,
121 pub data: serde_json::Value,
122}
123
124#[derive(Debug, Clone)]
126pub enum ToolExecutionError {
127 ToolNotFound(String),
129 InputValidation(String),
131 PermissionDenied(String),
133 ExecutionFailed(String),
135 Aborted,
137}
138
139impl std::fmt::Display for ToolExecutionError {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 match self {
142 ToolExecutionError::ToolNotFound(name) => write!(f, "No such tool available: {}", name),
143 ToolExecutionError::InputValidation(msg) => write!(f, "InputValidationError: {}", msg),
144 ToolExecutionError::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
145 ToolExecutionError::ExecutionFailed(msg) => write!(f, "Error calling tool: {}", msg),
146 ToolExecutionError::Aborted => write!(f, "Tool execution was aborted"),
147 }
148 }
149}
150
151impl std::error::Error for ToolExecutionError {}
152
153pub fn create_tool_error_message(tool_use_id: &str, error: &str, is_error: bool) -> Message {
155 Message {
156 role: crate::types::MessageRole::Tool,
157 content: format!("<tool_use_error>{}</tool_use_error>", error),
158 tool_call_id: Some(tool_use_id.to_string()),
159 is_error: Some(is_error),
160 ..Default::default()
161 }
162}
163
164pub fn create_progress_message(tool_use_id: &str, data: serde_json::Value) -> Message {
166 Message {
167 role: crate::types::MessageRole::User,
168 content: serde_json::json!({
169 "type": "progress",
170 "tool_use_id": tool_use_id,
171 "data": data,
172 })
173 .to_string(),
174 ..Default::default()
175 }
176}
177
178pub fn format_input_validation_error(tool_name: &str, error_message: &str) -> String {
180 format!("Error parsing {} input: {}", tool_name, error_message)
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_classify_tool_error_io() {
189 let error = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
191 let classified = classify_tool_error(&error);
192 assert!(classified.contains("Error:") || classified == "Error");
194 }
195
196 #[test]
197 fn test_classify_tool_error_from_message() {
198 assert_eq!(
199 classify_tool_error_from_message("File not found"),
200 "Error:ENOENT"
201 );
202 assert_eq!(
203 classify_tool_error_from_message("Permission denied"),
204 "Error:EACCES"
205 );
206 assert_eq!(
207 classify_tool_error_from_message("timeout error"),
208 "Error:ETIMEDOUT"
209 );
210 assert_eq!(
211 classify_tool_error_from_message("Some other error"),
212 "Error"
213 );
214 }
215
216 #[test]
217 fn test_build_schema_not_sent_hint_tool_available() {
218 let tools = vec![ToolDefinition {
219 name: "test_tool".to_string(),
220 description: "Test tool".to_string(),
221 input_schema: crate::types::ToolInputSchema {
222 schema_type: "object".to_string(),
223 properties: serde_json::json!({}),
224 required: None,
225 },
226 annotations: None,
227 should_defer: None,
228 always_load: None,
229 is_mcp: None,
230 search_hint: None,
231 aliases: None,
232 user_facing_name: None,
233 interrupt_behavior: None,
234 }];
235 let messages = vec![];
236
237 let hint = build_schema_not_sent_hint("test_tool", &messages, &tools);
238 assert!(hint.is_none());
240 }
241
242 #[test]
243 fn test_build_schema_not_sent_hint_discovered() {
244 let tools = vec![];
245 let messages = vec![Message {
246 role: crate::types::MessageRole::Assistant,
247 content: "Using discovered_tool".to_string(),
248 ..Default::default()
249 }];
250
251 let hint = build_schema_not_sent_hint("discovered_tool", &messages, &tools);
252 assert!(hint.is_some());
254 assert!(hint.unwrap().contains("discovered_tool"));
255 }
256
257 #[test]
258 fn test_create_tool_error_message() {
259 let msg = create_tool_error_message("tool_123", "Test error", true);
260 assert!(msg.content.contains("tool_use_error"));
261 assert!(msg.content.contains("Test error"));
262 assert!(msg.is_error == Some(true));
263 }
264
265 #[test]
266 fn test_format_input_validation_error() {
267 let error = format_input_validation_error("Read", "expected string, got number");
268 assert!(error.contains("Read"));
269 assert!(error.contains("expected string"));
270 }
271
272 #[test]
273 fn test_constants() {
274 assert_eq!(HOOK_TIMING_DISPLAY_THRESHOLD_MS, 500);
275 assert_eq!(SLOW_PHASE_LOG_THRESHOLD_MS, 2000);
276 }
277}