spring_axum_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::{parse_macro_input, Ident, ItemFn, ItemStruct, ItemImpl, ImplItem, LitStr, Expr, FnArg, Type, TypePath};
4use syn::parse::{Parse, ParseStream};
5use syn::punctuated::Punctuated;
6use syn::Token;
7
8fn route_macro_impl(method_ident: &str, path: LitStr, input: ItemFn, layer: Option<Expr>) -> TokenStream {
9    let fn_name: Ident = input.sig.ident.clone();
10    let route_fn_name = Ident::new(&format!("__spring_axum_route_{}", fn_name), fn_name.span());
11    let method_ident = Ident::new(method_ident, fn_name.span());
12    let base_route = quote! { ::spring_axum::Router::new().route(#path, ::spring_axum::#method_ident(#fn_name)) };
13    let router_stmt = if let Some(layer_expr) = layer {
14        quote! { #base_route.route_layer(#layer_expr) }
15    } else {
16        base_route
17    };
18    let expanded = quote! {
19        #input
20
21        #[allow(non_snake_case)]
22        pub fn #route_fn_name() -> ::spring_axum::Router {
23            #router_stmt
24        }
25    };
26    expanded.into()
27}
28
29// Parser for route attributes supporting: path literal and optional `layer = <expr>`
30struct RouteArgs {
31    path: LitStr,
32    layer: Option<Expr>,
33}
34
35impl Parse for RouteArgs {
36    fn parse(input: ParseStream) -> syn::Result<Self> {
37        let path: LitStr = input.parse()?;
38        let mut layer: Option<Expr> = None;
39        if input.peek(Token![,]) {
40            let _comma: Token![,] = input.parse()?;
41            let key: Ident = input.parse()?;
42            if key == "layer" {
43                let _eq: Token![=] = input.parse()?;
44                let expr: Expr = input.parse()?;
45                layer = Some(expr);
46            } else {
47                return Err(syn::Error::new(key.span(), "unsupported key, expected `layer`"));
48            }
49        }
50        Ok(Self { path, layer })
51    }
52}
53
54#[proc_macro_attribute]
55pub fn route_get(args: TokenStream, input: TokenStream) -> TokenStream {
56    let args = parse_macro_input!(args as RouteArgs);
57    let input = parse_macro_input!(input as ItemFn);
58    route_macro_impl("get", args.path, input, args.layer)
59}
60
61#[proc_macro_attribute]
62pub fn route_post(args: TokenStream, input: TokenStream) -> TokenStream {
63    let args = parse_macro_input!(args as RouteArgs);
64    let input = parse_macro_input!(input as ItemFn);
65    route_macro_impl("post", args.path, input, args.layer)
66}
67
68#[proc_macro_attribute]
69pub fn route_put(args: TokenStream, input: TokenStream) -> TokenStream {
70    let args = parse_macro_input!(args as RouteArgs);
71    let input = parse_macro_input!(input as ItemFn);
72    route_macro_impl("put", args.path, input, args.layer)
73}
74
75#[proc_macro_attribute]
76pub fn route_delete(args: TokenStream, input: TokenStream) -> TokenStream {
77    let args = parse_macro_input!(args as RouteArgs);
78    let input = parse_macro_input!(input as ItemFn);
79    route_macro_impl("delete", args.path, input, args.layer)
80}
81
82fn replace_type_path_to_validated(tp: TypePath, target: &str, replacement: &str) -> Option<Type> {
83    // Match last segment ident (Json or Query), replace with ::spring_axum::<ValidatedJson/ValidatedQuery>
84    let last = tp.path.segments.last()?.ident.to_string();
85    if last == target {
86        // Extract generic argument inside <T>
87        if let syn::PathArguments::AngleBracketed(args) = &tp.path.segments.last()?.arguments {
88            if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
89                let new_ty: Type = syn::parse_str(&format!("::spring_axum::{}<{}>", replacement, quote!(#inner_ty))).ok()?;
90                return Some(new_ty);
91            }
92        }
93    }
94    None
95}
96
97fn validate_transform(input: ItemFn, kind: &str) -> TokenStream {
98    let mut func = input.clone();
99    for arg in func.sig.inputs.iter_mut() {
100        if let FnArg::Typed(pt) = arg {
101            if let Type::Path(tp) = &*pt.ty {
102                let (target, repl) = match kind {
103                    "json" => ("Json", "ValidatedJson"),
104                    "query" => ("Query", "ValidatedQuery"),
105                    "form" => ("Form", "ValidatedForm"),
106                    "json_stream" => ("Json", "ValidatedJsonStream"),
107                    _ => unreachable!(),
108                };
109                if let Some(new_ty) = replace_type_path_to_validated(tp.clone(), target, repl) {
110                    pt.ty = Box::new(new_ty);
111                }
112            }
113        }
114    }
115    quote! { #func }.into()
116}
117
118#[proc_macro_attribute]
119pub fn validate_json(_args: TokenStream, input: TokenStream) -> TokenStream {
120    let input = parse_macro_input!(input as ItemFn);
121    validate_transform(input, "json")
122}
123
124#[proc_macro_attribute]
125pub fn validate_query(_args: TokenStream, input: TokenStream) -> TokenStream {
126    let input = parse_macro_input!(input as ItemFn);
127    validate_transform(input, "query")
128}
129
130#[proc_macro_attribute]
131pub fn validate_form(_args: TokenStream, input: TokenStream) -> TokenStream {
132    let input = parse_macro_input!(input as ItemFn);
133    validate_transform(input, "form")
134}
135
136#[proc_macro_attribute]
137pub fn validate_json_stream(_args: TokenStream, input: TokenStream) -> TokenStream {
138    let input = parse_macro_input!(input as ItemFn);
139    validate_transform(input, "json_stream")
140}
141
142struct IdentList(Punctuated<Ident, Token![,]>);
143impl Parse for IdentList {
144    fn parse(input: ParseStream) -> syn::Result<Self> {
145        Ok(Self(Punctuated::parse_terminated(input)?))
146    }
147}
148
149#[proc_macro_attribute]
150pub fn controller(args: TokenStream, input: TokenStream) -> TokenStream {
151    let input_struct = parse_macro_input!(input as ItemStruct);
152    let idents = parse_macro_input!(args as IdentList).0;
153
154    let routes: Vec<Ident> = idents
155        .iter()
156        .map(|i| Ident::new(&format!("__spring_axum_route_{}", i), i.span()))
157        .collect();
158
159    let name = input_struct.ident.clone();
160    let mut merge_stmts = quote! { let mut router = ::spring_axum::Router::new(); };
161    for r in routes {
162        merge_stmts = quote! {
163            #merge_stmts
164            router = router.merge(#r());
165        };
166    }
167
168    let expanded = quote! {
169        #input_struct
170
171        impl ::spring_axum::Controller for #name {
172            fn routes(&self) -> ::spring_axum::Router {
173                #merge_stmts
174                router
175            }
176        }
177
178        // Auto-register this controller's routes into inventory for discovery
179        inventory::submit!(::spring_axum::ControllerRouterRegistration {
180            init: || -> ::spring_axum::Router {
181                #merge_stmts
182                router
183            },
184        });
185    };
186    expanded.into()
187}
188
189#[proc_macro_attribute]
190pub fn component(_args: TokenStream, input: TokenStream) -> TokenStream {
191    let input_item = parse_macro_input!(input as ItemStruct);
192    let name = input_item.ident.clone();
193
194    // Generate inventory registration requiring Default
195    let expanded = quote! {
196        #input_item
197
198        inventory::submit!(::spring_axum::ComponentRegistration {
199            init: |_: &::spring_axum::ApplicationContext| -> (::std::any::TypeId, Box<dyn ::std::any::Any + Send + Sync>) {
200                let value: #name = ::std::default::Default::default();
201                let arc: ::std::sync::Arc<#name> = ::std::sync::Arc::new(value);
202                (::std::any::TypeId::of::<#name>(), Box::new(arc))
203            },
204        });
205    };
206    expanded.into()
207}
208
209#[proc_macro_attribute]
210pub fn interceptor(_args: TokenStream, input: TokenStream) -> TokenStream {
211    let input_item = parse_macro_input!(input as ItemStruct);
212    let name = input_item.ident.clone();
213
214    let expanded = quote! {
215        #input_item
216
217        // Auto-register interceptor application into inventory (requires Default)
218        inventory::submit!(::spring_axum::InterceptorRegistration {
219            apply: |router: ::spring_axum::Router| -> ::spring_axum::Router {
220                router.layer(::spring_axum::InterceptorLayer::new(#name::default()))
221            },
222        });
223    };
224    expanded.into()
225}
226
227// ---------------- Transactions & Cache Macros -----------------
228
229#[proc_macro_attribute]
230pub fn transactional(_args: TokenStream, input: TokenStream) -> TokenStream {
231    let mut func = parse_macro_input!(input as ItemFn);
232    let body = func.block.clone();
233    func.block = Box::new(syn::parse_quote!({
234        ::spring_axum::transaction(|| async move { #body }).await
235    }));
236    quote! { #func }.into()
237}
238
239// Attribute macro marker to opt-out of transaction wrapping inside #[tx_service]
240#[proc_macro_attribute]
241pub fn non_tx(_args: TokenStream, input: TokenStream) -> TokenStream {
242    // Acts as a marker; actual stripping happens in tx_service. Leave item unchanged.
243    input
244}
245
246// Apply transactional wrapping to methods within an impl block by default.
247// Methods annotated with #[non_tx] are left unchanged.
248#[proc_macro_attribute]
249pub fn tx_service(_args: TokenStream, input: TokenStream) -> TokenStream {
250    let mut item_impl = parse_macro_input!(input as ItemImpl);
251
252    for impl_item in item_impl.items.iter_mut() {
253        if let ImplItem::Fn(method) = impl_item {
254            // Detect #[non_tx] marker and strip it
255            let has_non_tx = method.attrs.iter().any(|a| a.path().is_ident("non_tx"));
256            if has_non_tx {
257                method.attrs.retain(|a| !a.path().is_ident("non_tx"));
258                continue;
259            }
260
261            // Wrap body in spring_axum::transaction(async move { ... }).await
262            let body = method.block.clone();
263            method.block = syn::parse_quote!({
264                ::spring_axum::transaction(|| async move { #body }).await
265            });
266
267            // Ensure method is async to allow .await in body
268            if method.sig.asyncness.is_none() {
269                method.sig.asyncness = Some(syn::token::Async { span: method.sig.fn_token.span });
270            }
271        }
272    }
273
274    quote! { #item_impl }.into()
275}
276
277// Parse optional args: `ttl = <secs>`
278struct CacheArgs { ttl_secs: Option<u64> }
279impl Parse for CacheArgs {
280    fn parse(input: ParseStream) -> syn::Result<Self> {
281        if input.is_empty() { return Ok(Self { ttl_secs: None }); }
282        let key: Ident = input.parse()?;
283        if key != "ttl" { return Err(syn::Error::new(key.span(), "expected ttl = <secs>")); }
284        let _eq: Token![=] = input.parse()?;
285        let lit: syn::LitInt = input.parse()?;
286        let secs = lit.base10_parse::<u64>()?;
287        Ok(Self { ttl_secs: Some(secs) })
288    }
289}
290
291fn extract_app_result_inner_ty(fn_item: &ItemFn) -> Option<Type> {
292    if let syn::ReturnType::Type(_, ty_box) = &fn_item.sig.output {
293        if let Type::Path(tp) = &**ty_box {
294            if let Some(seg) = tp.path.segments.last() {
295                // Expect something like AppResult<T>
296                if let syn::PathArguments::AngleBracketed(args) = &seg.arguments {
297                    if let Some(syn::GenericArgument::Type(inner)) = args.args.first() {
298                        return Some(inner.clone());
299                    }
300                }
301            }
302        }
303    }
304    None
305}
306
307#[proc_macro_attribute]
308pub fn cacheable(args: TokenStream, input: TokenStream) -> TokenStream {
309    let args = parse_macro_input!(args as CacheArgs);
310    let mut func = parse_macro_input!(input as ItemFn);
311    let fn_name = func.sig.ident.to_string();
312    let inner_ty = extract_app_result_inner_ty(&func).expect("cacheable requires return type AppResult<T>");
313    let body = func.block.clone();
314    let ttl_expr = if let Some(secs) = args.ttl_secs { quote! { Some(::std::time::Duration::from_secs(#secs)) } } else { quote! { None } };
315    // Collect args into json map
316    let mut fields = Vec::<proc_macro2::TokenStream>::new();
317    for arg in func.sig.inputs.iter() {
318        if let FnArg::Typed(pt) = arg {
319            if let syn::Pat::Ident(pident) = &*pt.pat {
320                let ident = &pident.ident;
321                fields.push(quote! { stringify!(#ident) : &#ident });
322            }
323        }
324    }
325    let expanded_body = quote!({
326        let __args_json = ::serde_json::json!({ #(#fields),* });
327        let __key = ::spring_axum::default_cache_key(#fn_name, &__args_json);
328        if let Some(__cached) = ::spring_axum::cache_instance().get_typed::<#inner_ty>(&__key) {
329            return Ok(__cached);
330        }
331        let __res: ::spring_axum::AppResult<#inner_ty> = (async move { #body }).await;
332        match __res {
333            Ok(__val) => {
334                ::spring_axum::cache_instance().put_typed(__key, __val.clone(), #ttl_expr);
335                Ok(__val)
336            }
337            Err(e) => Err(e),
338        }
339    });
340    func.block = Box::new(syn::parse_quote! { #expanded_body });
341    quote! { #func }.into()
342}
343
344#[proc_macro_attribute]
345pub fn cache_evict(_args: TokenStream, input: TokenStream) -> TokenStream {
346    let mut func = parse_macro_input!(input as ItemFn);
347    let fn_name = func.sig.ident.to_string();
348    let body = func.block.clone();
349    let mut fields = Vec::<proc_macro2::TokenStream>::new();
350    for arg in func.sig.inputs.iter() {
351        if let FnArg::Typed(pt) = arg {
352            if let syn::Pat::Ident(pident) = &*pt.pat {
353                let ident = &pident.ident;
354                fields.push(quote! { stringify!(#ident) : &#ident });
355            }
356        }
357    }
358    let expanded_body = quote!({
359        let __args_json = ::serde_json::json!({ #(#fields),* });
360        let __key = ::spring_axum::default_cache_key(#fn_name, &__args_json);
361        ::spring_axum::cache_instance().evict(&__key);
362        (async move { #body }).await
363    });
364    func.block = Box::new(syn::parse_quote! { #expanded_body });
365    quote! { #func }.into()
366}
367
368#[proc_macro_attribute]
369pub fn cache_put(args: TokenStream, input: TokenStream) -> TokenStream {
370    let args = parse_macro_input!(args as CacheArgs);
371    let mut func = parse_macro_input!(input as ItemFn);
372    let fn_name = func.sig.ident.to_string();
373    let inner_ty = extract_app_result_inner_ty(&func).expect("cache_put requires return type AppResult<T>");
374    let body = func.block.clone();
375    let ttl_expr = if let Some(secs) = args.ttl_secs { quote! { Some(::std::time::Duration::from_secs(#secs)) } } else { quote! { None } };
376    let mut fields = Vec::<proc_macro2::TokenStream>::new();
377    for arg in func.sig.inputs.iter() {
378        if let FnArg::Typed(pt) = arg {
379            if let syn::Pat::Ident(pident) = &*pt.pat {
380                let ident = &pident.ident;
381                fields.push(quote! { stringify!(#ident) : &#ident });
382            }
383        }
384    }
385    let expanded_body = quote!({
386        let __args_json = ::serde_json::json!({ #(#fields),* });
387        let __key = ::spring_axum::default_cache_key(#fn_name, &__args_json);
388        let __res: ::spring_axum::AppResult<#inner_ty> = (async move { #body }).await;
389        match __res {
390            Ok(__val) => {
391                ::spring_axum::cache_instance().put_typed(__key, __val.clone(), #ttl_expr);
392                Ok(__val)
393            }
394            Err(e) => Err(e),
395        }
396    });
397    func.block = Box::new(syn::parse_quote! { #expanded_body });
398    quote! { #func }.into()
399}
400
401// ---------------- Application Event Listener Macro -----------------
402struct EventTypePath(TypePath);
403impl Parse for EventTypePath {
404    fn parse(input: ParseStream) -> syn::Result<Self> { Ok(Self(input.parse()?)) }
405}
406
407#[proc_macro_attribute]
408pub fn event_listener(args: TokenStream, input: TokenStream) -> TokenStream {
409    let event_ty = parse_macro_input!(args as EventTypePath).0;
410    let func = parse_macro_input!(input as ItemFn);
411    let name = func.sig.ident.clone();
412    let expanded = quote! {
413        #func
414        inventory::submit!(::spring_axum::EventListenerRegistration {
415            type_id: ::std::any::TypeId::of::<#event_ty>(),
416            handle: |ev: &dyn ::std::any::Any, ctx: &::spring_axum::ApplicationContext| {
417                if let Some(e) = ev.downcast_ref::<#event_ty>() {
418                    #name(e, ctx);
419                }
420            },
421        });
422    };
423    expanded.into()
424}
425
426// ---------------- Mapper & SQL method macros ----------------
427
428// Marker attribute for methods to be turned into SQL calls.
429// This attribute itself doesn't transform; `#[mapper]` on the impl does.
430#[proc_macro_attribute]
431pub fn sql(_args: TokenStream, input: TokenStream) -> TokenStream {
432    input
433}
434
435// `#[mapper]` on an impl block binds XML namespace to the impl type name by default,
436// and rewrites `#[sql]` methods to call into MyBatis using the method name.
437// Optional usage: #[mapper(namespace = "CustomNS")] to override.
438#[proc_macro_attribute]
439pub fn mapper(args: TokenStream, input: TokenStream) -> TokenStream {
440    // Parse optional `namespace = "..."`
441    #[derive(Default)]
442    struct MapperArgs { namespace: Option<String> }
443    impl Parse for MapperArgs {
444        fn parse(input: ParseStream) -> syn::Result<Self> {
445            if input.is_empty() { return Ok(MapperArgs::default()); }
446            let key: Ident = input.parse()?;
447            if key != "namespace" { return Err(syn::Error::new(key.span(), "expected `namespace`")); }
448            let _eq: Token![=] = input.parse()?;
449            let lit: LitStr = input.parse()?;
450            Ok(MapperArgs { namespace: Some(lit.value()) })
451        }
452    }
453    let parsed_args = parse_macro_input!(args as MapperArgs);
454    let mut item_impl = parse_macro_input!(input as syn::ItemImpl);
455
456    // Determine namespace: override or type ident
457    let ns = if let Some(ns) = parsed_args.namespace {
458        ns
459    } else {
460        // Extract last ident from self type path
461        match &*item_impl.self_ty {
462            Type::Path(tp) => tp.path.segments.last().map(|s| s.ident.to_string()).unwrap_or_default(),
463            _ => String::new(),
464        }
465    };
466    let ns_lit = LitStr::new(&ns, proc_macro2::Span::call_site());
467
468    // Transform methods marked with #[sql]
469    let mut new_items: Vec<syn::ImplItem> = Vec::new();
470    for it in item_impl.items.into_iter() {
471        if let syn::ImplItem::Fn(mut m) = it {
472            let has_sql = m.attrs.iter().any(|a| a.path().segments.last().map(|s| s.ident == "sql").unwrap_or(false));
473            if has_sql {
474                // Remove the marker attribute to avoid unused warnings
475                m.attrs.retain(|a| !a.path().segments.last().map(|s| s.ident == "sql").unwrap_or(false));
476
477                // Collect parameter idents (skip receiver)
478                let mut param_idents: Vec<Ident> = Vec::new();
479                for arg in m.sig.inputs.iter() {
480                    if let FnArg::Typed(pt) = arg {
481                        if let syn::Pat::Ident(pi) = &*pt.pat {
482                            param_idents.push(pi.ident.clone());
483                        }
484                    }
485                }
486
487                // Build JSON pairs like "username": username
488                let json_pairs: Vec<proc_macro2::TokenStream> = param_idents
489                    .iter()
490                    .map(|id| {
491                        let key = LitStr::new(&id.to_string(), id.span());
492                        quote! { #key : #id }
493                    })
494                    .collect();
495
496                let method_name = m.sig.ident.clone();
497                let stmt_id = quote! { concat!(#ns_lit, ".", stringify!(#method_name)) };
498
499                // Replace body with execution
500                m.block = syn::parse_quote!({
501                    let params = ::serde_json::json!({ #(#json_pairs),* });
502                    let exec = ::spring_axum_mybatis::NoopExecutor::default();
503                    ::spring_axum::mybatis_exec!(exec, #stmt_id, params);
504                    Ok(())
505                });
506
507                new_items.push(syn::ImplItem::Fn(m));
508            } else {
509                new_items.push(syn::ImplItem::Fn(m));
510            }
511        } else {
512            new_items.push(it);
513        }
514    }
515
516    item_impl.items = new_items;
517    quote! { #item_impl }.into()
518}