1use std::marker::PhantomData;
4
5use axum_core::response::IntoResponse;
6use bytes::Bytes;
7use futures_core::future::BoxFuture;
8use http_body::Body;
9
10use crate::{extension::Extension, Marker, State};
11
12pub struct Layer<DB: Marker, E> {
24 state: State<DB>,
25 _error: PhantomData<E>,
26}
27
28impl<DB: Marker, E> Layer<DB, E>
29where
30 E: IntoResponse,
31 sqlx::Error: Into<E>,
32{
33 pub(crate) fn new(state: State<DB>) -> Self {
34 Self {
35 state,
36 _error: PhantomData,
37 }
38 }
39}
40
41impl<DB: Marker, E> Clone for Layer<DB, E> {
42 fn clone(&self) -> Self {
43 Self {
44 state: self.state.clone(),
45 _error: self._error,
46 }
47 }
48}
49
50impl<DB: Marker, S, E> tower_layer::Layer<S> for Layer<DB, E>
51where
52 E: IntoResponse,
53 sqlx::Error: Into<E>,
54{
55 type Service = Service<DB, S, E>;
56
57 fn layer(&self, inner: S) -> Self::Service {
58 Service {
59 state: self.state.clone(),
60 inner,
61 _error: self._error,
62 }
63 }
64}
65
66pub struct Service<DB: Marker, S, E> {
70 state: State<DB>,
71 inner: S,
72 _error: PhantomData<E>,
73}
74
75impl<DB: Marker, S: Clone, E> Clone for Service<DB, S, E> {
77 fn clone(&self) -> Self {
78 Self {
79 state: self.state.clone(),
80 inner: self.inner.clone(),
81 _error: self._error,
82 }
83 }
84}
85
86impl<DB: Marker, S, E, ReqBody, ResBody> tower_service::Service<http::Request<ReqBody>>
87 for Service<DB, S, E>
88where
89 S: tower_service::Service<
90 http::Request<ReqBody>,
91 Response = http::Response<ResBody>,
92 Error = std::convert::Infallible,
93 >,
94 S::Future: Send + 'static,
95 E: IntoResponse,
96 sqlx::Error: Into<E>,
97 ResBody: Body<Data = Bytes> + Send + 'static,
98 ResBody::Error: Into<Box<dyn std::error::Error + Send + Sync + 'static>>,
99{
100 type Response = http::Response<axum_core::body::Body>;
101 type Error = S::Error;
102 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
103
104 fn poll_ready(
105 &mut self,
106 cx: &mut std::task::Context<'_>,
107 ) -> std::task::Poll<Result<(), Self::Error>> {
108 self.inner.poll_ready(cx).map_err(|err| match err {})
109 }
110
111 fn call(&mut self, mut req: http::Request<ReqBody>) -> Self::Future {
112 let ext = Extension::new(self.state.clone());
113 req.extensions_mut().insert(ext.clone());
114
115 let res = self.inner.call(req);
116
117 Box::pin(async move {
118 let res = res.await.unwrap(); if !res.status().is_server_error() && !res.status().is_client_error() {
121 if let Err(error) = ext.resolve().await {
122 return Ok(error.into().into_response());
123 }
124 }
125
126 Ok(res.map(axum_core::body::Body::new))
127 })
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use tokio::net::TcpListener;
134
135 use crate::{Error, State};
136
137 use super::Layer;
138
139 #[allow(unused, unreachable_code, clippy::diverging_sub_expression)]
142 fn layer_compiles() {
143 let state: State<sqlx::Sqlite> = todo!();
144
145 let layer = Layer::<_, Error>::new(state);
146
147 let app = axum::Router::new()
148 .route("/", axum::routing::get(|| async { "hello" }))
149 .layer(layer);
150
151 let listener: TcpListener = todo!();
152 axum::serve(listener, app);
153 }
154}