use aion_proto::{ProtoWireError, WireError, WireErrorCode};
use prost::Message;
use tonic::Code;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ErrorDetail {
pub message: String,
pub error_type: Option<String>,
}
impl ErrorDetail {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
error_type: None,
}
}
#[must_use]
pub fn with_type(message: impl Into<String>, error_type: impl Into<String>) -> Self {
Self {
message: message.into(),
error_type: Some(error_type.into()),
}
}
}
impl std::fmt::Display for ErrorDetail {
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.error_type {
Some(error_type) => write!(formatter, "{} [{error_type}]", self.message),
None => formatter.write_str(&self.message),
}
}
}
impl From<String> for ErrorDetail {
fn from(message: String) -> Self {
Self::new(message)
}
}
impl From<&str> for ErrorDetail {
fn from(message: &str) -> Self {
Self::new(message)
}
}
impl From<WireError> for ErrorDetail {
fn from(error: WireError) -> Self {
Self {
message: error.message,
error_type: error.error_type,
}
}
}
#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
pub enum ClientError {
#[error("not_found: {detail}")]
NotFound {
detail: ErrorDetail,
},
#[error("already_exists: {detail}")]
AlreadyExists {
detail: ErrorDetail,
},
#[error("query_failed: {detail}")]
QueryFailed {
detail: ErrorDetail,
},
#[error("query_timeout: {detail}")]
QueryTimeout {
detail: ErrorDetail,
},
#[error("unknown_query: {detail}")]
UnknownQuery {
detail: ErrorDetail,
},
#[error("not_running: {detail}")]
NotRunning {
detail: ErrorDetail,
},
#[error("cancelled: {detail}")]
Cancelled {
detail: ErrorDetail,
},
#[error("unavailable: {detail}")]
Unavailable {
detail: ErrorDetail,
},
#[error("unauthenticated: {detail}")]
Unauthenticated {
detail: ErrorDetail,
},
#[error("namespace_denied: {detail}")]
NamespaceDenied {
detail: ErrorDetail,
},
#[error("invalid_input: {detail}")]
InvalidArgument {
detail: ErrorDetail,
},
#[error("backend: {detail}")]
Server {
detail: ErrorDetail,
},
}
macro_rules! detail_constructors {
($(($constructor:ident, $variant:ident, $doc:literal)),+ $(,)?) => {
$(
#[doc = $doc]
#[must_use]
pub fn $constructor(detail: impl Into<ErrorDetail>) -> Self {
Self::$variant {
detail: detail.into(),
}
}
)+
};
}
impl ClientError {
detail_constructors!(
(not_found, NotFound, "Creates a not-found error."),
(
already_exists,
AlreadyExists,
"Creates an idempotency-conflict error."
),
(
query_failed,
QueryFailed,
"Creates a query-handler failure."
),
(query_timeout, QueryTimeout, "Creates a query timeout."),
(
unknown_query,
UnknownQuery,
"Creates an unknown-query error."
),
(not_running, NotRunning, "Creates a not-running error."),
(cancelled, Cancelled, "Creates a cancellation error."),
(
unavailable,
Unavailable,
"Creates a transport-unavailable error."
),
(
unauthenticated,
Unauthenticated,
"Creates a credential-rejection error."
),
(
namespace_denied,
NamespaceDenied,
"Creates a namespace-grant denial."
),
(
invalid_argument,
InvalidArgument,
"Creates an [`ClientError::InvalidArgument`] carrying a precise message."
),
(
server,
Server,
"Creates an unexpected-server-failure error from a local conversion or server detail."
),
);
#[must_use]
pub const fn class(&self) -> &'static str {
match self {
Self::NotFound { .. } => "not_found",
Self::AlreadyExists { .. } => "already_exists",
Self::QueryFailed { .. } => "query_failed",
Self::QueryTimeout { .. } => "query_timeout",
Self::UnknownQuery { .. } => "unknown_query",
Self::NotRunning { .. } => "not_running",
Self::Cancelled { .. } => "cancelled",
Self::Unavailable { .. } => "unavailable",
Self::Unauthenticated { .. } => "unauthenticated",
Self::NamespaceDenied { .. } => "namespace_denied",
Self::InvalidArgument { .. } => "invalid_input",
Self::Server { .. } => "backend",
}
}
#[must_use]
pub const fn detail(&self) -> &ErrorDetail {
match self {
Self::NotFound { detail }
| Self::AlreadyExists { detail }
| Self::QueryFailed { detail }
| Self::QueryTimeout { detail }
| Self::UnknownQuery { detail }
| Self::NotRunning { detail }
| Self::Cancelled { detail }
| Self::Unavailable { detail }
| Self::Unauthenticated { detail }
| Self::NamespaceDenied { detail }
| Self::InvalidArgument { detail }
| Self::Server { detail } => detail,
}
}
#[must_use]
pub fn from_wire_error(error: WireError) -> Self {
let code = error.code;
let detail = ErrorDetail::from(error);
match code {
WireErrorCode::NotFound => Self::NotFound { detail },
WireErrorCode::NamespaceDenied => Self::NamespaceDenied { detail },
WireErrorCode::UnknownQuery => Self::UnknownQuery { detail },
WireErrorCode::NotRunning => Self::NotRunning { detail },
WireErrorCode::InvalidInput => Self::InvalidArgument { detail },
WireErrorCode::SequenceConflict | WireErrorCode::Backend => Self::Server { detail },
WireErrorCode::QueryFailed => Self::QueryFailed { detail },
WireErrorCode::QueryTimeout => Self::QueryTimeout { detail },
WireErrorCode::Lagged => Self::Unavailable { detail },
}
}
#[must_use]
pub fn from_proto_wire_error(error: ProtoWireError) -> Self {
match WireError::try_from(error) {
Ok(error) | Err(error) => Self::from_wire_error(error),
}
}
#[must_use]
pub fn from_status(status: &tonic::Status) -> Self {
if let Some(error) = decode_status_details(status) {
return Self::from_proto_wire_error(error);
}
let detail = ErrorDetail::new(status.message());
match status.code() {
Code::NotFound => Self::NotFound { detail },
Code::AlreadyExists => Self::AlreadyExists { detail },
Code::DeadlineExceeded => Self::QueryTimeout { detail },
Code::Cancelled => Self::Cancelled { detail },
Code::Unavailable | Code::ResourceExhausted => Self::Unavailable { detail },
Code::Unauthenticated => Self::Unauthenticated { detail },
Code::PermissionDenied => Self::NamespaceDenied { detail },
Code::InvalidArgument => Self::InvalidArgument { detail },
Code::FailedPrecondition => Self::NotRunning { detail },
_ => Self::Server { detail },
}
}
#[must_use]
pub fn from_transport_error(error: &tonic::transport::Error) -> Self {
Self::Unavailable {
detail: ErrorDetail::new(source_chain(error)),
}
}
}
fn source_chain(error: &(dyn std::error::Error + 'static)) -> String {
let mut message = error.to_string();
let mut source = error.source();
while let Some(cause) = source {
message.push_str(": ");
message.push_str(&cause.to_string());
source = cause.source();
}
message
}
fn decode_status_details(status: &tonic::Status) -> Option<ProtoWireError> {
let details = status.details();
if details.is_empty() {
return None;
}
ProtoWireError::decode(details).ok()
}
#[cfg(test)]
mod tests {
use super::{ClientError, ErrorDetail};
fn assert_send_sync_static<T: Send + Sync + 'static>() {}
#[test]
fn client_error_is_send_sync_static() {
assert_send_sync_static::<ClientError>();
}
fn all_variants() -> Vec<ClientError> {
vec![
ClientError::not_found("d"),
ClientError::already_exists("d"),
ClientError::query_failed("d"),
ClientError::query_timeout("d"),
ClientError::unknown_query("d"),
ClientError::not_running("d"),
ClientError::cancelled("d"),
ClientError::unavailable("d"),
ClientError::unauthenticated("d"),
ClientError::namespace_denied("d"),
ClientError::invalid_argument("d"),
ClientError::server("d"),
]
}
#[test]
fn display_is_class_colon_detail_for_every_variant() {
let mut classes = Vec::new();
for error in all_variants() {
assert_eq!(
error.to_string(),
format!("{}: d", error.class()),
"{error:?} Display must be `<class>: <detail>`",
);
assert_eq!(error.detail().message, "d");
classes.push(error.class());
}
let expected = [
"not_found",
"already_exists",
"query_failed",
"query_timeout",
"unknown_query",
"not_running",
"cancelled",
"unavailable",
"unauthenticated",
"namespace_denied",
"invalid_input",
"backend",
];
assert_eq!(classes, expected, "class strings are a pinned contract");
}
#[test]
fn detail_display_appends_the_typed_discriminator() {
assert_eq!(ErrorDetail::new("plain").to_string(), "plain");
assert_eq!(
ErrorDetail::with_type("store unavailable", "Durability").to_string(),
"store unavailable [Durability]"
);
assert_eq!(
ClientError::not_found(ErrorDetail::with_type(
"workflow was not found",
"WorkflowNotFound"
))
.to_string(),
"not_found: workflow was not found [WorkflowNotFound]"
);
}
}