okapi_operation/axum_integration/
method_router.rs

1use std::{collections::HashMap, convert::Infallible, fmt};
2
3use axum::{
4    error_handling::HandleError,
5    extract::Request,
6    handler::Handler,
7    http::Method,
8    response::IntoResponse,
9    routing::{MethodFilter, MethodRouter as AxumMethodRouter, Route},
10};
11use tower::{Layer, Service};
12
13use super::handler_traits::{HandlerWithOperation, ServiceWithOperation};
14use crate::OperationGenerator;
15
16macro_rules! top_level_service_fn {
17    (
18        $(#[$m:meta])*
19        $name:ident, $method:ident
20    ) => {
21        $(#[$m])*
22        pub fn $name<I, Svc, S, E>(svc: I) -> MethodRouter<S, E>
23        where
24            I: Into<ServiceWithOperation<Svc, E>>,
25            Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
26            Svc::Response: IntoResponse + 'static,
27            Svc::Future: Send + 'static,
28            S: Clone,
29        {
30            on_service(MethodFilter::$method, svc)
31        }
32    };
33}
34
35macro_rules! top_level_handler_fn {
36    (
37        $(#[$m:meta])*
38        $name:ident, $method:ident
39    ) => {
40        $(#[$m])*
41        pub fn $name<I, H, T, S>(handler: I) -> MethodRouter<S, Infallible>
42        where
43            I: Into<HandlerWithOperation<H, T, S>>,
44            H: Handler<T, S>,
45            T: 'static,
46            S: Clone + Send + Sync + 'static,
47        {
48            on(MethodFilter::$method, handler)
49        }
50    };
51}
52
53/// Macro for implementing service methods on [`MethodRouter`].
54macro_rules! chained_service_fn {
55    (
56        $(#[$m:meta])*
57        $name:ident, $method:ident
58    ) => {
59        $(#[$m])*
60        pub fn $name<I, Svc>(self, svc: I) -> Self
61        where
62            I: Into<ServiceWithOperation<Svc, E>>,
63            Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
64            Svc::Response: IntoResponse + 'static,
65            Svc::Future: Send + 'static,
66        {
67            self.on_service(MethodFilter::$method, svc)
68        }
69    };
70}
71
72/// Macro for implementing handler methods on [`MethodRouter`].
73macro_rules! chained_handler_fn {
74    (
75        $(#[$m:meta])*
76        $name:ident, $method:ident
77    ) => {
78        $(#[$m])*
79        pub fn $name<I, H, T>(self, handler: I) -> Self
80        where
81            I: Into<HandlerWithOperation<H, T, S>>,
82            H: Handler<T, S>,
83            T: 'static,
84            S: Send + Sync + 'static
85        {
86            self.on(MethodFilter::$method, handler)
87        }
88    };
89}
90
91// TODO: check whether E generic parameter is redundant
92pub fn on_service<I, Svc, S, E>(filter: MethodFilter, svc: I) -> MethodRouter<S, E>
93where
94    I: Into<ServiceWithOperation<Svc, E>>,
95    Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
96    Svc::Response: IntoResponse + 'static,
97    Svc::Future: Send + 'static,
98    S: Clone,
99{
100    MethodRouter::new().on_service(filter, svc)
101}
102
103top_level_service_fn!(delete_service, DELETE);
104top_level_service_fn!(get_service, GET);
105top_level_service_fn!(head_service, HEAD);
106top_level_service_fn!(options_service, OPTIONS);
107top_level_service_fn!(patch_service, PATCH);
108top_level_service_fn!(post_service, POST);
109top_level_service_fn!(put_service, PUT);
110top_level_service_fn!(trace_service, TRACE);
111
112pub fn on<I, H, T, S>(filter: MethodFilter, handler: I) -> MethodRouter<S, Infallible>
113where
114    I: Into<HandlerWithOperation<H, T, S>>,
115    H: Handler<T, S>,
116    T: 'static,
117    S: Clone + Send + Sync + 'static,
118{
119    MethodRouter::new().on(filter, handler)
120}
121
122top_level_handler_fn!(delete, DELETE);
123top_level_handler_fn!(get, GET);
124top_level_handler_fn!(head, HEAD);
125top_level_handler_fn!(options, OPTIONS);
126top_level_handler_fn!(patch, PATCH);
127top_level_handler_fn!(post, POST);
128top_level_handler_fn!(put, PUT);
129top_level_handler_fn!(trace, TRACE);
130
131#[derive(Clone, Default)]
132pub(super) struct MethodRouterOperations {
133    get: Option<OperationGenerator>,
134    head: Option<OperationGenerator>,
135    delete: Option<OperationGenerator>,
136    options: Option<OperationGenerator>,
137    patch: Option<OperationGenerator>,
138    post: Option<OperationGenerator>,
139    put: Option<OperationGenerator>,
140    trace: Option<OperationGenerator>,
141}
142
143impl MethodRouterOperations {
144    fn on(mut self, filter: MethodFilter, operation: Option<OperationGenerator>) -> Self {
145        if is_filter_present(filter, MethodFilter::GET) {
146            self.get = operation;
147        }
148        if is_filter_present(filter, MethodFilter::HEAD) {
149            self.head = operation;
150        }
151        if is_filter_present(filter, MethodFilter::DELETE) {
152            self.delete = operation;
153        }
154        if is_filter_present(filter, MethodFilter::OPTIONS) {
155            self.options = operation;
156        }
157        if is_filter_present(filter, MethodFilter::PATCH) {
158            self.patch = operation;
159        }
160        if is_filter_present(filter, MethodFilter::POST) {
161            self.post = operation;
162        }
163        if is_filter_present(filter, MethodFilter::PUT) {
164            self.put = operation;
165        }
166        if is_filter_present(filter, MethodFilter::TRACE) {
167            self.trace = operation;
168        }
169        self
170    }
171
172    pub(super) fn merge(self, other: Self) -> Self {
173        macro_rules! merge {
174            ( $first:ident, $second:ident ) => {
175                match ($first, $second) {
176                    (Some(_), Some(_)) => panic!(concat!(
177                        "Overlapping method operation. Cannot merge two method operation that both define `",
178                        stringify!($first),
179                        "`"
180                    )),
181                    (Some(svc), None) => Some(svc),
182                    (None, Some(svc)) => Some(svc),
183                    (None, None) => None,
184                }
185            };
186        }
187
188        let Self {
189            get,
190            head,
191            delete,
192            options,
193            patch,
194            post,
195            put,
196            trace,
197        } = self;
198
199        let Self {
200            get: get_other,
201            head: head_other,
202            delete: delete_other,
203            options: options_other,
204            patch: patch_other,
205            post: post_other,
206            put: put_other,
207            trace: trace_other,
208        } = other;
209
210        let get = merge!(get, get_other);
211        let head = merge!(head, head_other);
212        let delete = merge!(delete, delete_other);
213        let options = merge!(options, options_other);
214        let patch = merge!(patch, patch_other);
215        let post = merge!(post, post_other);
216        let put = merge!(put, put_other);
217        let trace = merge!(trace, trace_other);
218
219        Self {
220            get,
221            head,
222            delete,
223            options,
224            patch,
225            post,
226            put,
227            trace,
228        }
229    }
230
231    pub(crate) fn into_map(self) -> HashMap<Method, OperationGenerator> {
232        let mut map = HashMap::new();
233        if let Some(m) = self.get {
234            let _ = map.insert(Method::GET, m);
235        }
236        if let Some(m) = self.head {
237            let _ = map.insert(Method::HEAD, m);
238        }
239        if let Some(m) = self.delete {
240            let _ = map.insert(Method::DELETE, m);
241        }
242        if let Some(m) = self.options {
243            let _ = map.insert(Method::OPTIONS, m);
244        }
245        if let Some(m) = self.patch {
246            let _ = map.insert(Method::PATCH, m);
247        }
248        if let Some(m) = self.post {
249            let _ = map.insert(Method::POST, m);
250        }
251        if let Some(m) = self.put {
252            let _ = map.insert(Method::PUT, m);
253        }
254        if let Some(m) = self.trace {
255            let _ = map.insert(Method::TRACE, m);
256        }
257        map
258    }
259}
260
261/// Drop-in replacement for [`axum::routing::MethodRouter`], which supports
262/// OpenAPI definitions of handlers or services.
263pub struct MethodRouter<S = (), E = Infallible> {
264    pub(super) axum_method_router: AxumMethodRouter<S, E>,
265    pub(super) operations: MethodRouterOperations,
266}
267
268impl<S, E> fmt::Debug for MethodRouter<S, E> {
269    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
270        self.axum_method_router.fmt(f)
271    }
272}
273
274impl<S, E> Default for MethodRouter<S, E>
275where
276    S: Clone,
277{
278    fn default() -> Self {
279        Self::new()
280    }
281}
282
283impl<S, E> From<AxumMethodRouter<S, E>> for MethodRouter<S, E> {
284    fn from(value: AxumMethodRouter<S, E>) -> Self {
285        Self {
286            axum_method_router: value,
287            operations: Default::default(),
288        }
289    }
290}
291
292impl<S> MethodRouter<S, Infallible>
293where
294    S: Clone,
295{
296    pub fn on<I, H, T>(self, filter: MethodFilter, handler: I) -> Self
297    where
298        I: Into<HandlerWithOperation<H, T, S>>,
299        H: Handler<T, S>,
300        T: 'static,
301        S: Send + Sync + 'static,
302    {
303        let HandlerWithOperation {
304            handler, operation, ..
305        } = handler.into();
306
307        Self {
308            axum_method_router: self.axum_method_router.on(filter, handler),
309            operations: self.operations.on(filter, operation),
310        }
311    }
312
313    chained_handler_fn!(delete, DELETE);
314    chained_handler_fn!(get, GET);
315    chained_handler_fn!(head, HEAD);
316    chained_handler_fn!(options, OPTIONS);
317    chained_handler_fn!(patch, PATCH);
318    chained_handler_fn!(post, POST);
319    chained_handler_fn!(put, PUT);
320    chained_handler_fn!(trace, TRACE);
321
322    pub fn fallback<H, T>(self, handler: H) -> Self
323    where
324        H: Handler<T, S>,
325        T: 'static,
326        S: Send + Sync + 'static,
327    {
328        Self {
329            axum_method_router: self.axum_method_router.fallback(handler),
330            ..self
331        }
332    }
333}
334
335impl<S, E> MethodRouter<S, E>
336where
337    S: Clone,
338{
339    pub fn new() -> Self {
340        Self {
341            axum_method_router: AxumMethodRouter::new(),
342            operations: Default::default(),
343        }
344    }
345
346    /// Convert method router into [`axum::routing::MethodRouter`], dropping related OpenAPI definitions.
347    pub fn into_axum(self) -> AxumMethodRouter<S, E> {
348        self.axum_method_router
349    }
350
351    pub fn on_service<I, Svc>(self, filter: MethodFilter, svc: I) -> Self
352    where
353        I: Into<ServiceWithOperation<Svc, E>>,
354        Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
355        Svc::Response: IntoResponse + 'static,
356        Svc::Future: Send + 'static,
357    {
358        let ServiceWithOperation {
359            service, operation, ..
360        } = svc.into();
361        Self {
362            axum_method_router: self.axum_method_router.on_service(filter, service),
363            operations: self.operations.on(filter, operation),
364        }
365    }
366
367    chained_service_fn!(delete_service, DELETE);
368    chained_service_fn!(get_service, GET);
369    chained_service_fn!(head_service, HEAD);
370    chained_service_fn!(options_service, OPTIONS);
371    chained_service_fn!(patch_service, PATCH);
372    chained_service_fn!(post_service, POST);
373    chained_service_fn!(put_service, PUT);
374    chained_service_fn!(trace_service, TRACE);
375
376    pub fn fallback_service<Svc>(self, svc: Svc) -> Self
377    where
378        Svc: Service<Request, Error = E> + Clone + Send + Sync + 'static,
379        Svc::Response: IntoResponse + 'static,
380        Svc::Future: Send + 'static,
381    {
382        Self {
383            axum_method_router: self.axum_method_router.fallback_service(svc),
384            ..self
385        }
386    }
387
388    pub fn layer<L, NewError>(self, layer: L) -> MethodRouter<S, NewError>
389    where
390        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
391        L::Service: Service<Request> + Clone + Send + Sync + 'static,
392        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
393        <L::Service as Service<Request>>::Error: Into<NewError> + 'static,
394        <L::Service as Service<Request>>::Future: Send + 'static,
395        E: 'static,
396        S: 'static,
397        NewError: 'static,
398    {
399        MethodRouter {
400            axum_method_router: self.axum_method_router.layer(layer),
401            operations: self.operations,
402        }
403    }
404
405    pub fn route_layer<L>(self, layer: L) -> MethodRouter<S, E>
406    where
407        L: Layer<Route<E>> + Clone + Send + Sync + 'static,
408        L::Service: Service<Request, Error = E> + Clone + Send + Sync + 'static,
409        <L::Service as Service<Request>>::Response: IntoResponse + 'static,
410        <L::Service as Service<Request>>::Future: Send + 'static,
411        E: 'static,
412        S: 'static,
413    {
414        MethodRouter {
415            axum_method_router: self.axum_method_router.route_layer(layer),
416            operations: self.operations,
417        }
418    }
419
420    pub fn merge(self, other: MethodRouter<S, E>) -> Self {
421        MethodRouter {
422            axum_method_router: self.axum_method_router.merge(other.axum_method_router),
423            operations: self.operations.merge(other.operations),
424        }
425    }
426
427    pub fn handle_error<F, T>(self, f: F) -> MethodRouter<S, Infallible>
428    where
429        F: Clone + Send + Sync + 'static,
430        HandleError<Route<E>, F, T>: Service<Request, Error = Infallible>,
431        <HandleError<Route<E>, F, T> as Service<Request>>::Future: Send,
432        <HandleError<Route<E>, F, T> as Service<Request>>::Response: IntoResponse + Send,
433        T: 'static,
434        E: 'static,
435        S: 'static,
436    {
437        MethodRouter {
438            axum_method_router: self.axum_method_router.handle_error(f),
439            operations: self.operations,
440        }
441    }
442
443    pub fn with_state<S2>(self, state: S) -> MethodRouter<S2, E> {
444        MethodRouter {
445            axum_method_router: self.axum_method_router.with_state(state),
446            operations: self.operations,
447        }
448    }
449}
450
451fn is_filter_present(lhs: MethodFilter, rhs: MethodFilter) -> bool {
452    lhs.or(rhs) == lhs
453}
454
455#[test]
456fn test_is_filter_present() {
457    // Positive tests
458    assert!(is_filter_present(
459        MethodFilter::DELETE,
460        MethodFilter::DELETE
461    ));
462    assert!(is_filter_present(
463        MethodFilter::DELETE.or(MethodFilter::GET),
464        MethodFilter::DELETE
465    ));
466    assert!(is_filter_present(
467        MethodFilter::GET.or(MethodFilter::DELETE),
468        MethodFilter::DELETE
469    ));
470    assert!(is_filter_present(
471        MethodFilter::DELETE.or(MethodFilter::DELETE),
472        MethodFilter::DELETE
473    ));
474
475    // Negative tests
476    assert!(!is_filter_present(MethodFilter::GET, MethodFilter::DELETE));
477}