use std::collections::BTreeMap;
use bytes::Bytes;
use uuid::Uuid;
use crate::apidocs::{ApiDocGenerator, ApiMeta, DocViewer};
use crate::auth::AuthUser;
use crate::callables::Operation;
use crate::routes::AxumRouter;
use crate::Site;
use super::{Bundle, BundleError};
#[derive(Clone, Debug)]
pub struct OpenApiConf {
pub doc_path: Option<String>,
pub spec_path: String,
pub meta: ApiMeta,
pub viewer: DocViewer,
pub auth: Option<fn(&AuthUser) -> bool>,
}
impl Default for OpenApiConf {
fn default() -> Self {
Self {
spec_path: "/openapi.json".to_string(),
doc_path: None,
meta: ApiMeta::default(),
viewer: DocViewer::Swagger,
auth: None,
}
}
}
impl OpenApiConf {
pub fn spec(mut self, path: impl Into<String>) -> Self {
self.spec_path = path.into();
self
}
pub fn doc(mut self, path: impl Into<String>) -> Self {
self.doc_path = Some(path.into());
self
}
pub fn viewer(mut self, viewer: DocViewer) -> Self {
self.viewer = viewer;
self
}
pub fn title(mut self, title: impl Into<String>) -> Self {
self.meta.title = title.into();
self
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.meta.version = version.into();
self
}
pub fn description(mut self, description: impl Into<String>) -> Self {
self.meta.description = Some(description.into());
self
}
pub fn tags(mut self, tags: Vec<crate::apidocs::TagInfo>) -> Self {
self.meta.tags = tags;
self
}
pub fn auth(mut self, pred: fn(&AuthUser) -> bool) -> Self {
self.auth = Some(pred);
self
}
}
pub(super) struct DocNode {
spec_op_id: Uuid,
doc_op_id: Option<Uuid>,
operation_ids: Vec<Uuid>,
meta: ApiMeta,
viewer: DocViewer,
auth: Option<fn(&AuthUser) -> bool>,
}
pub(crate) struct DocEngine {
nodes: Vec<DocNode>,
}
impl DocEngine {
pub(super) fn new() -> Self {
Self { nodes: Vec::new() }
}
pub(super) fn register(&mut self, node: DocNode) {
self.nodes.push(node);
}
pub(crate) fn merge(&mut self, other: DocEngine) {
self.nodes.extend(other.nodes);
}
pub(crate) fn setup(
&self,
router: &mut AxumRouter<Site>,
ops: &BTreeMap<Uuid, Operation>,
) -> Result<(), BundleError> {
for node in &self.nodes {
let spec_path = ops
.get(&node.spec_op_id)
.map(|op| op.path.clone())
.unwrap_or_else(|| node.spec_op_id.to_string());
let views: Vec<&Operation> = node
.operation_ids
.iter()
.filter_map(|id| ops.get(id))
.collect();
let spec_bytes = generate_spec(&views, &node.meta)?;
let auth_pred = node.auth;
let spec_route = {
let b = spec_bytes;
axum::routing::get(move |axum::extract::State(site): axum::extract::State<Site>, req: axum::extract::Request| {
let body = b.clone();
async move {
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
if let Some(pred) = auth_pred {
let (parts, _) = req.into_parts();
match site.authenticator().extract_user(&parts, &[], false) {
Err(e) => return e.into_response(),
Ok(user) if !pred(&user) => {
return StatusCode::FORBIDDEN.into_response();
}
Ok(_) => {}
}
}
(StatusCode::OK, [(header::CONTENT_TYPE, "application/json")], body).into_response()
}
})
};
*router = std::mem::take(router).route(&spec_path, spec_route);
if let Some(doc_op_id) = node.doc_op_id {
let doc_path = ops
.get(&doc_op_id)
.map(|op| op.path.clone())
.unwrap_or_else(|| doc_op_id.to_string());
let viewer_html = generate_viewer(&doc_path, &spec_path, node.viewer);
let viewer_route = {
let h = viewer_html;
axum::routing::get(move |axum::extract::State(site): axum::extract::State<Site>, req: axum::extract::Request| {
let body = h.clone();
async move {
use axum::http::{StatusCode, header};
use axum::response::IntoResponse;
if let Some(pred) = auth_pred {
let (parts, _) = req.into_parts();
match site.authenticator().extract_user(&parts, &[], false) {
Err(e) => return e.into_response(),
Ok(user) if !pred(&user) => {
return StatusCode::FORBIDDEN.into_response();
}
Ok(_) => {}
}
}
(StatusCode::OK, [(header::CONTENT_TYPE, "text/html; charset=utf-8")], body).into_response()
}
})
};
*router = std::mem::take(router).route(&doc_path, viewer_route);
}
}
Ok(())
}
}
impl Bundle {
pub fn with_openapi(mut self, conf: OpenApiConf) -> Self {
let operation_ids: Vec<Uuid> = self
.ops
.values()
.filter(|op| !op.hidden)
.map(|op| op.id)
.collect();
let spec_op = crate::callables::Operation::from_api_doc(
&format!("__spec__{}", conf.spec_path),
&conf.spec_path,
);
let spec_op_id = spec_op.id;
self.ops.insert(spec_op_id, spec_op);
let doc_op_id = conf.doc_path.as_deref().map(|path| {
let op = crate::callables::Operation::from_api_doc(
&format!("__doc__{}", path),
path,
);
let id = op.id;
self.ops.insert(id, op);
id
});
self.doc_engine.register(DocNode {
spec_op_id,
doc_op_id,
operation_ids,
meta: conf.meta,
viewer: conf.viewer,
auth: conf.auth,
});
self
}
}
fn generate_spec(views: &[&Operation], meta: &ApiMeta) -> Result<Bytes, BundleError> {
let doc_gen = ApiDocGenerator::new(meta.clone());
let api = doc_gen.generate(views).map_err(|e| BundleError::DocGen(e.to_string()))?;
let vec = serde_json::to_vec(&api).map_err(|e| BundleError::DocGen(e.to_string()))?;
Ok(Bytes::from(vec))
}
fn generate_viewer(doc_path: &str, spec_path: &str, viewer: DocViewer) -> String {
let from_dir = doc_path.rfind('/').map(|i| &doc_path[..=i]).unwrap_or("/");
let relative = spec_path
.strip_prefix(from_dir)
.unwrap_or(spec_path);
let html = ApiDocGenerator::serve_doc(relative, viewer);
html.0
}