axum_enum_response/
lib.rs1#![warn(clippy::pedantic)]
39
40use proc_macro::TokenStream;
41use quote::quote;
42use syn::{
43 parse::Parse, parse_macro_input, Attribute, Data, DeriveInput, Error, Ident, LitStr, Meta, Result, Token, Type,
44};
45
46type TokenStream2 = proc_macro2::TokenStream;
47
48#[proc_macro_derive(EnumIntoResponse, attributes(status_code, body, key, from))]
49pub fn enum_into_response(input: TokenStream) -> TokenStream {
50 let input = parse_macro_input!(input as DeriveInput);
51 match impl_enum_into_response(input) {
52 Ok(tokens) => tokens,
53 Err(err) => err.into_compile_error().into(),
54 }
55}
56
57fn impl_enum_into_response(input: DeriveInput) -> Result<TokenStream> {
58 let enum_name = input.ident;
59 let Data::Enum(data_enum) = input.data else {
60 return Err(Error::new_spanned(
61 enum_name,
62 "You may only use 'EnumIntoResponse' on enums",
63 ));
64 };
65
66 let (match_branches, impls) = data_enum.variants.into_iter().map(|variant| {
67 let ident = &variant.ident;
68 let field_attributes = parse_field_attributes(&variant.fields)?;
69 let VariantAttributes { status_code, body } = parse_attributes(ident, &variant.attrs)?;
70
71 let match_branches = if let Some(FieldAttributes { key, from_ty }) = &field_attributes {
72 if from_ty.is_some() {
73 if let Some(key) = key {
74 quote! {
75 #enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([(#key, v.to_string())])).into_response())),
76 }
77 } else {
78 quote! {
79 #enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([("error", v.to_string())])).into_response())),
80 }
81 }
82 } else if let Some(key) = key {
83 quote! {
84 #enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([(#key, v)])).into_response())),
85 }
86 } else {
87 quote! {
88 #enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(v).into_response())),
89 }
90 }
91 } else if let Some(BodyAttribute { key, value }) = body {
92 let key = key.unwrap_or_else(|| "error".to_string());
93 quote! {
94 #enum_name::#ident => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([(#key, #value)])).into_response())),
95 }
96 } else {
97 quote! {
98 #enum_name::#ident => (::axum::http::StatusCode::#status_code, None),
99 }
100 };
101
102 Result::Ok((match_branches, if let Some(FieldAttributes { from_ty: Some(ty), .. }) = field_attributes {
103 Some(quote! {
104 impl From<#ty> for #enum_name {
105 fn from(value: #ty) -> Self {
106 Self::#ident(value)
107 }
108 }
109 })
110 } else {
111 None
112 }))
113 }).collect::<Result<(Vec<_>, Vec<_>)>>()?;
114
115 let output = quote! {
116 impl ::axum::response::IntoResponse for #enum_name {
117 fn into_response(self) -> ::axum::response::Response {
118 let (status_code, body): (::axum::http::StatusCode, Option<::axum::response::Response>) = match self {
119 #( #match_branches )*
120 };
121
122 let Some(body) = body else {
123 return status_code.into_response();
124 };
125
126 (status_code, body).into_response()
127 }
128 }
129
130 impl ::core::convert::From<#enum_name> for ::axum::response::Response {
131 fn from(value: #enum_name) -> ::axum::response::Response {
132 ::axum::response::IntoResponse::into_response(value)
133 }
134 }
135
136 #( #impls )*
137 };
138
139 Ok(output.into())
140}
141
142struct FieldAttributes {
143 key: Option<TokenStream2>,
144 from_ty: Option<Type>,
145}
146
147fn parse_field_attributes(fields: &syn::Fields) -> Result<Option<FieldAttributes>> {
148 let mut fields = fields.iter();
149 let Some(field) = fields.next() else {
150 return Ok(None);
151 };
152
153 if field.ident.is_some() {
154 return Err(syn::Error::new_spanned(
155 field,
156 "EnumIntoResponse only supports unnamed fields.",
157 ));
158 }
159
160 if let Some(field) = fields.next() {
161 return Err(syn::Error::new_spanned(
162 field,
163 "EnumIntoResponse only supports up to one unnamed field.",
164 ));
165 }
166
167 let mut key = None;
168 let mut from_ty = None;
169
170 for attribute in &field.attrs {
171 let Some(iden) = attribute.path().get_ident() else {
172 return Err(Error::new_spanned(attribute, "You must name attributes"));
173 };
174
175 match iden.to_string().as_str() {
176 "key" => {
177 if let Meta::List(list) = &attribute.meta {
178 let tokens = &list.tokens;
179 key = Some(quote! {
180 #tokens
181 });
182 } else {
183 return Err(Error::new_spanned(attribute, "'key' attribute value must be a string"));
184 }
185 }
186
187 "from" => {
188 from_ty = Some(field.ty.clone());
189 }
190
191 _ => {}
192 }
193 }
194
195 Ok(Some(FieldAttributes { key, from_ty }))
196}
197
198struct VariantAttributes {
199 status_code: TokenStream2,
200 body: Option<BodyAttribute>,
201}
202
203struct BodyAttribute {
204 key: Option<String>,
205 value: String,
206}
207
208impl Parse for BodyAttribute {
209 fn parse(input: syn::parse::ParseStream) -> Result<Self> {
210 let first = input.parse::<LitStr>()?;
211 let mut second: Option<LitStr> = None;
212
213 if input.peek(Token![=>]) {
214 input.parse::<Token![=>]>()?;
215 second = Some(input.parse::<LitStr>()?);
216 }
217
218 if let Some(value) = second {
219 Ok(Self {
220 key: Some(first.value()),
221 value: value.value(),
222 })
223 } else {
224 Ok(Self {
225 key: None,
226 value: first.value(),
227 })
228 }
229 }
230}
231
232fn parse_attributes(ident: &Ident, attributes: &Vec<Attribute>) -> Result<VariantAttributes> {
233 if attributes.is_empty() {
234 return Err(Error::new_spanned(
235 ident,
236 "You must specify the 'status_code' attribute",
237 ));
238 }
239
240 let mut status_code = None;
241 let mut body = None;
242
243 for attribute in attributes {
244 let Some(iden) = attribute.path().get_ident() else {
245 return Err(Error::new_spanned(ident, "You must name attributes"));
246 };
247
248 match iden.to_string().as_str() {
249 "status_code" => {
250 status_code = Some(attribute.meta.require_list()?.tokens.clone());
251 }
252
253 "body" => {
254 body = Some(attribute.meta.require_list()?.parse_args::<BodyAttribute>()?);
255 }
256
257 _ => {}
258 }
259 }
260
261 let Some(status_code) = status_code else {
262 return Err(Error::new_spanned(ident, "'status_code' attribute must be specified"));
263 };
264
265 Ok(VariantAttributes { status_code, body })
266}