use http::{header, Method, StatusCode};
use rustauth_core::api::{
api_error, build_openapi_schema, core_auth_async_endpoints, create_auth_endpoint, ApiErrorCode,
ApiResponse, AsyncAuthEndpoint, AuthEndpointOptions, OpenApiOperation,
};
use rustauth_core::context::AuthContext;
use rustauth_core::error::RustAuthError;
use rustauth_core::plugin::AuthPlugin;
use serde::{Deserialize, Serialize};
use serde_json::json;
pub const UPSTREAM_PLUGIN_ID: &str = "open-api";
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct OpenApiOptions {
pub path: String,
pub disable_default_reference: bool,
pub theme: String,
pub nonce: Option<String>,
}
impl Default for OpenApiOptions {
fn default() -> Self {
Self {
path: "/reference".to_owned(),
disable_default_reference: false,
theme: "default".to_owned(),
nonce: None,
}
}
}
impl OpenApiOptions {
#[must_use]
pub fn builder() -> OpenApiOptionsBuilder {
OpenApiOptionsBuilder::default()
}
#[must_use]
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = normalize_path(path.into());
self
}
#[must_use]
pub fn disable_default_reference(mut self, disabled: bool) -> Self {
self.disable_default_reference = disabled;
self
}
#[must_use]
pub fn theme(mut self, theme: impl Into<String>) -> Self {
self.theme = theme.into();
self
}
#[must_use]
pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
self.nonce = Some(nonce.into());
self
}
}
#[derive(Debug, Clone, Default)]
pub struct OpenApiOptionsBuilder {
path: Option<String>,
disable_default_reference: Option<bool>,
theme: Option<String>,
nonce: Option<Option<String>>,
}
impl OpenApiOptionsBuilder {
#[must_use]
pub fn path(mut self, path: impl Into<String>) -> Self {
self.path = Some(normalize_path(path.into()));
self
}
#[must_use]
pub fn disable_default_reference(mut self, disabled: bool) -> Self {
self.disable_default_reference = Some(disabled);
self
}
#[must_use]
pub fn theme(mut self, theme: impl Into<String>) -> Self {
self.theme = Some(theme.into());
self
}
#[must_use]
pub fn nonce(mut self, nonce: impl Into<String>) -> Self {
self.nonce = Some(Some(nonce.into()));
self
}
#[must_use]
pub fn build(self) -> OpenApiOptions {
let defaults = OpenApiOptions::default();
OpenApiOptions {
path: self.path.unwrap_or(defaults.path),
disable_default_reference: self
.disable_default_reference
.unwrap_or(defaults.disable_default_reference),
theme: self.theme.unwrap_or(defaults.theme),
nonce: self.nonce.unwrap_or(defaults.nonce),
}
}
}
#[must_use]
pub fn open_api(options: OpenApiOptions) -> AuthPlugin {
AuthPlugin::new(UPSTREAM_PLUGIN_ID)
.with_version(crate::VERSION)
.with_options(serde_json::to_value(&options).unwrap_or(serde_json::Value::Null))
.with_endpoint(generate_schema_endpoint())
.with_endpoint(reference_endpoint(options))
}
fn generate_schema_endpoint() -> AsyncAuthEndpoint {
create_auth_endpoint(
"/open-api/generate-schema",
Method::GET,
AuthEndpointOptions::new()
.operation_id("generateOpenAPISchema")
.openapi(
OpenApiOperation::new("generateOpenAPISchema")
.description("Generate the OpenAPI schema for this RustAuth instance")
.response(
"200",
json!({
"description": "OpenAPI schema",
"content": {
"application/json": {
"schema": {
"type": "object"
}
}
}
}),
),
),
|context, _request| async move {
json_response(
StatusCode::OK,
serde_json::to_vec(&schema_for_context(&context))
.map_err(|error| RustAuthError::Api(error.to_string()))?,
)
},
)
}
fn reference_endpoint(options: OpenApiOptions) -> AsyncAuthEndpoint {
let path = options.path.clone();
create_auth_endpoint(
path,
Method::GET,
AuthEndpointOptions::new()
.operation_id("openApiReference")
.hide_from_openapi()
.openapi(
OpenApiOperation::new("openApiReference")
.summary("OpenAPI reference")
.description("Serve the interactive OpenAPI reference"),
),
move |context, _request| {
let options = options.clone();
async move {
if options.disable_default_reference {
return api_error(StatusCode::NOT_FOUND, ApiErrorCode::NotFound);
}
html_response(get_html(
&schema_for_context(&context),
&options.theme,
options.nonce.as_deref(),
))
}
},
)
}
fn schema_for_context(context: &AuthContext) -> serde_json::Value {
let mut endpoints = core_auth_async_endpoints();
for plugin in &context.plugins {
endpoints.extend(plugin.endpoints.iter().cloned());
}
build_openapi_schema(context, &endpoints)
}
fn get_html(api_reference: &serde_json::Value, theme: &str, nonce: Option<&str>) -> String {
let nonce_attr = nonce
.map(|nonce| format!(" nonce=\"{}\"", escape_html_attr(nonce)))
.unwrap_or_default();
let api_reference = escape_script_json(&api_reference.to_string());
let theme = escape_js_string(theme);
format!(
r#"<!doctype html>
<html>
<head>
<title>RustAuth API Reference</title>
<meta charset="utf-8" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
</head>
<body>
<script id="api-reference" type="application/json">{api_reference}</script>
<script{nonce_attr}>
var configuration = {{
theme: "{theme}",
metaData: {{
title: "RustAuth API",
description: "API Reference for your RustAuth instance"
}}
}}
document.getElementById("api-reference").dataset.configuration =
JSON.stringify(configuration)
</script>
<script src="https://cdn.jsdelivr.net/npm/@scalar/api-reference"{nonce_attr}></script>
</body>
</html>"#,
api_reference = api_reference,
theme = theme,
nonce_attr = nonce_attr,
)
}
fn json_response(status: StatusCode, body: Vec<u8>) -> Result<ApiResponse, RustAuthError> {
http::Response::builder()
.status(status)
.header(header::CONTENT_TYPE, "application/json")
.body(body)
.map_err(|error| RustAuthError::Api(error.to_string()))
}
fn html_response(body: String) -> Result<ApiResponse, RustAuthError> {
http::Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/html; charset=utf-8")
.body(body.into_bytes())
.map_err(|error| RustAuthError::Api(error.to_string()))
}
fn normalize_path(path: String) -> String {
if path.starts_with('/') {
path
} else {
format!("/{path}")
}
}
fn escape_html_attr(value: &str) -> String {
value
.replace('&', "&")
.replace('"', """)
.replace('<', "<")
.replace('>', ">")
}
fn escape_js_string(value: &str) -> String {
let mut escaped = String::with_capacity(value.len());
for character in value.chars() {
match character {
'\\' => escaped.push_str("\\\\"),
'"' => escaped.push_str("\\\""),
'&' => escaped.push_str("\\u0026"),
'<' => escaped.push_str("\\u003c"),
'>' => escaped.push_str("\\u003e"),
'\u{2028}' => escaped.push_str("\\u2028"),
'\u{2029}' => escaped.push_str("\\u2029"),
character => escaped.push(character),
}
}
escaped
}
fn escape_script_json(value: &str) -> String {
value
.replace('&', "\\u0026")
.replace('<', "\\u003c")
.replace('>', "\\u003e")
.replace('\u{2028}', "\\u2028")
.replace('\u{2029}', "\\u2029")
}