use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{Attribute, DeriveInput, parse_macro_input};
use crate::utils::docs::extract_doc_summary;
pub fn derive(input: TokenStream) -> TokenStream {
let input = parse_macro_input!(input as DeriveInput);
match generate(&input) {
Ok(tokens) => tokens.into(),
Err(err) => err.to_compile_error().into()
}
}
fn generate(input: &DeriveInput) -> syn::Result<TokenStream2> {
let name = &input.ident;
let vis = &input.vis;
let variants = match &input.data {
syn::Data::Enum(data) => &data.variants,
_ => {
return Err(syn::Error::new_spanned(
input,
"EntityError can only be derived for enums"
));
}
};
let error_variants: Vec<ErrorVariant> = variants
.iter()
.filter_map(|v| parse_error_variant(v).ok())
.collect();
if error_variants.is_empty() {
return Ok(TokenStream2::new());
}
let responses_struct = format_ident!("{}Responses", name);
let status_codes: Vec<u16> = error_variants.iter().map(|v| v.status).collect();
let descriptions: Vec<&String> = error_variants.iter().map(|v| &v.description).collect();
let doc = format!(
"OpenAPI error responses for `{}`.\\n\\n\
Use with `#[utoipa::path(responses(...))]`.",
name
);
Ok(quote! {
#[doc = #doc]
#vis struct #responses_struct;
impl #responses_struct {
#[must_use]
pub const fn status_codes() -> &'static [u16] {
&[#(#status_codes),*]
}
#[must_use]
pub fn descriptions() -> &'static [&'static str] {
&[#(#descriptions),*]
}
#[must_use]
pub fn utoipa_responses() -> Vec<(u16, &'static str)> {
vec![
#((#status_codes, #descriptions)),*
]
}
}
})
}
struct ErrorVariant {
status: u16,
description: String
}
fn parse_error_variant(variant: &syn::Variant) -> syn::Result<ErrorVariant> {
let status = parse_status_attr(&variant.attrs)?;
let description =
extract_doc_summary(&variant.attrs).unwrap_or_else(|| format!("{} error", variant.ident));
Ok(ErrorVariant {
status,
description
})
}
fn parse_status_attr(attrs: &[Attribute]) -> syn::Result<u16> {
for attr in attrs {
if attr.path().is_ident("status") {
let status: syn::LitInt = attr.parse_args()?;
return status.base10_parse();
}
}
Err(syn::Error::new(
proc_macro2::Span::call_site(),
"Missing #[status(code)] attribute"
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_status_code() {
let input: DeriveInput = syn::parse_quote! {
enum UserError {
#[status(404)]
NotFound,
}
};
if let syn::Data::Enum(data) = &input.data {
let variant = &data.variants[0];
let status = parse_status_attr(&variant.attrs).unwrap();
assert_eq!(status, 404);
}
}
#[test]
fn parse_error_variant_full() {
let input: DeriveInput = syn::parse_quote! {
enum UserError {
#[status(409)]
EmailExists,
}
};
if let syn::Data::Enum(data) = &input.data {
let variant = &data.variants[0];
let parsed = parse_error_variant(variant).unwrap();
assert_eq!(parsed.status, 409);
assert_eq!(parsed.description, "User with this email already exists");
}
}
#[test]
fn parse_missing_status_fails() {
let input: DeriveInput = syn::parse_quote! {
enum UserError {
NoStatus,
}
};
if let syn::Data::Enum(data) = &input.data {
let variant = &data.variants[0];
let result = parse_error_variant(variant);
assert!(result.is_err());
}
}
#[test]
fn generate_for_non_enum_fails() {
let input: DeriveInput = syn::parse_quote! {
struct NotAnEnum {
field: String,
}
};
let result = generate(&input);
assert!(result.is_err());
}
#[test]
fn generate_empty_variants_returns_empty() {
let input: DeriveInput = syn::parse_quote! {
enum EmptyError {
NoStatus,
}
};
let result = generate(&input);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn generate_multiple_variants() {
let input: DeriveInput = syn::parse_quote! {
enum UserError {
#[status(404)]
NotFound,
#[status(409)]
AlreadyExists,
#[status(500)]
Internal,
}
};
let result = generate(&input);
assert!(result.is_ok());
let output = result.unwrap().to_string();
assert!(output.contains("UserErrorResponses"));
assert!(output.contains("status_codes"));
assert!(output.contains("descriptions"));
assert!(output.contains("utoipa_responses"));
assert!(output.contains("404"));
assert!(output.contains("409"));
assert!(output.contains("500"));
}
#[test]
fn parse_variant_without_doc_uses_default() {
let input: DeriveInput = syn::parse_quote! {
enum Error {
#[status(400)]
BadRequest,
}
};
if let syn::Data::Enum(data) = &input.data {
let variant = &data.variants[0];
let parsed = parse_error_variant(variant).unwrap();
assert_eq!(parsed.status, 400);
assert!(parsed.description.contains("BadRequest"));
}
}
#[test]
fn generate_public_visibility() {
let input: DeriveInput = syn::parse_quote! {
pub enum ApiError {
#[status(404)]
NotFound,
}
};
let result = generate(&input);
assert!(result.is_ok());
let output = result.unwrap().to_string();
assert!(output.contains("pub struct ApiErrorResponses"));
}
#[test]
fn generate_private_visibility() {
let input: DeriveInput = syn::parse_quote! {
enum PrivateError {
#[status(500)]
Internal,
}
};
let result = generate(&input);
assert!(result.is_ok());
let output = result.unwrap().to_string();
assert!(output.contains("struct PrivateErrorResponses"));
assert!(!output.contains("pub struct PrivateErrorResponses"));
}
#[test]
fn status_code_parsing_various_codes() {
let codes = [200_u16, 201, 400, 401, 403, 404, 409, 422, 500, 502, 503];
for code in codes {
let code_str = code.to_string();
let input: DeriveInput = syn::parse_quote! {
enum Error {
#[status(#code)]
Test,
}
};
if let syn::Data::Enum(data) = &input.data {
let variant = &data.variants[0];
let result = parse_status_attr(&variant.attrs);
assert!(result.is_ok(), "Should parse status code {}", code_str);
}
}
}
}