use std::error::Error;
use std::fmt;
use std::panic::Location;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use axum::response::Response;
#[cfg(feature = "opentelemetry")]
use opentelemetry::trace::TraceId;
#[cfg(feature = "rorm")]
use rorm::db::transaction::TransactionError;
use schemars::schema::Schema;
use thiserror::Error;
use tracing::debug;
use tracing::error;
use crate::handler::context::EndpointContext;
use crate::handler::response_body::ResponseBody;
use crate::handler::response_body::ShouldBeResponseBody;
use crate::stuff::api_json::ApiJson;
use crate::stuff::schema::ApiErrorResponse;
pub type CoreApiResult<T> = Result<T, CoreApiError>;
#[derive(Debug, Error)]
pub struct CoreApiError {
pub status_code: ApiErrorStatusCode,
pub context: Option<&'static str>,
pub location: &'static Location<'static>,
pub source: Option<Box<dyn Error + Send + Sync + 'static>>,
#[cfg(feature = "opentelemetry")]
pub trace_id: TraceId,
}
#[derive(Debug, Copy, Clone)]
pub enum ApiErrorStatusCode {
BadRequest,
ServerError,
Unauthorized,
}
impl ApiErrorStatusCode {
pub fn to_http(&self) -> StatusCode {
match self {
ApiErrorStatusCode::BadRequest => StatusCode::BAD_REQUEST,
ApiErrorStatusCode::ServerError => StatusCode::INTERNAL_SERVER_ERROR,
ApiErrorStatusCode::Unauthorized => StatusCode::UNAUTHORIZED,
}
}
pub fn all() -> impl Iterator<Item = Self> {
[Self::BadRequest, Self::ServerError, Self::Unauthorized].into_iter()
}
}
impl fmt::Display for CoreApiError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self.status_code {
ApiErrorStatusCode::Unauthorized => write!(f, "Unauthorized")?,
ApiErrorStatusCode::BadRequest => write!(f, "Bad Request")?,
ApiErrorStatusCode::ServerError => write!(f, "Server Error")?,
}
if let Some(context) = self.context {
write!(f, " '{context}'")?;
}
if let Some(source) = &self.source {
write!(f, " cause by '{source}'")?;
}
write!(f, " at '{}'", self.location)
}
}
impl CoreApiError {
#[track_caller]
pub fn bad_request(context: &'static str) -> Self {
Self::new(ApiErrorStatusCode::BadRequest, Some(context))
}
#[track_caller]
pub fn server_error(context: &'static str) -> Self {
Self::new(ApiErrorStatusCode::ServerError, Some(context))
}
#[track_caller]
pub fn unauthorized(context: &'static str) -> Self {
Self::new(ApiErrorStatusCode::Unauthorized, Some(context))
}
pub fn with_source(self, source: impl Error + Send + Sync + 'static) -> Self {
self.with_boxed_source(source.into())
}
pub fn with_boxed_source(mut self, source: Box<dyn Error + Send + Sync + 'static>) -> Self {
self.source = Some(source);
self
}
pub fn with_manual_location(mut self, location: &'static Location<'static>) -> Self {
self.location = location;
self
}
#[track_caller]
pub fn map_server_error<E: Error + Send + Sync + 'static>(
context: &'static str,
) -> impl Fn(E) -> Self {
let location = Location::caller();
move |error| {
Self::server_error(context)
.with_source(error)
.with_manual_location(location)
}
}
pub fn emit_tracing_event(&self) {
let Self {
status_code,
context,
location,
source,
#[cfg(feature = "opentelemetry")]
trace_id: _, } = &self;
match status_code {
ApiErrorStatusCode::Unauthorized | ApiErrorStatusCode::BadRequest => {
debug!(
error.status_code = status_code.to_http().as_u16(),
error.status_message = status_code.to_http().as_str(),
error.context = context,
error.file = location.file(),
error.line = location.line(),
error.column = location.column(),
error.display = source.as_ref().map(tracing::field::display),
error.debug = source.as_ref().map(tracing::field::debug),
"Client error"
);
}
ApiErrorStatusCode::ServerError => {
error!(
error.status_code = status_code.to_http().as_u16(),
error.status_message = status_code.to_http().as_str(),
error.context = context,
error.file = location.file(),
error.line = location.line(),
error.column = location.column(),
error.display = source.as_ref().map(tracing::field::display),
error.debug = source.as_ref().map(tracing::field::debug),
"Server error"
);
}
}
}
#[cfg(feature = "opentelemetry")]
pub fn get_trace_id() -> TraceId {
use opentelemetry::trace::TraceContextExt;
use tracing::Span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
Span::current().context().span().span_context().trace_id()
}
#[track_caller]
fn new(status_code: ApiErrorStatusCode, context: Option<&'static str>) -> Self {
Self {
status_code,
context,
location: Location::caller(),
source: None,
#[cfg(feature = "opentelemetry")]
trace_id: Self::get_trace_id(),
}
}
}
impl IntoResponse for CoreApiError {
fn into_response(self) -> Response {
self.emit_tracing_event();
let response = ApiErrorResponse {
#[cfg(feature = "opentelemetry")]
trace_id: self.trace_id.to_string(),
};
(self.status_code.to_http(), ApiJson(response)).into_response()
}
}
impl ShouldBeResponseBody for CoreApiError {}
impl ResponseBody for CoreApiError {
fn body(ctx: &mut EndpointContext) -> Vec<(StatusCode, Option<(mime::Mime, Option<Schema>)>)> {
let schema = ctx.generator.generate::<ApiErrorResponse>();
ApiErrorStatusCode::all()
.map(|status_code| {
(
status_code.to_http(),
Some((mime::APPLICATION_JSON, Some(schema.clone()))),
)
})
.collect()
}
}
#[cfg(feature = "rorm")]
impl<'rf, E, M>
From<rorm::crud::update::UpdateBuilder<'rf, E, M, rorm::crud::update::columns::Empty>>
for CoreApiError
{
#[track_caller]
fn from(
_value: rorm::crud::update::UpdateBuilder<'rf, E, M, rorm::crud::update::columns::Empty>,
) -> Self {
Self::bad_request("Nothing to update")
}
}
trait IntoServerError: Into<Box<dyn Error + Send + Sync + 'static>> {}
impl<E: IntoServerError> From<E> for CoreApiError {
#[track_caller]
fn from(value: E) -> Self {
Self {
status_code: ApiErrorStatusCode::ServerError,
context: None,
location: Location::caller(),
source: Some(value.into()),
#[cfg(feature = "opentelemetry")]
trace_id: Self::get_trace_id(),
}
}
}
#[cfg(feature = "rorm")]
impl IntoServerError for rorm::Error {}
#[cfg(feature = "rorm")]
impl From<TransactionError> for CoreApiError {
#[track_caller]
fn from(value: TransactionError) -> Self {
Self {
status_code: ApiErrorStatusCode::ServerError,
context: None,
location: Location::caller(),
source: Some(match value {
TransactionError::Database(x) => x.into(),
TransactionError::Hook(x) => x,
}),
#[cfg(feature = "opentelemetry")]
trace_id: Self::get_trace_id(),
}
}
}
#[cfg(feature = "sessions")]
impl IntoServerError for tower_sessions::session::Error {}
impl IntoServerError for anyhow::Error {}