1#![cfg_attr(
2 not(any(
3 feature = "bincode",
4 feature = "bitcode",
5 feature = "serde",
6 feature = "aide",
7 feature = "validator"
8 )),
9 allow(unused_variables, dead_code)
10)]
11
12use proc_macro::TokenStream;
13use syn::parse::Parse;
14
15mod apply;
16mod attr_parsing;
17mod debug_handler;
18mod with_position;
19
20#[proc_macro_attribute]
23pub fn apply(
24 attr: proc_macro::TokenStream,
25 input: proc_macro::TokenStream,
26) -> proc_macro::TokenStream {
27 apply::apply(attr, input)
28}
29
30#[proc_macro_attribute]
34pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream {
35 #[cfg(not(debug_assertions))]
36 return input;
37
38 #[cfg(debug_assertions)]
39 return expand_attr_with(_attr, input, |attrs, item_fn| {
40 debug_handler::expand(attrs, item_fn, debug_handler::FunctionKind::Handler)
41 });
42}
43
44#[proc_macro_attribute]
48pub fn debug_middleware(_attr: TokenStream, input: TokenStream) -> TokenStream {
49 #[cfg(not(debug_assertions))]
50 return input;
51
52 #[cfg(debug_assertions)]
53 return expand_attr_with(_attr, input, |attrs, item_fn| {
54 debug_handler::expand(attrs, item_fn, debug_handler::FunctionKind::Middleware)
55 });
56}
57
58fn expand_attr_with<F, A, I, K>(attr: TokenStream, input: TokenStream, f: F) -> TokenStream
59where
60 F: FnOnce(A, I) -> K,
61 A: Parse,
62 I: Parse,
63 K: quote::ToTokens,
64{
65 let expand_result = (|| {
66 let attr = syn::parse(attr)?;
67 let input = syn::parse(input)?;
68 Ok(f(attr, input))
69 })();
70 expand(expand_result)
71}
72
73fn expand<T>(result: syn::Result<T>) -> TokenStream
74where
75 T: quote::ToTokens,
76{
77 match result {
78 Ok(tokens) => {
79 let tokens = (quote::quote! { #tokens }).into();
80 if std::env::var_os("AXUM_MACROS_DEBUG").is_some() {
81 eprintln!("{tokens}");
82 }
83 tokens
84 }
85 Err(err) => err.into_compile_error().into(),
86 }
87}
88
89fn infer_state_types<'a, I>(types: I) -> impl Iterator<Item = syn::Type> + 'a
90where
91 I: Iterator<Item = &'a syn::Type> + 'a,
92{
93 types
94 .filter_map(|ty| {
95 if let syn::Type::Path(path) = ty {
96 Some(&path.path)
97 } else {
98 None
99 }
100 })
101 .filter_map(|path| {
102 if let Some(last_segment) = path.segments.last() {
103 if last_segment.ident != "State" {
104 return None;
105 }
106
107 match &last_segment.arguments {
108 syn::PathArguments::AngleBracketed(args) if args.args.len() == 1 => {
109 Some(args.args.first().unwrap())
110 }
111 _ => None,
112 }
113 } else {
114 None
115 }
116 })
117 .filter_map(|generic_arg| {
118 if let syn::GenericArgument::Type(ty) = generic_arg {
119 Some(ty)
120 } else {
121 None
122 }
123 })
124 .cloned()
125}
126
127#[doc(hidden)]
128#[proc_macro]
129pub fn __private_decode_trait(input: TokenStream) -> TokenStream {
130 __private::decode_trait(input.into()).into()
131}
132
133#[doc(hidden)]
134#[proc_macro]
135pub fn __private_encode_trait(input: TokenStream) -> TokenStream {
136 __private::encode_trait(input.into()).into()
137}
138
139#[allow(unused_imports, unused_mut)]
140mod __private {
141 use proc_macro2::TokenStream;
142 use quote::quote;
143
144 pub fn decode_trait(input: TokenStream) -> TokenStream {
145 let mut codec_trait = TokenStream::default();
146 let mut codec_impl = TokenStream::default();
147
148 codec_trait.extend(quote! {
149 #input
150 #[diagnostic::on_unimplemented(
151 note = "If you're looking for a zero-copy extractor, use `BorrowCodec`"
152 )]
153 pub trait CodecDecode<'de>
154 });
155
156 codec_impl.extend(quote! {
157 impl<'de, T> CodecDecode<'de> for T
158 });
159
160 #[cfg(any(
161 feature = "bincode",
162 feature = "bitcode",
163 feature = "serde",
164 feature = "aide",
165 feature = "validator"
166 ))]
167 {
168 codec_trait.extend(quote! {
169 :
170 });
171
172 codec_impl.extend(quote! {
173 where T:
174 });
175 }
176
177 let mut constraints = TokenStream::default();
178
179 #[cfg(feature = "serde")]
180 {
181 if !constraints.is_empty() {
182 constraints.extend(quote! { + });
183 }
184
185 constraints.extend(quote! {
186 serde::de::Deserialize<'de>
187 });
188 }
189
190 #[cfg(feature = "bincode")]
191 {
192 if !constraints.is_empty() {
193 constraints.extend(quote! { + });
194 }
195
196 constraints.extend(quote! {
197 bincode::BorrowDecode<'de>
198 });
199 }
200
201 #[cfg(feature = "bitcode")]
202 {
203 if !constraints.is_empty() {
204 constraints.extend(quote! { + });
205 }
206
207 constraints.extend(quote! {
208 bitcode::Decode<'de>
209 });
210 }
211
212 #[cfg(feature = "validator")]
213 {
214 if !constraints.is_empty() {
215 constraints.extend(quote! { + });
216 }
217
218 constraints.extend(quote! {
219 validator::Validate
220 });
221 }
222
223 codec_trait.extend(constraints.clone());
224 codec_impl.extend(constraints);
225
226 codec_trait.extend(quote!({}));
227 codec_impl.extend(quote!({}));
228
229 codec_trait.extend(codec_impl);
230 codec_trait
231 }
232
233 pub fn encode_trait(input: TokenStream) -> TokenStream {
234 let mut codec_trait = TokenStream::default();
235 let mut codec_impl = TokenStream::default();
236
237 codec_trait.extend(quote! {
238 #input
239 pub trait CodecEncode
240 });
241
242 codec_impl.extend(quote! {
243 impl<T> CodecEncode for T
244 });
245
246 #[cfg(any(
247 feature = "bincode",
248 feature = "bitcode",
249 feature = "serde",
250 feature = "aide",
251 feature = "validator"
252 ))]
253 {
254 codec_trait.extend(quote! {
255 :
256 });
257
258 codec_impl.extend(quote! {
259 where T:
260 });
261 }
262
263 let mut constraints = TokenStream::default();
264
265 #[cfg(feature = "serde")]
266 {
267 if !constraints.is_empty() {
268 constraints.extend(quote! { + });
269 }
270
271 constraints.extend(quote! {
272 serde::Serialize
273 });
274 }
275
276 #[cfg(feature = "bincode")]
277 {
278 if !constraints.is_empty() {
279 constraints.extend(quote! { + });
280 }
281
282 constraints.extend(quote! {
283 bincode::Encode
284 });
285 }
286
287 #[cfg(feature = "bitcode")]
288 {
289 if !constraints.is_empty() {
290 constraints.extend(quote! { + });
291 }
292
293 constraints.extend(quote! {
294 bitcode::Encode
295 });
296 }
297
298 codec_trait.extend(constraints.clone());
299 codec_impl.extend(constraints);
300
301 codec_trait.extend(quote!({}));
302 codec_impl.extend(quote!({}));
303
304 codec_trait.extend(codec_impl);
305 codec_trait
306 }
307}