use std::{
error::Error as StdError,
future::Future,
marker::PhantomData,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use thiserror::Error;
use tower::Service;
use crate::worker::WorkerError;
pub type BoxDynError = Box<dyn StdError + 'static + Send + Sync>;
#[derive(Error, Debug, Clone)]
#[non_exhaustive]
pub enum Error {
#[error("FailedError: {0}")]
Failed(#[source] Arc<BoxDynError>),
#[error("AbortError: {0}")]
Abort(#[source] Arc<BoxDynError>),
#[doc(hidden)]
#[error("WorkerError: {0}")]
WorkerError(WorkerError),
#[error("MissingDataError: {0}")]
MissingData(String),
#[doc(hidden)]
#[error("Encountered an error during service execution")]
ServiceError(#[source] Arc<BoxDynError>),
#[doc(hidden)]
#[error("Encountered an error during streaming")]
SourceError(#[source] Arc<BoxDynError>),
}
impl From<BoxDynError> for Error {
fn from(err: BoxDynError) -> Self {
if let Some(e) = err.downcast_ref::<Error>() {
e.clone()
} else {
Error::Failed(Arc::new(err))
}
}
}
#[derive(Clone, Debug)]
pub struct ErrorHandlingLayer {
_p: PhantomData<()>,
}
impl ErrorHandlingLayer {
pub fn new() -> Self {
Self { _p: PhantomData }
}
}
impl Default for ErrorHandlingLayer {
fn default() -> Self {
Self::new()
}
}
impl<S> tower::layer::Layer<S> for ErrorHandlingLayer {
type Service = ErrorHandlingService<S>;
fn layer(&self, service: S) -> Self::Service {
ErrorHandlingService { service }
}
}
#[derive(Clone, Debug)]
pub struct ErrorHandlingService<S> {
service: S,
}
impl<S, Request> Service<Request> for ErrorHandlingService<S>
where
S: Service<Request>,
S::Error: Into<BoxDynError>,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.service.poll_ready(cx).map_err(|e| {
let boxed_error: BoxDynError = e.into();
boxed_error.into()
})
}
fn call(&mut self, req: Request) -> Self::Future {
let fut = self.service.call(req);
Box::pin(async move {
fut.await.map_err(|e| {
let boxed_error: BoxDynError = e.into();
boxed_error.into()
})
})
}
}