use serdes_ai_models::ModelError;
use serdes_ai_tools::ToolError;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum AgentRunError {
#[error("Model error: {0}")]
Model(#[from] ModelError),
#[error("Tool error: {0}")]
Tool(#[from] ToolError),
#[error("Output validation failed: {0}")]
OutputValidationFailed(#[source] OutputValidationError),
#[error("Output parsing failed: {0}")]
OutputParseFailed(#[source] OutputParseError),
#[error("Usage limit exceeded: {0}")]
UsageLimitExceeded(#[from] UsageLimitError),
#[error("Model stopped unexpectedly without output")]
UnexpectedStop,
#[error("No output produced")]
NoOutput,
#[error("Max retries exceeded: {message}")]
MaxRetriesExceeded {
message: String,
},
#[error("Serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("Configuration error: {0}")]
Configuration(String),
#[error("Agent run was cancelled")]
Cancelled,
#[error("Agent run timed out after {seconds}s")]
Timeout {
seconds: u64,
},
#[error("Provider error: {0}")]
Provider(String),
#[error(transparent)]
Other(#[from] anyhow::Error),
}
impl AgentRunError {
pub fn max_retries(message: impl Into<String>) -> Self {
Self::MaxRetriesExceeded {
message: message.into(),
}
}
pub fn config(message: impl Into<String>) -> Self {
Self::Configuration(message.into())
}
pub fn timeout(seconds: u64) -> Self {
Self::Timeout { seconds }
}
pub fn is_retryable(&self) -> bool {
match self {
Self::Model(e) => e.is_retryable(),
Self::Tool(e) => e.is_retryable(),
Self::UsageLimitExceeded(_) => false,
Self::Cancelled => false,
Self::Timeout { .. } => false,
Self::MaxRetriesExceeded { .. } => false,
_ => true,
}
}
}
#[derive(Debug, Error)]
pub enum OutputValidationError {
#[error("Validation failed: {message}")]
ValidationFailed {
message: String,
field: Option<String>,
},
#[error("Custom validation failed: {0}")]
Custom(String),
#[error("Type mismatch: expected {expected}, got {actual}")]
TypeMismatch {
expected: String,
actual: String,
},
#[error("Missing required field: {0}")]
MissingField(String),
}
impl OutputValidationError {
pub fn failed(message: impl Into<String>) -> Self {
Self::ValidationFailed {
message: message.into(),
field: None,
}
}
pub fn field_failed(field: impl Into<String>, message: impl Into<String>) -> Self {
Self::ValidationFailed {
message: message.into(),
field: Some(field.into()),
}
}
pub fn custom(message: impl Into<String>) -> Self {
Self::Custom(message.into())
}
pub fn retry_message(&self) -> String {
match self {
Self::ValidationFailed { message, field } => {
if let Some(f) = field {
format!("Validation failed for field '{}': {}", f, message)
} else {
format!("Validation failed: {}", message)
}
}
Self::Custom(msg) => msg.clone(),
Self::TypeMismatch { expected, actual } => {
format!("Type mismatch: expected {}, got {}", expected, actual)
}
Self::MissingField(field) => {
format!("Missing required field: {}", field)
}
}
}
}
#[derive(Debug, Error)]
pub enum OutputParseError {
#[error("JSON parse error: {0}")]
Json(#[from] serde_json::Error),
#[error("No output found in model response")]
NotFound,
#[error("Output tool was not called by the model")]
ToolNotCalled,
#[error("Invalid output format: {0}")]
InvalidFormat(String),
#[error("Output does not match schema: {0}")]
SchemaMismatch(String),
}
impl OutputParseError {
pub fn invalid_format(message: impl Into<String>) -> Self {
Self::InvalidFormat(message.into())
}
pub fn schema_mismatch(message: impl Into<String>) -> Self {
Self::SchemaMismatch(message.into())
}
}
#[derive(Debug, Error)]
pub enum UsageLimitError {
#[error("Request token limit exceeded: {used} > {limit}")]
RequestTokens {
used: u64,
limit: u64,
},
#[error("Response token limit exceeded: {used} > {limit}")]
ResponseTokens {
used: u64,
limit: u64,
},
#[error("Total token limit exceeded: {used} > {limit}")]
TotalTokens {
used: u64,
limit: u64,
},
#[error("Request count limit exceeded: {count} > {limit}")]
RequestCount {
count: u32,
limit: u32,
},
#[error("Tool call limit exceeded: {count} > {limit}")]
ToolCalls {
count: u32,
limit: u32,
},
#[error("Time limit exceeded: {elapsed_seconds}s > {limit_seconds}s")]
TimeLimit {
elapsed_seconds: u64,
limit_seconds: u64,
},
}
#[derive(Debug, Error)]
pub enum AgentBuildError {
#[error("Missing required field: {0}")]
MissingField(&'static str),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Tool registration error: {0}")]
ToolRegistration(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_agent_run_error_display() {
let err = AgentRunError::NoOutput;
assert_eq!(err.to_string(), "No output produced");
let err = AgentRunError::max_retries("tool call");
assert!(err.to_string().contains("tool call"));
}
#[test]
fn test_output_validation_error() {
let err = OutputValidationError::failed("invalid value");
assert!(err.retry_message().contains("invalid value"));
let err = OutputValidationError::field_failed("name", "too short");
assert!(err.retry_message().contains("name"));
assert!(err.retry_message().contains("too short"));
}
#[test]
fn test_usage_limit_error() {
let err = UsageLimitError::TotalTokens {
used: 1000,
limit: 500,
};
assert!(err.to_string().contains("1000"));
assert!(err.to_string().contains("500"));
}
#[test]
fn test_is_retryable() {
assert!(!AgentRunError::Cancelled.is_retryable());
assert!(!AgentRunError::timeout(60).is_retryable());
assert!(AgentRunError::NoOutput.is_retryable());
}
}