use thiserror::Error;
use tonic::{Code, Status};
use crate::proto::errors::{ErrorCategory, ErrorCode};
pub type NetResult<T> = Result<T, NetError>;
#[derive(Debug, Error)]
pub enum NetError {
#[error("Network timeout: {0}")]
Timeout(String),
#[error("Connection refused: {0}")]
ConnectionRefused(String),
#[error("Connection reset: {0}")]
ConnectionReset(String),
#[error("DNS resolution failed: {0}")]
DnsFailure(String),
#[error("TLS handshake failed: {0}")]
TlsHandshake(String),
#[error("Invalid request: {0}")]
InvalidRequest(String),
#[error("Unsupported protocol version: {0}")]
UnsupportedVersion(String),
#[error("Malformed message: {0}")]
MalformedMessage(String),
#[error("Missing required field: {0}")]
MissingField(String),
#[error("Authentication failed: {0}")]
AuthFailed(String),
#[error("Authentication expired: {0}")]
AuthExpired(String),
#[error("Insufficient permissions: {0}")]
InsufficientPermissions(String),
#[error("Invalid certificate: {0}")]
InvalidCertificate(String),
#[error("TLS error: {0}")]
TlsError(String),
#[error("Storage error: {0}")]
Storage(#[from] amaters_core::error::AmateRSError),
#[error("Server internal error: {0}")]
ServerInternal(String),
#[error("Server unavailable: {0}")]
ServerUnavailable(String),
#[error("Rate limit exceeded: {0}")]
RateLimitExceeded(#[from] crate::rate_limiter::RateLimitError),
#[error("Server overloaded: {0}")]
ServerOverloaded(String),
#[error("Server shutting down: {0}")]
ServerShuttingDown(String),
#[error("gRPC transport error: {0}")]
Transport(#[from] tonic::transport::Error),
#[error("gRPC status error: {0}")]
GrpcStatus(String),
#[error("Unknown error: {0}")]
Unknown(String),
}
impl NetError {
pub fn error_code(&self) -> ErrorCode {
match self {
NetError::Timeout(_) => ErrorCode::ErrorNetworkTimeout,
NetError::ConnectionRefused(_) => ErrorCode::ErrorNetworkConnectionRefused,
NetError::ConnectionReset(_) => ErrorCode::ErrorNetworkConnectionReset,
NetError::DnsFailure(_) => ErrorCode::ErrorNetworkDnsFailed,
NetError::TlsHandshake(_) => ErrorCode::ErrorNetworkTlsHandshake,
NetError::InvalidRequest(_) => ErrorCode::ErrorProtocolInvalidRequest,
NetError::UnsupportedVersion(_) => ErrorCode::ErrorProtocolUnsupportedVersion,
NetError::MalformedMessage(_) => ErrorCode::ErrorProtocolMalformedMessage,
NetError::MissingField(_) => ErrorCode::ErrorProtocolMissingField,
NetError::AuthFailed(_) => ErrorCode::ErrorAuthFailed,
NetError::AuthExpired(_) => ErrorCode::ErrorAuthExpired,
NetError::InsufficientPermissions(_) => ErrorCode::ErrorAuthInsufficientPermissions,
NetError::InvalidCertificate(_) | NetError::TlsError(_) => {
ErrorCode::ErrorAuthInvalidCertificate
}
NetError::RateLimitExceeded(_) => ErrorCode::ErrorServerOverloaded,
NetError::Storage(_) => ErrorCode::ErrorStorageIo,
NetError::ServerInternal(_) => ErrorCode::ErrorServerInternal,
NetError::ServerUnavailable(_) => ErrorCode::ErrorServerUnavailable,
NetError::ServerOverloaded(_) => ErrorCode::ErrorServerOverloaded,
NetError::ServerShuttingDown(_) => ErrorCode::ErrorServerShuttingDown,
NetError::Transport(_) | NetError::GrpcStatus(_) => ErrorCode::ErrorNetworkTimeout,
NetError::Unknown(_) => ErrorCode::ErrorUnknown,
}
}
pub fn error_category(&self) -> ErrorCategory {
match self {
NetError::Timeout(_)
| NetError::ConnectionRefused(_)
| NetError::ConnectionReset(_)
| NetError::ServerUnavailable(_)
| NetError::ServerOverloaded(_) => ErrorCategory::CategoryRetryable,
NetError::RateLimitExceeded(_) => ErrorCategory::CategoryRetryable,
NetError::AuthFailed(_) | NetError::AuthExpired(_) => ErrorCategory::CategoryAuth,
NetError::InvalidRequest(_)
| NetError::MalformedMessage(_)
| NetError::MissingField(_)
| NetError::InsufficientPermissions(_) => ErrorCategory::CategoryClientError,
NetError::ServerInternal(_) | NetError::ServerShuttingDown(_) => {
ErrorCategory::CategoryServerError
}
_ => ErrorCategory::CategoryNonRetryable,
}
}
pub fn is_retryable(&self) -> bool {
matches!(self.error_category(), ErrorCategory::CategoryRetryable)
}
}
impl From<NetError> for Status {
fn from(err: NetError) -> Self {
let code = match &err {
NetError::Timeout(_) => Code::DeadlineExceeded,
NetError::ConnectionRefused(_) | NetError::ConnectionReset(_) => Code::Unavailable,
NetError::DnsFailure(_) | NetError::TlsHandshake(_) => Code::Unavailable,
NetError::InvalidRequest(_)
| NetError::MalformedMessage(_)
| NetError::MissingField(_) => Code::InvalidArgument,
NetError::UnsupportedVersion(_) => Code::Unimplemented,
NetError::AuthFailed(_) | NetError::InvalidCertificate(_) | NetError::TlsError(_) => {
Code::Unauthenticated
}
NetError::AuthExpired(_) => Code::Unauthenticated,
NetError::InsufficientPermissions(_) => Code::PermissionDenied,
NetError::RateLimitExceeded(_) => Code::ResourceExhausted,
NetError::Storage(_) => Code::Internal,
NetError::ServerInternal(_) => Code::Internal,
NetError::ServerUnavailable(_) | NetError::ServerOverloaded(_) => Code::Unavailable,
NetError::ServerShuttingDown(_) => Code::Unavailable,
NetError::Transport(_) => Code::Unavailable,
NetError::GrpcStatus(_) => Code::Unknown,
NetError::Unknown(_) => Code::Unknown,
};
Status::new(code, err.to_string())
}
}
impl From<Status> for NetError {
fn from(status: Status) -> Self {
match status.code() {
Code::DeadlineExceeded => NetError::Timeout(status.message().to_string()),
Code::Unavailable => NetError::ServerUnavailable(status.message().to_string()),
Code::InvalidArgument => NetError::InvalidRequest(status.message().to_string()),
Code::Unimplemented => NetError::UnsupportedVersion(status.message().to_string()),
Code::Unauthenticated => NetError::AuthFailed(status.message().to_string()),
Code::PermissionDenied => {
NetError::InsufficientPermissions(status.message().to_string())
}
Code::Internal => NetError::ServerInternal(status.message().to_string()),
_ => NetError::Unknown(status.message().to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_code_mapping() {
let err = NetError::Timeout("test".to_string());
assert_eq!(err.error_code(), ErrorCode::ErrorNetworkTimeout);
let err = NetError::AuthFailed("test".to_string());
assert_eq!(err.error_code(), ErrorCode::ErrorAuthFailed);
}
#[test]
fn test_error_category() {
let err = NetError::Timeout("test".to_string());
assert_eq!(err.error_category(), ErrorCategory::CategoryRetryable);
assert!(err.is_retryable());
let err = NetError::InvalidRequest("test".to_string());
assert_eq!(err.error_category(), ErrorCategory::CategoryClientError);
assert!(!err.is_retryable());
}
#[test]
fn test_status_conversion() {
let err = NetError::Timeout("timeout".to_string());
let status: Status = err.into();
assert_eq!(status.code(), Code::DeadlineExceeded);
}
#[test]
fn test_status_from_error() {
let err = NetError::Timeout("timeout".to_string());
let status: Status = err.into();
assert_eq!(status.code(), Code::DeadlineExceeded);
let err2 = NetError::GrpcStatus("grpc error".to_string());
let status2: Status = err2.into();
assert_eq!(status2.code(), Code::Unknown);
}
}