okapi_operation/axum_integration/
handler_traits.rs1use std::marker::PhantomData;
2
3use axum::{extract::Request, handler::Handler, response::IntoResponse};
4use tower::Service;
5
6use crate::OperationGenerator;
7
8pub struct HandlerWithOperation<H, T, S>
10where
11 H: Handler<T, S>,
12{
13 pub(super) handler: H,
14 pub(super) operation: Option<OperationGenerator>,
15 _t: PhantomData<T>,
16 _s: PhantomData<S>,
17}
18
19impl<H, T, S> From<H> for HandlerWithOperation<H, T, S>
20where
21 H: Handler<T, S>,
22{
23 fn from(value: H) -> Self {
24 Self {
25 handler: value,
26 operation: None,
27 _t: PhantomData,
28 _s: PhantomData,
29 }
30 }
31}
32
33impl<H, T, S> HandlerWithOperation<H, T, S>
34where
35 H: Handler<T, S>,
36{
37 pub fn new(handler: H, operation: Option<OperationGenerator>) -> Self {
38 Self {
39 handler,
40 operation,
41 _t: PhantomData,
42 _s: PhantomData,
43 }
44 }
45}
46
47pub trait HandlerExt<H, T, S>
49where
50 H: Handler<T, S>,
51{
52 fn into_handler_with_operation(self) -> HandlerWithOperation<H, T, S>;
53
54 fn with_openapi(self, operation: OperationGenerator) -> HandlerWithOperation<H, T, S>
56 where
57 Self: Sized,
58 {
59 let mut h = self.into_handler_with_operation();
60 h.operation = Some(operation);
61 h
62 }
63}
64
65impl<H, T, S> HandlerExt<H, T, S> for H
66where
67 H: Handler<T, S>,
68{
69 fn into_handler_with_operation(self) -> HandlerWithOperation<H, T, S> {
70 HandlerWithOperation::new(self, None)
71 }
72}
73
74impl<H, T, S> HandlerExt<H, T, S> for HandlerWithOperation<H, T, S>
75where
76 H: Handler<T, S>,
77{
78 fn into_handler_with_operation(self) -> HandlerWithOperation<H, T, S> {
79 self
80 }
81}
82
83pub struct ServiceWithOperation<Svc, E>
85where
86 Svc: Service<Request, Error = E> + Clone + Send + 'static,
87 Svc::Response: IntoResponse + 'static,
88 Svc::Future: Send + 'static,
89{
90 pub(crate) service: Svc,
91 pub(crate) operation: Option<OperationGenerator>,
92 _e: PhantomData<E>,
93}
94
95impl<Svc, E> ServiceWithOperation<Svc, E>
96where
97 Svc: Service<Request, Error = E> + Clone + Send + 'static,
98 Svc::Response: IntoResponse + 'static,
99 Svc::Future: Send + 'static,
100{
101 pub(crate) fn new(service: Svc, operation: Option<OperationGenerator>) -> Self {
102 Self {
103 service,
104 operation,
105 _e: PhantomData,
106 }
107 }
108}
109
110impl<Svc, E> From<Svc> for ServiceWithOperation<Svc, E>
111where
112 Svc: Service<Request, Error = E> + Clone + Send + 'static,
113 Svc::Response: IntoResponse + 'static,
114 Svc::Future: Send + 'static,
115{
116 fn from(value: Svc) -> Self {
117 Self::new(value, None)
118 }
119}
120
121pub trait ServiceExt<Svc, E>
123where
124 Svc: Service<Request, Error = E> + Clone + Send + 'static,
125 Svc::Response: IntoResponse + 'static,
126 Svc::Future: Send + 'static,
127{
128 fn into_service_with_operation(self) -> ServiceWithOperation<Svc, E>
129where;
130
131 fn with_openapi(self, operation: OperationGenerator) -> ServiceWithOperation<Svc, E>
133 where
134 Self: Sized,
135 {
136 let mut h = self.into_service_with_operation();
137 h.operation = Some(operation);
138 h
139 }
140}
141
142impl<Svc, E> ServiceExt<Svc, E> for Svc
143where
144 Svc: Service<Request, Error = E> + Clone + Send + 'static,
145 Svc::Response: IntoResponse + 'static,
146 Svc::Future: Send + 'static,
147{
148 fn into_service_with_operation(self) -> ServiceWithOperation<Svc, E> {
149 ServiceWithOperation::new(self, None)
150 }
151}
152
153impl<Svc, E> ServiceExt<Svc, E> for ServiceWithOperation<Svc, E>
154where
155 Svc: Service<Request, Error = E> + Clone + Send + 'static,
156 Svc::Response: IntoResponse + 'static,
157 Svc::Future: Send + 'static,
158{
159 fn into_service_with_operation(self) -> ServiceWithOperation<Svc, E> {
160 self
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 #![allow(clippy::let_underscore_future)]
167
168 use std::convert::Infallible;
169
170 use axum::{
171 body::Body, extract::Request, http::Method, response::Response, routing::MethodFilter,
172 };
173 use okapi::openapi3::Operation;
174 use tokio::net::TcpListener;
175 use tower::service_fn;
176
177 use super::*;
178 use crate::{
179 Components,
180 axum_integration::{MethodRouter, Router},
181 };
182
183 fn openapi_generator(_: &mut Components) -> Result<Operation, anyhow::Error> {
184 unimplemented!()
185 }
186
187 #[test]
188 fn handler_with_operation() {
189 async fn handler() {}
190
191 let mr: MethodRouter = MethodRouter::new()
192 .on(
193 MethodFilter::GET,
194 (|| async {}).with_openapi(openapi_generator),
195 )
196 .on(
197 MethodFilter::POST,
198 handler
199 .with_openapi(openapi_generator)
200 .with_openapi(openapi_generator),
201 )
202 .on(MethodFilter::PUT, handler)
203 .on(MethodFilter::DELETE, || async {});
204 let (app, ops) = Router::new().route("/", mr).into_parts();
205 assert!(ops.get("/", &Method::GET).is_some());
206 assert!(ops.get("/", &Method::POST).is_some());
207
208 let make_service = app.into_make_service();
209 let _ = async move {
210 let listener = TcpListener::bind("").await.unwrap();
211 axum::serve(listener, make_service).await.unwrap()
212 };
213 }
214
215 #[test]
216 fn service_with_operation() {
217 async fn service(_request: Request) -> Result<Response<Body>, Infallible> {
218 Ok::<_, Infallible>(Response::new(Body::empty()))
219 }
220
221 let service2 = service_fn(|_request: Request| async {
222 Ok::<_, Infallible>(Response::new(Body::empty()))
223 });
224
225 let mr: MethodRouter = MethodRouter::new()
226 .on_service(
227 MethodFilter::GET,
228 service_fn(service).with_openapi(openapi_generator),
229 )
230 .on_service(
231 MethodFilter::POST,
232 service2
233 .with_openapi(openapi_generator)
234 .with_openapi(openapi_generator),
235 )
236 .on_service(MethodFilter::PUT, service_fn(service))
237 .on_service(MethodFilter::DELETE, service2);
238 let (app, ops) = Router::new().route("/", mr).into_parts();
239 assert!(ops.get("/", &Method::GET).is_some());
240 assert!(ops.get("/", &Method::POST).is_some());
241
242 let make_service = app.into_make_service();
243 let _ = async move {
244 let listener = TcpListener::bind("").await.unwrap();
245 axum::serve(listener, make_service).await.unwrap()
246 };
247 }
248}