use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorCode {
InvalidInput,
Unauthenticated,
Forbidden,
NotFound,
Conflict,
FailedPrecondition,
RateLimited,
Internal,
NotImplemented,
Unavailable,
}
impl ErrorCode {
pub fn http_status(&self) -> u16 {
match self {
ErrorCode::InvalidInput => 400,
ErrorCode::Unauthenticated => 401,
ErrorCode::Forbidden => 403,
ErrorCode::NotFound => 404,
ErrorCode::Conflict => 409,
ErrorCode::FailedPrecondition => 422,
ErrorCode::RateLimited => 429,
ErrorCode::Internal => 500,
ErrorCode::NotImplemented => 501,
ErrorCode::Unavailable => 503,
}
}
pub fn exit_code(&self) -> i32 {
match self {
ErrorCode::NotFound => 1,
ErrorCode::InvalidInput => 2,
ErrorCode::Unauthenticated | ErrorCode::Forbidden => 3,
ErrorCode::Conflict | ErrorCode::FailedPrecondition => 4,
ErrorCode::RateLimited => 5,
ErrorCode::Internal | ErrorCode::Unavailable => 1,
ErrorCode::NotImplemented => 1,
}
}
pub fn grpc_code(&self) -> &'static str {
match self {
ErrorCode::InvalidInput => "INVALID_ARGUMENT",
ErrorCode::Unauthenticated => "UNAUTHENTICATED",
ErrorCode::Forbidden => "PERMISSION_DENIED",
ErrorCode::NotFound => "NOT_FOUND",
ErrorCode::Conflict => "ALREADY_EXISTS",
ErrorCode::FailedPrecondition => "FAILED_PRECONDITION",
ErrorCode::RateLimited => "RESOURCE_EXHAUSTED",
ErrorCode::Internal => "INTERNAL",
ErrorCode::NotImplemented => "UNIMPLEMENTED",
ErrorCode::Unavailable => "UNAVAILABLE",
}
}
pub fn infer_from_name(name: &str) -> Self {
let name_lower = name.to_lowercase();
if name_lower.contains("notfound")
|| name_lower.contains("not_found")
|| name_lower.contains("missing")
{
ErrorCode::NotFound
} else if name_lower.contains("invalid")
|| name_lower.contains("validation")
|| name_lower.contains("parse")
{
ErrorCode::InvalidInput
} else if name_lower.contains("unauthorized") || name_lower.contains("unauthenticated") {
ErrorCode::Unauthenticated
} else if name_lower.contains("forbidden")
|| name_lower.contains("permission")
|| name_lower.contains("denied")
{
ErrorCode::Forbidden
} else if name_lower.contains("conflict")
|| name_lower.contains("exists")
|| name_lower.contains("duplicate")
{
ErrorCode::Conflict
} else if name_lower.contains("ratelimit")
|| name_lower.contains("rate_limit")
|| name_lower.contains("throttle")
{
ErrorCode::RateLimited
} else if name_lower.contains("unavailable") || name_lower.contains("temporarily") {
ErrorCode::Unavailable
} else if name_lower.contains("unimplemented") || name_lower.contains("not_implemented") {
ErrorCode::NotImplemented
} else {
ErrorCode::Internal
}
}
}
pub trait IntoErrorCode {
fn error_code(&self) -> ErrorCode;
fn message(&self) -> String;
}
impl IntoErrorCode for std::io::Error {
fn error_code(&self) -> ErrorCode {
match self.kind() {
std::io::ErrorKind::NotFound => ErrorCode::NotFound,
std::io::ErrorKind::PermissionDenied => ErrorCode::Forbidden,
std::io::ErrorKind::InvalidInput | std::io::ErrorKind::InvalidData => {
ErrorCode::InvalidInput
}
_ => ErrorCode::Internal,
}
}
fn message(&self) -> String {
self.to_string()
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ErrorResponse {
pub code: String,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub details: Option<serde_json::Value>,
}
impl ErrorResponse {
pub fn new(code: ErrorCode, message: impl Into<String>) -> Self {
Self {
code: format!("{:?}", code).to_uppercase(),
message: message.into(),
details: None,
}
}
pub fn with_details(mut self, details: serde_json::Value) -> Self {
self.details = Some(details);
self
}
}
impl fmt::Display for ErrorResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.code, self.message)
}
}
impl std::error::Error for ErrorResponse {}
#[derive(Debug, Clone)]
pub struct SchemaValidationError {
pub schema_type: String,
pub missing_lines: Vec<String>,
pub extra_lines: Vec<String>,
}
impl SchemaValidationError {
pub fn new(schema_type: impl Into<String>) -> Self {
Self {
schema_type: schema_type.into(),
missing_lines: Vec::new(),
extra_lines: Vec::new(),
}
}
pub fn add_missing(&mut self, line: impl Into<String>) {
self.missing_lines.push(line.into());
}
pub fn add_extra(&mut self, line: impl Into<String>) {
self.extra_lines.push(line.into());
}
pub fn has_differences(&self) -> bool {
!self.missing_lines.is_empty() || !self.extra_lines.is_empty()
}
}
impl fmt::Display for SchemaValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
writeln!(f, "{} schema validation failed:", self.schema_type)?;
if !self.missing_lines.is_empty() {
writeln!(f, "\nExpected methods/messages not found in generated:")?;
for line in &self.missing_lines {
writeln!(f, " - {}", line)?;
}
}
if !self.extra_lines.is_empty() {
writeln!(f, "\nGenerated methods/messages not in expected:")?;
for line in &self.extra_lines {
writeln!(f, " + {}", line)?;
}
}
writeln!(f)?;
writeln!(f, "Hints:")?;
if !self.missing_lines.is_empty() && !self.extra_lines.is_empty() {
writeln!(
f,
" - Method signature or type may have changed. Check parameter names and types."
)?;
}
if !self.missing_lines.is_empty() {
writeln!(
f,
" - Missing items may indicate removed or renamed methods in Rust code."
)?;
}
if !self.extra_lines.is_empty() {
writeln!(
f,
" - Extra items may indicate new methods added. Update the schema file."
)?;
}
writeln!(
f,
" - Run `write_{schema}()` to regenerate the schema file.",
schema = self.schema_type.to_lowercase()
)?;
Ok(())
}
}
impl std::error::Error for SchemaValidationError {}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_code_inference() {
assert_eq!(ErrorCode::infer_from_name("NotFound"), ErrorCode::NotFound);
assert_eq!(
ErrorCode::infer_from_name("UserNotFound"),
ErrorCode::NotFound
);
assert_eq!(
ErrorCode::infer_from_name("InvalidEmail"),
ErrorCode::InvalidInput
);
assert_eq!(
ErrorCode::infer_from_name("Forbidden"),
ErrorCode::Forbidden
);
assert_eq!(
ErrorCode::infer_from_name("AlreadyExists"),
ErrorCode::Conflict
);
}
#[test]
fn test_http_status_codes() {
assert_eq!(ErrorCode::NotFound.http_status(), 404);
assert_eq!(ErrorCode::InvalidInput.http_status(), 400);
assert_eq!(ErrorCode::Internal.http_status(), 500);
}
}