Skip to main content

desert_framework_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, FnArg, ImplItem, ImplItemFn, ItemImpl, ItemStruct, Meta, Pat, PatType,
7    Token, Type,
8};
9
10fn parse_controller_path(attr: TokenStream) -> String {
11    if attr.is_empty() {
12        return String::new();
13    }
14    let meta: Meta = syn::parse(attr).expect("expected `path = \"...\"`");
15    match meta {
16        Meta::NameValue(nv) if nv.path.is_ident("path") => match &nv.value {
17            syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
18                syn::Lit::Str(s) => s.value(),
19                _ => panic!("expected string literal for path"),
20            },
21            _ => panic!("expected string literal for path"),
22        },
23        _ => panic!("expected `path = \"...\"`"),
24    }
25}
26
27fn route_path_from_attr(attr: TokenStream) -> String {
28    let s: syn::LitStr = syn::parse(attr).expect("expected path string like `\"/path\"`");
29    s.value()
30}
31
32fn method_code(http: &str) -> u8 {
33    match http {
34        "get" => 0,
35        "post" => 1,
36        "put" => 2,
37        "delete" => 3,
38        "patch" => 4,
39        _ => 255,
40    }
41}
42
43fn code_to_ident(code: u8) -> TokenStream2 {
44    match code {
45        0 => quote! { ::axum::routing::get },
46        1 => quote! { ::axum::routing::post },
47        2 => quote! { ::axum::routing::put },
48        3 => quote! { ::axum::routing::delete },
49        4 => quote! { ::axum::routing::patch },
50        _ => unreachable!(),
51    }
52}
53
54fn extract_route_info(attr: &syn::Attribute) -> (String, String) {
55    let method_name = attr.path().segments.last().unwrap().ident.to_string();
56
57    let path = match &attr.meta {
58        Meta::List(meta_list) => {
59            let lit: syn::LitStr =
60                syn::parse2(meta_list.tokens.clone()).expect("expected path string");
61            lit.value()
62        }
63        _ => panic!("expected #[method(\"path\")]"),
64    };
65
66    (method_name, path)
67}
68
69fn is_route_attr(attr: &syn::Attribute) -> bool {
70    let ident = attr.path().segments.last().unwrap().ident.to_string();
71    matches!(ident.as_str(), "get" | "post" | "put" | "delete" | "patch")
72}
73
74/// #[controller(path = "/api")] on struct
75fn controller_on_struct(path: String, s: ItemStruct) -> TokenStream {
76    let name = &s.ident;
77
78    quote! {
79        #s
80        impl #name { pub const __CONTROLLER_PATH: &str = #path; }
81        impl ::desert_framework::ControllerRoutes for #name {
82            const CONTROLLER_PATH: &'static str = #path;
83        }
84    }
85    .into()
86}
87
88/// #[controller] on impl block — discovers route methods automatically
89fn controller_on_impl(impl_block: ItemImpl) -> TokenStream {
90    if impl_block.trait_.is_some() {
91        panic!("#[controller] on impl block is only supported for bare impls (not trait impls)");
92    }
93
94    let self_type = &impl_block.self_ty;
95
96    let type_name = match self_type.as_ref() {
97        Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
98        _ => panic!("#[controller] on impl block requires a named type"),
99    };
100
101    let mut cleaned_methods: Vec<TokenStream2> = Vec::new();
102    let mut factory_fns: Vec<TokenStream2> = Vec::new();
103    let mut inventory_submits: Vec<TokenStream2> = Vec::new();
104
105    for item in &impl_block.items {
106        if let ImplItem::Fn(method) = item {
107            let route_attr = method.attrs.iter().find(|a| is_route_attr(a));
108
109            if let Some(attr) = route_attr {
110                let (http_method, route_path) = extract_route_info(attr);
111                let code = method_code(&http_method);
112                let name = &method.sig.ident;
113                let is_async = method.sig.asyncness.is_some();
114                let router_fn = code_to_ident(code);
115
116                let extra: Vec<&FnArg> = method
117                    .sig
118                    .inputs
119                    .iter()
120                    .filter(|a| !matches!(a, FnArg::Receiver(_)))
121                    .collect();
122
123                let pats: Vec<&Pat> = extra
124                    .iter()
125                    .map(|a| match a {
126                        FnArg::Typed(PatType { pat, .. }) => pat.as_ref(),
127                        _ => unreachable!(),
128                    })
129                    .collect();
130
131                let tys: Vec<&Type> = extra
132                    .iter()
133                    .map(|a| match a {
134                        FnArg::Typed(PatType { ty, .. }) => ty.as_ref(),
135                        _ => unreachable!(),
136                    })
137                    .collect();
138
139                let closure = if extra.is_empty() {
140                    if is_async {
141                        quote! { move || async move { state.#name().await } }
142                    } else {
143                        quote! { move || { state.#name() } }
144                    }
145                } else if is_async {
146                    quote! {
147                        move |#(#pats: #tys),*| async move {
148                            state.#name(#(#pats),*).await
149                        }
150                    }
151                } else {
152                    quote! {
153                        move |#(#pats: #tys),*| {
154                            state.#name(#(#pats),*)
155                        }
156                    }
157                };
158
159                let factory_name = syn::Ident::new(&format!("__make_route_{}", name), name.span());
160
161                // Cleaned method (without route attribute)
162                let non_route_attrs: Vec<_> =
163                    method.attrs.iter().filter(|a| !is_route_attr(a)).collect();
164                let vis = &method.vis;
165                let sig = &method.sig;
166                let block = &method.block;
167
168                cleaned_methods.push(quote! {
169                    #(#non_route_attrs)*
170                    #vis #sig #block
171                });
172
173                // Factory function (type-erased)
174                factory_fns.push(quote! {
175                    fn #factory_name(
176                        state: ::std::sync::Arc<dyn ::std::any::Any + Send + Sync>,
177                    ) -> ::axum::routing::MethodRouter<()> {
178                        let state = state.downcast::<#type_name>().unwrap();
179                        #router_fn(#closure)
180                    }
181                });
182
183                // inventory::submit!
184                inventory_submits.push(quote! {
185                    ::desert_framework::inventory::submit! {
186                        ::desert_framework::RouteEntry {
187                            controller_type_id: ::std::any::TypeId::of::<#type_name>(),
188                            path: #route_path,
189                            method: #code,
190                            make_route: #factory_name,
191                        }
192                    }
193                });
194            } else {
195                cleaned_methods.push(quote! { #method });
196            }
197        } else {
198            cleaned_methods.push(quote! { #item });
199        }
200    }
201
202    let defaultness = &impl_block.defaultness;
203    let generics = &impl_block.generics;
204    let self_ty = &impl_block.self_ty;
205    let where_clause = &generics.where_clause;
206
207    quote! {
208        #defaultness impl #generics #self_ty #where_clause {
209            #(#cleaned_methods)*
210        }
211
212        #(#factory_fns)*
213        #(#inventory_submits)*
214    }
215    .into()
216}
217
218// ─── #[controller] dispatch ───
219
220#[proc_macro_attribute]
221pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
222    let input = item.clone();
223    if let Ok(s) = syn::parse::<ItemStruct>(input) {
224        let path = parse_controller_path(attr);
225        return controller_on_struct(path, s);
226    }
227
228    let input = item.clone();
229    if let Ok(impl_block) = syn::parse::<ItemImpl>(input) {
230        return controller_on_impl(impl_block);
231    }
232
233    panic!("#[controller] can only be applied to structs or impl blocks");
234}
235
236// ─── Standalone route attributes (for backward compat) ───
237
238fn process_route_method(http: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
239    let method = parse_macro_input!(item as ImplItemFn);
240    let name = &method.sig.ident;
241    let is_async = method.sig.asyncness.is_some();
242    let code = method_code(http);
243    let path = route_path_from_attr(attr);
244
245    let extra: Vec<&FnArg> = method
246        .sig
247        .inputs
248        .iter()
249        .filter(|a| !matches!(a, FnArg::Receiver(_)))
250        .collect();
251
252    let pats: Vec<&Pat> = extra
253        .iter()
254        .map(|a| match a {
255            FnArg::Typed(PatType { pat, .. }) => pat.as_ref(),
256            _ => unreachable!(),
257        })
258        .collect();
259
260    let tys: Vec<&Type> = extra
261        .iter()
262        .map(|a| match a {
263            FnArg::Typed(PatType { ty, .. }) => ty.as_ref(),
264            _ => unreachable!(),
265        })
266        .collect();
267
268    let router_fn = code_to_ident(code);
269
270    let closure = if extra.is_empty() {
271        if is_async {
272            quote! { move || async move { state.#name().await } }
273        } else {
274            quote! { move || { state.#name() } }
275        }
276    } else if is_async {
277        quote! {
278            move |#(#pats: #tys),*| async move {
279                state.#name(#(#pats),*).await
280            }
281        }
282    } else {
283        quote! {
284            move |#(#pats: #tys),*| {
285                state.#name(#(#pats),*)
286            }
287        }
288    };
289
290    let factory_name = syn::Ident::new(&format!("__make_route_{}", name), name.span());
291    let method_const = syn::Ident::new(&format!("__ROUTE_METHOD_{}", name), name.span());
292    let path_const = syn::Ident::new(&format!("__ROUTE_PATH_{}", name), name.span());
293
294    quote! {
295        #method
296
297        #[allow(non_upper_case_globals)]
298        pub const #method_const: u8 = #code;
299        #[allow(non_upper_case_globals)]
300        pub const #path_const: &str = #path;
301
302        pub fn #factory_name(state: std::sync::Arc<Self>) -> ::axum::routing::MethodRouter<()> {
303            #router_fn(#closure)
304        }
305    }
306    .into()
307}
308
309#[proc_macro_attribute]
310pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
311    process_route_method("get", attr, item)
312}
313
314#[proc_macro_attribute]
315pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
316    process_route_method("post", attr, item)
317}
318
319#[proc_macro_attribute]
320pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
321    process_route_method("put", attr, item)
322}
323
324#[proc_macro_attribute]
325pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
326    process_route_method("delete", attr, item)
327}
328
329#[proc_macro_attribute]
330pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
331    process_route_method("patch", attr, item)
332}
333
334// ─── impl_routes! (backward compat) ───
335
336struct ImplRoutesInput {
337    type_: syn::Path,
338    methods: Vec<syn::Ident>,
339}
340
341impl Parse for ImplRoutesInput {
342    fn parse(input: ParseStream) -> syn::Result<Self> {
343        let type_: syn::Path = input.parse()?;
344        let _: Option<Token![,]> = input.parse()?;
345        let content;
346        syn::bracketed!(content in input);
347        let methods = content.parse_terminated(syn::Ident::parse, Token![,])?;
348        Ok(ImplRoutesInput {
349            type_,
350            methods: methods.into_iter().collect(),
351        })
352    }
353}
354
355#[proc_macro]
356pub fn impl_routes(input: TokenStream) -> TokenStream {
357    let input = parse_macro_input!(input as ImplRoutesInput);
358    let ty = &input.type_;
359    let methods = &input.methods;
360
361    let entries: Vec<TokenStream2> = methods
362        .iter()
363        .map(|m| {
364            let factory = syn::Ident::new(&format!("__make_route_{}", m), m.span());
365            let path_const = syn::Ident::new(&format!("__ROUTE_PATH_{}", m), m.span());
366
367            quote! {
368                {
369                    let __path_suffix = <#ty>::#path_const;
370                    let __full_path = ::std::format!("{}{}", <#ty>::__CONTROLLER_PATH, __path_suffix);
371                    let __mr = <#ty>::#factory(state.clone());
372                    router = router.route(&__full_path, __mr);
373                }
374            }
375        })
376        .collect();
377
378    quote! {
379        impl #ty {
380            pub fn get_router(self) -> ::axum::Router {
381                let state = ::std::sync::Arc::new(self);
382                let mut router = ::axum::Router::new();
383                #(#entries)*
384                router
385            }
386        }
387    }
388    .into()
389}