axum_codec_macros/
lib.rs

1#![cfg_attr(
2	not(any(
3		feature = "bincode",
4		feature = "bitcode",
5		feature = "serde",
6		feature = "aide",
7		feature = "validator"
8	)),
9	allow(unused_variables, dead_code)
10)]
11
12use proc_macro::TokenStream;
13use syn::parse::Parse;
14
15mod apply;
16mod attr_parsing;
17mod debug_handler;
18mod with_position;
19
20/// A utility macro for automatically deriving the correct traits
21/// depending on the enabled features.
22#[proc_macro_attribute]
23pub fn apply(
24	attr: proc_macro::TokenStream,
25	input: proc_macro::TokenStream,
26) -> proc_macro::TokenStream {
27	apply::apply(attr, input)
28}
29
30/// Generates better error messages when applied to handler functions.
31///
32/// For more information, see [`axum::debug_handler`](https://docs.rs/axum/latest/axum/attr.debug_handler.html).
33#[proc_macro_attribute]
34pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {
35	#[cfg(not(debug_assertions))]
36	return input;
37
38	#[cfg(debug_assertions)]
39	return expand_attr_with(_attr, input, |attrs, item_fn| {
40		debug_handler::expand(attrs, item_fn, debug_handler::FunctionKind::Handler)
41	});
42}
43
44/// Generates better error messages when applied to middleware functions.
45///
46/// For more information, see [`axum::debug_middleware`](https://docs.rs/axum/latest/axum/attr.debug_middleware.html).
47#[proc_macro_attribute]
48pub fn debug_middleware(_attr: TokenStream, input: TokenStream) -> TokenStream {
49	#[cfg(not(debug_assertions))]
50	return input;
51
52	#[cfg(debug_assertions)]
53	return expand_attr_with(_attr, input, |attrs, item_fn| {
54		debug_handler::expand(attrs, item_fn, debug_handler::FunctionKind::Middleware)
55	});
56}
57
58fn expand_attr_with<F, A, I, K>(attr: TokenStream, input: TokenStream, f: F) -> TokenStream
59where
60	F: FnOnce(A, I) -> K,
61	A: Parse,
62	I: Parse,
63	K: quote::ToTokens,
64{
65	let expand_result = (|| {
66		let attr = syn::parse(attr)?;
67		let input = syn::parse(input)?;
68		Ok(f(attr, input))
69	})();
70	expand(expand_result)
71}
72
73fn expand<T>(result: syn::Result<T>) -> TokenStream
74where
75	T: quote::ToTokens,
76{
77	match result {
78		Ok(tokens) => {
79			let tokens = (quote::quote! { #tokens }).into();
80			if std::env::var_os("AXUM_MACROS_DEBUG").is_some() {
81				eprintln!("{tokens}");
82			}
83			tokens
84		}
85		Err(err) => err.into_compile_error().into(),
86	}
87}
88
89fn infer_state_types<'a, I>(types: I) -> impl Iterator<Item = syn::Type> + 'a
90where
91	I: Iterator<Item = &'a syn::Type> + 'a,
92{
93	types
94		.filter_map(|ty| {
95			if let syn::Type::Path(path) = ty {
96				Some(&path.path)
97			} else {
98				None
99			}
100		})
101		.filter_map(|path| {
102			if let Some(last_segment) = path.segments.last() {
103				if last_segment.ident != "State" {
104					return None;
105				}
106
107				match &last_segment.arguments {
108					syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
109						Some(args.args.first().unwrap())
110					}
111					_ => None,
112				}
113			} else {
114				None
115			}
116		})
117		.filter_map(|generic_arg| {
118			if let syn::GenericArgument::Type(ty) = generic_arg {
119				Some(ty)
120			} else {
121				None
122			}
123		})
124		.cloned()
125}
126
127#[doc(hidden)]
128#[proc_macro]
129pub fn __private_decode_trait(input: TokenStream) -> TokenStream {
130	__private::decode_trait(input.into()).into()
131}
132
133#[doc(hidden)]
134#[proc_macro]
135pub fn __private_encode_trait(input: TokenStream) -> TokenStream {
136	__private::encode_trait(input.into()).into()
137}
138
139#[allow(unused_imports, unused_mut)]
140mod __private {
141	use proc_macro2::TokenStream;
142	use quote::quote;
143
144	pub fn decode_trait(input: TokenStream) -> TokenStream {
145		let mut codec_trait = TokenStream::default();
146		let mut codec_impl = TokenStream::default();
147
148		codec_trait.extend(quote! {
149			#input
150			#[diagnostic::on_unimplemented(
151				note = "If you're looking for a zero-copy extractor, use `BorrowCodec`"
152			)]
153			pub trait CodecDecode<'de>
154		});
155
156		codec_impl.extend(quote! {
157			impl<'de, T> CodecDecode<'de> for T
158		});
159
160		#[cfg(any(
161			feature = "bincode",
162			feature = "bitcode",
163			feature = "serde",
164			feature = "aide",
165			feature = "validator"
166		))]
167		{
168			codec_trait.extend(quote! {
169				:
170			});
171
172			codec_impl.extend(quote! {
173				where T:
174			});
175		}
176
177		let mut constraints = TokenStream::default();
178
179		#[cfg(feature = "serde")]
180		{
181			if !constraints.is_empty() {
182				constraints.extend(quote! { + });
183			}
184
185			constraints.extend(quote! {
186				serde::de::Deserialize<'de>
187			});
188		}
189
190		#[cfg(feature = "bincode")]
191		{
192			if !constraints.is_empty() {
193				constraints.extend(quote! { + });
194			}
195
196			constraints.extend(quote! {
197				bincode::BorrowDecode<'de>
198			});
199		}
200
201		#[cfg(feature = "bitcode")]
202		{
203			if !constraints.is_empty() {
204				constraints.extend(quote! { + });
205			}
206
207			constraints.extend(quote! {
208				bitcode::Decode<'de>
209			});
210		}
211
212		#[cfg(feature = "validator")]
213		{
214			if !constraints.is_empty() {
215				constraints.extend(quote! { + });
216			}
217
218			constraints.extend(quote! {
219				validator::Validate
220			});
221		}
222
223		codec_trait.extend(constraints.clone());
224		codec_impl.extend(constraints);
225
226		codec_trait.extend(quote!({}));
227		codec_impl.extend(quote!({}));
228
229		codec_trait.extend(codec_impl);
230		codec_trait
231	}
232
233	pub fn encode_trait(input: TokenStream) -> TokenStream {
234		let mut codec_trait = TokenStream::default();
235		let mut codec_impl = TokenStream::default();
236
237		codec_trait.extend(quote! {
238			#input
239			pub trait CodecEncode
240		});
241
242		codec_impl.extend(quote! {
243			impl<T> CodecEncode for T
244		});
245
246		#[cfg(any(
247			feature = "bincode",
248			feature = "bitcode",
249			feature = "serde",
250			feature = "aide",
251			feature = "validator"
252		))]
253		{
254			codec_trait.extend(quote! {
255				:
256			});
257
258			codec_impl.extend(quote! {
259				where T:
260			});
261		}
262
263		let mut constraints = TokenStream::default();
264
265		#[cfg(feature = "serde")]
266		{
267			if !constraints.is_empty() {
268				constraints.extend(quote! { + });
269			}
270
271			constraints.extend(quote! {
272				serde::Serialize
273			});
274		}
275
276		#[cfg(feature = "bincode")]
277		{
278			if !constraints.is_empty() {
279				constraints.extend(quote! { + });
280			}
281
282			constraints.extend(quote! {
283				bincode::Encode
284			});
285		}
286
287		#[cfg(feature = "bitcode")]
288		{
289			if !constraints.is_empty() {
290				constraints.extend(quote! { + });
291			}
292
293			constraints.extend(quote! {
294				bitcode::Encode
295			});
296		}
297
298		codec_trait.extend(constraints.clone());
299		codec_impl.extend(constraints);
300
301		codec_trait.extend(quote!({}));
302		codec_impl.extend(quote!({}));
303
304		codec_trait.extend(codec_impl);
305		codec_trait
306	}
307}