1use std::{collections::HashMap, convert::Infallible, fmt};
2
3use axum::{
4 error_handling::HandleError,
5 extract::Request,
6 handler::Handler,
7 http::Method,
8 response::IntoResponse,
9 routing::{MethodFilter, MethodRouter as AxumMethodRouter, Route},
10};
11use tower::{Layer, Service};
12
13use super::handler_traits::{HandlerWithOperation, ServiceWithOperation};
14use crate::OperationGenerator;
15
16macro_rules! top_level_service_fn {
17 (
18 $(#[$m:meta])*
19 $name:ident, $method:ident
20 ) => {
21 $(#[$m])*
22 pub fn $name<I, Svc, S, E>(svc: I) -> MethodRouter<S, E>
23 where
24 I: Into<ServiceWithOperation<Svc, E>>,
25 Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
26 Svc::Response: IntoResponse + 'static,
27 Svc::Future: Send + 'static,
28 S: Clone,
29 {
30 on_service(MethodFilter::$method, svc)
31 }
32 };
33}
34
35macro_rules! top_level_handler_fn {
36 (
37 $(#[$m:meta])*
38 $name:ident, $method:ident
39 ) => {
40 $(#[$m])*
41 pub fn $name<I, H, T, S>(handler: I) -> MethodRouter<S, Infallible>
42 where
43 I: Into<HandlerWithOperation<H, T, S>>,
44 H: Handler<T, S>,
45 T: 'static,
46 S: Clone + Send + Sync + 'static,
47 {
48 on(MethodFilter::$method, handler)
49 }
50 };
51}
52
53macro_rules! chained_service_fn {
55 (
56 $(#[$m:meta])*
57 $name:ident, $method:ident
58 ) => {
59 $(#[$m])*
60 pub fn $name<I, Svc>(self, svc: I) -> Self
61 where
62 I: Into<ServiceWithOperation<Svc, E>>,
63 Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
64 Svc::Response: IntoResponse + 'static,
65 Svc::Future: Send + 'static,
66 {
67 self.on_service(MethodFilter::$method, svc)
68 }
69 };
70}
71
72macro_rules! chained_handler_fn {
74 (
75 $(#[$m:meta])*
76 $name:ident, $method:ident
77 ) => {
78 $(#[$m])*
79 pub fn $name<I, H, T>(self, handler: I) -> Self
80 where
81 I: Into<HandlerWithOperation<H, T, S>>,
82 H: Handler<T, S>,
83 T: 'static,
84 S: Send + Sync + 'static
85 {
86 self.on(MethodFilter::$method, handler)
87 }
88 };
89}
90
91pub fn on_service<I, Svc, S, E>(filter: MethodFilter, svc: I) -> MethodRouter<S, E>
93where
94 I: Into<ServiceWithOperation<Svc, E>>,
95 Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
96 Svc::Response: IntoResponse + 'static,
97 Svc::Future: Send + 'static,
98 S: Clone,
99{
100 MethodRouter::new().on_service(filter, svc)
101}
102
103top_level_service_fn!(delete_service, DELETE);
104top_level_service_fn!(get_service, GET);
105top_level_service_fn!(head_service, HEAD);
106top_level_service_fn!(options_service, OPTIONS);
107top_level_service_fn!(patch_service, PATCH);
108top_level_service_fn!(post_service, POST);
109top_level_service_fn!(put_service, PUT);
110top_level_service_fn!(trace_service, TRACE);
111
112pub fn on<I, H, T, S>(filter: MethodFilter, handler: I) -> MethodRouter<S, Infallible>
113where
114 I: Into<HandlerWithOperation<H, T, S>>,
115 H: Handler<T, S>,
116 T: 'static,
117 S: Clone + Send + Sync + 'static,
118{
119 MethodRouter::new().on(filter, handler)
120}
121
122top_level_handler_fn!(delete, DELETE);
123top_level_handler_fn!(get, GET);
124top_level_handler_fn!(head, HEAD);
125top_level_handler_fn!(options, OPTIONS);
126top_level_handler_fn!(patch, PATCH);
127top_level_handler_fn!(post, POST);
128top_level_handler_fn!(put, PUT);
129top_level_handler_fn!(trace, TRACE);
130
131#[derive(Clone, Default)]
132pub(super) struct MethodRouterOperations {
133 get: Option<OperationGenerator>,
134 head: Option<OperationGenerator>,
135 delete: Option<OperationGenerator>,
136 options: Option<OperationGenerator>,
137 patch: Option<OperationGenerator>,
138 post: Option<OperationGenerator>,
139 put: Option<OperationGenerator>,
140 trace: Option<OperationGenerator>,
141}
142
143impl MethodRouterOperations {
144 fn on(mut self, filter: MethodFilter, operation: Option<OperationGenerator>) -> Self {
145 if is_filter_present(filter, MethodFilter::GET) {
146 self.get = operation;
147 }
148 if is_filter_present(filter, MethodFilter::HEAD) {
149 self.head = operation;
150 }
151 if is_filter_present(filter, MethodFilter::DELETE) {
152 self.delete = operation;
153 }
154 if is_filter_present(filter, MethodFilter::OPTIONS) {
155 self.options = operation;
156 }
157 if is_filter_present(filter, MethodFilter::PATCH) {
158 self.patch = operation;
159 }
160 if is_filter_present(filter, MethodFilter::POST) {
161 self.post = operation;
162 }
163 if is_filter_present(filter, MethodFilter::PUT) {
164 self.put = operation;
165 }
166 if is_filter_present(filter, MethodFilter::TRACE) {
167 self.trace = operation;
168 }
169 self
170 }
171
172 pub(super) fn merge(self, other: Self) -> Self {
173 macro_rules! merge {
174 ( $first:ident, $second:ident ) => {
175 match ($first, $second) {
176 (Some(_), Some(_)) => panic!(concat!(
177 "Overlapping method operation. Cannot merge two method operation that both define `",
178 stringify!($first),
179 "`"
180 )),
181 (Some(svc), None) => Some(svc),
182 (None, Some(svc)) => Some(svc),
183 (None, None) => None,
184 }
185 };
186 }
187
188 let Self {
189 get,
190 head,
191 delete,
192 options,
193 patch,
194 post,
195 put,
196 trace,
197 } = self;
198
199 let Self {
200 get: get_other,
201 head: head_other,
202 delete: delete_other,
203 options: options_other,
204 patch: patch_other,
205 post: post_other,
206 put: put_other,
207 trace: trace_other,
208 } = other;
209
210 let get = merge!(get, get_other);
211 let head = merge!(head, head_other);
212 let delete = merge!(delete, delete_other);
213 let options = merge!(options, options_other);
214 let patch = merge!(patch, patch_other);
215 let post = merge!(post, post_other);
216 let put = merge!(put, put_other);
217 let trace = merge!(trace, trace_other);
218
219 Self {
220 get,
221 head,
222 delete,
223 options,
224 patch,
225 post,
226 put,
227 trace,
228 }
229 }
230
231 pub(crate) fn into_map(self) -> HashMap<Method, OperationGenerator> {
232 let mut map = HashMap::new();
233 if let Some(m) = self.get {
234 let _ = map.insert(Method::GET, m);
235 }
236 if let Some(m) = self.head {
237 let _ = map.insert(Method::HEAD, m);
238 }
239 if let Some(m) = self.delete {
240 let _ = map.insert(Method::DELETE, m);
241 }
242 if let Some(m) = self.options {
243 let _ = map.insert(Method::OPTIONS, m);
244 }
245 if let Some(m) = self.patch {
246 let _ = map.insert(Method::PATCH, m);
247 }
248 if let Some(m) = self.post {
249 let _ = map.insert(Method::POST, m);
250 }
251 if let Some(m) = self.put {
252 let _ = map.insert(Method::PUT, m);
253 }
254 if let Some(m) = self.trace {
255 let _ = map.insert(Method::TRACE, m);
256 }
257 map
258 }
259}
260
261pub struct MethodRouter<S = (), E = Infallible> {
264 pub(super) axum_method_router: AxumMethodRouter<S, E>,
265 pub(super) operations: MethodRouterOperations,
266}
267
268impl<S, E> fmt::Debug for MethodRouter<S, E> {
269 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270 self.axum_method_router.fmt(f)
271 }
272}
273
274impl<S, E> Default for MethodRouter<S, E>
275where
276 S: Clone,
277{
278 fn default() -> Self {
279 Self::new()
280 }
281}
282
283impl<S, E> From<AxumMethodRouter<S, E>> for MethodRouter<S, E> {
284 fn from(value: AxumMethodRouter<S, E>) -> Self {
285 Self {
286 axum_method_router: value,
287 operations: Default::default(),
288 }
289 }
290}
291
292impl<S> MethodRouter<S, Infallible>
293where
294 S: Clone,
295{
296 pub fn on<I, H, T>(self, filter: MethodFilter, handler: I) -> Self
297 where
298 I: Into<HandlerWithOperation<H, T, S>>,
299 H: Handler<T, S>,
300 T: 'static,
301 S: Send + Sync + 'static,
302 {
303 let HandlerWithOperation {
304 handler, operation, ..
305 } = handler.into();
306
307 Self {
308 axum_method_router: self.axum_method_router.on(filter, handler),
309 operations: self.operations.on(filter, operation),
310 }
311 }
312
313 chained_handler_fn!(delete, DELETE);
314 chained_handler_fn!(get, GET);
315 chained_handler_fn!(head, HEAD);
316 chained_handler_fn!(options, OPTIONS);
317 chained_handler_fn!(patch, PATCH);
318 chained_handler_fn!(post, POST);
319 chained_handler_fn!(put, PUT);
320 chained_handler_fn!(trace, TRACE);
321
322 pub fn fallback<H, T>(self, handler: H) -> Self
323 where
324 H: Handler<T, S>,
325 T: 'static,
326 S: Send + Sync + 'static,
327 {
328 Self {
329 axum_method_router: self.axum_method_router.fallback(handler),
330 ..self
331 }
332 }
333}
334
335impl<S, E> MethodRouter<S, E>
336where
337 S: Clone,
338{
339 pub fn new() -> Self {
340 Self {
341 axum_method_router: AxumMethodRouter::new(),
342 operations: Default::default(),
343 }
344 }
345
346 pub fn into_axum(self) -> AxumMethodRouter<S, E> {
348 self.axum_method_router
349 }
350
351 pub fn on_service<I, Svc>(self, filter: MethodFilter, svc: I) -> Self
352 where
353 I: Into<ServiceWithOperation<Svc, E>>,
354 Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
355 Svc::Response: IntoResponse + 'static,
356 Svc::Future: Send + 'static,
357 {
358 let ServiceWithOperation {
359 service, operation, ..
360 } = svc.into();
361 Self {
362 axum_method_router: self.axum_method_router.on_service(filter, service),
363 operations: self.operations.on(filter, operation),
364 }
365 }
366
367 chained_service_fn!(delete_service, DELETE);
368 chained_service_fn!(get_service, GET);
369 chained_service_fn!(head_service, HEAD);
370 chained_service_fn!(options_service, OPTIONS);
371 chained_service_fn!(patch_service, PATCH);
372 chained_service_fn!(post_service, POST);
373 chained_service_fn!(put_service, PUT);
374 chained_service_fn!(trace_service, TRACE);
375
376 pub fn fallback_service<Svc>(self, svc: Svc) -> Self
377 where
378 Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
379 Svc::Response: IntoResponse + 'static,
380 Svc::Future: Send + 'static,
381 {
382 Self {
383 axum_method_router: self.axum_method_router.fallback_service(svc),
384 ..self
385 }
386 }
387
388 pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
389 where
390 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
391 L::Service: Service<Request> + Clone + Send + Sync + 'static,
392 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
393 <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
394 <L::Service as Service<Request>>::Future: Send + 'static,
395 E: 'static,
396 S: 'static,
397 NewError: 'static,
398 {
399 MethodRouter {
400 axum_method_router: self.axum_method_router.layer(layer),
401 operations: self.operations,
402 }
403 }
404
405 pub fn route_layer<L>(self, layer: L) -> MethodRouter<S, E>
406 where
407 L: Layer<Route<E>> + Clone + Send + Sync + 'static,
408 L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
409 <L::Service as Service<Request>>::Response: IntoResponse + 'static,
410 <L::Service as Service<Request>>::Future: Send + 'static,
411 E: 'static,
412 S: 'static,
413 {
414 MethodRouter {
415 axum_method_router: self.axum_method_router.route_layer(layer),
416 operations: self.operations,
417 }
418 }
419
420 pub fn merge(self, other: MethodRouter<S, E>) -> Self {
421 MethodRouter {
422 axum_method_router: self.axum_method_router.merge(other.axum_method_router),
423 operations: self.operations.merge(other.operations),
424 }
425 }
426
427 pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
428 where
429 F: Clone + Send + Sync + 'static,
430 HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
431 <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
432 <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
433 T: 'static,
434 E: 'static,
435 S: 'static,
436 {
437 MethodRouter {
438 axum_method_router: self.axum_method_router.handle_error(f),
439 operations: self.operations,
440 }
441 }
442
443 pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
444 MethodRouter {
445 axum_method_router: self.axum_method_router.with_state(state),
446 operations: self.operations,
447 }
448 }
449}
450
451fn is_filter_present(lhs: MethodFilter, rhs: MethodFilter) -> bool {
452 lhs.or(rhs) == lhs
453}
454
455#[test]
456fn test_is_filter_present() {
457 assert!(is_filter_present(
459 MethodFilter::DELETE,
460 MethodFilter::DELETE
461 ));
462 assert!(is_filter_present(
463 MethodFilter::DELETE.or(MethodFilter::GET),
464 MethodFilter::DELETE
465 ));
466 assert!(is_filter_present(
467 MethodFilter::GET.or(MethodFilter::DELETE),
468 MethodFilter::DELETE
469 ));
470 assert!(is_filter_present(
471 MethodFilter::DELETE.or(MethodFilter::DELETE),
472 MethodFilter::DELETE
473 ));
474
475 assert!(!is_filter_present(MethodFilter::GET, MethodFilter::DELETE));
477}