use heck::{ToPascalCase, ToSnakeCase};
use openapiv3::{APIKeyLocation, OpenAPI, Operation, ReferenceOr, SecurityScheme};
use proc_macro2::TokenStream;
use quote::{format_ident, quote};
#[derive(Debug, Clone, Copy)]
pub enum ApiKeyIn {
Header,
Query,
Cookie,
}
#[derive(Debug, Clone)]
pub enum SchemeKind {
ApiKey { key: String, location: ApiKeyIn },
HttpBearer,
HttpBasic,
}
#[derive(Debug, Clone)]
pub struct SchemeInfo {
pub name: String,
pub ident: syn::Ident,
pub snake: String,
pub kind: SchemeKind,
}
#[derive(Debug, Clone, Default)]
pub struct OpSecurity {
pub alternatives: Vec<String>,
pub had_unsupported_and: bool,
}
#[must_use]
pub fn collect_schemes(openapi: &OpenAPI) -> Vec<SchemeInfo> {
let Some(components) = openapi.components.as_ref() else {
return Vec::new();
};
let mut out = Vec::new();
for (name, ref_or) in &components.security_schemes {
let scheme = match ref_or {
ReferenceOr::Item(s) => s,
ReferenceOr::Reference { .. } => continue,
};
let kind = match scheme {
SecurityScheme::APIKey { location, name, .. } => {
let loc = match location {
APIKeyLocation::Header => ApiKeyIn::Header,
APIKeyLocation::Query => ApiKeyIn::Query,
APIKeyLocation::Cookie => ApiKeyIn::Cookie,
};
SchemeKind::ApiKey {
key: name.clone(),
location: loc,
}
}
SecurityScheme::HTTP { scheme, .. } => match scheme.to_ascii_lowercase().as_str() {
"bearer" => SchemeKind::HttpBearer,
"basic" => SchemeKind::HttpBasic,
_ => continue,
},
SecurityScheme::OAuth2 { .. } | SecurityScheme::OpenIDConnect { .. } => continue,
};
let ident = format_ident!("{}", name.to_pascal_case());
let snake = name.to_snake_case();
out.push(SchemeInfo {
name: name.clone(),
ident,
snake,
kind,
});
}
out
}
#[must_use]
pub fn resolve_op_security(
op: &Operation,
openapi: &OpenAPI,
schemes: &[SchemeInfo],
) -> OpSecurity {
let requirements = match op.security.as_ref() {
Some(v) => v,
None => match openapi.security.as_ref() {
Some(v) => v,
None => return OpSecurity::default(),
},
};
let mut out = OpSecurity::default();
for req in requirements {
if req.len() > 1 {
out.had_unsupported_and = true;
continue;
}
let Some((name, _scopes)) = req.iter().next() else {
continue;
};
if scheme_by_name(schemes, name).is_some() {
out.alternatives.push(name.clone());
}
}
out
}
#[must_use]
pub fn scheme_by_name<'a>(schemes: &'a [SchemeInfo], name: &str) -> Option<&'a SchemeInfo> {
schemes.iter().find(|s| s.name == name)
}
#[must_use]
pub fn generate_scheme_types(schemes: &[SchemeInfo]) -> TokenStream {
let items: Vec<TokenStream> = schemes
.iter()
.map(|s| {
let ident = &s.ident;
match &s.kind {
SchemeKind::ApiKey { .. } | SchemeKind::HttpBearer => quote! {
#[derive(::core::fmt::Debug, ::core::clone::Clone)]
pub struct #ident(pub ::std::string::String);
},
SchemeKind::HttpBasic => quote! {
#[derive(::core::fmt::Debug, ::core::clone::Clone)]
pub struct #ident {
pub username: ::std::string::String,
pub password: ::std::string::String,
}
},
}
})
.collect();
quote! { #(#items)* }
}
#[must_use]
pub fn generate_op_auth_enum(op_id: &str, alternatives: &[&SchemeInfo]) -> Option<TokenStream> {
if alternatives.len() < 2 {
return None;
}
let ident = auth_enum_ident(op_id);
let variants: Vec<TokenStream> = alternatives
.iter()
.map(|s| {
let variant = &s.ident;
let ty = &s.ident;
quote! { #variant(#ty), }
})
.collect();
Some(quote! {
#[derive(::core::fmt::Debug, ::core::clone::Clone)]
pub enum #ident {
#(#variants)*
}
})
}
#[must_use]
pub fn auth_enum_ident(op_id: &str) -> syn::Ident {
format_ident!("{}Auth", op_id.to_pascal_case())
}
#[must_use]
pub fn auth_param_type(op_id: &str, op_security: &OpSecurity) -> Option<TokenStream> {
if op_security.alternatives.is_empty() {
return None;
}
if op_security.alternatives.len() == 1 {
return Some(TokenStream::new()); }
let ident = auth_enum_ident(op_id);
Some(quote! { #ident })
}
#[must_use]
pub fn resolve_alternatives<'a>(
op_security: &'a OpSecurity,
schemes: &'a [SchemeInfo],
) -> Vec<&'a SchemeInfo> {
op_security
.alternatives
.iter()
.filter_map(|name| scheme_by_name(schemes, name))
.collect()
}
#[must_use]
pub fn auth_state_ident(mod_ident: &syn::Ident) -> syn::Ident {
format_ident!("{}AuthState", mod_ident.to_string().to_pascal_case())
}
#[must_use]
pub fn client_auth_trait_ident(mod_ident: &syn::Ident) -> syn::Ident {
format_ident!("{}ClientAuth", mod_ident.to_string().to_pascal_case())
}
#[must_use]
pub fn scheme_field_ident(scheme: &SchemeInfo) -> syn::Ident {
format_ident!("{}", scheme.snake)
}
#[must_use]
pub fn generate_unsupported_and_errors(ops_with_and: &[String]) -> TokenStream {
if ops_with_and.is_empty() {
return TokenStream::new();
}
let msgs: Vec<TokenStream> = ops_with_and
.iter()
.map(|op_id| {
let msg = format!(
"openapi-trait: operation `{op_id}` requires multiple security schemes simultaneously (AND); v0.1 only supports OR of single-scheme alternatives"
);
quote! { ::core::compile_error!(#msg); }
})
.collect();
quote! { #(#msgs)* }
}