use anyhow::Error;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SessionFailureKind {
BlankAssistantResponse,
RequiredCompletionToolMissing,
}
impl fmt::Display for SessionFailureKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::BlankAssistantResponse => write!(f, "blank_assistant_response"),
Self::RequiredCompletionToolMissing => {
write!(f, "required_completion_tool_missing")
}
}
}
}
#[derive(Debug)]
pub struct SessionFailureError {
pub kind: SessionFailureKind,
pub message: String,
}
impl SessionFailureError {
pub fn new(kind: SessionFailureKind, message: impl Into<String>) -> Self {
Self {
kind,
message: message.into(),
}
}
}
impl fmt::Display for SessionFailureError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.message)
}
}
impl std::error::Error for SessionFailureError {}
pub fn extract_session_failure_kind(error: &Error) -> Option<SessionFailureKind> {
error.chain().find_map(|cause| {
cause
.downcast_ref::<SessionFailureError>()
.map(|failure| failure.kind)
})
}
#[derive(Debug)]
pub struct ToolExecutionError {
pub tool_name: String,
pub call_id: Option<String>,
pub arguments: Value,
pub error: Error,
pub suggestion: Option<String>,
}
impl ToolExecutionError {
pub fn new(tool_name: impl Into<String>, args: Value, error: Error) -> Self {
let tool_name = tool_name.into();
let suggestion = analyze_tool_error(&tool_name, &args, &error);
Self {
tool_name,
call_id: None,
arguments: args,
error,
suggestion,
}
}
pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
self.call_id = Some(call_id.into());
self
}
pub fn with_suggestion(mut self, suggestion: impl Into<String>) -> Self {
self.suggestion = Some(suggestion.into());
self
}
}
impl fmt::Display for ToolExecutionError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "Tool execution failed")?;
writeln!(f, " ├─ Tool: {}", self.tool_name)?;
if let Some(call_id) = &self.call_id {
writeln!(f, " ├─ Call ID: {}", call_id)?;
}
let args_str = serde_json::to_string_pretty(&self.arguments)
.unwrap_or_else(|_| format!("{:?}", self.arguments));
if args_str.len() > 200 {
writeln!(f, " ├─ Arguments: {}...", &args_str[..200])?;
writeln!(f, " │ (truncated, {} bytes total)", args_str.len())?;
} else {
writeln!(f, " ├─ Arguments: {}", args_str)?;
}
writeln!(f, " └─ Reason: {}", self.error)?;
if let Some(suggestion) = &self.suggestion {
writeln!(f)?;
writeln!(f, " Help: {}", suggestion)?;
}
Ok(())
}
}
impl std::error::Error for ToolExecutionError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
Some(self.error.as_ref())
}
}
pub fn analyze_tool_error(tool_name: &str, args: &Value, error: &Error) -> Option<String> {
let error_str = error.to_string().to_lowercase();
if error_str.contains("missing field") {
if let Some(field) = extract_field_name(&error_str) {
if let Some(similar) = find_similar_field(args, &field) {
return Some(format!(
"The tool '{}' expects field '{}' but received '{}'. \n \
Check the tool schema or ensure the LLM provides the correct field name.",
tool_name, field, similar
));
}
return Some(format!(
"The tool '{}' requires field '{}' which was not provided. \n \
Ensure the LLM includes all required parameters.",
tool_name, field
));
}
}
if error_str.contains("invalid type") || error_str.contains("type mismatch") {
return Some(format!(
"Parameter type mismatch for tool '{}'. \n \
Ensure the LLM provides correct types (string, number, boolean, object, array). \n \
Check the tool's parameter schema and LLM prompt.",
tool_name
));
}
if error_str.contains("failed to parse") || error_str.contains("deserialization") {
return Some(format!(
"Failed to parse parameters for tool '{}'. \n \
The LLM may have provided malformed JSON or incorrect parameter structure. \n \
Review the tool schema and improve the system prompt.",
tool_name
));
}
if error_str.contains("no such file") || error_str.contains("not found") {
return Some(format!(
"Tool '{}' could not find the specified file or resource. \n \
Verify the path is correct and the file exists. \n \
Consider adding file existence checks in your tool implementation.",
tool_name
));
}
if error_str.contains("permission denied") || error_str.contains("access denied") {
return Some(format!(
"Tool '{}' encountered a permission error. \n \
Check file permissions or consider running with appropriate privileges. \n \
You may need to add permission checks or user confirmation.",
tool_name
));
}
None
}
fn extract_field_name(error: &str) -> Option<String> {
if let Some(start) = error.find('`') {
if let Some(end) = error[start + 1..].find('`') {
return Some(error[start + 1..start + 1 + end].to_string());
}
}
if let Some(start) = error.find('\'') {
if let Some(end) = error[start + 1..].find('\'') {
return Some(error[start + 1..start + 1 + end].to_string());
}
}
if let Some(start) = error.find('"') {
if let Some(end) = error[start + 1..].find('"') {
return Some(error[start + 1..start + 1 + end].to_string());
}
}
None
}
fn find_similar_field(args: &Value, target: &str) -> Option<String> {
if let Some(obj) = args.as_object() {
for key in obj.keys() {
if target.contains(key) || key.contains(target) {
return Some(key.clone());
}
let distance = levenshtein_distance(key, target);
if distance > 0 && distance <= 2 {
return Some(key.clone());
}
}
}
None
}
fn levenshtein_distance(a: &str, b: &str) -> usize {
let a_chars: Vec<char> = a.chars().collect();
let b_chars: Vec<char> = b.chars().collect();
let len_a = a_chars.len();
let len_b = b_chars.len();
if len_a == 0 {
return len_b;
}
if len_b == 0 {
return len_a;
}
let mut matrix = vec![vec![0; len_b + 1]; len_a + 1];
for (i, row) in matrix.iter_mut().enumerate().take(len_a + 1) {
row[0] = i;
}
for (j, value) in matrix[0].iter_mut().enumerate().take(len_b + 1) {
*value = j;
}
for i in 1..=len_a {
for j in 1..=len_b {
let cost = if a_chars[i - 1] == b_chars[j - 1] {
0
} else {
1
};
matrix[i][j] = std::cmp::min(
std::cmp::min(
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
matrix[i - 1][j - 1] + cost, );
}
}
matrix[len_a][len_b]
}
#[cfg(test)]
mod tests {
use super::*;
use anyhow::anyhow;
use serde_json::json;
#[test]
fn test_extract_field_name() {
assert_eq!(
extract_field_name("missing field `file_path`"),
Some("file_path".to_string())
);
assert_eq!(
extract_field_name("missing field 'username'"),
Some("username".to_string())
);
assert_eq!(
extract_field_name("missing field \"count\""),
Some("count".to_string())
);
}
#[test]
fn test_levenshtein_distance() {
assert_eq!(levenshtein_distance("file_path", "filepath"), 1);
assert_eq!(levenshtein_distance("file_path", "file_pat"), 1);
assert_eq!(levenshtein_distance("count", "cont"), 1);
assert_eq!(levenshtein_distance("hello", "hello"), 0);
assert_eq!(levenshtein_distance("hello", "world"), 4);
}
#[test]
fn test_find_similar_field() {
let args = json!({
"path": "file.txt",
"count": 10,
});
assert_eq!(
find_similar_field(&args, "file_path"),
Some("path".to_string())
);
assert_eq!(find_similar_field(&args, "cont"), Some("count".to_string()));
assert_eq!(find_similar_field(&args, "totally_different"), None);
}
#[test]
fn test_analyze_missing_field() {
let args = json!({"path": "file.txt"});
let error = anyhow!("Missing field `file_path`");
let suggestion = analyze_tool_error("read_file", &args, &error);
assert!(suggestion.is_some());
let msg = suggestion.unwrap();
assert!(msg.contains("file_path"));
assert!(msg.contains("path"));
}
#[test]
fn test_analyze_type_mismatch() {
let args = json!({"count": "not a number"});
let error = anyhow!("invalid type: string \"not a number\", expected number");
let suggestion = analyze_tool_error("counter", &args, &error);
assert!(suggestion.is_some());
assert!(suggestion.unwrap().contains("type mismatch"));
}
#[test]
fn test_tool_execution_error_display() {
let args = json!({"path": "test.txt"});
let error = anyhow!("Missing field `file_path`");
let err = ToolExecutionError::new("read_file", args, error).with_call_id("abc123");
let display = format!("{}", err);
assert!(display.contains("Tool: read_file"));
assert!(display.contains("Call ID: abc123"));
assert!(display.contains("Arguments"));
assert!(display.contains("Help:"));
}
#[test]
fn test_extract_session_failure_kind_from_error_chain() {
let error = anyhow::Error::new(SessionFailureError::new(
SessionFailureKind::RequiredCompletionToolMissing,
"required completion tool missing",
));
assert_eq!(
extract_session_failure_kind(&error),
Some(SessionFailureKind::RequiredCompletionToolMissing)
);
}
}