Skip to main content

desert_framework_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5    parse::{Parse, ParseStream},
6    parse_macro_input, FnArg, ImplItemFn, ItemStruct, Meta, Pat, PatType, Token, Type,
7};
8
9fn parse_controller_path(attr: TokenStream) -> String {
10    if attr.is_empty() {
11        return String::new();
12    }
13    let meta: Meta = syn::parse(attr).expect("expected `path = \"...\"`");
14    match meta {
15        Meta::NameValue(nv) if nv.path.is_ident("path") => match &nv.value {
16            syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
17                syn::Lit::Str(s) => s.value(),
18                _ => panic!("expected string literal for path"),
19            },
20            _ => panic!("expected string literal for path"),
21        },
22        _ => panic!("expected `path = \"...\"`"),
23    }
24}
25
26fn route_path_from_attr(attr: TokenStream) -> String {
27    let s: syn::LitStr = syn::parse(attr).expect("expected path string like `\"/path\"`");
28    s.value()
29}
30
31fn method_code(http: &str) -> u8 {
32    match http {
33        "get" => 0,
34        "post" => 1,
35        "put" => 2,
36        "delete" => 3,
37        "patch" => 4,
38        _ => 255,
39    }
40}
41
42/// #[controller(path = "/api")] on struct
43#[proc_macro_attribute]
44pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
45    let path = parse_controller_path(attr);
46    let s = parse_macro_input!(item as ItemStruct);
47    let name = &s.ident;
48
49    quote! {
50        #s
51        impl #name { pub const __CONTROLLER_PATH: &str = #path; }
52    }
53    .into()
54}
55
56/// Generate cleaned method + metadata consts + handler factory
57fn process_route_method(http: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
58    let method = parse_macro_input!(item as ImplItemFn);
59    let name = &method.sig.ident;
60    let is_async = method.sig.asyncness.is_some();
61    let code = method_code(http);
62    let path = route_path_from_attr(attr);
63
64    let extra: Vec<&FnArg> = method
65        .sig
66        .inputs
67        .iter()
68        .filter(|a| !matches!(a, FnArg::Receiver(_)))
69        .collect();
70
71    let pats: Vec<&Pat> = extra
72        .iter()
73        .map(|a| match a {
74            FnArg::Typed(PatType { pat, .. }) => pat.as_ref(),
75            _ => unreachable!(),
76        })
77        .collect();
78
79    let tys: Vec<&Type> = extra
80        .iter()
81        .map(|a| match a {
82            FnArg::Typed(PatType { ty, .. }) => ty.as_ref(),
83            _ => unreachable!(),
84        })
85        .collect();
86
87    let router_fn = code_to_ident(code);
88
89    let closure = if extra.is_empty() {
90        if is_async {
91            quote! { move || async move { state.#name().await } }
92        } else {
93            quote! { move || { state.#name() } }
94        }
95    } else if is_async {
96        quote! {
97            move |#(#pats: #tys),*| async move {
98                state.#name(#(#pats),*).await
99            }
100        }
101    } else {
102        quote! {
103            move |#(#pats: #tys),*| {
104                state.#name(#(#pats),*)
105            }
106        }
107    };
108
109    let factory_name = syn::Ident::new(&format!("__make_route_{}", name), name.span());
110    let method_const = syn::Ident::new(&format!("__ROUTE_METHOD_{}", name), name.span());
111    let path_const = syn::Ident::new(&format!("__ROUTE_PATH_{}", name), name.span());
112
113    quote! {
114        #method
115
116        #[allow(non_upper_case_globals)]
117        pub const #method_const: u8 = #code;
118        #[allow(non_upper_case_globals)]
119        pub const #path_const: &str = #path;
120
121        pub fn #factory_name(state: std::sync::Arc<Self>) -> ::axum::routing::MethodRouter<()> {
122            #router_fn(#closure)
123        }
124    }
125    .into()
126}
127
128fn code_to_ident(code: u8) -> TokenStream2 {
129    match code {
130        0 => quote! { ::axum::routing::get },
131        1 => quote! { ::axum::routing::post },
132        2 => quote! { ::axum::routing::put },
133        3 => quote! { ::axum::routing::delete },
134        4 => quote! { ::axum::routing::patch },
135        _ => unreachable!(),
136    }
137}
138
139#[proc_macro_attribute]
140pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
141    process_route_method("get", attr, item)
142}
143
144#[proc_macro_attribute]
145pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
146    process_route_method("post", attr, item)
147}
148
149#[proc_macro_attribute]
150pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
151    process_route_method("put", attr, item)
152}
153
154#[proc_macro_attribute]
155pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
156    process_route_method("delete", attr, item)
157}
158
159#[proc_macro_attribute]
160pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
161    process_route_method("patch", attr, item)
162}
163
164struct ImplRoutesInput {
165    type_: syn::Path,
166    methods: Vec<syn::Ident>,
167}
168
169impl Parse for ImplRoutesInput {
170    fn parse(input: ParseStream) -> syn::Result<Self> {
171        let type_: syn::Path = input.parse()?;
172        let _: Option<Token![,]> = input.parse()?;
173        let content;
174        syn::bracketed!(content in input);
175        let methods = content.parse_terminated(syn::Ident::parse, Token![,])?;
176        Ok(ImplRoutesInput {
177            type_,
178            methods: methods.into_iter().collect(),
179        })
180    }
181}
182
183/// impl_routes!(MyCtrl, [hello, login])
184#[proc_macro]
185pub fn impl_routes(input: TokenStream) -> TokenStream {
186    let input = parse_macro_input!(input as ImplRoutesInput);
187    let ty = &input.type_;
188    let methods = &input.methods;
189
190    let entries: Vec<TokenStream2> = methods
191        .iter()
192        .map(|m| {
193            let factory = syn::Ident::new(&format!("__make_route_{}", m), m.span());
194            let path_const = syn::Ident::new(&format!("__ROUTE_PATH_{}", m), m.span());
195
196            quote! {
197                {
198                    let __path_suffix = <#ty>::#path_const;
199                    let __full_path = ::std::format!("{}{}", <#ty>::__CONTROLLER_PATH, __path_suffix);
200                    let __mr = <#ty>::#factory(state.clone());
201                    router = router.route(&__full_path, __mr);
202                }
203            }
204        })
205        .collect();
206
207    quote! {
208        impl #ty {
209            pub fn get_router(self) -> ::axum::Router {
210                let state = ::std::sync::Arc::new(self);
211                let mut router = ::axum::Router::new();
212                #(#entries)*
213                router
214            }
215        }
216    }
217    .into()
218}