axum_enum_response/
lib.rs

1//! Easily create axum::http::Response's from Enums!
2//! MSRV: 1.65.0
3//!
4//! # Example Usage
5//! ```
6//! #[derive(axum_enum_response::EnumIntoResponse)]
7//! enum ErrorResponse {
8//!    #[status_code(UNAUTHORIZED)]
9//!    Unauthorized, // 401, empty body
10//!    #[status_code(OK)]
11//!    #[body("hello"=>"world")]
12//!    Ok, // 200, body = {"hello": "world"}
13//!    #[status_code(FORBIDDEN)]
14//!    #[body("mew")]
15//!    Forbidden, // 403, body = {"error": "mew"}
16//!    #[status_code(INTERNAL_SERVER_ERROR)]
17//!	   FromUtf8Error(#[from] FromUtf8Error), // 500, body = {"error": FromUtf8Error::to_string()}
18//!    #[status_code(INTERNAL_SERVER_ERROR)]
19//!    InternalServerError(#[key("awwa")] String), // 500, body = {"awwa": STRING}
20//! }
21//! ```
22//!
23//! You can also use any struct that implements `serde::Serialize` as a field like this:
24//! ```no_run
25//! #[derive(serde::Serialize)]
26//! struct SomeData {
27//!     meow: String,
28//! }
29//!
30//! #[derive(axum_enum_response::EnumIntoResponse)]
31//! enum ErrorResponse {
32//!     #[status_code(BAD_REQUEST)]
33//!     BadRequest(SomeData), // 400, body = {"meow": STRING}
34//! }
35//! ```
36//!
37
38#![warn(clippy::pedantic)]
39
40use proc_macro::TokenStream;
41use quote::quote;
42use syn::{
43	parse::Parse, parse_macro_input, Attribute, Data, DeriveInput, Error, Ident, LitStr, Meta, Result, Token, Type,
44};
45
46type TokenStream2 = proc_macro2::TokenStream;
47
48#[proc_macro_derive(EnumIntoResponse, attributes(status_code, body, key, from))]
49pub fn enum_into_response(input: TokenStream) -> TokenStream {
50	let input = parse_macro_input!(input as DeriveInput);
51	match impl_enum_into_response(input) {
52		Ok(tokens) => tokens,
53		Err(err) => err.into_compile_error().into(),
54	}
55}
56
57fn impl_enum_into_response(input: DeriveInput) -> Result<TokenStream> {
58	let enum_name = input.ident;
59	let Data::Enum(data_enum) = input.data else {
60		return Err(Error::new_spanned(
61			enum_name,
62			"You may only use 'EnumIntoResponse' on enums",
63		));
64	};
65
66	let (match_branches, impls) = data_enum.variants.into_iter().map(|variant| {
67		let ident = &variant.ident;
68		let field_attributes = parse_field_attributes(&variant.fields)?;
69		let VariantAttributes { status_code, body } = parse_attributes(ident, &variant.attrs)?;
70
71		let match_branches = if let Some(FieldAttributes { key, from_ty }) = &field_attributes {
72			if from_ty.is_some() {
73				if let Some(key) = key {
74					quote! {
75						#enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([(#key, v.to_string())])).into_response())),
76					}
77				} else {
78					quote! {
79						#enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([("error", v.to_string())])).into_response())),
80					}
81				}
82			} else if let Some(key) = key {
83				quote! {
84					#enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([(#key, v)])).into_response())),
85				}
86			} else {
87				quote! {
88					#enum_name::#ident(v) => (::axum::http::StatusCode::#status_code, Some(::axum::Json(v).into_response())),
89				}
90			}
91		} else if let Some(BodyAttribute { key, value }) = body {
92			let key = key.unwrap_or_else(|| "error".to_string());
93			quote! {
94				#enum_name::#ident => (::axum::http::StatusCode::#status_code, Some(::axum::Json(::std::collections::HashMap::from([(#key, #value)])).into_response())),
95			}
96		} else {
97			quote! {
98				#enum_name::#ident => (::axum::http::StatusCode::#status_code, None),
99			}
100		};
101
102		Result::Ok((match_branches, if let Some(FieldAttributes { from_ty: Some(ty), .. }) = field_attributes {
103			Some(quote! {
104			impl From<#ty> for #enum_name {
105				fn from(value: #ty) -> Self {
106					Self::#ident(value)
107				}
108			}
109			})
110		} else {
111			None
112		}))
113	}).collect::<Result<(Vec<_>, Vec<_>)>>()?;
114
115	let output = quote! {
116		impl ::axum::response::IntoResponse for #enum_name {
117			fn into_response(self) -> ::axum::response::Response {
118				let (status_code, body): (::axum::http::StatusCode, Option<::axum::response::Response>) = match self {
119					#( #match_branches )*
120				};
121
122				let Some(body) = body else {
123					return status_code.into_response();
124				};
125
126				(status_code, body).into_response()
127			}
128		}
129
130		impl ::core::convert::From<#enum_name> for ::axum::response::Response {
131			fn from(value: #enum_name) -> ::axum::response::Response {
132				::axum::response::IntoResponse::into_response(value)
133			}
134		}
135
136		#( #impls )*
137	};
138
139	Ok(output.into())
140}
141
142struct FieldAttributes {
143	key: Option<TokenStream2>,
144	from_ty: Option<Type>,
145}
146
147fn parse_field_attributes(fields: &syn::Fields) -> Result<Option<FieldAttributes>> {
148	let mut fields = fields.iter();
149	let Some(field) = fields.next() else {
150		return Ok(None);
151	};
152
153	if field.ident.is_some() {
154		return Err(syn::Error::new_spanned(
155			field,
156			"EnumIntoResponse only supports unnamed fields.",
157		));
158	}
159
160	if let Some(field) = fields.next() {
161		return Err(syn::Error::new_spanned(
162			field,
163			"EnumIntoResponse only supports up to one unnamed field.",
164		));
165	}
166
167	let mut key = None;
168	let mut from_ty = None;
169
170	for attribute in &field.attrs {
171		let Some(iden) = attribute.path().get_ident() else {
172			return Err(Error::new_spanned(attribute, "You must name attributes"));
173		};
174
175		match iden.to_string().as_str() {
176			"key" => {
177				if let Meta::List(list) = &attribute.meta {
178					let tokens = &list.tokens;
179					key = Some(quote! {
180						#tokens
181					});
182				} else {
183					return Err(Error::new_spanned(attribute, "'key' attribute value must be a string"));
184				}
185			}
186
187			"from" => {
188				from_ty = Some(field.ty.clone());
189			}
190
191			_ => {}
192		}
193	}
194
195	Ok(Some(FieldAttributes { key, from_ty }))
196}
197
198struct VariantAttributes {
199	status_code: TokenStream2,
200	body: Option<BodyAttribute>,
201}
202
203struct BodyAttribute {
204	key: Option<String>,
205	value: String,
206}
207
208impl Parse for BodyAttribute {
209	fn parse(input: syn::parse::ParseStream) -> Result<Self> {
210		let first = input.parse::<LitStr>()?;
211		let mut second: Option<LitStr> = None;
212
213		if input.peek(Token![=>]) {
214			input.parse::<Token![=>]>()?;
215			second = Some(input.parse::<LitStr>()?);
216		}
217
218		if let Some(value) = second {
219			Ok(Self {
220				key: Some(first.value()),
221				value: value.value(),
222			})
223		} else {
224			Ok(Self {
225				key: None,
226				value: first.value(),
227			})
228		}
229	}
230}
231
232fn parse_attributes(ident: &Ident, attributes: &Vec<Attribute>) -> Result<VariantAttributes> {
233	if attributes.is_empty() {
234		return Err(Error::new_spanned(
235			ident,
236			"You must specify the 'status_code' attribute",
237		));
238	}
239
240	let mut status_code = None;
241	let mut body = None;
242
243	for attribute in attributes {
244		let Some(iden) = attribute.path().get_ident() else {
245			return Err(Error::new_spanned(ident, "You must name attributes"));
246		};
247
248		match iden.to_string().as_str() {
249			"status_code" => {
250				status_code = Some(attribute.meta.require_list()?.tokens.clone());
251			}
252
253			"body" => {
254				body = Some(attribute.meta.require_list()?.parse_args::<BodyAttribute>()?);
255			}
256
257			_ => {}
258		}
259	}
260
261	let Some(status_code) = status_code else {
262		return Err(Error::new_spanned(ident, "'status_code' attribute must be specified"));
263	};
264
265	Ok(VariantAttributes { status_code, body })
266}