use std::collections::HashMap;
use std::time::Duration;
use axum::body::Body;
use axum::response::IntoResponse;
use http::header::{self};
use hyper::{Response, StatusCode};
use serde::{Deserialize, Serialize, Serializer};
use thiserror::Error;
pub type GenericError = Box<dyn std::error::Error + Send + Sync>;
macro_rules! twirp_error_codes {
(
$(
$(#[$docs:meta])*
($konst:ident, $num:expr, $phrase:ident);
)+
) => {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Deserialize)]
#[serde(field_identifier, rename_all = "snake_case")]
#[non_exhaustive]
pub enum TwirpErrorCode {
$(
$(#[$docs])*
$konst,
)+
}
impl TwirpErrorCode {
pub fn http_status_code(&self) -> StatusCode {
match *self {
$(
TwirpErrorCode::$konst => $num,
)+
}
}
pub fn twirp_code(&self) -> &'static str {
match *self {
$(
TwirpErrorCode::$konst => stringify!($phrase),
)+
}
}
}
impl From<StatusCode> for TwirpErrorCode {
fn from(code: StatusCode) -> Self {
$(
if code == $num {
return TwirpErrorCode::$konst;
}
)+
return TwirpErrorCode::Unknown
}
}
$(
pub fn $phrase<T: ToString>(msg: T) -> TwirpErrorResponse {
TwirpErrorResponse {
code: TwirpErrorCode::$konst,
msg: msg.to_string(),
meta: Default::default(),
rust_error: None,
retry_after: None,
}
}
)+
}
}
twirp_error_codes! {
(Canceled, StatusCode::REQUEST_TIMEOUT, canceled);
(Unknown, StatusCode::INTERNAL_SERVER_ERROR, unknown);
(InvalidArgument, StatusCode::BAD_REQUEST, invalid_argument);
(Malformed, StatusCode::BAD_REQUEST, malformed);
(DeadlineExceeded, StatusCode::REQUEST_TIMEOUT, deadline_exceeded);
(NotFound, StatusCode::NOT_FOUND, not_found);
(BadRoute, StatusCode::NOT_FOUND, bad_route);
(AlreadyExists, StatusCode::CONFLICT, already_exists);
(PermissionDenied, StatusCode::FORBIDDEN, permission_denied);
(Unauthenticated, StatusCode::UNAUTHORIZED, unauthenticated);
(ResourceExhausted, StatusCode::TOO_MANY_REQUESTS, resource_exhausted);
(FailedPrecondition, StatusCode::PRECONDITION_FAILED, failed_precondition);
(Aborted, StatusCode::CONFLICT, aborted);
(OutOfRange, StatusCode::BAD_REQUEST, out_of_range);
(Unimplemented, StatusCode::NOT_IMPLEMENTED, unimplemented);
(Internal, StatusCode::INTERNAL_SERVER_ERROR, internal);
(Unavailable, StatusCode::SERVICE_UNAVAILABLE, unavailable);
(Dataloss, StatusCode::INTERNAL_SERVER_ERROR, dataloss);
}
impl Serialize for TwirpErrorCode {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(self.twirp_code())
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Error)]
pub struct TwirpErrorResponse {
pub code: TwirpErrorCode,
pub msg: String,
#[serde(skip_serializing_if = "HashMap::is_empty")]
#[serde(default)]
pub meta: HashMap<String, String>,
#[serde(skip_serializing)]
retry_after: Option<Duration>,
#[serde(skip_serializing)]
rust_error: Option<String>,
}
impl TwirpErrorResponse {
pub fn new(code: TwirpErrorCode, msg: String) -> Self {
Self {
code,
msg,
meta: HashMap::new(),
rust_error: None,
retry_after: None,
}
}
pub fn http_status_code(&self) -> StatusCode {
self.code.http_status_code()
}
pub fn meta_mut(&mut self) -> &mut HashMap<String, String> {
&mut self.meta
}
pub fn with_meta<S1: ToString, S2: ToString>(mut self, key: S1, value: S2) -> Self {
self.meta.insert(key.to_string(), value.to_string());
self
}
pub fn retry_after(&self) -> Option<Duration> {
self.retry_after
}
pub fn with_generic_error(self, err: GenericError) -> Self {
self.with_rust_error_string(format!("{err:?}"))
}
pub fn with_rust_error<E: std::error::Error>(self, err: E) -> Self {
self.with_rust_error_string(format!("{err:?}"))
}
pub fn with_rust_error_string(mut self, rust_error: String) -> Self {
self.rust_error = Some(rust_error);
self
}
pub fn rust_error(&self) -> Option<&String> {
self.rust_error.as_ref()
}
pub fn with_retry_after(mut self, duration: impl Into<Option<Duration>>) -> Self {
let duration = duration.into();
self.retry_after = duration.map(|d| {
if d.as_secs() < 1 {
Duration::from_secs(1)
} else {
d
}
});
self
}
}
pub fn internal_server_error<E: std::error::Error>(err: E) -> TwirpErrorResponse {
internal("internal server error").with_rust_error(err)
}
impl From<prost::DecodeError> for TwirpErrorResponse {
fn from(e: prost::DecodeError) -> Self {
internal(e.to_string()).with_rust_error(e)
}
}
impl From<serde_json::Error> for TwirpErrorResponse {
fn from(e: serde_json::Error) -> Self {
internal(e.to_string()).with_rust_error(e)
}
}
impl From<reqwest::Error> for TwirpErrorResponse {
fn from(e: reqwest::Error) -> Self {
let msg = e.to_string();
let resp = if e.is_builder() {
invalid_argument(msg)
} else if e.is_redirect() || e.is_body() || e.is_decode() {
internal(msg)
} else {
unavailable(msg)
};
resp.with_rust_error(e)
}
}
impl From<url::ParseError> for TwirpErrorResponse {
fn from(e: url::ParseError) -> Self {
invalid_argument(e.to_string()).with_rust_error(e)
}
}
impl From<header::InvalidHeaderValue> for TwirpErrorResponse {
fn from(e: header::InvalidHeaderValue) -> Self {
invalid_argument(e.to_string())
}
}
impl From<anyhow::Error> for TwirpErrorResponse {
fn from(err: anyhow::Error) -> Self {
internal("internal server error").with_rust_error_string(format!("{err:#}"))
}
}
impl IntoResponse for TwirpErrorResponse {
fn into_response(self) -> Response<Body> {
let mut resp = Response::builder()
.status(self.http_status_code())
.extension(self.clone())
.header(header::CONTENT_TYPE, crate::headers::CONTENT_TYPE_JSON);
if let Some(duration) = self.retry_after {
resp = resp.header(header::RETRY_AFTER, duration.as_secs().to_string());
}
let json = serde_json::to_string(&self)
.expect("json serialization of a TwirpErrorResponse should not fail");
resp.body(Body::new(json))
.expect("failed to build TwirpErrorResponse")
}
}
impl std::fmt::Display for TwirpErrorResponse {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "error {:?}: {}", self.code, self.msg)?;
if !self.meta.is_empty() {
write!(f, " (meta: {{")?;
let mut first = true;
for (k, v) in &self.meta {
if !first {
write!(f, ", ")?;
}
write!(f, "{k:?}: {v:?}")?;
first = false;
}
write!(f, "}})")?;
}
if let Some(ref retry_after) = self.retry_after {
write!(f, " (retry_after: {:?})", retry_after)?;
}
if let Some(ref rust_error) = self.rust_error {
write!(f, " (rust_error: {:?})", rust_error)?;
}
Ok(())
}
}
#[cfg(test)]
mod test {
use std::collections::HashMap;
use crate::{TwirpErrorCode, TwirpErrorResponse};
#[test]
fn twirp_status_mapping() {
assert_code(TwirpErrorCode::Canceled, "canceled", 408);
assert_code(TwirpErrorCode::Unknown, "unknown", 500);
assert_code(TwirpErrorCode::InvalidArgument, "invalid_argument", 400);
assert_code(TwirpErrorCode::Malformed, "malformed", 400);
assert_code(TwirpErrorCode::Unauthenticated, "unauthenticated", 401);
assert_code(TwirpErrorCode::PermissionDenied, "permission_denied", 403);
assert_code(TwirpErrorCode::DeadlineExceeded, "deadline_exceeded", 408);
assert_code(TwirpErrorCode::NotFound, "not_found", 404);
assert_code(TwirpErrorCode::BadRoute, "bad_route", 404);
assert_code(TwirpErrorCode::Unimplemented, "unimplemented", 501);
assert_code(TwirpErrorCode::Internal, "internal", 500);
assert_code(TwirpErrorCode::Unavailable, "unavailable", 503);
}
fn assert_code(code: TwirpErrorCode, msg: &str, http: u16) {
assert_eq!(
code.http_status_code(),
http,
"expected http status code {} but got {}",
http,
code.http_status_code()
);
assert_eq!(
code.twirp_code(),
msg,
"expected error message '{}' but got '{}'",
msg,
code.twirp_code()
);
}
#[test]
fn twirp_error_response_serialization() {
let meta = HashMap::from([
("key1".to_string(), "value1".to_string()),
("key2".to_string(), "value2".to_string()),
]);
let response = TwirpErrorResponse {
code: TwirpErrorCode::DeadlineExceeded,
msg: "test".to_string(),
meta,
rust_error: None,
retry_after: None,
};
let result = serde_json::to_string(&response).unwrap();
assert!(result.contains(r#""code":"deadline_exceeded""#));
assert!(result.contains(r#""msg":"test""#));
assert!(result.contains(r#""key1":"value1""#));
assert!(result.contains(r#""key2":"value2""#));
let result = serde_json::from_str(&result).unwrap();
assert_eq!(response, result);
}
#[tokio::test]
async fn reqwest_timeout_error_maps_to_unavailable() {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
let addr = listener.local_addr().unwrap();
let _accept_thread = std::thread::spawn(move || {
let (_stream, _) = listener.accept().unwrap();
std::thread::sleep(std::time::Duration::from_secs(60));
});
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_millis(1))
.build()
.unwrap();
let err = client
.get(format!("http://{addr}"))
.send()
.await
.unwrap_err();
let twirp_err: TwirpErrorResponse = err.into();
assert_eq!(twirp_err.code, TwirpErrorCode::Unavailable);
}
#[test]
fn reqwest_builder_error_maps_to_invalid_argument() {
let err = reqwest::Client::builder()
.build()
.unwrap()
.get("")
.build()
.unwrap_err();
let twirp_err: TwirpErrorResponse = err.into();
assert_eq!(twirp_err.code, TwirpErrorCode::InvalidArgument);
}
#[test]
fn twirp_error_response_serialization_skips_fields() {
let response = TwirpErrorResponse {
code: TwirpErrorCode::Unauthenticated,
msg: "test".to_string(),
meta: HashMap::new(),
rust_error: Some("not included".to_string()),
retry_after: None,
};
let result = serde_json::to_string(&response).unwrap();
assert!(result.contains(r#""code":"unauthenticated""#));
assert!(result.contains(r#""msg":"test""#));
assert!(!result.contains(r#"rust_error"#));
}
}