use std::fmt::{Debug, Formatter};
use aide::openapi::Operation;
use cot::openapi::RouteContext;
use cot::request::Request;
use cot::response::Response;
use cot::router::method::InnerHandler;
use schemars::SchemaGenerator;
use crate::RequestHandler;
use crate::openapi::{
AsApiOperation, AsApiRoute, BoxApiRequestHandler, into_box_api_request_handler,
};
use crate::router::method::InnerMethodRouter;
#[derive(Debug)]
#[must_use]
pub struct ApiMethodRouter {
inner: InnerMethodRouter<InnerApiHandler>,
}
macro_rules! define_method {
($name:ident => $method:ident) => {
#[doc = concat!("Set a handler for the [`",
stringify!($method),
"`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/",
stringify!($method),
") HTTP method with OpenAPI specification.")]
#[doc = concat!(
"let method_router = ApiMethodRouter::new().",
stringify!($name),
"(test_handler);"
)]
#[doc = concat!(
"# let request = cot::test::TestRequestBuilder::with_method(\"/\", cot::Method::",
stringify!($method),
")"
)]
pub fn $name<HandlerParams, ApiParams, H>(mut self, handler: H) -> Self
where
HandlerParams: 'static,
ApiParams: 'static,
H: RequestHandler<HandlerParams>
+ AsApiOperation<ApiParams>
+ Send
+ Sync
+ 'static,
{
self.inner.$name = Some(InnerApiHandler::new(handler));
self
}
};
}
impl Default for ApiMethodRouter {
fn default() -> Self {
Self::new()
}
}
impl ApiMethodRouter {
pub fn new() -> Self {
Self {
inner: InnerMethodRouter::new(),
}
}
define_method!(get => GET);
define_method!(head => HEAD);
define_method!(delete => DELETE);
define_method!(options => OPTIONS);
define_method!(patch => PATCH);
define_method!(post => POST);
define_method!(put => PUT);
define_method!(trace => TRACE);
pub fn connect<HandlerParams, H>(mut self, handler: H) -> Self
where
HandlerParams: 'static,
H: RequestHandler<HandlerParams> + Send + Sync + 'static,
{
self.inner.connect = Some(InnerHandler::new(handler));
self
}
pub fn fallback<HandlerParams, H>(mut self, handler: H) -> Self
where
HandlerParams: 'static,
H: RequestHandler<HandlerParams> + Send + Sync + 'static,
{
self.inner.fallback = InnerHandler::new(handler);
self
}
}
impl RequestHandler for ApiMethodRouter {
fn handle(&self, request: Request) -> impl Future<Output = cot::Result<Response>> + Send {
self.inner.handle(request)
}
}
impl AsApiRoute for ApiMethodRouter {
fn as_api_route(
&self,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> aide::openapi::PathItem {
macro_rules! add_method {
($path_item:ident, $method_func:ident, $method:ident) => {
if let Some(handler) = &self.inner.$method_func {
let mut route_context = route_context.clone();
route_context.method = Some(cot::Method::$method);
$path_item.$method_func =
handler.as_api_operation(&route_context, schema_generator);
}
};
}
let mut path_item = aide::openapi::PathItem::default();
add_method!(path_item, get, GET);
add_method!(path_item, head, HEAD);
add_method!(path_item, delete, DELETE);
add_method!(path_item, options, OPTIONS);
add_method!(path_item, patch, PATCH);
add_method!(path_item, post, POST);
add_method!(path_item, put, PUT);
add_method!(path_item, trace, TRACE);
path_item
}
}
struct InnerApiHandler(Box<dyn BoxApiRequestHandler + Send + Sync>);
impl InnerApiHandler {
fn new<HandlerParams, ApiParams, H>(handler: H) -> Self
where
HandlerParams: 'static,
ApiParams: 'static,
H: RequestHandler<HandlerParams> + AsApiOperation<ApiParams> + Send + Sync + 'static,
{
Self(Box::new(into_box_api_request_handler(handler)))
}
}
impl Debug for InnerApiHandler {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("InnerApiHandler").finish_non_exhaustive()
}
}
impl RequestHandler for InnerApiHandler {
fn handle(&self, request: Request) -> impl Future<Output = cot::Result<Response>> + Send {
self.0.handle(request)
}
}
impl AsApiOperation for InnerApiHandler {
fn as_api_operation(
&self,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> Option<Operation> {
self.0.as_api_operation(route_context, schema_generator)
}
}
macro_rules! define_method_router {
($func_name:ident, $name:ident => $method:ident) => {
#[doc = concat!(
"Create a new [`ApiMethodRouter`] with a [`",
stringify!($method),
"`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Methods/",
stringify!($method),
") handler."
)]
#[doc = concat!(
"This is a shorthand to call [`ApiMethodRouter::new`] and then [`ApiMethodRouter::",
stringify!($name),
"`]."
)]
#[doc = concat!("use cot::router::method::openapi::", stringify!($func_name), ";")]
#[doc = concat!("let method_router = ", stringify!($func_name), "(test_handler);")]
#[doc = concat!(
"# let request = cot::test::TestRequestBuilder::with_method(\"/\", cot::Method::",
stringify!($method),
")"
)]
pub fn $func_name<HandlerParams, ApiParams, H>(handler: H) -> ApiMethodRouter
where
HandlerParams: 'static,
ApiParams: 'static,
H: RequestHandler<HandlerParams>
+ AsApiOperation<ApiParams>
+ Send
+ Sync
+ 'static,
{
ApiMethodRouter::new().$name(handler)
}
};
}
define_method_router!(api_get, get => GET);
define_method_router!(api_head, head => HEAD);
define_method_router!(api_delete, delete => DELETE);
define_method_router!(api_options, options => OPTIONS);
define_method_router!(api_patch, patch => PATCH);
define_method_router!(api_post, post => POST);
define_method_router!(api_put, put => PUT);
define_method_router!(api_trace, trace => TRACE);
pub fn api_connect<HandlerParams, H>(handler: H) -> ApiMethodRouter
where
HandlerParams: 'static,
H: RequestHandler<HandlerParams> + Send + Sync + 'static,
{
ApiMethodRouter::new().connect(handler)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::MethodNotAllowed;
use crate::html::Html;
use crate::json::Json;
use crate::request::extractors::Path;
use crate::response::{IntoResponse, Response};
use crate::test::TestRequestBuilder;
use crate::{Method, StatusCode};
async fn test_handler(method: Method) -> cot::Result<Response> {
Html::new(method.as_str()).into_response()
}
#[test]
fn inner_api_handler_debug() {
let handler = InnerApiHandler::new(test_handler);
let debug_str = format!("{handler:?}");
assert_eq!(debug_str, "InnerApiHandler(..)");
}
#[cot::test]
async fn api_method_router_fallback() {
let router = ApiMethodRouter::new();
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
}
#[cot::test]
async fn api_method_router_default_fallback() {
let router = ApiMethodRouter::default();
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
}
#[cot::test]
async fn api_method_router_custom_fallback() {
let router = ApiMethodRouter::new().fallback(test_handler);
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(response.into_body().into_bytes().await.unwrap(), "GET");
}
#[cot::test]
async fn api_method_router_router_get() {
let router = api_get(test_handler);
let request = TestRequestBuilder::get("/").build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
let methods = [
Method::DELETE,
Method::OPTIONS,
Method::PATCH,
Method::POST,
Method::PUT,
Method::TRACE,
Method::CONNECT,
];
for method in methods {
let request = TestRequestBuilder::with_method("/", method).build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
}
}
macro_rules! test_api_method_router {
($test_name:ident, $constructor_name:ident, $method_name:ident) => {
#[cot::test]
async fn $test_name() {
let router = $constructor_name(test_handler);
let request = TestRequestBuilder::with_method("/", Method::$method_name).build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
};
}
test_api_method_router!(method_api_router_head, api_head, HEAD);
test_api_method_router!(method_api_router_delete, api_delete, DELETE);
test_api_method_router!(method_api_router_options, api_options, OPTIONS);
test_api_method_router!(method_api_router_patch, api_patch, PATCH);
test_api_method_router!(method_api_router_post, api_post, POST);
test_api_method_router!(method_api_router_put, api_put, PUT);
test_api_method_router!(method_api_router_trace, api_trace, TRACE);
test_api_method_router!(method_api_router_connect, api_connect, CONNECT);
#[cot::test]
async fn api_method_router_default_head() {
let router = ApiMethodRouter::new();
let request = TestRequestBuilder::with_method("/", Method::HEAD).build();
let response = router.handle(request).await.unwrap_err();
let inner = response.inner();
assert_eq!(inner.status_code(), StatusCode::METHOD_NOT_ALLOWED);
assert!(inner.is::<MethodNotAllowed>());
let router = api_get(test_handler);
let request = TestRequestBuilder::with_method("/", Method::HEAD).build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
}
#[cot::test]
async fn api_method_router_multiple() {
let router = ApiMethodRouter::new()
.get(test_handler)
.head(test_handler)
.delete(test_handler)
.options(test_handler)
.patch(test_handler)
.post(test_handler)
.put(test_handler)
.trace(test_handler)
.connect(test_handler);
for (method, expected_string) in [
(Method::GET, "GET"),
(Method::HEAD, "HEAD"),
(Method::DELETE, "DELETE"),
(Method::OPTIONS, "OPTIONS"),
(Method::PATCH, "PATCH"),
(Method::POST, "POST"),
(Method::PUT, "PUT"),
(Method::TRACE, "TRACE"),
(Method::CONNECT, "CONNECT"),
] {
let request = TestRequestBuilder::with_method("/", method).build();
let response = router.handle(request).await.unwrap();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(
response.into_body().into_bytes().await.unwrap(),
expected_string
);
}
}
async fn test_handler_with_params(
Path(_): Path<i32>,
Json(_): Json<String>,
) -> cot::Result<Response> {
Html::new("").into_response()
}
#[test]
fn openapi_empty() {
let router = ApiMethodRouter::new();
let route_context = RouteContext::new();
let endpoint = router.as_api_route(&route_context, &mut SchemaGenerator::default());
assert!(endpoint.get.is_none());
assert!(endpoint.head.is_none());
assert!(endpoint.delete.is_none());
assert!(endpoint.options.is_none());
assert!(endpoint.patch.is_none());
assert!(endpoint.post.is_none());
assert!(endpoint.put.is_none());
assert!(endpoint.trace.is_none());
}
#[test]
fn openapi_post() {
let router = api_post(test_handler_with_params);
let mut route_context = RouteContext::new();
route_context.param_names = &["123"];
let endpoint = router.as_api_route(&route_context, &mut SchemaGenerator::default());
assert!(endpoint.post.is_some());
let operation = endpoint.post.unwrap();
assert_eq!(operation.parameters.len(), 1);
assert!(operation.request_body.is_some());
}
#[test]
fn openapi_multiple() {
let router = api_post(test_handler_with_params).put(test_handler_with_params);
let mut route_context = RouteContext::new();
route_context.param_names = &["123"];
let endpoint = router.as_api_route(&route_context, &mut SchemaGenerator::default());
assert!(endpoint.post.is_some());
let post_operation = endpoint.post.unwrap();
assert_eq!(post_operation.parameters.len(), 1);
assert!(post_operation.request_body.is_some());
assert!(endpoint.put.is_some());
let put = endpoint.put.unwrap();
assert_eq!(put.parameters.len(), 1);
assert!(put.request_body.is_some());
assert!(endpoint.get.is_none());
assert!(endpoint.head.is_none());
assert!(endpoint.delete.is_none());
assert!(endpoint.options.is_none());
assert!(endpoint.patch.is_none());
assert!(endpoint.trace.is_none());
}
}