#[cfg(feature = "swagger-ui")]
pub mod swagger_ui;
use std::marker::PhantomData;
use std::pin::Pin;
use aide::openapi::{
MediaType, Operation, Parameter, ParameterData, ParameterSchemaOrContent, PathItem, PathStyle,
QueryStyle, ReferenceOr, RequestBody, StatusCode,
};
use http::request::Parts;
use indexmap::IndexMap;
use schemars::schema::{InstanceType, Schema, SchemaObject, SingleOrVec};
use schemars::{JsonSchema, SchemaGenerator};
use crate::auth::Auth;
use crate::form::Form;
use crate::handler::BoxRequestHandler;
use crate::json::Json;
use crate::request::Request;
use crate::request::extractors::{FromRequest, FromRequestParts, Path, RequestForm, UrlQuery};
use crate::response::{Response, WithExtension};
use crate::router::Urls;
use crate::session::Session;
use crate::{Method, RequestHandler};
#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct RouteContext<'a> {
pub method: Option<Method>,
pub param_names: &'a [&'a str],
}
impl RouteContext<'_> {
#[must_use]
pub fn new() -> Self {
Self {
method: None,
param_names: &[],
}
}
}
impl Default for RouteContext<'_> {
fn default() -> Self {
Self::new()
}
}
pub trait AsApiRoute {
fn as_api_route(
&self,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> PathItem;
}
pub trait AsApiOperation<T = ()> {
fn as_api_operation(
&self,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> Option<Operation>;
}
pub(crate) trait BoxApiRequestHandler: BoxRequestHandler + AsApiOperation {}
pub(crate) fn into_box_api_request_handler<HandlerParams, ApiParams, H>(
handler: H,
) -> impl BoxApiRequestHandler
where
H: RequestHandler<HandlerParams> + AsApiOperation<ApiParams> + Send + Sync,
{
struct Inner<HandlerParams, ApiParams, H>(
H,
PhantomData<fn() -> HandlerParams>,
PhantomData<fn() -> ApiParams>,
);
impl<HandlerParams, ApiParams, H> BoxRequestHandler for Inner<HandlerParams, ApiParams, H>
where
H: RequestHandler<HandlerParams> + AsApiOperation<ApiParams> + Send + Sync,
{
fn handle(
&self,
request: Request,
) -> Pin<Box<dyn Future<Output = cot::Result<Response>> + Send + '_>> {
Box::pin(self.0.handle(request))
}
}
impl<HandlerParams, ApiParams, H> AsApiOperation for Inner<HandlerParams, ApiParams, H>
where
H: RequestHandler<HandlerParams> + AsApiOperation<ApiParams> + Send + Sync,
{
fn as_api_operation(
&self,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> Option<Operation> {
self.0.as_api_operation(route_context, schema_generator)
}
}
impl<HandlerParams, ApiParams, H> BoxApiRequestHandler for Inner<HandlerParams, ApiParams, H> where
H: RequestHandler<HandlerParams> + AsApiOperation<ApiParams> + Send + Sync
{
}
Inner(handler, PhantomData, PhantomData)
}
pub(crate) trait BoxApiEndpointRequestHandler: BoxRequestHandler + AsApiRoute {
fn as_box_request_handler(&self) -> &(dyn BoxRequestHandler + Send + Sync);
}
pub(crate) fn into_box_api_endpoint_request_handler<HandlerParams, H>(
handler: H,
) -> impl BoxApiEndpointRequestHandler
where
H: RequestHandler<HandlerParams> + AsApiRoute + Send + Sync,
{
struct Inner<HandlerParams, H>(H, PhantomData<fn() -> HandlerParams>);
impl<HandlerParams, H> BoxRequestHandler for Inner<HandlerParams, H>
where
H: RequestHandler<HandlerParams> + AsApiRoute + Send + Sync,
{
fn handle(
&self,
request: Request,
) -> Pin<Box<dyn Future<Output = cot::Result<Response>> + Send + '_>> {
Box::pin(self.0.handle(request))
}
}
impl<HandlerParams, H> AsApiRoute for Inner<HandlerParams, H>
where
H: RequestHandler<HandlerParams> + AsApiRoute + Send + Sync,
{
fn as_api_route(
&self,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> PathItem {
self.0.as_api_route(route_context, schema_generator)
}
}
impl<HandlerParams, H> BoxApiEndpointRequestHandler for Inner<HandlerParams, H>
where
H: RequestHandler<HandlerParams> + AsApiRoute + Send + Sync,
{
fn as_box_request_handler(&self) -> &(dyn BoxRequestHandler + Send + Sync) {
self
}
}
Inner(handler, PhantomData)
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct NoApi<T>(pub T);
impl<HandlerParams, H> RequestHandler<HandlerParams> for NoApi<H>
where
H: RequestHandler<HandlerParams>,
{
fn handle(&self, request: Request) -> impl Future<Output = cot::Result<Response>> + Send {
self.0.handle(request)
}
}
impl<T: FromRequest> FromRequest for NoApi<T> {
async fn from_request(request: Request) -> cot::Result<Self> {
T::from_request(request).await.map(Self)
}
}
impl<T: FromRequestParts> FromRequestParts for NoApi<T> {
async fn from_request_parts(parts: &mut Parts) -> cot::Result<Self> {
T::from_request_parts(parts).await.map(Self)
}
}
impl<T> ApiOperationPart for NoApi<T> {}
impl<T> AsApiOperation for NoApi<T> {
fn as_api_operation(
&self,
_route_context: &RouteContext<'_>,
_schema_generator: &mut SchemaGenerator,
) -> Option<Operation> {
None
}
}
macro_rules! impl_as_openapi_operation {
($($ty:ident),*) => {
impl<T, $($ty,)* R, Response> AsApiOperation<($($ty,)*)> for T
where
T: Fn($($ty,)*) -> R + Clone + Send + Sync + 'static,
$($ty: ApiOperationPart,)*
R: for<'a> Future<Output = Response> + Send,
Response: ApiOperationResponse,
{
#[allow(non_snake_case)]
fn as_api_operation(
&self,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> Option<Operation> {
let mut operation = Operation::default();
$(
$ty::modify_api_operation(
&mut operation,
&route_context,
schema_generator
);
)*
let responses = Response::api_operation_responses(
&mut operation,
&route_context,
schema_generator
);
let operation_responses = operation.responses.get_or_insert_default();
for (response_code, response) in responses {
if let Some(response_code) = response_code {
operation_responses.responses.insert(
response_code,
ReferenceOr::Item(response),
);
} else {
operation_responses.default = Some(ReferenceOr::Item(response));
}
}
Some(operation)
}
}
};
}
handle_all_parameters!(impl_as_openapi_operation);
pub trait ApiOperationPart {
#[expect(unused)]
fn modify_api_operation(
operation: &mut Operation,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) {
}
}
pub trait ApiOperationResponse {
#[expect(unused)]
fn api_operation_responses(
operation: &mut Operation,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> Vec<(Option<StatusCode>, aide::openapi::Response)> {
Vec::new()
}
}
impl ApiOperationPart for Request {}
impl ApiOperationPart for Urls {}
impl ApiOperationPart for Method {}
impl ApiOperationPart for Session {}
impl ApiOperationPart for Auth {}
#[cfg(feature = "db")]
impl ApiOperationPart for crate::request::extractors::RequestDb {}
impl<D: JsonSchema> ApiOperationPart for Json<D> {
fn modify_api_operation(
operation: &mut Operation,
_route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) {
operation.request_body = Some(ReferenceOr::Item(RequestBody {
content: IndexMap::from([(
crate::headers::JSON_CONTENT_TYPE.to_string(),
MediaType {
schema: Some(aide::openapi::SchemaObject {
json_schema: D::json_schema(schema_generator),
external_docs: None,
example: None,
}),
..Default::default()
},
)]),
required: true,
..Default::default()
}));
}
}
impl<D: JsonSchema> ApiOperationPart for Path<D> {
#[track_caller]
fn modify_api_operation(
operation: &mut Operation,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) {
let schema = D::json_schema(schema_generator).into_object();
if let Some(array) = schema.array {
if let Some(items) = array.items {
match items {
SingleOrVec::Single(_) => {}
SingleOrVec::Vec(item_list) => {
assert_eq!(
route_context.param_names.len(),
item_list.len(),
"the number of path parameters in the route URL must match \
the number of params in the Path type (found path params: {:?})",
route_context.param_names,
);
for (¶m_name, item) in
route_context.param_names.iter().zip(item_list.into_iter())
{
let schema = item.into_object();
add_path_param(operation, schema, param_name.to_owned());
}
}
}
}
} else if let Some(object) = schema.object {
let mut route_context_sorted = route_context.param_names.to_vec();
route_context_sorted.sort_unstable();
let mut object_props_sorted = object.properties.keys().collect::<Vec<_>>();
object_props_sorted.sort();
assert_eq!(
route_context_sorted, object_props_sorted,
"Path parameters in the route info must exactly match parameters \
in the Path type. Make sure that the type you pass to Path contains \
all the parameters for the route, and that the names match exactly."
);
for (key, item) in object.properties {
let object_item = item.into_object();
add_path_param(operation, object_item, key);
}
} else if schema.instance_type.is_some() {
assert_eq!(
route_context.param_names.len(),
1,
"the number of path parameters in the route URL must equal \
to 1 if a single parameter was passed to the Path type (found path params: {:?})",
route_context.param_names,
);
add_path_param(operation, schema, route_context.param_names[0].to_owned());
}
}
}
impl<D: JsonSchema> ApiOperationPart for UrlQuery<D> {
fn modify_api_operation(
operation: &mut Operation,
_route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) {
let schema = D::json_schema(schema_generator).into_object();
if let Some(object) = schema.object {
for (key, item) in object.properties {
let object_item = item.into_object();
add_query_param(operation, object_item, key);
}
}
}
}
impl<F: Form + JsonSchema> ApiOperationPart for RequestForm<F> {
fn modify_api_operation(
operation: &mut Operation,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) {
if route_context.method == Some(Method::GET) || route_context.method == Some(Method::HEAD) {
let schema = F::json_schema(schema_generator).into_object();
if let Some(object) = schema.object {
for (key, item) in object.properties {
let object_item = item.into_object();
add_query_param(operation, object_item, key);
}
}
} else {
operation.request_body = Some(ReferenceOr::Item(RequestBody {
content: IndexMap::from([(
crate::headers::FORM_CONTENT_TYPE.to_string(),
MediaType {
schema: Some(aide::openapi::SchemaObject {
json_schema: F::json_schema(schema_generator),
external_docs: None,
example: None,
}),
..Default::default()
},
)]),
required: true,
..Default::default()
}));
}
}
}
fn add_path_param(operation: &mut Operation, mut schema: SchemaObject, param_name: String) {
let required = extract_is_required(&mut schema);
operation
.parameters
.push(ReferenceOr::Item(Parameter::Path {
parameter_data: param_with_name(param_name, schema, required),
style: PathStyle::default(),
}));
}
fn add_query_param(operation: &mut Operation, mut schema: SchemaObject, param_name: String) {
let required = extract_is_required(&mut schema);
operation
.parameters
.push(ReferenceOr::Item(Parameter::Query {
parameter_data: param_with_name(param_name, schema, required),
allow_reserved: false,
style: QueryStyle::default(),
allow_empty_value: None,
}));
}
fn extract_is_required(object_item: &mut SchemaObject) -> bool {
match &mut object_item.instance_type {
Some(SingleOrVec::Vec(type_list)) => {
let nullable = type_list.contains(&InstanceType::Null);
type_list.retain(|&element| element != InstanceType::Null);
!nullable
}
Some(SingleOrVec::Single(_)) | None => true,
}
}
fn param_with_name(
param_name: String,
schema_object: SchemaObject,
required: bool,
) -> ParameterData {
ParameterData {
name: param_name,
description: None,
required,
deprecated: None,
format: ParameterSchemaOrContent::Schema(aide::openapi::SchemaObject {
json_schema: Schema::Object(schema_object),
external_docs: None,
example: None,
}),
example: None,
examples: IndexMap::default(),
explode: None,
extensions: IndexMap::default(),
}
}
impl<S: JsonSchema> ApiOperationResponse for Json<S> {
fn api_operation_responses(
_operation: &mut Operation,
_route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> Vec<(Option<StatusCode>, aide::openapi::Response)> {
vec![(
Some(StatusCode::Code(http::StatusCode::OK.as_u16())),
aide::openapi::Response {
description: "OK".to_string(),
content: IndexMap::from([(
crate::headers::JSON_CONTENT_TYPE.to_string(),
MediaType {
schema: Some(aide::openapi::SchemaObject {
json_schema: S::json_schema(schema_generator),
external_docs: None,
example: None,
}),
..Default::default()
},
)]),
..Default::default()
},
)]
}
}
impl<T: ApiOperationResponse, D> ApiOperationResponse for WithExtension<T, D> {
fn api_operation_responses(
operation: &mut Operation,
route_context: &RouteContext<'_>,
schema_generator: &mut SchemaGenerator,
) -> Vec<(Option<StatusCode>, aide::openapi::Response)> {
T::api_operation_responses(operation, route_context, schema_generator)
}
}
impl ApiOperationResponse for crate::Result<Response> {
fn api_operation_responses(
_operation: &mut Operation,
_route_context: &RouteContext<'_>,
_schema_generator: &mut SchemaGenerator,
) -> Vec<(Option<StatusCode>, aide::openapi::Response)> {
vec![(
None,
aide::openapi::Response {
description: "*<unspecified>*".to_string(),
..Default::default()
},
)]
}
}
#[cfg(test)]
mod tests {
use aide::openapi::{Operation, Parameter};
use schemars::SchemaGenerator;
use schemars::schema::Schema;
use serde::{Deserialize, Serialize};
use super::*;
use crate::html::Html;
use crate::json::Json;
use crate::openapi::AsApiOperation;
use crate::request::extractors::{Path, UrlQuery};
#[derive(Deserialize, Serialize, schemars::JsonSchema)]
struct TestRequest {
field1: String,
field2: i32,
optional_field: Option<bool>,
}
#[derive(Form, schemars::JsonSchema)]
struct TestForm {
field1: String,
field2: i32,
optional_field: Option<bool>,
}
#[derive(schemars::JsonSchema)]
#[expect(dead_code)] struct TestPath {
field1: String,
field2: i32,
}
async fn test_handler() -> Html {
Html::new("test")
}
#[test]
fn route_context() {
let context = RouteContext::default();
assert!(context.method.is_none());
assert!(context.param_names.is_empty());
let context = RouteContext::new();
assert!(context.method.is_none());
assert!(context.param_names.is_empty());
}
#[test]
fn no_api_handler() {
let handler = NoApi(test_handler);
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
let operation = handler.as_api_operation(&route_context, &mut schema_generator);
assert!(operation.is_none());
}
#[test]
fn no_api_param() {
let mut operation = Operation::default();
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
NoApi::<()>::modify_api_operation(&mut operation, &route_context, &mut schema_generator);
assert_eq!(operation, Operation::default());
}
#[test]
fn api_operation_part_for_json() {
let mut operation = Operation::default();
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
Json::<TestRequest>::modify_api_operation(
&mut operation,
&route_context,
&mut schema_generator,
);
if let Some(ReferenceOr::Item(request_body)) = &operation.request_body {
let content = &request_body.content;
assert!(content.contains_key("application/json"));
let content_json = content.get("application/json").unwrap();
let schema_obj = &content_json.schema.clone().unwrap().json_schema;
if let Schema::Object(obj) = schema_obj {
if let Some(properties) = &obj.object.as_ref().map(|o| &o.properties) {
assert!(properties.contains_key("field1"));
assert!(properties.contains_key("field2"));
assert!(properties.contains_key("optional_field"));
} else {
panic!("Expected properties in schema");
}
} else {
panic!("Expected object schema");
}
} else {
panic!("Expected request body: {:?}", &operation.request_body);
}
}
#[test]
fn api_operation_part_for_form_get() {
let mut operation = Operation::default();
let mut route_context = RouteContext::new();
route_context.method = Some(Method::GET);
let mut schema_generator = SchemaGenerator::default();
RequestForm::<TestForm>::modify_api_operation(
&mut operation,
&route_context,
&mut schema_generator,
);
assert_eq!(operation.parameters.len(), 3);
for param in &operation.parameters {
match param {
ReferenceOr::Item(Parameter::Query { parameter_data, .. }) => {
assert!(
parameter_data.name == "field1"
|| parameter_data.name == "field2"
|| parameter_data.name == "optional_field"
);
if parameter_data.name == "optional_field" {
assert!(!parameter_data.required);
} else {
assert!(parameter_data.required);
}
}
_ => panic!("Expected query parameter"),
}
}
}
#[test]
fn api_operation_part_for_form_post() {
let mut operation = Operation::default();
let mut route_context = RouteContext::new();
route_context.method = Some(Method::POST);
let mut schema_generator = SchemaGenerator::default();
RequestForm::<TestForm>::modify_api_operation(
&mut operation,
&route_context,
&mut schema_generator,
);
if let Some(ReferenceOr::Item(request_body)) = &operation.request_body {
let content = &request_body.content;
assert!(content.contains_key("application/x-www-form-urlencoded"));
let content_json = content.get("application/x-www-form-urlencoded").unwrap();
let schema_obj = &content_json.schema.clone().unwrap().json_schema;
if let Schema::Object(obj) = schema_obj {
if let Some(properties) = &obj.object.as_ref().map(|o| &o.properties) {
assert!(properties.contains_key("field1"));
assert!(properties.contains_key("field2"));
assert!(properties.contains_key("optional_field"));
} else {
panic!("Expected properties in schema");
}
} else {
panic!("Expected object schema");
}
} else {
panic!("Expected request body: {:?}", &operation.request_body);
}
}
#[test]
fn api_operation_part_for_path_single() {
let mut operation = Operation::default();
let mut route_context = RouteContext::new();
route_context.param_names = &["id"];
let mut schema_generator = SchemaGenerator::default();
Path::<i32>::modify_api_operation(&mut operation, &route_context, &mut schema_generator);
assert_eq!(operation.parameters.len(), 1);
if let ReferenceOr::Item(Parameter::Path { parameter_data, .. }) = &operation.parameters[0]
{
assert_eq!(parameter_data.name, "id");
assert!(parameter_data.required);
} else {
panic!("Expected path parameter");
}
}
#[test]
fn api_operation_part_for_path_tuple() {
let mut operation = Operation::default();
let mut route_context = RouteContext::new();
route_context.param_names = &["id", "id2"];
let mut schema_generator = SchemaGenerator::default();
Path::<(i32, i32)>::modify_api_operation(
&mut operation,
&route_context,
&mut schema_generator,
);
assert_eq!(operation.parameters.len(), 2);
if let ReferenceOr::Item(Parameter::Path { parameter_data, .. }) = &operation.parameters[0]
{
assert_eq!(parameter_data.name, "id");
assert!(parameter_data.required);
} else {
panic!("Expected path parameter");
}
if let ReferenceOr::Item(Parameter::Path { parameter_data, .. }) = &operation.parameters[1]
{
assert_eq!(parameter_data.name, "id2");
assert!(parameter_data.required);
} else {
panic!("Expected path parameter");
}
}
#[test]
fn api_operation_part_for_path_object() {
let mut operation = Operation::default();
let mut route_context = RouteContext::new();
route_context.param_names = &["field1", "field2"];
let mut schema_generator = SchemaGenerator::default();
Path::<TestPath>::modify_api_operation(
&mut operation,
&route_context,
&mut schema_generator,
);
assert_eq!(operation.parameters.len(), 2);
if let ReferenceOr::Item(Parameter::Path { parameter_data, .. }) = &operation.parameters[0]
{
assert_eq!(parameter_data.name, "field1");
assert!(parameter_data.required);
} else {
panic!("Expected path parameter");
}
if let ReferenceOr::Item(Parameter::Path { parameter_data, .. }) = &operation.parameters[1]
{
assert_eq!(parameter_data.name, "field2");
assert!(parameter_data.required);
} else {
panic!("Expected path parameter");
}
}
#[test]
#[should_panic(
expected = "Path parameters in the route info must exactly match parameters in the Path"
)]
fn api_operation_part_for_path_object_invalid_route_info() {
let mut operation = Operation::default();
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
Path::<TestPath>::modify_api_operation(
&mut operation,
&route_context,
&mut schema_generator,
);
}
#[test]
fn api_operation_part_for_query() {
let mut operation = Operation::default();
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
UrlQuery::<TestRequest>::modify_api_operation(
&mut operation,
&route_context,
&mut schema_generator,
);
assert_eq!(operation.parameters.len(), 3);
for param in &operation.parameters {
match param {
ReferenceOr::Item(Parameter::Query { parameter_data, .. }) => {
assert!(
parameter_data.name == "field1"
|| parameter_data.name == "field2"
|| parameter_data.name == "optional_field"
);
if parameter_data.name == "optional_field" {
assert!(!parameter_data.required);
} else {
assert!(parameter_data.required);
}
}
_ => panic!("Expected query parameter"),
}
}
}
#[test]
fn api_operation_response_for_json() {
let mut operation = Operation::default();
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
let responses = Json::<TestRequest>::api_operation_responses(
&mut operation,
&route_context,
&mut schema_generator,
);
assert_eq!(responses.len(), 1);
let (status_code, response) = &responses[0];
assert_eq!(status_code, &Some(StatusCode::Code(200)));
assert_eq!(response.description, "OK");
assert!(response.content.contains_key("application/json"));
let content = response.content.get("application/json").unwrap();
assert!(content.schema.is_some());
let schema = &content.schema.as_ref().unwrap().json_schema;
if let Schema::Object(obj) = schema {
if let Some(object_schema) = &obj.object {
assert!(object_schema.properties.contains_key("field1"));
assert!(object_schema.properties.contains_key("field2"));
assert!(object_schema.properties.contains_key("optional_field"));
} else {
panic!("Expected object schema");
}
} else {
panic!("Expected schema object");
}
}
#[test]
fn api_operation_response_for_with_extension() {
let mut operation = Operation::default();
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
let responses = WithExtension::<Json<TestRequest>, ()>::api_operation_responses(
&mut operation,
&route_context,
&mut schema_generator,
);
assert_eq!(responses.len(), 1);
let (status_code, _) = &responses[0];
assert_eq!(status_code, &Some(StatusCode::Code(200)));
}
#[test]
fn api_operation_response_for_result() {
let mut operation = Operation::default();
let route_context = RouteContext::new();
let mut schema_generator = SchemaGenerator::default();
let responses = <crate::Result<Response>>::api_operation_responses(
&mut operation,
&route_context,
&mut schema_generator,
);
assert_eq!(responses.len(), 1);
let (status_code, response) = &responses[0];
assert_eq!(status_code, &None); assert_eq!(response.description, "*<unspecified>*");
assert!(response.content.is_empty());
}
}