use http::StatusCode;
use serde::Serialize;
use std::fmt;
use std::sync::OnceLock;
use uuid::Uuid;
pub type Result<T, E = ApiError> = std::result::Result<T, E>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Environment {
#[default]
Development,
Production,
}
impl Environment {
pub fn from_env() -> Self {
match std::env::var("RUSTAPI_ENV")
.map(|s| s.to_lowercase())
.as_deref()
{
Ok("production") | Ok("prod") => Environment::Production,
_ => Environment::Development,
}
}
pub fn is_production(&self) -> bool {
matches!(self, Environment::Production)
}
pub fn is_development(&self) -> bool {
matches!(self, Environment::Development)
}
}
impl fmt::Display for Environment {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Environment::Development => write!(f, "development"),
Environment::Production => write!(f, "production"),
}
}
}
static ENVIRONMENT: OnceLock<Environment> = OnceLock::new();
pub fn get_environment() -> Environment {
*ENVIRONMENT.get_or_init(Environment::from_env)
}
#[cfg(test)]
#[allow(dead_code)]
pub fn set_environment_for_test(env: Environment) -> Result<(), Environment> {
ENVIRONMENT.set(env)
}
pub fn generate_error_id() -> String {
format!("err_{}", Uuid::new_v4().simple())
}
#[derive(Debug, Clone)]
pub struct ApiError {
pub status: StatusCode,
pub error_type: String,
pub message: String,
pub fields: Option<Vec<FieldError>>,
pub(crate) internal: Option<String>,
}
#[derive(Debug, Clone, Serialize)]
pub struct FieldError {
pub field: String,
pub code: String,
pub message: String,
}
impl ApiError {
pub fn new(
status: StatusCode,
error_type: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self {
status,
error_type: error_type.into(),
message: message.into(),
fields: None,
internal: None,
}
}
pub fn validation(fields: Vec<FieldError>) -> Self {
Self {
status: StatusCode::UNPROCESSABLE_ENTITY,
error_type: "validation_error".to_string(),
message: "Request validation failed".to_string(),
fields: Some(fields),
internal: None,
}
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self::new(StatusCode::BAD_REQUEST, "bad_request", message)
}
pub fn unauthorized(message: impl Into<String>) -> Self {
Self::new(StatusCode::UNAUTHORIZED, "unauthorized", message)
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self::new(StatusCode::FORBIDDEN, "forbidden", message)
}
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(StatusCode::NOT_FOUND, "not_found", message)
}
pub fn conflict(message: impl Into<String>) -> Self {
Self::new(StatusCode::CONFLICT, "conflict", message)
}
pub fn internal(message: impl Into<String>) -> Self {
Self::new(StatusCode::INTERNAL_SERVER_ERROR, "internal_error", message)
}
pub fn with_internal(mut self, details: impl Into<String>) -> Self {
self.internal = Some(details.into());
self
}
}
impl fmt::Display for ApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.error_type, self.message)
}
}
impl std::error::Error for ApiError {}
#[derive(Serialize)]
pub struct ErrorResponse {
pub error: ErrorBody,
pub error_id: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub request_id: Option<String>,
}
#[derive(Serialize)]
pub struct ErrorBody {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub fields: Option<Vec<FieldError>>,
}
impl ErrorResponse {
pub fn from_api_error(err: ApiError, env: Environment) -> Self {
let error_id = generate_error_id();
if err.status.is_server_error() {
crate::trace_error!(
error_id = %error_id,
error_type = %err.error_type,
message = %err.message,
status = %err.status.as_u16(),
internal = ?err.internal,
environment = %env,
"Server error occurred"
);
} else if err.status.is_client_error() {
crate::trace_warn!(
error_id = %error_id,
error_type = %err.error_type,
message = %err.message,
status = %err.status.as_u16(),
environment = %env,
"Client error occurred"
);
} else {
crate::trace_info!(
error_id = %error_id,
error_type = %err.error_type,
message = %err.message,
status = %err.status.as_u16(),
environment = %env,
"Error response generated"
);
}
let (message, fields) = if env.is_production() && err.status.is_server_error() {
let masked_message = "An internal error occurred".to_string();
let fields = if err.error_type == "validation_error" {
err.fields
} else {
None
};
(masked_message, fields)
} else {
(err.message, err.fields)
};
Self {
error: ErrorBody {
error_type: err.error_type,
message,
fields,
},
error_id,
request_id: None,
}
}
}
impl From<ApiError> for ErrorResponse {
fn from(err: ApiError) -> Self {
let env = get_environment();
Self::from_api_error(err, env)
}
}
impl From<serde_json::Error> for ApiError {
fn from(err: serde_json::Error) -> Self {
ApiError::bad_request(format!("Invalid JSON: {}", err))
}
}
impl From<crate::json::JsonError> for ApiError {
fn from(err: crate::json::JsonError) -> Self {
ApiError::bad_request(format!("Invalid JSON: {}", err))
}
}
impl From<std::io::Error> for ApiError {
fn from(err: std::io::Error) -> Self {
ApiError::internal("I/O error").with_internal(err.to_string())
}
}
impl From<hyper::Error> for ApiError {
fn from(err: hyper::Error) -> Self {
ApiError::internal("HTTP error").with_internal(err.to_string())
}
}
impl From<rustapi_validate::ValidationError> for ApiError {
fn from(err: rustapi_validate::ValidationError) -> Self {
let fields = err
.fields
.into_iter()
.map(|f| FieldError {
field: f.field,
code: f.code,
message: f.message,
})
.collect();
ApiError::validation(fields)
}
}
impl From<rustapi_validate::v2::ValidationErrors> for ApiError {
fn from(err: rustapi_validate::v2::ValidationErrors) -> Self {
let fields = err
.fields
.into_iter()
.flat_map(|(field, errors)| {
errors.into_iter().map(move |e| {
let message = e.interpolate_message();
FieldError {
field: field.clone(),
code: e.code,
message,
}
})
})
.collect();
ApiError::validation(fields)
}
}
impl ApiError {
pub fn from_validation_error(err: rustapi_validate::ValidationError) -> Self {
err.into()
}
pub fn service_unavailable(message: impl Into<String>) -> Self {
Self::new(
StatusCode::SERVICE_UNAVAILABLE,
"service_unavailable",
message,
)
}
}
#[cfg(feature = "sqlx")]
impl From<sqlx::Error> for ApiError {
fn from(err: sqlx::Error) -> Self {
match &err {
sqlx::Error::PoolTimedOut => {
ApiError::service_unavailable("Database connection pool exhausted")
.with_internal(err.to_string())
}
sqlx::Error::PoolClosed => {
ApiError::service_unavailable("Database connection pool is closed")
.with_internal(err.to_string())
}
sqlx::Error::RowNotFound => ApiError::not_found("Resource not found"),
sqlx::Error::Database(db_err) => {
if let Some(code) = db_err.code() {
let code_str = code.as_ref();
if code_str == "23505" || code_str == "1062" || code_str == "2067" {
return ApiError::conflict("Resource already exists")
.with_internal(db_err.to_string());
}
if code_str == "23503" || code_str == "1452" || code_str == "787" {
return ApiError::bad_request("Referenced resource does not exist")
.with_internal(db_err.to_string());
}
if code_str == "23514" {
return ApiError::bad_request("Data validation failed")
.with_internal(db_err.to_string());
}
}
ApiError::internal("Database error").with_internal(db_err.to_string())
}
sqlx::Error::Io(_) => ApiError::service_unavailable("Database connection error")
.with_internal(err.to_string()),
sqlx::Error::Tls(_) => {
ApiError::service_unavailable("Database TLS error").with_internal(err.to_string())
}
sqlx::Error::Protocol(_) => {
ApiError::internal("Database protocol error").with_internal(err.to_string())
}
sqlx::Error::TypeNotFound { .. } => {
ApiError::internal("Database type error").with_internal(err.to_string())
}
sqlx::Error::ColumnNotFound(_) => {
ApiError::internal("Database column not found").with_internal(err.to_string())
}
sqlx::Error::ColumnIndexOutOfBounds { .. } => {
ApiError::internal("Database column index error").with_internal(err.to_string())
}
sqlx::Error::ColumnDecode { .. } => {
ApiError::internal("Database decode error").with_internal(err.to_string())
}
sqlx::Error::Configuration(_) => {
ApiError::internal("Database configuration error").with_internal(err.to_string())
}
sqlx::Error::Migrate(_) => {
ApiError::internal("Database migration error").with_internal(err.to_string())
}
_ => ApiError::internal("Database error").with_internal(err.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use proptest::prelude::*;
use std::collections::HashSet;
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_error_id_uniqueness(
num_errors in 10usize..200,
) {
let error_ids: Vec<String> = (0..num_errors)
.map(|_| generate_error_id())
.collect();
let unique_ids: HashSet<&String> = error_ids.iter().collect();
prop_assert_eq!(
unique_ids.len(),
error_ids.len(),
"Generated {} error IDs but only {} were unique",
error_ids.len(),
unique_ids.len()
);
for id in &error_ids {
prop_assert!(
id.starts_with("err_"),
"Error ID '{}' does not start with 'err_'",
id
);
let uuid_part = &id[4..];
prop_assert_eq!(
uuid_part.len(),
32,
"UUID part '{}' should be 32 characters, got {}",
uuid_part,
uuid_part.len()
);
prop_assert!(
uuid_part.chars().all(|c| c.is_ascii_hexdigit()),
"UUID part '{}' contains non-hex characters",
uuid_part
);
}
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_error_response_contains_error_id(
error_type in "[a-z_]{1,20}",
message in "[a-zA-Z0-9 ]{1,100}",
) {
let api_error = ApiError::new(StatusCode::INTERNAL_SERVER_ERROR, error_type, message);
let error_response = ErrorResponse::from(api_error);
prop_assert!(
error_response.error_id.starts_with("err_"),
"Error ID '{}' does not start with 'err_'",
error_response.error_id
);
let uuid_part = &error_response.error_id[4..];
prop_assert_eq!(uuid_part.len(), 32);
prop_assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
}
}
#[test]
fn test_error_id_format() {
let error_id = generate_error_id();
assert!(error_id.starts_with("err_"));
assert_eq!(error_id.len(), 36);
let uuid_part = &error_id[4..];
assert!(uuid_part.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn test_error_response_includes_error_id() {
let api_error = ApiError::bad_request("test error");
let error_response = ErrorResponse::from(api_error);
assert!(error_response.error_id.starts_with("err_"));
assert_eq!(error_response.error_id.len(), 36);
}
#[test]
fn test_error_id_in_json_serialization() {
let api_error = ApiError::internal("test error");
let error_response = ErrorResponse::from(api_error);
let json = serde_json::to_string(&error_response).unwrap();
assert!(json.contains("\"error_id\":"));
assert!(json.contains("err_"));
}
#[test]
fn test_multiple_error_ids_are_unique() {
let ids: Vec<String> = (0..1000).map(|_| generate_error_id()).collect();
let unique: HashSet<_> = ids.iter().collect();
assert_eq!(ids.len(), unique.len(), "All error IDs should be unique");
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_production_error_masking(
sensitive_message in "[a-zA-Z0-9_]{10,200}",
internal_details in "[a-zA-Z0-9_]{10,200}",
status_code in prop::sample::select(vec![500u16, 501, 502, 503, 504, 505]),
) {
let api_error = ApiError::new(
StatusCode::from_u16(status_code).unwrap(),
"internal_error",
sensitive_message.clone()
).with_internal(internal_details.clone());
let error_response = ErrorResponse::from_api_error(api_error, Environment::Production);
prop_assert_eq!(
&error_response.error.message,
"An internal error occurred",
"Production 5xx error should have masked message, got: {}",
&error_response.error.message
);
if sensitive_message.len() >= 10 {
prop_assert!(
!error_response.error.message.contains(&sensitive_message),
"Production error response should not contain original message"
);
}
let json = serde_json::to_string(&error_response).unwrap();
if internal_details.len() >= 10 {
prop_assert!(
!json.contains(&internal_details),
"Production error response should not contain internal details"
);
}
prop_assert!(
error_response.error_id.starts_with("err_"),
"Error ID should be present in production error response"
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_development_error_details(
error_message in "[a-zA-Z0-9 ]{1,100}",
error_type in "[a-z_]{1,20}",
status_code in prop::sample::select(vec![400u16, 401, 403, 404, 500, 502, 503]),
) {
let api_error = ApiError::new(
StatusCode::from_u16(status_code).unwrap(),
error_type.clone(),
error_message.clone()
);
let error_response = ErrorResponse::from_api_error(api_error, Environment::Development);
prop_assert_eq!(
error_response.error.message,
error_message,
"Development error should preserve original message"
);
prop_assert_eq!(
error_response.error.error_type,
error_type,
"Development error should preserve error type"
);
prop_assert!(
error_response.error_id.starts_with("err_"),
"Error ID should be present in development error response"
);
}
}
proptest! {
#![proptest_config(ProptestConfig::with_cases(100))]
#[test]
fn prop_validation_error_field_details(
field_name in "[a-z_]{1,20}",
field_code in "[a-z_]{1,15}",
field_message in "[a-zA-Z0-9 ]{1,50}",
is_production in proptest::bool::ANY,
) {
let env = if is_production {
Environment::Production
} else {
Environment::Development
};
let field_error = FieldError {
field: field_name.clone(),
code: field_code.clone(),
message: field_message.clone(),
};
let api_error = ApiError::validation(vec![field_error]);
let error_response = ErrorResponse::from_api_error(api_error, env);
prop_assert!(
error_response.error.fields.is_some(),
"Validation error should always include fields in {} mode",
env
);
let fields = error_response.error.fields.as_ref().unwrap();
prop_assert_eq!(
fields.len(),
1,
"Should have exactly one field error"
);
let field = &fields[0];
prop_assert_eq!(
&field.field,
&field_name,
"Field name should be preserved in {} mode",
env
);
prop_assert_eq!(
&field.code,
&field_code,
"Field code should be preserved in {} mode",
env
);
prop_assert_eq!(
&field.message,
&field_message,
"Field message should be preserved in {} mode",
env
);
let json = serde_json::to_string(&error_response).unwrap();
prop_assert!(
json.contains(&field_name),
"JSON should contain field name in {} mode",
env
);
prop_assert!(
json.contains(&field_code),
"JSON should contain field code in {} mode",
env
);
prop_assert!(
json.contains(&field_message),
"JSON should contain field message in {} mode",
env
);
}
}
#[test]
fn test_environment_from_env_production() {
assert!(matches!(
match "production".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Production
));
assert!(matches!(
match "prod".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Production
));
assert!(matches!(
match "PRODUCTION".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Production
));
assert!(matches!(
match "PROD".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Production
));
}
#[test]
fn test_environment_from_env_development() {
assert!(matches!(
match "development".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Development
));
assert!(matches!(
match "dev".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Development
));
assert!(matches!(
match "test".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Development
));
assert!(matches!(
match "anything_else".to_lowercase().as_str() {
"production" | "prod" => Environment::Production,
_ => Environment::Development,
},
Environment::Development
));
}
#[test]
fn test_environment_default_is_development() {
assert_eq!(Environment::default(), Environment::Development);
}
#[test]
fn test_environment_display() {
assert_eq!(format!("{}", Environment::Development), "development");
assert_eq!(format!("{}", Environment::Production), "production");
}
#[test]
fn test_environment_is_methods() {
assert!(Environment::Production.is_production());
assert!(!Environment::Production.is_development());
assert!(Environment::Development.is_development());
assert!(!Environment::Development.is_production());
}
#[test]
fn test_production_masks_5xx_errors() {
let error =
ApiError::internal("Sensitive database connection string: postgres://user:pass@host");
let response = ErrorResponse::from_api_error(error, Environment::Production);
assert_eq!(response.error.message, "An internal error occurred");
assert!(!response.error.message.contains("postgres"));
}
#[test]
fn test_production_shows_4xx_errors() {
let error = ApiError::bad_request("Invalid email format");
let response = ErrorResponse::from_api_error(error, Environment::Production);
assert_eq!(response.error.message, "Invalid email format");
}
#[test]
fn test_development_shows_all_errors() {
let error = ApiError::internal("Detailed error: connection refused to 192.168.1.1:5432");
let response = ErrorResponse::from_api_error(error, Environment::Development);
assert_eq!(
response.error.message,
"Detailed error: connection refused to 192.168.1.1:5432"
);
}
#[test]
fn test_validation_errors_always_show_fields() {
let fields = vec![
FieldError {
field: "email".to_string(),
code: "invalid_format".to_string(),
message: "Invalid email format".to_string(),
},
FieldError {
field: "age".to_string(),
code: "min".to_string(),
message: "Must be at least 18".to_string(),
},
];
let error = ApiError::validation(fields.clone());
let prod_response = ErrorResponse::from_api_error(error.clone(), Environment::Production);
assert!(prod_response.error.fields.is_some());
let prod_fields = prod_response.error.fields.unwrap();
assert_eq!(prod_fields.len(), 2);
assert_eq!(prod_fields[0].field, "email");
assert_eq!(prod_fields[1].field, "age");
let dev_response = ErrorResponse::from_api_error(error, Environment::Development);
assert!(dev_response.error.fields.is_some());
let dev_fields = dev_response.error.fields.unwrap();
assert_eq!(dev_fields.len(), 2);
}
}