1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::{
5 parse::{Parse, ParseStream},
6 parse_macro_input, FnArg, ImplItem, ImplItemFn, ItemImpl, ItemStruct, Meta, Pat, PatType,
7 Token, Type,
8};
9
10fn parse_controller_path(attr: TokenStream) -> String {
11 if attr.is_empty() {
12 return String::new();
13 }
14 let meta: Meta = syn::parse(attr).expect("expected `path = \"...\"`");
15 match meta {
16 Meta::NameValue(nv) if nv.path.is_ident("path") => match &nv.value {
17 syn::Expr::Lit(expr_lit) => match &expr_lit.lit {
18 syn::Lit::Str(s) => s.value(),
19 _ => panic!("expected string literal for path"),
20 },
21 _ => panic!("expected string literal for path"),
22 },
23 _ => panic!("expected `path = \"...\"`"),
24 }
25}
26
27fn route_path_from_attr(attr: TokenStream) -> String {
28 let s: syn::LitStr = syn::parse(attr).expect("expected path string like `\"/path\"`");
29 s.value()
30}
31
32fn method_code(http: &str) -> u8 {
33 match http {
34 "get" => 0,
35 "post" => 1,
36 "put" => 2,
37 "delete" => 3,
38 "patch" => 4,
39 _ => 255,
40 }
41}
42
43fn code_to_ident(code: u8) -> TokenStream2 {
44 match code {
45 0 => quote! { ::axum::routing::get },
46 1 => quote! { ::axum::routing::post },
47 2 => quote! { ::axum::routing::put },
48 3 => quote! { ::axum::routing::delete },
49 4 => quote! { ::axum::routing::patch },
50 _ => unreachable!(),
51 }
52}
53
54fn extract_route_info(attr: &syn::Attribute) -> (String, String) {
55 let method_name = attr.path().segments.last().unwrap().ident.to_string();
56
57 let path = match &attr.meta {
58 Meta::List(meta_list) => {
59 let lit: syn::LitStr =
60 syn::parse2(meta_list.tokens.clone()).expect("expected path string");
61 lit.value()
62 }
63 _ => panic!("expected #[method(\"path\")]"),
64 };
65
66 (method_name, path)
67}
68
69fn is_route_attr(attr: &syn::Attribute) -> bool {
70 let ident = attr.path().segments.last().unwrap().ident.to_string();
71 matches!(ident.as_str(), "get" | "post" | "put" | "delete" | "patch")
72}
73
74fn controller_on_struct(path: String, s: ItemStruct) -> TokenStream {
76 let name = &s.ident;
77
78 quote! {
79 #s
80 impl #name { pub const __CONTROLLER_PATH: &str = #path; }
81 impl ::desert_framework::ControllerRoutes for #name {
82 const CONTROLLER_PATH: &'static str = #path;
83 }
84 }
85 .into()
86}
87
88fn controller_on_impl(impl_block: ItemImpl) -> TokenStream {
90 if impl_block.trait_.is_some() {
91 panic!("#[controller] on impl block is only supported for bare impls (not trait impls)");
92 }
93
94 let self_type = &impl_block.self_ty;
95
96 let type_name = match self_type.as_ref() {
97 Type::Path(type_path) => type_path.path.segments.last().unwrap().ident.clone(),
98 _ => panic!("#[controller] on impl block requires a named type"),
99 };
100
101 let mut cleaned_methods: Vec<TokenStream2> = Vec::new();
102 let mut factory_fns: Vec<TokenStream2> = Vec::new();
103 let mut inventory_submits: Vec<TokenStream2> = Vec::new();
104
105 for item in &impl_block.items {
106 if let ImplItem::Fn(method) = item {
107 let route_attr = method.attrs.iter().find(|a| is_route_attr(a));
108
109 if let Some(attr) = route_attr {
110 let (http_method, route_path) = extract_route_info(attr);
111 let code = method_code(&http_method);
112 let name = &method.sig.ident;
113 let is_async = method.sig.asyncness.is_some();
114 let router_fn = code_to_ident(code);
115
116 let extra: Vec<&FnArg> = method
117 .sig
118 .inputs
119 .iter()
120 .filter(|a| !matches!(a, FnArg::Receiver(_)))
121 .collect();
122
123 let pats: Vec<&Pat> = extra
124 .iter()
125 .map(|a| match a {
126 FnArg::Typed(PatType { pat, .. }) => pat.as_ref(),
127 _ => unreachable!(),
128 })
129 .collect();
130
131 let tys: Vec<&Type> = extra
132 .iter()
133 .map(|a| match a {
134 FnArg::Typed(PatType { ty, .. }) => ty.as_ref(),
135 _ => unreachable!(),
136 })
137 .collect();
138
139 let closure = if extra.is_empty() {
140 if is_async {
141 quote! { move || async move { state.#name().await } }
142 } else {
143 quote! { move || { state.#name() } }
144 }
145 } else if is_async {
146 quote! {
147 move |#(#pats: #tys),*| async move {
148 state.#name(#(#pats),*).await
149 }
150 }
151 } else {
152 quote! {
153 move |#(#pats: #tys),*| {
154 state.#name(#(#pats),*)
155 }
156 }
157 };
158
159 let factory_name = syn::Ident::new(&format!("__make_route_{}", name), name.span());
160
161 let non_route_attrs: Vec<_> =
163 method.attrs.iter().filter(|a| !is_route_attr(a)).collect();
164 let vis = &method.vis;
165 let sig = &method.sig;
166 let block = &method.block;
167
168 cleaned_methods.push(quote! {
169 #(#non_route_attrs)*
170 #vis #sig #block
171 });
172
173 factory_fns.push(quote! {
175 fn #factory_name(
176 state: ::std::sync::Arc<dyn ::std::any::Any + Send + Sync>,
177 ) -> ::axum::routing::MethodRouter<()> {
178 let state = state.downcast::<#type_name>().unwrap();
179 #router_fn(#closure)
180 }
181 });
182
183 inventory_submits.push(quote! {
185 ::desert_framework::inventory::submit! {
186 ::desert_framework::RouteEntry {
187 controller_type_id: ::std::any::TypeId::of::<#type_name>(),
188 path: #route_path,
189 method: #code,
190 make_route: #factory_name,
191 }
192 }
193 });
194 } else {
195 cleaned_methods.push(quote! { #method });
196 }
197 } else {
198 cleaned_methods.push(quote! { #item });
199 }
200 }
201
202 let defaultness = &impl_block.defaultness;
203 let generics = &impl_block.generics;
204 let self_ty = &impl_block.self_ty;
205 let where_clause = &generics.where_clause;
206
207 quote! {
208 #defaultness impl #generics #self_ty #where_clause {
209 #(#cleaned_methods)*
210 }
211
212 #(#factory_fns)*
213 #(#inventory_submits)*
214 }
215 .into()
216}
217
218#[proc_macro_attribute]
221pub fn controller(attr: TokenStream, item: TokenStream) -> TokenStream {
222 let input = item.clone();
223 if let Ok(s) = syn::parse::<ItemStruct>(input) {
224 let path = parse_controller_path(attr);
225 return controller_on_struct(path, s);
226 }
227
228 let input = item.clone();
229 if let Ok(impl_block) = syn::parse::<ItemImpl>(input) {
230 return controller_on_impl(impl_block);
231 }
232
233 panic!("#[controller] can only be applied to structs or impl blocks");
234}
235
236fn process_route_method(http: &str, attr: TokenStream, item: TokenStream) -> TokenStream {
239 let method = parse_macro_input!(item as ImplItemFn);
240 let name = &method.sig.ident;
241 let is_async = method.sig.asyncness.is_some();
242 let code = method_code(http);
243 let path = route_path_from_attr(attr);
244
245 let extra: Vec<&FnArg> = method
246 .sig
247 .inputs
248 .iter()
249 .filter(|a| !matches!(a, FnArg::Receiver(_)))
250 .collect();
251
252 let pats: Vec<&Pat> = extra
253 .iter()
254 .map(|a| match a {
255 FnArg::Typed(PatType { pat, .. }) => pat.as_ref(),
256 _ => unreachable!(),
257 })
258 .collect();
259
260 let tys: Vec<&Type> = extra
261 .iter()
262 .map(|a| match a {
263 FnArg::Typed(PatType { ty, .. }) => ty.as_ref(),
264 _ => unreachable!(),
265 })
266 .collect();
267
268 let router_fn = code_to_ident(code);
269
270 let closure = if extra.is_empty() {
271 if is_async {
272 quote! { move || async move { state.#name().await } }
273 } else {
274 quote! { move || { state.#name() } }
275 }
276 } else if is_async {
277 quote! {
278 move |#(#pats: #tys),*| async move {
279 state.#name(#(#pats),*).await
280 }
281 }
282 } else {
283 quote! {
284 move |#(#pats: #tys),*| {
285 state.#name(#(#pats),*)
286 }
287 }
288 };
289
290 let factory_name = syn::Ident::new(&format!("__make_route_{}", name), name.span());
291 let method_const = syn::Ident::new(&format!("__ROUTE_METHOD_{}", name), name.span());
292 let path_const = syn::Ident::new(&format!("__ROUTE_PATH_{}", name), name.span());
293
294 quote! {
295 #method
296
297 #[allow(non_upper_case_globals)]
298 pub const #method_const: u8 = #code;
299 #[allow(non_upper_case_globals)]
300 pub const #path_const: &str = #path;
301
302 pub fn #factory_name(state: std::sync::Arc<Self>) -> ::axum::routing::MethodRouter<()> {
303 #router_fn(#closure)
304 }
305 }
306 .into()
307}
308
309#[proc_macro_attribute]
310pub fn get(attr: TokenStream, item: TokenStream) -> TokenStream {
311 process_route_method("get", attr, item)
312}
313
314#[proc_macro_attribute]
315pub fn post(attr: TokenStream, item: TokenStream) -> TokenStream {
316 process_route_method("post", attr, item)
317}
318
319#[proc_macro_attribute]
320pub fn put(attr: TokenStream, item: TokenStream) -> TokenStream {
321 process_route_method("put", attr, item)
322}
323
324#[proc_macro_attribute]
325pub fn delete(attr: TokenStream, item: TokenStream) -> TokenStream {
326 process_route_method("delete", attr, item)
327}
328
329#[proc_macro_attribute]
330pub fn patch(attr: TokenStream, item: TokenStream) -> TokenStream {
331 process_route_method("patch", attr, item)
332}
333
334struct ImplRoutesInput {
337 type_: syn::Path,
338 methods: Vec<syn::Ident>,
339}
340
341impl Parse for ImplRoutesInput {
342 fn parse(input: ParseStream) -> syn::Result<Self> {
343 let type_: syn::Path = input.parse()?;
344 let _: Option<Token![,]> = input.parse()?;
345 let content;
346 syn::bracketed!(content in input);
347 let methods = content.parse_terminated(syn::Ident::parse, Token![,])?;
348 Ok(ImplRoutesInput {
349 type_,
350 methods: methods.into_iter().collect(),
351 })
352 }
353}
354
355#[proc_macro]
356pub fn impl_routes(input: TokenStream) -> TokenStream {
357 let input = parse_macro_input!(input as ImplRoutesInput);
358 let ty = &input.type_;
359 let methods = &input.methods;
360
361 let entries: Vec<TokenStream2> = methods
362 .iter()
363 .map(|m| {
364 let factory = syn::Ident::new(&format!("__make_route_{}", m), m.span());
365 let path_const = syn::Ident::new(&format!("__ROUTE_PATH_{}", m), m.span());
366
367 quote! {
368 {
369 let __path_suffix = <#ty>::#path_const;
370 let __full_path = ::std::format!("{}{}", <#ty>::__CONTROLLER_PATH, __path_suffix);
371 let __mr = <#ty>::#factory(state.clone());
372 router = router.route(&__full_path, __mr);
373 }
374 }
375 })
376 .collect();
377
378 quote! {
379 impl #ty {
380 pub fn get_router(self) -> ::axum::Router {
381 let state = ::std::sync::Arc::new(self);
382 let mut router = ::axum::Router::new();
383 #(#entries)*
384 router
385 }
386 }
387 }
388 .into()
389}