1use std::{convert::Infallible, future::Future, mem, pin::Pin};
172
173use crate::{
174 generate::{self, in_context},
175 openapi::{OpenApi, PathItem, ReferenceOr, SchemaObject},
176 operation::OperationHandler,
177 util::{merge_paths, path_for_nested_route},
178 OperationInput, OperationOutput,
179};
180#[cfg(feature = "axum-tokio")]
181use axum::extract::connect_info::IntoMakeServiceWithConnectInfo;
182use axum::{
183 body::{Body, Bytes, HttpBody},
184 handler::Handler,
185 http::Request,
186 response::IntoResponse,
187 routing::{IntoMakeService, Route, RouterAsService, RouterIntoService},
188 Router,
189};
190use indexmap::map::Entry;
191use indexmap::IndexMap;
192use tower_layer::Layer;
193use tower_service::Service;
194
195#[cfg(feature = "axum-extra")]
196use axum_extra::routing::RouterExt as _;
197
198use self::routing::ApiMethodRouter;
199use crate::transform::{TransformOpenApi, TransformPathItem};
200
201mod inputs;
202mod outputs;
203
204pub mod routing;
205
206#[must_use]
209#[derive(Debug)]
210pub struct ApiRouter<S = ()> {
211 paths: IndexMap<String, PathItem>,
212 router: Router<S>,
213}
214
215impl<S> Clone for ApiRouter<S> {
216 fn clone(&self) -> Self {
217 Self {
218 paths: self.paths.clone(),
219 router: self.router.clone(),
220 }
221 }
222}
223
224impl<B> Service<Request<B>> for ApiRouter<()>
225where
226 B: HttpBody<Data = Bytes> + Send + 'static,
227 B::Error: Into<axum::BoxError>,
228{
229 type Response = axum::response::Response;
230 type Error = Infallible;
231 type Future = axum::routing::future::RouteFuture<Infallible>;
232
233 #[inline]
234 fn poll_ready(
235 &mut self,
236 cx: &mut std::task::Context<'_>,
237 ) -> std::task::Poll<Result<(), Self::Error>> {
238 Service::<Request<B>>::poll_ready(&mut self.router, cx)
239 }
240
241 #[inline]
242 fn call(&mut self, req: Request<B>) -> Self::Future {
243 self.router.call(req)
244 }
245}
246
247#[allow(clippy::mismatching_type_param_order)]
248impl Default for ApiRouter<()> {
249 fn default() -> Self {
250 Self::new()
251 }
252}
253
254impl<S> ApiRouter<S>
255where
256 S: Clone + Send + Sync + 'static,
257{
258 pub fn new() -> Self {
262 Self {
263 paths: IndexMap::new(),
264 router: Router::new(),
265 }
266 }
267
268 pub fn with_state<S2>(self, state: S) -> ApiRouter<S2> {
272 ApiRouter {
273 paths: self.paths,
274 router: self.router.with_state(state),
275 }
276 }
277
278 pub fn with_path_items(
282 mut self,
283 mut transform: impl FnMut(TransformPathItem) -> TransformPathItem,
284 ) -> Self {
285 for (_, item) in &mut self.paths {
286 let _ = transform(TransformPathItem::new(item));
287 }
288 self
289 }
290
291 #[tracing::instrument(skip_all, fields(path = path))]
298 pub fn api_route(mut self, path: &str, mut method_router: ApiMethodRouter<S>) -> Self {
299 in_context(|ctx| {
300 let new_path_item = method_router.take_path_item();
301
302 if let Some(path_item) = self.paths.get_mut(path) {
303 merge_paths(ctx, path, path_item, new_path_item);
304 } else {
305 self.paths.insert(path.into(), new_path_item);
306 }
307 });
308
309 self.router = self.router.route(path, method_router.router);
310 self
311 }
312
313 #[cfg(feature = "axum-extra")]
314 #[tracing::instrument(skip_all, fields(path = path))]
321 pub fn api_route_with_tsr(mut self, path: &str, mut method_router: ApiMethodRouter<S>) -> Self {
322 in_context(|ctx| {
323 let new_path_item = method_router.take_path_item();
324
325 if let Some(path_item) = self.paths.get_mut(path) {
326 merge_paths(ctx, path, path_item, new_path_item);
327 } else {
328 self.paths.insert(path.into(), new_path_item);
329 }
330 });
331
332 self.router = self.router.route_with_tsr(path, method_router.router);
333 self
334 }
335
336 #[tracing::instrument(skip_all, fields(path = path))]
344 pub fn api_route_with(
345 mut self,
346 path: &str,
347 mut method_router: ApiMethodRouter<S>,
348 transform: impl FnOnce(TransformPathItem) -> TransformPathItem,
349 ) -> Self {
350 in_context(|ctx| {
351 let mut p = method_router.take_path_item();
352 let t = transform(TransformPathItem::new(&mut p));
353
354 if !t.hidden {
355 if let Some(path_item) = self.paths.get_mut(path) {
356 merge_paths(ctx, path, path_item, p);
357 } else {
358 self.paths.insert(path.into(), p);
359 }
360 }
361 });
362
363 self.router = self.router.route(path, method_router.router);
364 self
365 }
366
367 #[cfg(feature = "axum-extra")]
368 #[tracing::instrument(skip_all, fields(path = path))]
376 pub fn api_route_with_tsr_and(
377 mut self,
378 path: &str,
379 mut method_router: ApiMethodRouter<S>,
380 transform: impl FnOnce(TransformPathItem) -> TransformPathItem,
381 ) -> Self {
382 in_context(|ctx| {
383 let mut p = method_router.take_path_item();
384 let t = transform(TransformPathItem::new(&mut p));
385
386 if !t.hidden {
387 if let Some(path_item) = self.paths.get_mut(path) {
388 merge_paths(ctx, path, path_item, p);
389 } else {
390 self.paths.insert(path.into(), p);
391 }
392 }
393 });
394
395 self.router = self.router.route_with_tsr(path, method_router.router);
396 self
397 }
398
399 #[tracing::instrument(skip_all)]
402 pub fn finish_api(mut self, api: &mut OpenApi) -> Router<S> {
403 self.merge_api(api);
404 self.router
405 }
406
407 #[tracing::instrument(skip_all)]
413 pub fn finish_api_with<F>(mut self, api: &mut OpenApi, transform: F) -> Router<S>
414 where
415 F: FnOnce(TransformOpenApi) -> TransformOpenApi,
416 {
417 self.merge_api_with(api, transform);
418 self.router
419 }
420
421 fn merge_api(&mut self, api: &mut OpenApi) {
422 self.merge_api_with(api, |x| x);
423 }
424
425 fn merge_api_with<F>(&mut self, api: &mut OpenApi, transform: F)
426 where
427 F: FnOnce(TransformOpenApi) -> TransformOpenApi,
428 {
429 if api.paths.is_none() {
430 api.paths = Some(Default::default());
431 }
432
433 let paths = api.paths.as_mut().unwrap();
434
435 paths.paths = mem::take(&mut self.paths)
436 .into_iter()
437 .map(|(route, path)| (route, ReferenceOr::Item(path)))
438 .collect();
439
440 let _ = transform(TransformOpenApi::new(api));
441
442 let needs_reset =
443 in_context(|ctx| {
444 if !ctx.extract_schemas {
445 return false;
446 }
447
448 let components = api.components.get_or_insert_with(Default::default);
449
450 components
451 .schemas
452 .extend(ctx.schema.take_definitions().into_iter().map(
453 |(name, json_schema)| {
454 (
455 name,
456 SchemaObject {
457 json_schema,
458 example: None,
459 external_docs: None,
460 },
461 )
462 },
463 ));
464
465 true
466 });
467
468 if needs_reset {
469 generate::reset_context();
470 }
471 }
472}
473
474impl<S> ApiRouter<S>
476where
477 S: Clone + Send + Sync + 'static,
478{
479 #[tracing::instrument(skip_all)]
483 pub fn route(mut self, path: &str, method_router: impl Into<ApiMethodRouter<S>>) -> Self {
484 self.router = self.router.route(path, method_router.into().router);
485 self
486 }
487
488 #[cfg(feature = "axum-extra")]
492 #[tracing::instrument(skip_all)]
493 pub fn route_with_tsr(
494 mut self,
495 path: &str,
496 method_router: impl Into<ApiMethodRouter<S>>,
497 ) -> Self {
498 self.router = self.router.route(path, method_router.into().router);
499 self
500 }
501
502 #[tracing::instrument(skip_all)]
504 pub fn route_service<T>(mut self, path: &str, service: T) -> Self
505 where
506 T: Service<Request<Body>, Error = Infallible> + Clone + Send + Sync + 'static,
507 T::Response: IntoResponse,
508 T::Future: Send + 'static,
509 {
510 self.router = self.router.route_service(path, service);
511 self
512 }
513
514 #[cfg(feature = "axum-extra")]
516 #[tracing::instrument(skip_all)]
517 pub fn route_service_with_tsr<T>(mut self, path: &str, service: T) -> Self
518 where
519 T: Service<axum::extract::Request, Error = Infallible> + Clone + Send + Sync + 'static,
520 T::Response: IntoResponse,
521 T::Future: Send + 'static,
522 Self: Sized,
523 {
524 self.router = self.router.route_service_with_tsr(path, service);
525 self
526 }
527
528 #[tracing::instrument(skip_all)]
532 pub fn nest(mut self, path: &str, router: ApiRouter<S>) -> Self {
533 self.router = self.router.nest(path, router.router);
534
535 self.paths.extend(
536 router
537 .paths
538 .into_iter()
539 .map(|(route, path_item)| (path_for_nested_route(path, &route), path_item)),
540 );
541
542 self
543 }
544
545 pub fn nest_api_service(mut self, path: &str, service: impl Into<ApiRouter<()>>) -> Self {
555 let router: ApiRouter<()> = service.into();
556
557 self.paths.extend(
558 router
559 .paths
560 .into_iter()
561 .map(|(route, path_item)| (path_for_nested_route(path, &route), path_item)),
562 );
563 self.router = self.router.nest_service(path, router.router);
564 self
565 }
566
567 pub fn nest_service<T>(mut self, path: &str, svc: T) -> Self
570 where
571 T: Service<Request<Body>, Error = Infallible> + Clone + Send + Sync + 'static,
572 T::Response: IntoResponse,
573 T::Future: Send + 'static,
574 {
575 self.router = self.router.nest_service(path, svc);
576
577 self
578 }
579
580 pub fn merge<R>(mut self, other: R) -> Self
585 where
586 R: Into<ApiRouter<S>>,
587 {
588 let other: ApiRouter<S> = other.into();
589
590 for (key, path) in other.paths {
591 match self.paths.entry(key) {
592 Entry::Occupied(mut o) => {
593 o.get_mut().merge_with(path);
594 }
595 Entry::Vacant(v) => {
596 v.insert(path);
597 }
598 }
599 }
600 self.router = self.router.merge(other.router);
601 self
602 }
603
604 pub fn layer<L>(self, layer: L) -> ApiRouter<S>
606 where
607 L: Layer<Route> + Clone + Send + Sync + 'static,
608 L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
609 <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
610 <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
611 <L::Service as Service<Request<Body>>>::Future: Send + 'static,
612 {
613 ApiRouter {
614 paths: self.paths,
615 router: self.router.layer(layer),
616 }
617 }
618
619 pub fn route_layer<L>(mut self, layer: L) -> Self
621 where
622 L: Layer<Route> + Clone + Send + Sync + 'static,
623 L::Service: Service<Request<Body>> + Clone + Send + Sync + 'static,
624 <L::Service as Service<Request<Body>>>::Response: IntoResponse + 'static,
625 <L::Service as Service<Request<Body>>>::Error: Into<Infallible> + 'static,
626 <L::Service as Service<Request<Body>>>::Future: Send + 'static,
627 {
628 self.router = self.router.route_layer(layer);
629 self
630 }
631
632 pub fn fallback<H, T>(mut self, handler: H) -> Self
634 where
635 H: Handler<T, S>,
636 T: 'static,
637 {
638 self.router = self.router.fallback(handler);
639 self
640 }
641
642 pub fn fallback_service<T>(mut self, svc: T) -> Self
644 where
645 T: Service<Request<Body>, Error = Infallible> + Clone + Send + Sync + 'static,
646 T::Response: IntoResponse,
647 T::Future: Send + 'static,
648 {
649 self.router = self.router.fallback_service(svc);
650 self
651 }
652
653 #[must_use]
657 pub fn as_service<B>(&mut self) -> RouterAsService<'_, B, S> {
658 self.router.as_service()
659 }
660
661 #[must_use]
665 pub fn into_service<B>(self) -> RouterIntoService<B, S> {
666 self.router.into_service()
667 }
668}
669
670impl ApiRouter<()> {
671 #[tracing::instrument(skip_all)]
673 #[must_use]
674 pub fn into_make_service(self) -> IntoMakeService<Router<()>> {
675 self.router.into_make_service()
676 }
677
678 #[tracing::instrument(skip_all)]
680 #[must_use]
681 #[cfg(feature = "axum-tokio")]
682 pub fn into_make_service_with_connect_info<C>(
683 self,
684 ) -> IntoMakeServiceWithConnectInfo<Router<()>, C> {
685 self.router.into_make_service_with_connect_info()
686 }
687}
688
689impl<S> From<Router<S>> for ApiRouter<S> {
690 fn from(router: Router<S>) -> Self {
691 ApiRouter {
692 paths: IndexMap::new(),
693 router,
694 }
695 }
696}
697
698impl<S> From<ApiRouter<S>> for Router<S> {
699 fn from(api: ApiRouter<S>) -> Self {
700 api.router
701 }
702}
703
704pub trait IntoApiResponse: IntoResponse + OperationOutput {}
713
714impl<T> IntoApiResponse for T where T: IntoResponse + OperationOutput {}
715
716pub trait RouterExt<S>: private::Sealed + Sized {
718 fn into_api(self) -> ApiRouter<S>;
721 fn api_route(self, path: &str, method_router: ApiMethodRouter<S>) -> ApiRouter<S>;
726 #[cfg(feature = "axum-extra")]
727 fn api_route_with_tsr(self, path: &str, method_router: ApiMethodRouter<S>) -> ApiRouter<S>;
729}
730
731impl<S> RouterExt<S> for Router<S>
732where
733 S: Clone + Send + Sync + 'static,
734{
735 #[tracing::instrument(skip_all)]
736 fn into_api(self) -> ApiRouter<S> {
737 ApiRouter::from(self)
738 }
739
740 #[tracing::instrument(skip_all)]
741 fn api_route(self, path: &str, method_router: ApiMethodRouter<S>) -> ApiRouter<S> {
742 ApiRouter::from(self).api_route(path, method_router)
743 }
744
745 #[cfg(feature = "axum-extra")]
746 #[tracing::instrument(skip_all)]
747 fn api_route_with_tsr(self, path: &str, method_router: ApiMethodRouter<S>) -> ApiRouter<S> {
748 ApiRouter::from(self).api_route_with_tsr(path, method_router)
749 }
750}
751
752impl<S> private::Sealed for Router<S> {}
753
754#[doc(hidden)]
755pub enum ServiceOrApiRouter<T> {
756 Service(T),
757 Router(ApiRouter<()>),
758}
759
760impl<T> From<T> for ServiceOrApiRouter<T>
761where
762 T: Service<Request<Body>, Error = Infallible> + Clone + Send + 'static,
763 T::Response: IntoResponse,
764 T::Future: Send + 'static,
765{
766 fn from(v: T) -> Self {
767 Self::Service(v)
768 }
769}
770
771impl From<ApiRouter<()>> for ServiceOrApiRouter<DefinitelyNotService> {
772 fn from(v: ApiRouter<()>) -> Self {
773 Self::Router(v)
774 }
775}
776
777#[derive(Clone)]
779#[doc(hidden)]
780pub enum DefinitelyNotService {}
781
782impl Service<Request<Body>> for DefinitelyNotService {
783 type Response = String;
784
785 type Error = Infallible;
786
787 type Future =
788 Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync + 'static>>;
789
790 fn poll_ready(
791 &mut self,
792 _cx: &mut std::task::Context<'_>,
793 ) -> std::task::Poll<Result<(), Self::Error>> {
794 unreachable!()
795 }
796
797 fn call(&mut self, _req: Request<Body>) -> Self::Future {
798 unreachable!()
799 }
800}
801
802mod private {
803 pub trait Sealed {}
804}
805
806impl<I, O, L, H, T, S> OperationHandler<I, O> for axum::handler::Layered<L, H, T, S>
807where
808 H: OperationHandler<I, O>,
809 I: OperationInput,
810 O: OperationOutput,
811{
812}
813
814pub trait AxumOperationHandler<I, O, T, S>: Handler<T, S> + OperationHandler<I, O>
820where
821 I: OperationInput,
822 O: OperationOutput,
823{
824}
825
826impl<H, I, O, T, S> AxumOperationHandler<I, O, T, S> for H
827where
828 H: Handler<T, S> + OperationHandler<I, O>,
829 I: OperationInput,
830 O: OperationOutput,
831{
832}
833
834#[cfg(test)]
835#[allow(clippy::unused_async)]
836mod tests {
837 use crate::axum::{routing, ApiRouter};
838 use axum::{extract::State, handler::Handler};
839
840 async fn test_handler1(State(_): State<TestState>) {}
841
842 async fn test_handler2(State(_): State<u8>) {}
843
844 async fn test_handler3() {}
845
846 fn nested_route() -> ApiRouter {
847 ApiRouter::new()
848 .api_route_with("/", routing::post(test_handler3), |t| t)
849 .api_route_with("/test1", routing::post(test_handler3), |t| t)
850 .api_route_with("/test2/", routing::post(test_handler3), |t| t)
851 }
852
853 #[derive(Clone, Copy)]
854 struct TestState {
855 field1: u8,
856 }
857
858 #[test]
859 fn test_nesting_with_nondefault_state() {
860 let _app: ApiRouter = ApiRouter::new()
861 .nest_api_service("/home", ApiRouter::new().with_state(1_isize))
862 .with_state(1_usize);
863 }
864
865 #[test]
866 fn test_method_router_with_state() {
867 let app: ApiRouter<TestState> =
868 ApiRouter::new().api_route("/", routing::get(test_handler1));
869 let app_with_state: ApiRouter = app.with_state(TestState { field1: 0 });
870 let _service = app_with_state.into_make_service();
872 }
873
874 #[test]
875 fn test_router_with_different_states() {
876 let state = TestState { field1: 0 };
877 let app: ApiRouter = ApiRouter::new()
878 .api_route("/test1", routing::get(test_handler1))
879 .api_route(
880 "/test2",
881 routing::get(test_handler2).with_state(state.field1),
882 )
883 .with_state(state);
884 let _service = app.into_make_service();
885 }
886
887 #[test]
888 fn test_api_route_with_same_router_different_methods() {
889 let app: ApiRouter = ApiRouter::new()
890 .api_route_with("/test1", routing::post(test_handler3), |t| t)
891 .api_route_with("/test1", routing::get(test_handler3), |t| t);
892
893 let item = app
894 .paths
895 .get("/test1")
896 .expect("should contain handler for /test1");
897
898 assert!(item.get.is_some());
899 assert!(item.post.is_some());
900 }
901
902 #[test]
903 fn test_nested_routing() {
904 let app: ApiRouter = ApiRouter::new().nest("/app", nested_route());
905
906 assert!(app.paths.contains_key("/app"));
907 assert!(!app.paths.contains_key("/app/"));
908 assert!(app.paths.contains_key("/app/test1"));
909 assert!(app.paths.contains_key("/app/test2/"));
910 }
911
912 #[test]
913 fn test_layered_handler() {
914 let _app: ApiRouter = ApiRouter::new().api_route(
915 "/test-route",
916 routing::get(test_handler3.layer(tower_layer::Identity::new())),
917 );
918 }
919}