use std::marker::PhantomData;
use axum_core::response::IntoResponse;
use bytes::Bytes;
use futures_core::future::BoxFuture;
use http_body::Body;
use crate::{extension::Extension, Marker, State};
pub struct Layer<DB: Marker, E> {
state: State<DB>,
_error: PhantomData<E>,
}
impl<DB: Marker, E> Layer<DB, E>
where
E: IntoResponse,
sqlx::Error: Into<E>,
{
pub(crate) fn new(state: State<DB>) -> Self {
Self {
state,
_error: PhantomData,
}
}
}
impl<DB: Marker, E> Clone for Layer<DB, E> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
_error: self._error,
}
}
}
impl<DB: Marker, S, E> tower_layer::Layer<S> for Layer<DB, E>
where
E: IntoResponse,
sqlx::Error: Into<E>,
{
type Service = Service<DB, S, E>;
fn layer(&self, inner: S) -> Self::Service {
Service {
state: self.state.clone(),
inner,
_error: self._error,
}
}
}
pub struct Service<DB: Marker, S, E> {
state: State<DB>,
inner: S,
_error: PhantomData<E>,
}
impl<DB: Marker, S: Clone, E> Clone for Service<DB, S, E> {
fn clone(&self) -> Self {
Self {
state: self.state.clone(),
inner: self.inner.clone(),
_error: self._error,
}
}
}
impl<DB: Marker, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
for Service<DB, S, E>
where
S: tower_service::Service<
http::Request<ReqBody>,
Response = http::Response<ResBody>,
Error = std::convert::Infallible,
>,
S::Future: Send + 'static,
E: IntoResponse,
sqlx::Error: Into<E>,
ResBody: Body<Data = Bytes> + Send + 'static,
ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
{
type Response = http::Response<axum_core::body::Body>;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx).map_err(|err| match err {})
}
fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
let ext = Extension::new(self.state.clone());
req.extensions_mut().insert(ext.clone());
let res = self.inner.call(req);
Box::pin(async move {
let res = res.await.unwrap();
if !res.status().is_server_error() && !res.status().is_client_error() {
if let Err(error) = ext.resolve().await {
return Ok(error.into().into_response());
}
}
Ok(res.map(axum_core::body::Body::new))
})
}
}
#[cfg(test)]
mod tests {
use tokio::net::TcpListener;
use crate::{Error, State};
use super::Layer;
#[allow(unused, unreachable_code, clippy::diverging_sub_expression)]
fn layer_compiles() {
let state: State<sqlx::Sqlite> = todo!();
let layer = Layer::<_, Error>::new(state);
let app = axum::Router::new()
.route("/", axum::routing::get(|| async { "hello" }))
.layer(layer);
let listener: TcpListener = todo!();
axum::serve(listener, app);
}
}