use serde::{Deserialize, Serialize};
use std::time::Duration;
use thiserror::Error;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct ValidationError {
pub field: Option<String>,
pub message: String,
}
impl ValidationError {
#[must_use]
pub fn new(field: Option<String>, message: impl Into<String>) -> Self {
Self {
field,
message: message.into(),
}
}
}
#[derive(Debug, Error)]
pub enum ToolError {
#[error("Tool execution failed: {message}")]
ExecutionFailed {
message: String,
retryable: bool,
},
#[error("Tool not found: {0}")]
NotFound(String),
#[error("Model retry requested: {0}")]
ModelRetry(String),
#[error("Approval required for tool '{tool_name}'")]
ApprovalRequired {
tool_name: String,
args: serde_json::Value,
},
#[error("Tool call '{tool_name}' deferred")]
CallDeferred {
tool_name: String,
args: serde_json::Value,
},
#[error("Tool execution timed out after {0:?}")]
Timeout(Duration),
#[error("Tool argument validation failed for '{tool_name}'")]
ValidationFailed {
tool_name: String,
errors: Vec<ValidationError>,
},
#[error("Tool execution cancelled")]
Cancelled,
#[error("Tool returned error: {0}")]
ToolReturnedError(String),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl ToolError {
#[must_use]
pub fn is_retryable(&self) -> bool {
match self {
Self::ExecutionFailed { retryable, .. } => *retryable,
Self::ModelRetry(_) => true,
Self::Timeout(_) => true,
Self::ValidationFailed { .. } => false,
Self::NotFound(_) => false,
Self::ApprovalRequired { .. } => false,
Self::CallDeferred { .. } => false,
Self::Cancelled => false,
Self::ToolReturnedError(_) => false,
Self::Json(_) => false,
Self::Other(_) => false,
}
}
#[must_use]
pub fn execution_failed(msg: impl Into<String>) -> Self {
Self::ExecutionFailed {
message: msg.into(),
retryable: false,
}
}
#[must_use]
pub fn retryable(msg: impl Into<String>) -> Self {
Self::ExecutionFailed {
message: msg.into(),
retryable: true,
}
}
#[must_use]
pub fn validation_failed(tool_name: impl Into<String>, errors: Vec<ValidationError>) -> Self {
Self::ValidationFailed {
tool_name: tool_name.into(),
errors,
}
}
#[must_use]
pub fn validation_error(
tool_name: impl Into<String>,
field: Option<String>,
message: impl Into<String>,
) -> Self {
Self::validation_failed(tool_name, vec![ValidationError::new(field, message)])
}
#[must_use]
pub fn invalid_arguments(tool_name: impl Into<String>, message: impl Into<String>) -> Self {
Self::validation_failed(tool_name, vec![ValidationError::new(None, message)])
}
#[must_use]
pub fn not_found(name: impl Into<String>) -> Self {
Self::NotFound(name.into())
}
#[must_use]
pub fn model_retry(msg: impl Into<String>) -> Self {
Self::ModelRetry(msg.into())
}
#[must_use]
pub fn approval_required(tool_name: impl Into<String>, args: serde_json::Value) -> Self {
Self::ApprovalRequired {
tool_name: tool_name.into(),
args,
}
}
#[must_use]
pub fn call_deferred(tool_name: impl Into<String>, args: serde_json::Value) -> Self {
Self::CallDeferred {
tool_name: tool_name.into(),
args,
}
}
#[must_use]
pub fn timeout(duration: Duration) -> Self {
Self::Timeout(duration)
}
#[must_use]
pub fn message(&self) -> String {
self.to_string()
}
#[must_use]
pub fn is_approval_required(&self) -> bool {
matches!(self, Self::ApprovalRequired { .. })
}
#[must_use]
pub fn is_call_deferred(&self) -> bool {
matches!(self, Self::CallDeferred { .. })
}
#[must_use]
pub fn is_model_retry(&self) -> bool {
matches!(self, Self::ModelRetry(_))
}
}
impl From<String> for ToolError {
fn from(s: String) -> Self {
Self::execution_failed(s)
}
}
impl From<&str> for ToolError {
fn from(s: &str) -> Self {
Self::execution_failed(s)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolErrorInfo {
pub error_type: String,
pub message: String,
pub retryable: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
impl ToolErrorInfo {
#[must_use]
pub fn new(error_type: impl Into<String>, message: impl Into<String>) -> Self {
Self {
error_type: error_type.into(),
message: message.into(),
retryable: false,
details: None,
}
}
#[must_use]
pub fn retryable(mut self, retryable: bool) -> Self {
self.retryable = retryable;
self
}
#[must_use]
pub fn with_details(mut self, details: serde_json::Value) -> Self {
self.details = Some(details);
self
}
}
impl From<&ToolError> for ToolErrorInfo {
fn from(err: &ToolError) -> Self {
let error_type = match err {
ToolError::ExecutionFailed { .. } => "execution_failed",
ToolError::NotFound(_) => "not_found",
ToolError::ModelRetry(_) => "model_retry",
ToolError::ApprovalRequired { .. } => "approval_required",
ToolError::CallDeferred { .. } => "call_deferred",
ToolError::Timeout(_) => "timeout",
ToolError::ValidationFailed { .. } => "validation_failed",
ToolError::Cancelled => "cancelled",
ToolError::ToolReturnedError(_) => "tool_error",
ToolError::Json(_) => "json_error",
ToolError::Other(_) => "other",
};
let details = match err {
ToolError::ValidationFailed { errors, .. } => serde_json::to_value(errors).ok(),
_ => None,
};
Self {
error_type: error_type.to_string(),
message: err.message(),
retryable: err.is_retryable(),
details,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_execution_failed() {
let err = ToolError::execution_failed("Something went wrong");
assert!(!err.is_retryable());
assert!(err.message().contains("Something went wrong"));
}
#[test]
fn test_retryable_error() {
let err = ToolError::retryable("Temporary failure");
assert!(err.is_retryable());
}
#[test]
fn test_not_found() {
let err = ToolError::not_found("unknown_tool");
assert!(!err.is_retryable());
assert!(err.message().contains("unknown_tool"));
}
#[test]
fn test_approval_required() {
let err = ToolError::approval_required("dangerous_tool", serde_json::json!({"x": 1}));
assert!(err.is_approval_required());
assert!(!err.is_retryable());
}
#[test]
fn test_call_deferred() {
let err = ToolError::call_deferred("slow_tool", serde_json::json!({"a": "b"}));
assert!(err.is_call_deferred());
}
#[test]
fn test_timeout() {
let err = ToolError::timeout(Duration::from_secs(30));
assert!(err.is_retryable());
assert!(err.message().contains("30"));
}
#[test]
fn test_model_retry() {
let err = ToolError::model_retry("Invalid format");
assert!(err.is_model_retry());
assert!(err.is_retryable());
}
#[test]
fn test_validation_failed() {
let err =
ToolError::validation_error("test_tool", Some("field".to_string()), "Invalid value");
assert!(!err.is_retryable());
let info = ToolErrorInfo::from(&err);
assert_eq!(info.error_type, "validation_failed");
assert!(info.details.is_some());
}
#[test]
fn test_error_info_from_error() {
let err = ToolError::execution_failed("Test error");
let info = ToolErrorInfo::from(&err);
assert_eq!(info.error_type, "execution_failed");
assert!(!info.retryable);
}
#[test]
fn test_from_string() {
let err: ToolError = "error message".into();
assert!(matches!(err, ToolError::ExecutionFailed { .. }));
}
}