use std::sync::Arc;
use bytes::Bytes;
use serde_json::{json, Map, Value};
use tork_core::constants::APPLICATION_JSON;
use tork_core::{
bytes_response, BoxFuture, HandlerFn, Method, OpenApiProvider, RequestBodyKind, RequestContext,
Response, Result, Route, StatusCode,
};
const OPENAPI_VERSION: &str = "3.1.0";
const DEFAULT_JSON_PATH: &str = "/openapi.json";
pub(crate) type DocGuard = Arc<dyn Fn(&RequestContext) -> bool + Send + Sync>;
pub struct OpenApi {
title: String,
version: String,
description: Option<String>,
json_path: String,
docs_path: Option<String>,
guard: Option<DocGuard>,
}
impl Default for OpenApi {
fn default() -> Self {
Self::new()
}
}
impl OpenApi {
pub fn new() -> Self {
Self {
title: "API".to_owned(),
version: "0.1.0".to_owned(),
description: None,
json_path: DEFAULT_JSON_PATH.to_owned(),
docs_path: None,
guard: None,
}
}
pub fn title(mut self, title: impl Into<String>) -> Self {
self.title = title.into();
self
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.version = version.into();
self
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.description = Some(description.into());
self
}
pub fn json(mut self, path: impl Into<String>) -> Self {
self.json_path = path.into();
self
}
pub fn docs(mut self, path: impl Into<String>) -> Self {
self.docs_path = Some(path.into());
self
}
pub fn protect<F>(mut self, predicate: F) -> Self
where
F: Fn(&RequestContext) -> bool + Send + Sync + 'static,
{
self.guard = Some(Arc::new(predicate));
self
}
pub fn build_document(&self, routes: &[Route]) -> Value {
build_document(self, routes)
}
}
impl OpenApiProvider for OpenApi {
fn documentation_routes(&self, registered: &[Route]) -> Vec<Route> {
let document = build_document(self, registered);
let body = serde_json::to_vec(&document).unwrap_or_default();
let mut routes = vec![spec_route(
&self.json_path,
Bytes::from(body),
self.guard.clone(),
)];
if let Some(docs_path) = &self.docs_path {
routes.push(crate::docs::docs_route(
docs_path,
&self.title,
&self.json_path,
self.guard.clone(),
));
}
routes
}
}
pub(crate) fn check_guard(guard: &Option<DocGuard>, ctx: &RequestContext) -> Result<()> {
match guard {
Some(guard) if !guard(ctx) => Err(tork_core::Error::not_found("not found")),
_ => Ok(()),
}
}
fn spec_route(path: &str, body: Bytes, guard: Option<DocGuard>) -> Route {
let handler: HandlerFn = Arc::new(
move |ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
let body = body.clone();
let guard = guard.clone();
Box::pin(async move {
check_guard(&guard, &ctx)?;
Ok(bytes_response(StatusCode::OK, APPLICATION_JSON, body))
})
},
);
Route::new(Method::GET, path.to_owned(), handler).summary("OpenAPI specification")
}
fn build_document(api: &OpenApi, routes: &[Route]) -> Value {
let mut generator = schemars::generate::SchemaSettings::openapi3().into_generator();
let mut paths: Map<String, Value> = Map::new();
for route in routes {
let path = route.path().to_owned();
let method = route.method().as_str().to_lowercase();
let meta = route.meta();
let mut operation = Map::new();
if let Some(summary) = &meta.summary {
operation.insert("summary".to_owned(), json!(sanitize_doc_text(summary)));
}
if let Some(description) = &meta.description {
operation.insert(
"description".to_owned(),
json!(sanitize_doc_text(description)),
);
}
if !meta.tags.is_empty() {
let tags: Vec<String> = meta.tags.iter().map(|tag| sanitize_doc_text(tag)).collect();
operation.insert("tags".to_owned(), json!(tags));
}
operation.insert(
"operationId".to_owned(),
json!(operation_id(&method, &path)),
);
let parameters: Vec<Value> = placeholder_names(&path)
.into_iter()
.map(|name| {
json!({
"name": name,
"in": "path",
"required": true,
"schema": { "type": "string" },
})
})
.collect();
if !parameters.is_empty() {
operation.insert("parameters".to_owned(), json!(parameters));
}
if let Some(request_schema) = meta.request_schema {
let schema = request_schema(&mut generator).as_value().clone();
let media_type = match meta.request_kind {
RequestBodyKind::Json => "application/json",
RequestBodyKind::Form => "application/x-www-form-urlencoded",
RequestBodyKind::Multipart => "multipart/form-data",
};
operation.insert(
"requestBody".to_owned(),
json!({
"required": true,
"content": { media_type: { "schema": schema } },
}),
);
}
let status = meta.status_code.as_u16().to_string();
let mut response = Map::new();
let schema = meta
.response_schema
.map(|thunk| thunk(&mut generator).as_value().clone());
if meta.streaming {
response.insert("description".to_owned(), json!("Server-Sent Events stream"));
if let Some(schema) = schema {
response.insert(
"content".to_owned(),
json!({ "text/event-stream": { "schema": schema } }),
);
}
} else {
let reason = meta.status_code.canonical_reason().unwrap_or("Response");
response.insert("description".to_owned(), json!(reason));
if let Some(schema) = schema {
response.insert(
"content".to_owned(),
json!({ "application/json": { "schema": schema } }),
);
}
}
operation.insert(
"responses".to_owned(),
json!({ status: Value::Object(response) }),
);
let entry = paths
.entry(path)
.or_insert_with(|| Value::Object(Map::new()));
if let Some(object) = entry.as_object_mut() {
object.insert(method, Value::Object(operation));
}
}
let mut info = Map::new();
info.insert("title".to_owned(), json!(sanitize_doc_text(&api.title)));
info.insert("version".to_owned(), json!(api.version));
if let Some(description) = &api.description {
info.insert(
"description".to_owned(),
json!(sanitize_doc_text(description)),
);
}
let mut document = json!({
"openapi": OPENAPI_VERSION,
"info": Value::Object(info),
"paths": Value::Object(paths),
});
let definitions = generator.take_definitions(true);
if !definitions.is_empty() {
document["components"] = json!({ "schemas": Value::Object(definitions) });
}
document
}
pub(crate) fn sanitize_doc_text(value: &str) -> String {
let mut sanitized = String::with_capacity(value.len());
for ch in value.chars() {
match ch {
'&' => sanitized.push_str("&"),
'<' => sanitized.push_str("<"),
'>' => sanitized.push_str(">"),
'"' => sanitized.push_str("""),
'\'' => sanitized.push_str("'"),
'`' => sanitized.push_str("`"),
'\n' | '\r' | '\t' => sanitized.push(ch),
ch if ch.is_control() => sanitized.push(' '),
_ => sanitized.push(ch),
}
}
sanitized
}
fn operation_id(method: &str, path: &str) -> String {
let mut id = String::from(method);
for segment in path.split('/').filter(|segment| !segment.is_empty()) {
id.push('_');
for ch in segment.chars() {
id.push(if ch.is_ascii_alphanumeric() { ch } else { '_' });
}
}
id
}
fn placeholder_names(path: &str) -> Vec<String> {
let mut names = Vec::new();
let bytes = path.as_bytes();
let mut index = 0;
while index < bytes.len() {
if bytes[index] == b'{' {
if let Some(offset) = path[index + 1..].find('}') {
let inner = &path[index + 1..index + 1 + offset];
names.push(inner.trim_start_matches('*').to_owned());
index += offset + 2;
continue;
}
}
index += 1;
}
names
}
#[cfg(test)]
mod tests {
use super::*;
fn dummy_handler() -> HandlerFn {
Arc::new(
|_ctx: RequestContext| -> BoxFuture<'static, Result<Response>> {
Box::pin(async {
Ok(bytes_response(
StatusCode::OK,
APPLICATION_JSON,
Bytes::new(),
))
})
},
)
}
#[test]
fn document_describes_routes() {
let routes = vec![Route::new(Method::GET, "/users/{user_id}", dummy_handler())
.summary("Get user")
.tag("users")];
let document = OpenApi::new()
.title("My API")
.version("1.0.0")
.build_document(&routes);
assert_eq!(document["openapi"], OPENAPI_VERSION);
assert_eq!(document["info"]["title"], "My API");
assert_eq!(document["info"]["version"], "1.0.0");
let operation = &document["paths"]["/users/{user_id}"]["get"];
assert_eq!(operation["summary"], "Get user");
assert_eq!(operation["tags"][0], "users");
assert_eq!(operation["parameters"][0]["name"], "user_id");
assert_eq!(operation["parameters"][0]["in"], "path");
assert!(operation["responses"]["200"].is_object());
}
#[derive(schemars::JsonSchema)]
#[allow(dead_code)]
struct Sample {
id: i64,
label: String,
}
#[derive(schemars::JsonSchema)]
#[allow(dead_code)]
struct Inner {
value: String,
}
#[derive(schemars::JsonSchema)]
#[allow(dead_code)]
struct Outer {
inner: Inner,
}
#[test]
fn nested_models_are_registered_as_components() {
let routes =
vec![Route::new(Method::GET, "/outer", dummy_handler()).response_schema::<Outer>()];
let schemas = &OpenApi::new().build_document(&routes)["components"]["schemas"];
assert!(schemas["Outer"].is_object(), "outer missing: {schemas}");
assert!(
schemas["Inner"].is_object(),
"nested inner missing: {schemas}"
);
}
#[test]
fn document_includes_component_schemas() {
let routes = vec![Route::new(Method::POST, "/samples", dummy_handler())
.request_schema::<Sample>()
.response_schema::<Sample>()];
let document = OpenApi::new().build_document(&routes);
assert!(
document["components"]["schemas"]["Sample"].is_object(),
"document: {document}"
);
let operation = &document["paths"]["/samples"]["post"];
let request_ref =
&operation["requestBody"]["content"]["application/json"]["schema"]["$ref"];
let response_ref =
&operation["responses"]["200"]["content"]["application/json"]["schema"]["$ref"];
assert_eq!(request_ref, "#/components/schemas/Sample");
assert_eq!(response_ref, "#/components/schemas/Sample");
}
#[test]
fn multipart_route_documents_form_data_with_binary_file() {
fn form_schema(_generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
schemars::Schema::try_from(json!({
"type": "object",
"properties": {
"token": { "type": "string" },
"file": { "type": "string", "format": "binary" },
},
"required": ["token", "file"],
}))
.unwrap()
}
let routes = vec![Route::new(Method::POST, "/files", dummy_handler())
.request_schema_fn(form_schema)
.request_kind(RequestBodyKind::Multipart)];
let document = OpenApi::new().build_document(&routes);
let content = &document["paths"]["/files"]["post"]["requestBody"]["content"];
let schema = &content["multipart/form-data"]["schema"];
assert_eq!(schema["properties"]["file"]["format"], "binary");
assert!(
content["application/json"].is_null(),
"multipart body must not be JSON: {content}"
);
}
#[test]
fn urlencoded_route_documents_form_content_type() {
let routes = vec![Route::new(Method::POST, "/login", dummy_handler())
.request_schema::<Sample>()
.request_kind(RequestBodyKind::Form)];
let document = OpenApi::new().build_document(&routes);
let content = &document["paths"]["/login"]["post"]["requestBody"]["content"];
assert!(
content["application/x-www-form-urlencoded"]["schema"].is_object(),
"expected urlencoded body: {content}"
);
assert!(content["application/json"].is_null());
}
#[test]
fn streaming_route_documents_event_stream() {
let routes = vec![Route::new(Method::GET, "/stream", dummy_handler())
.response_schema::<Sample>()
.streaming()];
let document = OpenApi::new().build_document(&routes);
let response = &document["paths"]["/stream"]["get"]["responses"]["200"];
assert_eq!(response["description"], "Server-Sent Events stream");
assert_eq!(
response["content"]["text/event-stream"]["schema"]["$ref"],
"#/components/schemas/Sample"
);
assert!(
response["content"]["application/json"].is_null(),
"streaming response must not be JSON: {response}"
);
}
#[test]
fn provider_registers_spec_and_docs_routes() {
let provider = OpenApi::new()
.title("Docs")
.version("1.2.3")
.json("/schema.json")
.docs("/docs");
let routes = provider.documentation_routes(&[]);
assert_eq!(routes.len(), 2);
assert_eq!(routes[0].path(), "/schema.json");
assert_eq!(routes[1].path(), "/docs");
}
#[test]
fn operation_id_and_placeholder_helpers_cover_edge_cases() {
assert_eq!(operation_id("patch", "/"), "patch");
assert_eq!(
operation_id("get", "/teams/{team-id}/members/{*rest}"),
"get_teams__team_id__members___rest_"
);
assert_eq!(
placeholder_names("/teams/{team_id}/members/{*rest}"),
vec!["team_id".to_owned(), "rest".to_owned()]
);
}
#[test]
fn document_sanitizes_route_and_info_text_fields() {
let routes = vec![Route::new(Method::GET, "/users/{user_id}", dummy_handler())
.summary("<script>alert(1)</script>")
.description("bad\u{0007}`quote`")
.tag("ops<script>")];
let document = OpenApi::new()
.title("Docs <unsafe>")
.description("line\u{0001}two")
.build_document(&routes);
let operation = &document["paths"]["/users/{user_id}"]["get"];
assert_eq!(
operation["summary"],
"<script>alert(1)</script>"
);
assert_eq!(operation["description"], "bad `quote`");
assert_eq!(operation["tags"][0], "ops<script>");
assert_eq!(document["info"]["title"], "Docs <unsafe>");
assert_eq!(document["info"]["description"], "line two");
}
}