ai_agent/services/
tool_execution.rs1use crate::types::Message;
7use crate::tools::ToolDefinition;
8
9pub const HOOK_TIMING_DISPLAY_THRESHOLD_MS: u64 = 500;
11
12pub const SLOW_PHASE_LOG_THRESHOLD_MS: u64 = 2000;
15
16pub fn classify_tool_error(error: &(dyn std::error::Error + 'static)) -> String {
26 let error_name = std::any::type_name_of_val(error);
30
31 if let Some(downcast) = error.downcast_ref::<std::io::Error>() {
33 let errno = downcast.raw_os_error();
34 if let Some(code) = errno {
35 return format!("Error:{}", code);
36 }
37 }
38
39 let name_len = error_name.len();
41 if name_len > 3 && !error_name.contains("std::io::Error") {
42 let short_name = error_name
44 .rsplit("::")
45 .next()
46 .unwrap_or(error_name)
47 .chars()
48 .take(60)
49 .collect::<String>();
50 return short_name;
51 }
52
53 "Error".to_string()
54}
55
56pub fn classify_tool_error_from_message(message: &str) -> String {
58 let lower = message.to_lowercase();
59
60 if lower.contains("enoent") || lower.contains("file not found") {
62 return "Error:ENOENT".to_string();
63 }
64 if lower.contains("eacces") || lower.contains("permission denied") {
65 return "Error:EACCES".to_string();
66 }
67 if lower.contains("timeout") {
68 return "Error:ETIMEDOUT".to_string();
69 }
70
71 "Error".to_string()
73}
74
75pub fn build_schema_not_sent_hint(
79 tool_name: &str,
80 messages: &[Message],
81 tools: &[ToolDefinition],
82) -> Option<String> {
83 let tool_available = tools.iter().any(|t| t.name == tool_name);
85 if tool_available {
86 return None;
87 }
88
89 let discovered_in_messages = messages.iter().any(|m| {
91 m.content.contains(tool_name)
92 });
93
94 if discovered_in_messages {
95 return Some(format!(
96 "\n\nThis tool's schema was not sent to the API — it was not in the discovered-tool set derived from message history. \
97 Without the schema in your prompt, typed parameters (arrays, numbers, booleans) get emitted as strings and the client-side parser rejects them. \
98 Load the tool first: call tool_search with query \"select:{}\", then retry this call.",
99 tool_name
100 ));
101 }
102
103 None
104}
105
106#[derive(Debug, Clone)]
108pub struct MessageUpdateLazy {
109 pub message: Message,
110 pub context_modifier: Option<ContextModifier>,
111}
112
113#[derive(Debug, Clone)]
115pub struct ContextModifier {
116 pub tool_use_id: String,
117}
118
119#[derive(Debug, Clone)]
121pub struct ToolProgress {
122 pub tool_use_id: String,
123 pub data: serde_json::Value,
124}
125
126#[derive(Debug, Clone)]
128pub enum ToolExecutionError {
129 ToolNotFound(String),
131 InputValidation(String),
133 PermissionDenied(String),
135 ExecutionFailed(String),
137 Aborted,
139}
140
141impl std::fmt::Display for ToolExecutionError {
142 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
143 match self {
144 ToolExecutionError::ToolNotFound(name) => write!(f, "No such tool available: {}", name),
145 ToolExecutionError::InputValidation(msg) => write!(f, "InputValidationError: {}", msg),
146 ToolExecutionError::PermissionDenied(msg) => write!(f, "Permission denied: {}", msg),
147 ToolExecutionError::ExecutionFailed(msg) => write!(f, "Error calling tool: {}", msg),
148 ToolExecutionError::Aborted => write!(f, "Tool execution was aborted"),
149 }
150 }
151}
152
153impl std::error::Error for ToolExecutionError {}
154
155pub fn create_tool_error_message(
157 tool_use_id: &str,
158 error: &str,
159 is_error: bool,
160) -> Message {
161 Message {
162 role: crate::types::MessageRole::Tool,
163 content: format!("<tool_use_error>{}</tool_use_error>", error),
164 tool_call_id: Some(tool_use_id.to_string()),
165 is_error: Some(is_error),
166 ..Default::default()
167 }
168}
169
170pub fn create_progress_message(
172 tool_use_id: &str,
173 data: serde_json::Value,
174) -> Message {
175 Message {
176 role: crate::types::MessageRole::User,
177 content: serde_json::json!({
178 "type": "progress",
179 "tool_use_id": tool_use_id,
180 "data": data,
181 }).to_string(),
182 ..Default::default()
183 }
184}
185
186pub fn format_input_validation_error(
188 tool_name: &str,
189 error_message: &str,
190) -> String {
191 format!("Error parsing {} input: {}", tool_name, error_message)
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197
198 #[test]
199 fn test_classify_tool_error_io() {
200 let error = std::io::Error::new(std::io::ErrorKind::NotFound, "file not found");
202 let classified = classify_tool_error(&error);
203 assert!(classified.contains("Error:") || classified == "Error");
205 }
206
207 #[test]
208 fn test_classify_tool_error_from_message() {
209 assert_eq!(classify_tool_error_from_message("File not found"), "Error:ENOENT");
210 assert_eq!(classify_tool_error_from_message("Permission denied"), "Error:EACCES");
211 assert_eq!(classify_tool_error_from_message("timeout error"), "Error:ETIMEDOUT");
212 assert_eq!(classify_tool_error_from_message("Some other error"), "Error");
213 }
214
215 #[test]
216 fn test_build_schema_not_sent_hint_tool_available() {
217 let tools = vec![ToolDefinition {
218 name: "test_tool".to_string(),
219 description: "Test tool".to_string(),
220 input_schema: crate::types::ToolInputSchema {
221 schema_type: "object".to_string(),
222 properties: serde_json::json!({}),
223 required: None,
224 },
225 annotations: None,
226 }];
227 let messages = vec![];
228
229 let hint = build_schema_not_sent_hint("test_tool", &messages, &tools);
230 assert!(hint.is_none());
232 }
233
234 #[test]
235 fn test_build_schema_not_sent_hint_discovered() {
236 let tools = vec![];
237 let messages = vec![Message {
238 role: crate::types::MessageRole::Assistant,
239 content: "Using discovered_tool".to_string(),
240 ..Default::default()
241 }];
242
243 let hint = build_schema_not_sent_hint("discovered_tool", &messages, &tools);
244 assert!(hint.is_some());
246 assert!(hint.unwrap().contains("discovered_tool"));
247 }
248
249 #[test]
250 fn test_create_tool_error_message() {
251 let msg = create_tool_error_message("tool_123", "Test error", true);
252 assert!(msg.content.contains("tool_use_error"));
253 assert!(msg.content.contains("Test error"));
254 assert!(msg.is_error == Some(true));
255 }
256
257 #[test]
258 fn test_format_input_validation_error() {
259 let error = format_input_validation_error("Read", "expected string, got number");
260 assert!(error.contains("Read"));
261 assert!(error.contains("expected string"));
262 }
263
264 #[test]
265 fn test_constants() {
266 assert_eq!(HOOK_TIMING_DISPLAY_THRESHOLD_MS, 500);
267 assert_eq!(SLOW_PHASE_LOG_THRESHOLD_MS, 2000);
268 }
269}