use std::{convert::Infallible, future::Future, mem, pin::Pin};
use crate::{
gen::{self, in_context},
openapi::{Components, OpenApi, PathItem, ReferenceOr, SchemaObject},
operation::OperationHandler,
util::merge_paths,
OperationInput, OperationOutput,
};
use axum::{
body::{Body, HttpBody},
extract::connect_info::IntoMakeServiceWithConnectInfo,
handler::Handler,
http::Request,
response::IntoResponse,
routing::{IntoMakeService, Route},
Router,
};
use indexmap::IndexMap;
use tower_layer::Layer;
use tower_service::Service;
use crate::{
transform::{TransformOpenApi, TransformPathItem},
util::path_colon_params,
};
use self::routing::ApiMethodRouter;
mod inputs;
mod outputs;
pub mod routing;
#[must_use]
#[derive(Debug)]
pub struct ApiRouter<S = (), B = Body> {
paths: IndexMap<String, PathItem>,
router: Router<S, B>,
}
impl<S, B> Clone for ApiRouter<S, B> {
fn clone(&self) -> Self {
Self {
paths: self.paths.clone(),
router: self.router.clone(),
}
}
}
impl<B> Service<Request<B>> for ApiRouter<(), B>
where
B: HttpBody + Send + 'static,
{
type Response = axum::response::Response;
type Error = Infallible;
type Future = axum::routing::future::RouteFuture<B, Infallible>;
#[inline]
fn poll_ready(
&mut self,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
self.router.poll_ready(cx)
}
#[inline]
fn call(&mut self, req: Request<B>) -> Self::Future {
self.router.call(req)
}
}
#[allow(clippy::mismatching_type_param_order)]
impl<B> Default for ApiRouter<(), B>
where
B: HttpBody + Send + 'static,
{
fn default() -> Self {
Self::new()
}
}
impl<S, B> ApiRouter<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
pub fn new() -> Self {
Self {
paths: IndexMap::new(),
router: Router::new(),
}
}
pub fn with_state<S2>(self, state: S) -> ApiRouter<S2, B> {
ApiRouter {
paths: self.paths,
router: self.router.with_state(state),
}
}
#[tracing::instrument(skip_all, fields(%path))]
pub fn api_route(mut self, path: &str, mut method_router: ApiMethodRouter<S, B>) -> Self {
in_context(|ctx| {
let new_path_item = method_router.take_path_item();
if let Some(path_item) = self.paths.get_mut(path) {
merge_paths(ctx, path, path_item, new_path_item);
} else {
self.paths.insert(path.into(), new_path_item);
}
});
self.router = self.router.route(path, method_router.router);
self
}
#[tracing::instrument(skip_all, fields(%path))]
pub fn api_route_with(
mut self,
path: &str,
mut method_router: ApiMethodRouter<S, B>,
transform: impl FnOnce(TransformPathItem) -> TransformPathItem,
) -> Self {
let mut p = method_router.take_path_item();
let t = transform(TransformPathItem::new(&mut p));
if !t.hidden {
self.paths.insert(path.into(), p);
}
self.router = self.router.route(path, method_router.router);
self
}
#[tracing::instrument(skip_all)]
pub fn finish_api(mut self, api: &mut OpenApi) -> Router<S, B> {
self.merge_api(api);
self.router
}
#[tracing::instrument(skip_all)]
pub fn finish_api_with<F>(mut self, api: &mut OpenApi, transform: F) -> Router<S, B>
where
F: FnOnce(TransformOpenApi) -> TransformOpenApi,
{
self.merge_api(api);
let _ = transform(TransformOpenApi::new(api));
self.router
}
fn merge_api(&mut self, api: &mut OpenApi) {
if api.paths.is_none() {
api.paths = Some(Default::default());
}
let paths = api.paths.as_mut().unwrap();
paths.paths = mem::take(&mut self.paths)
.into_iter()
.map(|(route, path)| {
(
path_colon_params(&route).into_owned(),
ReferenceOr::Item(path),
)
})
.collect();
let needs_reset =
in_context(|ctx| {
if !ctx.extract_schemas {
return false;
}
if api.components.is_none() {
api.components = Some(Components::default());
}
let components = api.components.as_mut().unwrap();
components
.schemas
.extend(ctx.schema.take_definitions().into_iter().map(
|(name, json_schema)| {
(
name,
SchemaObject {
json_schema,
example: None,
external_docs: None,
},
)
},
));
true
});
if needs_reset {
gen::reset_context();
}
}
}
impl<S, B> ApiRouter<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
#[tracing::instrument(skip_all)]
pub fn route(mut self, path: &str, method_router: impl Into<ApiMethodRouter<S, B>>) -> Self {
self.router = self.router.route(path, method_router.into().router);
self
}
#[tracing::instrument(skip_all)]
pub fn route_service<T>(mut self, path: &str, service: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.router = self.router.route_service(path, service);
self
}
#[tracing::instrument(skip_all)]
pub fn nest(mut self, mut path: &str, router: ApiRouter<S, B>) -> Self {
self.router = self.router.nest(path, router.router);
path = path.trim_end_matches('/');
self.paths.extend(
router
.paths
.into_iter()
.map(|(route, path_item)| (path.to_string() + &route, path_item)),
);
self
}
pub fn nest_api_service(
mut self,
mut path: &str,
service: impl Into<ApiRouter<(), B>>,
) -> Self {
let router: ApiRouter<(), B> = service.into();
path = path.trim_end_matches('/');
self.paths.extend(
router
.paths
.into_iter()
.map(|(route, path_item)| (path.to_string() + &route, path_item)),
);
self.router = self.router.nest_service(path, router.router);
self
}
pub fn nest_service<T>(mut self, path: &str, svc: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.router = self.router.nest_service(path, svc);
self
}
pub fn merge<R>(mut self, other: R) -> Self
where
R: Into<ApiRouter<S, B>>,
{
let other: ApiRouter<S, B> = other.into();
self.paths.extend(other.paths);
self.router = self.router.merge(other.router);
self
}
pub fn layer<L, NewReqBody>(self, layer: L) -> ApiRouter<S, NewReqBody>
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<NewReqBody>> + Clone + Send + 'static,
<L::Service as Service<Request<NewReqBody>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<NewReqBody>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<NewReqBody>>>::Future: Send + 'static,
NewReqBody: HttpBody + 'static,
{
ApiRouter {
paths: self.paths,
router: self.router.layer(layer),
}
}
pub fn route_layer<L>(mut self, layer: L) -> Self
where
L: Layer<Route<B>> + Clone + Send + 'static,
L::Service: Service<Request<B>> + Clone + Send + 'static,
<L::Service as Service<Request<B>>>::Response: IntoResponse + 'static,
<L::Service as Service<Request<B>>>::Error: Into<Infallible> + 'static,
<L::Service as Service<Request<B>>>::Future: Send + 'static,
{
self.router = self.router.route_layer(layer);
self
}
pub fn fallback<H, T>(mut self, handler: H) -> Self
where
H: Handler<T, S, B>,
T: 'static,
{
self.router = self.router.fallback(handler);
self
}
pub fn fallback_service<T>(mut self, svc: T) -> Self
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
self.router = self.router.fallback_service(svc);
self
}
}
impl<B> ApiRouter<(), B>
where
B: HttpBody + Send + 'static,
{
#[tracing::instrument(skip_all)]
#[must_use]
pub fn into_make_service(self) -> IntoMakeService<Router<(), B>> {
self.router.into_make_service()
}
#[tracing::instrument(skip_all)]
#[must_use]
pub fn into_make_service_with_connect_info<C>(
self,
) -> IntoMakeServiceWithConnectInfo<Router<(), B>, C> {
self.router.into_make_service_with_connect_info()
}
}
impl<S, B> From<Router<S, B>> for ApiRouter<S, B> {
fn from(router: Router<S, B>) -> Self {
ApiRouter {
paths: IndexMap::new(),
router,
}
}
}
impl<S, B> From<ApiRouter<S, B>> for Router<S, B> {
fn from(api: ApiRouter<S, B>) -> Self {
api.router
}
}
pub trait IntoApiResponse: IntoResponse + OperationOutput {}
impl<T> IntoApiResponse for T where T: IntoResponse + OperationOutput {}
pub trait RouterExt<S, B>: private::Sealed + Sized {
fn into_api(self) -> ApiRouter<S, B>;
fn api_route(self, path: &str, method_router: ApiMethodRouter<S, B>) -> ApiRouter<S, B>;
}
impl<S, B> RouterExt<S, B> for Router<S, B>
where
B: HttpBody + Send + 'static,
S: Clone + Send + Sync + 'static,
{
#[tracing::instrument(skip_all)]
fn into_api(self) -> ApiRouter<S, B> {
ApiRouter::from(self)
}
#[tracing::instrument(skip_all)]
fn api_route(self, path: &str, method_router: ApiMethodRouter<S, B>) -> ApiRouter<S, B> {
ApiRouter::from(self).api_route(path, method_router)
}
}
impl<S, B> private::Sealed for Router<S, B> {}
#[doc(hidden)]
pub enum ServiceOrApiRouter<B, T> {
Service(T),
Router(ApiRouter<(), B>),
}
impl<T, B> From<T> for ServiceOrApiRouter<B, T>
where
T: Service<Request<B>, Error = Infallible> + Clone + Send + 'static,
T::Response: IntoResponse,
T::Future: Send + 'static,
{
fn from(v: T) -> Self {
Self::Service(v)
}
}
impl<B> From<ApiRouter<(), B>> for ServiceOrApiRouter<B, DefinitelyNotService> {
fn from(v: ApiRouter<(), B>) -> Self {
Self::Router(v)
}
}
#[derive(Clone)]
#[doc(hidden)]
pub enum DefinitelyNotService {}
impl<B> Service<Request<B>> for DefinitelyNotService {
type Response = String;
type Error = Infallible;
type Future =
Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send + Sync + 'static>>;
fn poll_ready(
&mut self,
_cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Result<(), Self::Error>> {
unreachable!()
}
fn call(&mut self, _req: Request<B>) -> Self::Future {
unreachable!()
}
}
mod private {
pub trait Sealed {}
}
pub trait AxumOperationHandler<I, O, T, S, B>: Handler<T, S, B> + OperationHandler<I, O>
where
I: OperationInput,
O: OperationOutput,
{
}
impl<H, I, O, T, S, B> AxumOperationHandler<I, O, T, S, B> for H
where
H: Handler<T, S, B> + OperationHandler<I, O>,
I: OperationInput,
O: OperationOutput,
{
}
#[cfg(test)]
#[allow(clippy::unused_async)]
mod tests {
use crate::axum::{routing, ApiRouter};
use axum::extract::State;
async fn test_handler1(State(_): State<TestState>) {}
async fn test_handler2(State(_): State<u8>) {}
#[derive(Clone, Copy)]
struct TestState {
field1: u8,
}
#[test]
fn test_nesting_with_nondefault_state() {
let _app: ApiRouter = ApiRouter::new()
.nest_api_service("/", ApiRouter::new().with_state(1_isize))
.with_state(1_usize);
}
#[test]
fn test_method_router_with_state() {
let app: ApiRouter<TestState> =
ApiRouter::new().api_route("/", routing::get(test_handler1));
let app_with_state: ApiRouter = app.with_state(TestState { field1: 0 });
let _service = app_with_state.into_make_service();
}
#[test]
fn test_router_with_different_states() {
let state = TestState { field1: 0 };
let app: ApiRouter = ApiRouter::new()
.api_route("/test1", routing::get(test_handler1))
.api_route(
"/test2",
routing::get(test_handler2).with_state(state.field1),
)
.with_state(state);
let _service = app.into_make_service();
}
}