Skip to main content

ai_agent/utils/
tool_errors.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct ShellErrorData {
5    pub code: Option<i32>,
6    pub interrupted: bool,
7    pub stderr: String,
8    pub stdout: String,
9}
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct ZodIssue {
13    pub code: String,
14    pub message: String,
15    pub path: Vec<serde_json::Value>,
16    #[serde(default)]
17    pub keys: Option<Vec<String>>,
18    #[serde(default)]
19    pub expected: Option<String>,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ZodErrorData {
24    pub issues: Vec<ZodIssue>,
25}
26
27const INTERRUPT_MESSAGE: &str = "Interrupted";
28
29pub fn format_error(error: &str) -> String {
30    if error.is_empty() {
31        return INTERRUPT_MESSAGE.to_string();
32    }
33
34    if error.len() > 10000 {
35        let half = 5000;
36        let start = &error[..half];
37        let end = &error[error.len() - half..];
38        return format!(
39            "{}\n\n... [{} characters truncated] ...\n\n{}",
40            start,
41            error.len() - 10000,
42            end
43        );
44    }
45
46    error.to_string()
47}
48
49pub fn get_error_parts(error: &ShellErrorData) -> Vec<String> {
50    let mut parts = Vec::new();
51
52    if let Some(code) = error.code {
53        parts.push(format!("Exit code {}", code));
54    }
55
56    if error.interrupted {
57        parts.push(INTERRUPT_MESSAGE.to_string());
58    }
59
60    if !error.stderr.is_empty() {
61        parts.push(error.stderr.clone());
62    }
63
64    if !error.stdout.is_empty() {
65        parts.push(error.stdout.clone());
66    }
67
68    parts
69}
70
71fn format_validation_path(path: &[serde_json::Value]) -> String {
72    if path.is_empty() {
73        return String::new();
74    }
75
76    path.iter()
77        .enumerate()
78        .map(|(i, segment)| match segment {
79            serde_json::Value::Number(n) => format!("[{}]", n),
80            serde_json::Value::String(s) => {
81                if i == 0 {
82                    s.clone()
83                } else {
84                    format!(".{}", s)
85                }
86            }
87            other => other.to_string(),
88        })
89        .collect()
90}
91
92pub fn format_zod_validation_error(tool_name: &str, error: &ZodErrorData) -> String {
93    let missing_params: Vec<String> = error
94        .issues
95        .iter()
96        .filter(|err| err.code == "invalid_type" && err.message.contains("received undefined"))
97        .map(|err| format_validation_path(&err.path))
98        .collect();
99
100    let unexpected_params: Vec<String> = error
101        .issues
102        .iter()
103        .filter(|err| err.code == "unrecognized_keys")
104        .flat_map(|err| err.keys.clone().unwrap_or_default())
105        .collect();
106
107    let type_mismatch_params: Vec<(String, String, String)> = error
108        .issues
109        .iter()
110        .filter(|err| err.code == "invalid_type" && !err.message.contains("received undefined"))
111        .map(|err| {
112            let param = format_validation_path(&err.path);
113            let expected = err
114                .expected
115                .clone()
116                .unwrap_or_else(|| "unknown".to_string());
117            let received = err
118                .message
119                .split("received ")
120                .nth(1)
121                .map(|s| s.split_whitespace().next().unwrap_or("unknown"))
122                .unwrap_or("unknown")
123                .to_string();
124            (param, expected, received)
125        })
126        .collect();
127
128    let mut error_parts = Vec::new();
129
130    for param in &missing_params {
131        error_parts.push(format!("The required parameter `{}` is missing", param));
132    }
133
134    for param in &unexpected_params {
135        error_parts.push(format!("An unexpected parameter `{}` was provided", param));
136    }
137
138    for (param, expected, received) in &type_mismatch_params {
139        error_parts.push(format!(
140            "The parameter `{}` type is expected as `{}` but provided as `{}`",
141            param, expected, received
142        ));
143    }
144
145    if error_parts.is_empty() {
146        error
147            .issues
148            .first()
149            .map(|i| i.message.clone())
150            .unwrap_or_default()
151    } else {
152        let issue_word = if error_parts.len() > 1 {
153            "issues"
154        } else {
155            "issue"
156        };
157        format!(
158            "{} failed due to the following {}:\n{}",
159            tool_name,
160            issue_word,
161            error_parts.join("\n")
162        )
163    }
164}
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169
170    #[test]
171    fn test_format_error_short() {
172        let result = format_error("short error");
173        assert_eq!(result, "short error");
174    }
175
176    #[test]
177    fn test_format_validation_path() {
178        let path = vec![
179            serde_json::Value::String("todos".to_string()),
180            serde_json::Value::Number(serde_json::Number::from(0)),
181            serde_json::Value::String("activeForm".to_string()),
182        ];
183        let result = format_validation_path(&path);
184        assert_eq!(result, "todos[0].activeForm");
185    }
186}