axum_accept_macros/
lib.rs1#![deny(warnings)]
3#![deny(clippy::pedantic, clippy::unwrap_used)]
4#![deny(missing_docs)]
5extern crate proc_macro;
6
7use std::collections::HashMap;
8
9use mediatype::MediaTypeBuf;
10use proc_macro::TokenStream;
11use quote::quote;
12use syn::{
13 Attribute, Data, DeriveInput, Fields, GenericParam, Ident, Lit, Meta, TypeParam,
14 TypeParamBound, parse_macro_input,
15};
16
17#[proc_macro_derive(AcceptExtractor, attributes(accept))]
23pub fn derive_accept_extractor(input: TokenStream) -> TokenStream {
24 let input = parse_macro_input!(input as DeriveInput);
25
26 let name = &input.ident;
27
28 let (_, ty_generics, where_clause) = input.generics.split_for_impl();
29
30 let mut generics = input.generics.clone();
31
32 let s_param = GenericParam::Type(TypeParam {
34 attrs: vec![],
35 ident: Ident::new("S", proc_macro2::Span::call_site()),
36 colon_token: Some(syn::token::Colon::default()),
37 bounds: {
38 let mut bounds = syn::punctuated::Punctuated::new();
39 bounds.push(TypeParamBound::Trait(syn::TraitBound {
40 paren_token: None,
41 modifier: syn::TraitBoundModifier::None,
42 lifetimes: None,
43 path: syn::parse_str("Send").unwrap(),
44 }));
45 bounds.push(TypeParamBound::Trait(syn::TraitBound {
46 paren_token: None,
47 modifier: syn::TraitBoundModifier::None,
48 lifetimes: None,
49 path: syn::parse_str("Sync").unwrap(),
50 }));
51 bounds
52 },
53 eq_token: None,
54 default: None,
55 });
56 generics.params.push(s_param);
57
58 let (impl_generics, _, _) = generics.split_for_impl();
59
60 let Data::Enum(data) = &input.data else {
61 panic!("AcceptExtractor can only be derived for enums");
62 };
63
64 let has_default = data.variants.iter().any(|variant| {
65 variant.attrs.iter().any(|attr| match &attr.meta {
66 Meta::Path(path) => path.is_ident("default"),
67 _ => false,
68 })
69 });
70
71 let mut match_arms = Vec::new();
73 let mut match_arms_tys = HashMap::new();
75 let mut first_variant_name = None;
77
78 for variant in &data.variants {
79 let variant_name = &variant.ident;
80 let mediatype_raw = get_accept_mediatype(&variant.attrs);
81 let mediatype = MediaTypeBuf::from_string(mediatype_raw.clone()) .expect("Failed to parse mediatype");
83 let (ty, subty, suffix) = (
84 mediatype.ty().as_str(),
85 mediatype.subty().as_str(),
86 mediatype.suffix().map(|s| s.as_str()),
87 );
88
89 if ty == "*" || subty == "*" {
90 panic!("Please use a concrete mediatype");
91 }
92
93 if first_variant_name.is_none() {
94 first_variant_name = Some(variant_name.clone());
95 }
96
97 match_arms_tys.insert(ty.to_string(), variant_name);
98
99 match &variant.fields {
100 Fields::Unit => {
101 if let Some(suffix) = suffix {
104 match_arms.push(quote! {
105 (#ty, #subty, Some(#suffix)) => return Ok(#name::#variant_name),
106 });
107 } else {
108 match_arms.push(quote! {
109 (#ty, #subty, None) => return Ok(#name::#variant_name),
110 });
111 }
112 }
113 _ => panic!("Only unit fields are supported"),
114 }
115 }
116
117 let check_and_return_default = if has_default {
118 Some(quote! {
119 if mediatypes.is_empty() {
120 return Ok(#name::default());
121 }
122 })
123 } else {
124 None
125 };
126
127 let handle_star_star = if has_default {
128 quote! {
129 return Ok(#name::default());
130 }
131 } else {
132 quote! {
133 return Ok(#name::#first_variant_name);
134 }
135 };
136
137 let match_arms_tys = match_arms_tys.iter().map(|(ty, variant_name)| {
138 quote! {
139 (#ty) => return Ok(#name::#variant_name),
140 }
141 });
142
143 let expanded = quote! {
144 impl #impl_generics axum::extract::FromRequestParts<S> for #name #ty_generics #where_clause {
145 type Rejection = axum_accept::AcceptRejection;
146
147 async fn from_request_parts(parts: &mut axum::http::request::Parts, _state: &S) -> Result<Self, Self::Rejection> {
148 let mediatypes = axum_accept::parse_mediatypes(&parts.headers)?;
149 #check_and_return_default
150 for mt in mediatypes {
151 match (mt.ty.as_str(), mt.subty.as_str()) {
152 ("*", "*") => {
153 #handle_star_star
156 },
157 (_, "*") => match (mt.ty.as_str()) {
160 #(#match_arms_tys)*
161 _ => {} },
163 _ => match (mt.ty.as_str(), mt.subty.as_str(), mt.suffix.map(|s| s.as_str())) {
165 #(#match_arms)*
166 _ => {} },
168 }
169 }
170
171 Err(axum_accept::AcceptRejection::NoSupportedMediaTypeFound)
172 }
173 }
174 };
175
176 TokenStream::from(expanded)
177}
178
179fn get_accept_mediatype(attrs: &[Attribute]) -> String {
180 for attr in attrs {
181 if attr.path().is_ident("accept") {
182 if let Meta::List(meta_list) = &attr.meta {
183 for nested in meta_list
184 .parse_args_with(
185 syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated,
186 )
187 .expect("Failed to parse args")
188 {
189 if let syn::Meta::NameValue(name_value) = nested {
190 if name_value.path.is_ident("mediatype") {
191 if let syn::Expr::Lit(expr_lit) = &name_value.value {
192 if let Lit::Str(lit_str) = &expr_lit.lit {
193 return lit_str.value();
194 }
195 }
196 }
197 }
198 }
199 }
200 }
201 }
202
203 panic!(r#"Missing #[accept(mediatype = "...")]"#)
204}