axum_accept_macros/
lib.rs

1//! The proc-macro crate of axum-accept.
2#![deny(warnings)]
3#![deny(clippy::pedantic, clippy::unwrap_used)]
4#![deny(missing_docs)]
5extern crate proc_macro;
6
7use std::collections::HashMap;
8
9use mediatype::MediaTypeBuf;
10use proc_macro::TokenStream;
11use quote::quote;
12use syn::{
13    Attribute, Data, DeriveInput, Fields, GenericParam, Ident, Lit, Meta, TypeParam,
14    TypeParamBound, parse_macro_input,
15};
16
17/// This is the proc macro for `AcceptExtractor`.
18///
19/// # Panics
20///
21/// If it fails to parse the attributes.
22#[proc_macro_derive(AcceptExtractor, attributes(accept))]
23pub fn derive_accept_extractor(input: TokenStream) -> TokenStream {
24    let input = parse_macro_input!(input as DeriveInput);
25
26    let name = &input.ident;
27
28    let (_, ty_generics, where_clause) = input.generics.split_for_impl();
29
30    let mut generics = input.generics.clone();
31
32    // we need to add <S: Send + Sync> to the impl generics for FromRequestParts
33    let s_param = GenericParam::Type(TypeParam {
34        attrs: vec![],
35        ident: Ident::new("S", proc_macro2::Span::call_site()),
36        colon_token: Some(syn::token::Colon::default()),
37        bounds: {
38            let mut bounds = syn::punctuated::Punctuated::new();
39            bounds.push(TypeParamBound::Trait(syn::TraitBound {
40                paren_token: None,
41                modifier: syn::TraitBoundModifier::None,
42                lifetimes: None,
43                path: syn::parse_str("Send").unwrap(),
44            }));
45            bounds.push(TypeParamBound::Trait(syn::TraitBound {
46                paren_token: None,
47                modifier: syn::TraitBoundModifier::None,
48                lifetimes: None,
49                path: syn::parse_str("Sync").unwrap(),
50            }));
51            bounds
52        },
53        eq_token: None,
54        default: None,
55    });
56    generics.params.push(s_param);
57
58    let (impl_generics, _, _) = generics.split_for_impl();
59
60    let Data::Enum(data) = &input.data else {
61        panic!("AcceptExtractor can only be derived for enums");
62    };
63
64    let has_default = data.variants.iter().any(|variant| {
65        variant.attrs.iter().any(|attr| match &attr.meta {
66            Meta::Path(path) => path.is_ident("default"),
67            _ => false,
68        })
69    });
70
71    // Match arms with ty, subty and suffix
72    let mut match_arms = Vec::new();
73    // Match arms with ty only (for checking mediatypes like text/*)
74    let mut match_arms_tys = HashMap::new();
75    // Store first variant to fall back to if we don't have a default.
76    let mut first_variant_name = None;
77
78    for variant in &data.variants {
79        let variant_name = &variant.ident;
80        let mediatype_raw = get_accept_mediatype(&variant.attrs);
81        let mediatype = MediaTypeBuf::from_string(mediatype_raw.clone()) // compile time so clone is fine
82            .expect("Failed to parse mediatype");
83        let (ty, subty, suffix) = (
84            mediatype.ty().as_str(),
85            mediatype.subty().as_str(),
86            mediatype.suffix().map(|s| s.as_str()),
87        );
88
89        if ty == "*" || subty == "*" {
90            panic!("Please use a concrete mediatype");
91        }
92
93        if first_variant_name.is_none() {
94            first_variant_name = Some(variant_name.clone());
95        }
96
97        match_arms_tys.insert(ty.to_string(), variant_name);
98
99        match &variant.fields {
100            Fields::Unit => {
101                // quote encodes None to empty string, so we need to take extra
102                // steps
103                if let Some(suffix) = suffix {
104                    match_arms.push(quote! {
105                        (#ty, #subty, Some(#suffix)) => return Ok(#name::#variant_name),
106                    });
107                } else {
108                    match_arms.push(quote! {
109                        (#ty, #subty, None) => return Ok(#name::#variant_name),
110                    });
111                }
112            }
113            _ => panic!("Only unit fields are supported"),
114        }
115    }
116
117    let check_and_return_default = if has_default {
118        Some(quote! {
119            if mediatypes.is_empty() {
120                return Ok(#name::default());
121            }
122        })
123    } else {
124        None
125    };
126
127    let handle_star_star = if has_default {
128        quote! {
129            return Ok(#name::default());
130        }
131    } else {
132        quote! {
133            return Ok(#name::#first_variant_name);
134        }
135    };
136
137    let match_arms_tys = match_arms_tys.iter().map(|(ty, variant_name)| {
138        quote! {
139            (#ty) => return Ok(#name::#variant_name),
140        }
141    });
142
143    let expanded = quote! {
144        impl #impl_generics axum::extract::FromRequestParts<S> for #name #ty_generics #where_clause {
145            type Rejection = axum_accept::AcceptRejection;
146
147            async fn from_request_parts(parts: &mut axum::http::request::Parts, _state: &S) -> Result<Self, Self::Rejection> {
148                let mediatypes = axum_accept::parse_mediatypes(&parts.headers)?;
149                #check_and_return_default
150                for mt in mediatypes {
151                    match (mt.ty.as_str(), mt.subty.as_str()) {
152                        ("*", "*") => {
153                            // return either the the default or the first
154                            // variant
155                            #handle_star_star
156                        },
157                        // do we have any mediatype that shares the main type?
158                        // e.g. we offer text/plain and get accept: text/*
159                        (_, "*") => match (mt.ty.as_str()) {
160                            #(#match_arms_tys)*
161                            _ => {} // continue searching
162                        },
163                        // do proper matching
164                        _ =>  match (mt.ty.as_str(), mt.subty.as_str(), mt.suffix.map(|s| s.as_str())) {
165                            #(#match_arms)*
166                            _ => {} // continue searching
167                        },
168                    }
169                }
170
171                Err(axum_accept::AcceptRejection::NoSupportedMediaTypeFound)
172            }
173        }
174    };
175
176    TokenStream::from(expanded)
177}
178
179fn get_accept_mediatype(attrs: &[Attribute]) -> String {
180    for attr in attrs {
181        if attr.path().is_ident("accept") {
182            if let Meta::List(meta_list) = &attr.meta {
183                for nested in meta_list
184                    .parse_args_with(
185                        syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
186                    )
187                    .expect("Failed to parse args")
188                {
189                    if let syn::Meta::NameValue(name_value) = nested {
190                        if name_value.path.is_ident("mediatype") {
191                            if let syn::Expr::Lit(expr_lit) = &name_value.value {
192                                if let Lit::Str(lit_str) = &expr_lit.lit {
193                                    return lit_str.value();
194                                }
195                            }
196                        }
197                    }
198                }
199            }
200        }
201    }
202
203    panic!(r#"Missing #[accept(mediatype = "...")]"#)
204}