okapi_operation/axum_integration/
handler_traits.rs

1use std::marker::PhantomData;
2
3use axum::{extract::Request, handler::Handler, response::IntoResponse};
4use tower::Service;
5
6use crate::OperationGenerator;
7
8/// Wrapper around [`axum::handler::Handler`] with associated OpenAPI [`OperationGenerator`].
9pub struct HandlerWithOperation<H, T, S>
10where
11    H: Handler<T, S>,
12{
13    pub(super) handler: H,
14    pub(super) operation: Option<OperationGenerator>,
15    _t: PhantomData<T>,
16    _s: PhantomData<S>,
17}
18
19impl<H, T, S> From<H> for HandlerWithOperation<H, T, S>
20where
21    H: Handler<T, S>,
22{
23    fn from(value: H) -> Self {
24        Self {
25            handler: value,
26            operation: None,
27            _t: PhantomData,
28            _s: PhantomData,
29        }
30    }
31}
32
33impl<H, T, S> HandlerWithOperation<H, T, S>
34where
35    H: Handler<T, S>,
36{
37    pub fn new(handler: H, operation: Option<OperationGenerator>) -> Self {
38        Self {
39            handler,
40            operation,
41            _t: PhantomData,
42            _s: PhantomData,
43        }
44    }
45}
46
47/// Trait for converting [`axum::handler::Handler`] into wrapper.
48pub trait HandlerExt<H, T, S>
49where
50    H: Handler<T, S>,
51{
52    fn into_handler_with_operation(self) -> HandlerWithOperation<H, T, S>;
53
54    /// Add OpenAPI operation to handler.
55    fn with_openapi(self, operation: OperationGenerator) -> HandlerWithOperation<H, T, S>
56    where
57        Self: Sized,
58    {
59        let mut h = self.into_handler_with_operation();
60        h.operation = Some(operation);
61        h
62    }
63}
64
65impl<H, T, S> HandlerExt<H, T, S> for H
66where
67    H: Handler<T, S>,
68{
69    fn into_handler_with_operation(self) -> HandlerWithOperation<H, T, S> {
70        HandlerWithOperation::new(self, None)
71    }
72}
73
74impl<H, T, S> HandlerExt<H, T, S> for HandlerWithOperation<H, T, S>
75where
76    H: Handler<T, S>,
77{
78    fn into_handler_with_operation(self) -> HandlerWithOperation<H, T, S> {
79        self
80    }
81}
82
83/// Wrapper around [`Service`] with associated OpenAPI [`OperationGenerator`].
84pub struct ServiceWithOperation<Svc, E>
85where
86    Svc: Service<Request, Error = E> + Clone + Send + 'static,
87    Svc::Response: IntoResponse + 'static,
88    Svc::Future: Send + 'static,
89{
90    pub(crate) service: Svc,
91    pub(crate) operation: Option<OperationGenerator>,
92    _e: PhantomData<E>,
93}
94
95impl<Svc, E> ServiceWithOperation<Svc, E>
96where
97    Svc: Service<Request, Error = E> + Clone + Send + 'static,
98    Svc::Response: IntoResponse + 'static,
99    Svc::Future: Send + 'static,
100{
101    pub(crate) fn new(service: Svc, operation: Option<OperationGenerator>) -> Self {
102        Self {
103            service,
104            operation,
105            _e: PhantomData,
106        }
107    }
108}
109
110impl<Svc, E> From<Svc> for ServiceWithOperation<Svc, E>
111where
112    Svc: Service<Request, Error = E> + Clone + Send + 'static,
113    Svc::Response: IntoResponse + 'static,
114    Svc::Future: Send + 'static,
115{
116    fn from(value: Svc) -> Self {
117        Self::new(value, None)
118    }
119}
120
121/// Trait for converting [`Service`] into wrapper.
122pub trait ServiceExt<Svc, E>
123where
124    Svc: Service<Request, Error = E> + Clone + Send + 'static,
125    Svc::Response: IntoResponse + 'static,
126    Svc::Future: Send + 'static,
127{
128    fn into_service_with_operation(self) -> ServiceWithOperation<Svc, E>
129where;
130
131    /// Add OpenAPI operation to service.
132    fn with_openapi(self, operation: OperationGenerator) -> ServiceWithOperation<Svc, E>
133    where
134        Self: Sized,
135    {
136        let mut h = self.into_service_with_operation();
137        h.operation = Some(operation);
138        h
139    }
140}
141
142impl<Svc, E> ServiceExt<Svc, E> for Svc
143where
144    Svc: Service<Request, Error = E> + Clone + Send + 'static,
145    Svc::Response: IntoResponse + 'static,
146    Svc::Future: Send + 'static,
147{
148    fn into_service_with_operation(self) -> ServiceWithOperation<Svc, E> {
149        ServiceWithOperation::new(self, None)
150    }
151}
152
153impl<Svc, E> ServiceExt<Svc, E> for ServiceWithOperation<Svc, E>
154where
155    Svc: Service<Request, Error = E> + Clone + Send + 'static,
156    Svc::Response: IntoResponse + 'static,
157    Svc::Future: Send + 'static,
158{
159    fn into_service_with_operation(self) -> ServiceWithOperation<Svc, E> {
160        self
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    #![allow(clippy::let_underscore_future)]
167
168    use std::convert::Infallible;
169
170    use axum::{
171        body::Body, extract::Request, http::Method, response::Response, routing::MethodFilter,
172    };
173    use okapi::openapi3::Operation;
174    use tokio::net::TcpListener;
175    use tower::service_fn;
176
177    use super::*;
178    use crate::{
179        Components,
180        axum_integration::{MethodRouter, Router},
181    };
182
183    fn openapi_generator(_: &mut Components) -> Result<Operation, anyhow::Error> {
184        unimplemented!()
185    }
186
187    #[test]
188    fn handler_with_operation() {
189        async fn handler() {}
190
191        let mr: MethodRouter = MethodRouter::new()
192            .on(
193                MethodFilter::GET,
194                (|| async {}).with_openapi(openapi_generator),
195            )
196            .on(
197                MethodFilter::POST,
198                handler
199                    .with_openapi(openapi_generator)
200                    .with_openapi(openapi_generator),
201            )
202            .on(MethodFilter::PUT, handler)
203            .on(MethodFilter::DELETE, || async {});
204        let (app, ops) = Router::new().route("/", mr).into_parts();
205        assert!(ops.get("/", &Method::GET).is_some());
206        assert!(ops.get("/", &Method::POST).is_some());
207
208        let make_service = app.into_make_service();
209        let _ = async move {
210            let listener = TcpListener::bind("").await.unwrap();
211            axum::serve(listener, make_service).await.unwrap()
212        };
213    }
214
215    #[test]
216    fn service_with_operation() {
217        async fn service(_request: Request) -> Result<Response<Body>, Infallible> {
218            Ok::<_, Infallible>(Response::new(Body::empty()))
219        }
220
221        let service2 = service_fn(|_request: Request| async {
222            Ok::<_, Infallible>(Response::new(Body::empty()))
223        });
224
225        let mr: MethodRouter = MethodRouter::new()
226            .on_service(
227                MethodFilter::GET,
228                service_fn(service).with_openapi(openapi_generator),
229            )
230            .on_service(
231                MethodFilter::POST,
232                service2
233                    .with_openapi(openapi_generator)
234                    .with_openapi(openapi_generator),
235            )
236            .on_service(MethodFilter::PUT, service_fn(service))
237            .on_service(MethodFilter::DELETE, service2);
238        let (app, ops) = Router::new().route("/", mr).into_parts();
239        assert!(ops.get("/", &Method::GET).is_some());
240        assert!(ops.get("/", &Method::POST).is_some());
241
242        let make_service = app.into_make_service();
243        let _ = async move {
244            let listener = TcpListener::bind("").await.unwrap();
245            axum::serve(listener, make_service).await.unwrap()
246        };
247    }
248}