use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use server_less_parse::{MethodInfo, extract_groups, extract_methods, get_impl_name, resolve_method_group};
use syn::{ItemImpl, Token, parse::Parse};
use crate::openapi_gen::{ResponseOverride, RouteOverride, generate_openapi_spec};
use crate::server_attrs::{has_server_hidden, has_server_skip};
#[derive(Default)]
pub(crate) struct OpenApiArgs {
pub prefix: Option<String>,
}
impl Parse for OpenApiArgs {
fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
let mut args = OpenApiArgs::default();
while !input.is_empty() {
let ident: syn::Ident = input.parse()?;
input.parse::<Token![=]>()?;
match ident.to_string().as_str() {
"prefix" => {
let lit: syn::LitStr = input.parse()?;
args.prefix = Some(lit.value());
}
other => {
const VALID_ARGS: &[&str] = &["prefix"];
let suggestion = crate::did_you_mean(other, VALID_ARGS)
.map(|s| format!(" — did you mean `{s}`?"))
.unwrap_or_default();
return Err(syn::Error::new(
ident.span(),
format!(
"unknown argument `{other}`{suggestion}\n\
Valid arguments: prefix\n\
Example: #[openapi(prefix = \"/api/v1\")]"
),
));
}
}
if input.peek(Token![,]) {
input.parse::<Token![,]>()?;
}
}
Ok(args)
}
}
struct DetectedProtocols {
http: bool,
jsonrpc: bool,
ws: bool,
graphql: bool,
}
impl DetectedProtocols {
fn from_attrs(attrs: &[syn::Attribute]) -> Self {
let mut detected = DetectedProtocols {
http: false,
jsonrpc: false,
ws: false,
graphql: false,
};
for attr in attrs {
if let Some(ident) = attr.path().get_ident() {
match ident.to_string().as_str() {
"http" => detected.http = true,
"jsonrpc" => detected.jsonrpc = true,
"ws" => detected.ws = true,
"graphql" => detected.graphql = true,
_ => {}
}
}
}
detected
}
fn any_detected(&self) -> bool {
self.http || self.jsonrpc || self.ws || self.graphql
}
fn generate_merges(&self) -> TokenStream2 {
let mut merges = Vec::new();
if self.http {
merges.push(quote! { .merge_paths(Self::http_openapi_paths()) });
}
if self.jsonrpc {
merges.push(quote! { .merge_paths(Self::jsonrpc_openapi_paths()) });
}
if self.graphql {
merges.push(quote! { .merge_paths(Self::graphql_openapi_paths()) });
}
if self.ws {
merges.push(quote! { .merge_paths(Self::ws_openapi_paths()) });
}
quote! { #(#merges)* }
}
}
pub(crate) fn expand_openapi(args: OpenApiArgs, mut impl_block: ItemImpl) -> syn::Result<TokenStream2> {
crate::reject_generic_impl(&impl_block)?;
let app_meta = crate::app::extract_app_meta(&mut impl_block.attrs);
let struct_name = get_impl_name(&impl_block)?;
let generics_clone = impl_block.generics.clone();
let (impl_generics, _ty_generics, where_clause) = generics_clone.split_for_impl();
let self_ty = impl_block.self_ty.clone();
let struct_name_str = struct_name.to_string();
let protocols = DetectedProtocols::from_attrs(&impl_block.attrs);
if protocols.any_detected() {
let merges = protocols.generate_merges();
let mut detected_list = Vec::new();
if protocols.http {
detected_list.push("HTTP");
}
if protocols.jsonrpc {
detected_list.push("JSON-RPC");
}
if protocols.ws {
detected_list.push("WebSocket");
}
if protocols.graphql {
detected_list.push("GraphQL");
}
let openapi_doc = format!(
"Get combined OpenAPI 3.0 specification.\n\n\
Composed from {} protocol{}: {}.",
detected_list.len(),
if detected_list.len() == 1 { "" } else { "s" },
detected_list.join(", ")
);
let openapi_title = app_meta.name.unwrap_or_else(|| struct_name_str.clone());
let openapi_version = match app_meta.version.into_explicit() {
Some(v) => quote! { #v },
None => quote! { ::std::env!("CARGO_PKG_VERSION") },
};
Ok(quote! {
#impl_block
impl #impl_generics #self_ty #where_clause {
#[doc = #openapi_doc]
pub fn openapi_spec() -> ::server_less::serde_json::Value {
::server_less::OpenApiBuilder::new()
.title(#openapi_title)
.version(#openapi_version)
#merges
.build()
}
}
})
} else {
let methods = extract_methods(&impl_block)?;
let prefix = args.prefix.unwrap_or_default();
let group_registry = extract_groups(&impl_block)?;
let mut openapi_methods: Vec<(MethodInfo, RouteOverride, ResponseOverride)> = Vec::new();
for method in &methods {
let mut overrides = RouteOverride::parse_from_attrs(&method.method.attrs)?;
let response_overrides = ResponseOverride::parse_from_attrs(&method.method.attrs)?;
if overrides.skip || overrides.hidden || has_server_skip(method) || has_server_hidden(method) {
continue;
}
if let Some(group_name) = resolve_method_group(method, &group_registry)? {
overrides.tags.insert(0, group_name);
}
openapi_methods.push((method.clone(), overrides, response_overrides));
}
let openapi_fn =
generate_openapi_spec(&struct_name, &prefix, &openapi_methods)?;
let standalone_doc = format!(
"Get OpenAPI 3.0 specification for this service ({} endpoint{}).",
openapi_methods.len(),
if openapi_methods.len() == 1 { "" } else { "s" }
);
let mut clean_impl = impl_block;
clean_impl
.attrs
.retain(|attr| !attr.path().is_ident("server"));
for item in &mut clean_impl.items {
if let syn::ImplItem::Fn(method) = item {
method
.attrs
.retain(|attr| !attr.path().is_ident("route") && !attr.path().is_ident("response"));
for input in &mut method.sig.inputs {
if let syn::FnArg::Typed(pat_type) = input {
pat_type.attrs.retain(|attr| !attr.path().is_ident("param"));
}
}
}
}
let maybe_impl = if crate::is_protocol_impl_emitter(&clean_impl, "openapi") {
quote! { #clean_impl }
} else {
quote! {}
};
Ok(quote! {
#maybe_impl
impl #impl_generics #self_ty #where_clause {
#[doc = #standalone_doc]
pub fn openapi_spec() -> ::server_less::serde_json::Value {
#openapi_fn
}
}
})
}
}