use crate::error::StreamError;
use axum::http::StatusCode;
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorResponse {
pub status: u16,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<ErrorDetails>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ErrorDetails {
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub component: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timestamp: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub context: Option<String>,
}
pub fn map_to_http_error<T>(error: &StreamError<T>, include_details: bool) -> ErrorResponse
where
T: fmt::Debug + Clone + Send + Sync,
{
let error_message = error.to_string();
let error_lower = error_message.to_lowercase();
let (status, error_type) = if error_lower.contains("not found")
|| error_lower.contains("missing") && error_lower.contains("not found")
{
(StatusCode::NOT_FOUND, Some("NotFoundError".to_string()))
} else if error_lower.contains("unauthorized") || error_lower.contains("authentication") {
(
StatusCode::UNAUTHORIZED,
Some("UnauthorizedError".to_string()),
)
} else if error_lower.contains("forbidden") || error_lower.contains("permission") {
(StatusCode::FORBIDDEN, Some("ForbiddenError".to_string()))
} else if error_lower.contains("rate limit") || error_lower.contains("too many requests") {
(
StatusCode::TOO_MANY_REQUESTS,
Some("RateLimitError".to_string()),
)
} else if error_lower.contains("invalid")
|| error_lower.contains("validation")
|| error_lower.contains("bad request")
|| error_lower.contains("malformed")
{
(StatusCode::BAD_REQUEST, Some("ValidationError".to_string()))
} else if error_lower.contains("conflict") || error_lower.contains("already exists") {
(StatusCode::CONFLICT, Some("ConflictError".to_string()))
} else if error_lower.contains("timeout") || error_lower.contains("timed out") {
(
StatusCode::REQUEST_TIMEOUT,
Some("TimeoutError".to_string()),
)
} else if error_lower.contains("not implemented") {
(
StatusCode::NOT_IMPLEMENTED,
Some("NotImplementedError".to_string()),
)
} else if error_lower.contains("service unavailable") || error_lower.contains("unavailable") {
(
StatusCode::SERVICE_UNAVAILABLE,
Some("ServiceUnavailableError".to_string()),
)
} else {
(
StatusCode::INTERNAL_SERVER_ERROR,
Some("InternalServerError".to_string()),
)
};
let details = if include_details {
Some(ErrorDetails {
error: error_message.clone(),
component: Some(error.context.component_name.clone()),
timestamp: Some(error.context.timestamp.to_rfc3339()),
context: Some(format!("{:?}", error.context)),
})
} else {
None
};
ErrorResponse {
status: status.as_u16(),
message: error_message,
error: error_type,
details,
}
}
pub fn map_generic_error(
error: &dyn std::error::Error,
status: Option<StatusCode>,
include_details: bool,
) -> ErrorResponse {
let status = status.unwrap_or(StatusCode::INTERNAL_SERVER_ERROR);
let message = error.to_string();
let details = if include_details {
Some(ErrorDetails {
error: format!("{:?}", error),
component: None,
timestamp: Some(chrono::Utc::now().to_rfc3339()),
context: None,
})
} else {
None
};
ErrorResponse {
status: status.as_u16(),
message,
error: Some("Error".to_string()),
details,
}
}
pub fn create_custom_error(
status: StatusCode,
message: impl Into<String>,
error_type: Option<String>,
include_details: bool,
) -> ErrorResponse {
let message_str = message.into();
ErrorResponse {
status: status.as_u16(),
message: message_str.clone(),
error: error_type,
details: if include_details {
Some(ErrorDetails {
error: message_str,
component: None,
timestamp: Some(chrono::Utc::now().to_rfc3339()),
context: None,
})
} else {
None
},
}
}
pub fn is_development_mode() -> bool {
std::env::var("RUST_ENV")
.or_else(|_| std::env::var("ENVIRONMENT"))
.map(|env| env.to_lowercase() == "development" || env.to_lowercase() == "dev")
.unwrap_or(false)
}