use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, thiserror::Error)]
#[error("{error_type}: {message}")]
pub struct OpenResponseError {
#[serde(rename = "type")]
pub error_type: OpenResponseErrorType,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<OpenResponseErrorCode>,
#[serde(skip_serializing_if = "Option::is_none")]
pub param: Option<String>,
pub message: String,
}
impl OpenResponseError {
pub fn new(error_type: OpenResponseErrorType, message: impl Into<String>) -> Self {
Self {
error_type,
code: None,
param: None,
message: message.into(),
}
}
pub fn server_error(message: impl Into<String>) -> Self {
Self::new(OpenResponseErrorType::ServerError, message)
}
pub fn invalid_request(message: impl Into<String>) -> Self {
Self::new(OpenResponseErrorType::InvalidRequest, message)
}
pub fn invalid_param(param: impl Into<String>, message: impl Into<String>) -> Self {
Self {
error_type: OpenResponseErrorType::InvalidRequest,
code: None,
param: Some(param.into()),
message: message.into(),
}
}
pub fn model_error(message: impl Into<String>) -> Self {
Self::new(OpenResponseErrorType::ModelError, message)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(OpenResponseErrorType::NotFound, message)
}
pub fn rate_limit(message: impl Into<String>) -> Self {
Self::new(OpenResponseErrorType::TooManyRequests, message)
}
pub fn with_code(mut self, code: OpenResponseErrorCode) -> Self {
self.code = Some(code);
self
}
pub fn with_param(mut self, param: impl Into<String>) -> Self {
self.param = Some(param.into());
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OpenResponseErrorType {
ServerError,
InvalidRequest,
NotFound,
ModelError,
TooManyRequests,
AuthenticationError,
PermissionDenied,
}
impl std::fmt::Display for OpenResponseErrorType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ServerError => write!(f, "server_error"),
Self::InvalidRequest => write!(f, "invalid_request"),
Self::NotFound => write!(f, "not_found"),
Self::ModelError => write!(f, "model_error"),
Self::TooManyRequests => write!(f, "too_many_requests"),
Self::AuthenticationError => write!(f, "authentication_error"),
Self::PermissionDenied => write!(f, "permission_denied"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OpenResponseErrorCode {
InvalidApiKey,
InsufficientQuota,
ContextLengthExceeded,
InvalidModel,
ContentFilter,
ToolExecutionFailed,
Timeout,
RateLimitExceeded,
#[serde(untagged)]
Custom(String),
}
impl std::fmt::Display for OpenResponseErrorCode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::InvalidApiKey => write!(f, "invalid_api_key"),
Self::InsufficientQuota => write!(f, "insufficient_quota"),
Self::ContextLengthExceeded => write!(f, "context_length_exceeded"),
Self::InvalidModel => write!(f, "invalid_model"),
Self::ContentFilter => write!(f, "content_filter"),
Self::ToolExecutionFailed => write!(f, "tool_execution_failed"),
Self::Timeout => write!(f, "timeout"),
Self::RateLimitExceeded => write!(f, "rate_limit_exceeded"),
Self::Custom(code) => write!(f, "{}", code),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_creation() {
let err = OpenResponseError::invalid_param("model", "Invalid model ID");
assert_eq!(err.error_type, OpenResponseErrorType::InvalidRequest);
assert_eq!(err.param, Some("model".to_string()));
assert_eq!(err.message, "Invalid model ID");
}
#[test]
fn test_error_serialization() {
let err = OpenResponseError::server_error("Internal error")
.with_code(OpenResponseErrorCode::Timeout);
let json = serde_json::to_string(&err).unwrap();
assert!(json.contains("\"type\":\"server_error\""));
assert!(json.contains("\"code\":\"timeout\""));
}
}