use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize, Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[serde(rename_all = "snake_case")]
pub enum WireErrorCode {
NotFound,
NamespaceDenied,
SequenceConflict,
UnknownQuery,
QueryTimeout,
NotRunning,
Lagged,
InvalidInput,
Backend,
QueryFailed,
}
impl WireErrorCode {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::NotFound => "not_found",
Self::NamespaceDenied => "namespace_denied",
Self::SequenceConflict => "sequence_conflict",
Self::UnknownQuery => "unknown_query",
Self::QueryTimeout => "query_timeout",
Self::NotRunning => "not_running",
Self::Lagged => "lagged",
Self::InvalidInput => "invalid_input",
Self::Backend => "backend",
Self::QueryFailed => "query_failed",
}
}
}
impl fmt::Display for WireErrorCode {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str(self.as_str())
}
}
#[derive(thiserror::Error, Serialize, Deserialize, Clone, Debug, PartialEq, Eq)]
#[error("{code}: {message}")]
pub struct WireError {
pub code: WireErrorCode,
pub message: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub error_type: Option<String>,
}
impl WireError {
#[must_use]
pub fn new(code: WireErrorCode, message: impl Into<String>) -> Self {
Self {
code,
message: message.into(),
error_type: None,
}
}
#[must_use]
pub fn with_error_type(mut self, error_type: impl Into<String>) -> Self {
self.error_type = Some(error_type.into());
self
}
#[must_use]
pub fn with_optional_error_type(mut self, error_type: Option<String>) -> Self {
self.error_type = error_type;
self
}
#[must_use]
pub fn new_with_type(
code: WireErrorCode,
error_type: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self::new(code, message).with_error_type(error_type)
}
#[must_use]
pub fn not_found(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::NotFound, message)
}
#[must_use]
pub fn namespace_denied(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::NamespaceDenied, message)
}
#[must_use]
pub fn sequence_conflict(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::SequenceConflict, message)
}
#[must_use]
pub fn unknown_query(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::UnknownQuery, message)
}
#[must_use]
pub fn query_timeout(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::QueryTimeout, message)
}
#[must_use]
pub fn not_running(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::NotRunning, message)
}
#[must_use]
pub fn lagged(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::Lagged, message)
}
#[must_use]
pub fn invalid_input(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::InvalidInput, message)
}
#[must_use]
pub fn backend(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::Backend, message)
}
#[must_use]
pub fn query_failed(message: impl Into<String>) -> Self {
Self::new(WireErrorCode::QueryFailed, message)
}
#[must_use]
pub fn not_found_with_type(error_type: impl Into<String>, message: impl Into<String>) -> Self {
Self::new_with_type(WireErrorCode::NotFound, error_type, message)
}
#[must_use]
pub fn not_running_with_type(
error_type: impl Into<String>,
message: impl Into<String>,
) -> Self {
Self::new_with_type(WireErrorCode::NotRunning, error_type, message)
}
#[must_use]
pub fn backend_with_type(error_type: impl Into<String>, message: impl Into<String>) -> Self {
Self::new_with_type(WireErrorCode::Backend, error_type, message)
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize, prost::Enumeration)]
#[repr(i32)]
pub enum ProtoWireErrorCode {
Unspecified = 0,
NotFound = 1,
NamespaceDenied = 2,
SequenceConflict = 3,
UnknownQuery = 4,
QueryTimeout = 5,
NotRunning = 6,
Lagged = 7,
InvalidInput = 8,
Backend = 9,
QueryFailed = 10,
}
#[derive(Clone, PartialEq, Eq, Serialize, Deserialize, prost::Message)]
pub struct ProtoWireError {
#[prost(enumeration = "ProtoWireErrorCode", tag = "1")]
pub code: i32,
#[prost(string, tag = "2")]
pub message: String,
#[prost(string, optional, tag = "3")]
pub error_type: Option<String>,
}
impl From<WireErrorCode> for ProtoWireErrorCode {
fn from(value: WireErrorCode) -> Self {
match value {
WireErrorCode::NotFound => Self::NotFound,
WireErrorCode::NamespaceDenied => Self::NamespaceDenied,
WireErrorCode::SequenceConflict => Self::SequenceConflict,
WireErrorCode::UnknownQuery => Self::UnknownQuery,
WireErrorCode::QueryTimeout => Self::QueryTimeout,
WireErrorCode::NotRunning => Self::NotRunning,
WireErrorCode::Lagged => Self::Lagged,
WireErrorCode::InvalidInput => Self::InvalidInput,
WireErrorCode::Backend => Self::Backend,
WireErrorCode::QueryFailed => Self::QueryFailed,
}
}
}
impl TryFrom<ProtoWireErrorCode> for WireErrorCode {
type Error = WireError;
fn try_from(value: ProtoWireErrorCode) -> Result<Self, Self::Error> {
match value {
ProtoWireErrorCode::Unspecified => {
Err(WireError::backend("wire error code is missing"))
}
ProtoWireErrorCode::NotFound => Ok(Self::NotFound),
ProtoWireErrorCode::NamespaceDenied => Ok(Self::NamespaceDenied),
ProtoWireErrorCode::SequenceConflict => Ok(Self::SequenceConflict),
ProtoWireErrorCode::UnknownQuery => Ok(Self::UnknownQuery),
ProtoWireErrorCode::QueryTimeout => Ok(Self::QueryTimeout),
ProtoWireErrorCode::NotRunning => Ok(Self::NotRunning),
ProtoWireErrorCode::Lagged => Ok(Self::Lagged),
ProtoWireErrorCode::InvalidInput => Ok(Self::InvalidInput),
ProtoWireErrorCode::Backend => Ok(Self::Backend),
ProtoWireErrorCode::QueryFailed => Ok(Self::QueryFailed),
}
}
}
impl From<WireError> for ProtoWireError {
fn from(value: WireError) -> Self {
let code = ProtoWireErrorCode::from(value.code) as i32;
Self {
code,
message: value.message,
error_type: value.error_type,
}
}
}
impl TryFrom<ProtoWireError> for WireError {
type Error = WireError;
fn try_from(value: ProtoWireError) -> Result<Self, Self::Error> {
let code = ProtoWireErrorCode::try_from(value.code)
.map_err(|_| WireError::backend("wire error code is unknown"))?;
Ok(Self::new(WireErrorCode::try_from(code)?, value.message)
.with_optional_error_type(value.error_type))
}
}
#[cfg(test)]
mod tests {
use super::{ProtoWireError, ProtoWireErrorCode, WireError, WireErrorCode};
fn assert_send_sync<T: Send + Sync>() {}
const fn next_code(code: WireErrorCode) -> Option<WireErrorCode> {
match code {
WireErrorCode::NotFound => Some(WireErrorCode::NamespaceDenied),
WireErrorCode::NamespaceDenied => Some(WireErrorCode::SequenceConflict),
WireErrorCode::SequenceConflict => Some(WireErrorCode::UnknownQuery),
WireErrorCode::UnknownQuery => Some(WireErrorCode::QueryTimeout),
WireErrorCode::QueryTimeout => Some(WireErrorCode::NotRunning),
WireErrorCode::NotRunning => Some(WireErrorCode::Lagged),
WireErrorCode::Lagged => Some(WireErrorCode::InvalidInput),
WireErrorCode::InvalidInput => Some(WireErrorCode::Backend),
WireErrorCode::Backend => Some(WireErrorCode::QueryFailed),
WireErrorCode::QueryFailed => None,
}
}
fn all_codes() -> Vec<WireErrorCode> {
let mut codes = vec![WireErrorCode::NotFound];
while let Some(&last) = codes.last() {
match next_code(last) {
Some(next) => codes.push(next),
None => break,
}
}
codes
}
#[test]
fn wire_error_is_send_sync() {
assert_send_sync::<WireError>();
}
#[test]
fn proto_numeric_values_are_pinned() {
let expected: &[(WireErrorCode, i32)] = &[
(WireErrorCode::NotFound, 1),
(WireErrorCode::NamespaceDenied, 2),
(WireErrorCode::SequenceConflict, 3),
(WireErrorCode::UnknownQuery, 4),
(WireErrorCode::QueryTimeout, 5),
(WireErrorCode::NotRunning, 6),
(WireErrorCode::Lagged, 7),
(WireErrorCode::InvalidInput, 8),
(WireErrorCode::Backend, 9),
(WireErrorCode::QueryFailed, 10),
];
assert_eq!(
expected.len(),
all_codes().len(),
"every WireErrorCode variant must have a pinned numeric value"
);
for &(code, number) in expected {
assert_eq!(
ProtoWireErrorCode::from(code) as i32,
number,
"{code:?} must keep proto enum value {number}",
);
}
}
#[test]
fn string_codes_are_pinned() {
let expected: &[(WireErrorCode, &str)] = &[
(WireErrorCode::NotFound, "not_found"),
(WireErrorCode::NamespaceDenied, "namespace_denied"),
(WireErrorCode::SequenceConflict, "sequence_conflict"),
(WireErrorCode::UnknownQuery, "unknown_query"),
(WireErrorCode::QueryTimeout, "query_timeout"),
(WireErrorCode::NotRunning, "not_running"),
(WireErrorCode::Lagged, "lagged"),
(WireErrorCode::InvalidInput, "invalid_input"),
(WireErrorCode::Backend, "backend"),
(WireErrorCode::QueryFailed, "query_failed"),
];
assert_eq!(
expected.len(),
all_codes().len(),
"every WireErrorCode variant must have a pinned string code"
);
for &(code, string) in expected {
assert_eq!(code.as_str(), string, "{code:?} must keep code {string}");
}
}
#[test]
fn json_codes_match_as_str_and_round_trip() -> Result<(), serde_json::Error> {
for code in all_codes() {
let serialized = serde_json::to_value(code)?;
assert_eq!(
serialized,
serde_json::Value::String(code.as_str().to_owned()),
"JSON serialization of {code:?} must equal as_str()",
);
let deserialized: WireErrorCode =
serde_json::from_value(serde_json::Value::String(code.as_str().to_owned()))?;
assert_eq!(deserialized, code, "{code:?} must round-trip through JSON");
let error = WireError::new(code, format!("message for {}", code.as_str()));
let body = serde_json::to_value(&error)?;
assert_eq!(
body.get("code"),
Some(&serde_json::Value::String(code.as_str().to_owned())),
"WireError JSON body must carry the snake_case code for {code:?}",
);
let decoded: WireError = serde_json::from_value(body)?;
assert_eq!(decoded, error);
}
Ok(())
}
#[test]
fn proto_round_trips_every_code() -> Result<(), WireError> {
for code in all_codes() {
let error = WireError::new_with_type(
code,
format!("{}Variant", code.as_str()),
format!("message for {}", code.as_str()),
);
let proto = ProtoWireError::from(error.clone());
let decoded = WireError::try_from(proto)?;
assert_eq!(decoded, error);
}
Ok(())
}
#[test]
fn rejects_unspecified_proto_code() {
let proto = ProtoWireError {
code: 0,
message: String::from("missing"),
error_type: None,
};
let result = WireError::try_from(proto);
assert_eq!(
result,
Err(WireError::backend("wire error code is missing"))
);
}
#[test]
fn representative_documented_mappings_use_stable_codes() {
let engine_unknown_workflow = WireError::not_found("workflow was not found");
let store_sequence_conflict = WireError::sequence_conflict("event sequence conflicted");
assert_eq!(engine_unknown_workflow.code, WireErrorCode::NotFound);
assert_eq!(
store_sequence_conflict.code,
WireErrorCode::SequenceConflict
);
assert_eq!(
WireError::namespace_denied("denied").code,
WireErrorCode::NamespaceDenied
);
assert_eq!(
WireError::query_timeout("timeout").code,
WireErrorCode::QueryTimeout
);
assert_eq!(
WireError::unknown_query("unknown").code,
WireErrorCode::UnknownQuery
);
assert_eq!(
WireError::not_running("terminal").code,
WireErrorCode::NotRunning
);
assert_eq!(
WireError::invalid_input("malformed").code,
WireErrorCode::InvalidInput
);
assert_eq!(
WireError::query_failed("handler raised").code,
WireErrorCode::QueryFailed
);
}
}