Skip to main content

ranvier_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{ToTokens, quote};
3use std::collections::HashSet;
4use syn::{
5    DeriveInput, FnArg, GenericArgument, ItemFn, PathArguments, ReturnType, Type,
6    parse_macro_input,
7};
8
9/// Attribute macro to transform an async function into a `Transition` implementation.
10#[proc_macro_attribute]
11pub fn transition(attr: TokenStream, item: TokenStream) -> TokenStream {
12    let mut input_fn = parse_macro_input!(item as ItemFn);
13    let original_ident = input_fn.sig.ident.clone();
14    let vis = &input_fn.vis;
15    let block = &input_fn.block;
16    let inputs = &input_fn.sig.inputs;
17
18    // We don't rename the function here, instead we use a prefix for the struct.
19    // However, to make .then(multiply_by_res) work, multiply_by_res MUST be the struct name.
20    // So we rename the FUNCTION and keep the name for the STRUCT.
21    let internal_fn_ident = quote::format_ident!("__ranvier_fn_{}", original_ident);
22    input_fn.sig.ident = internal_fn_ident.clone();
23
24    // Parse attribute for explicit resource type override
25    let mut res_override = None;
26    let mut bus_allow_types: Vec<Type> = Vec::new();
27    let mut bus_deny_types: Vec<Type> = Vec::new();
28    let mut bus_allow_specified = false;
29    let mut bus_deny_specified = false;
30    let mut x_pos = None;
31    let mut y_pos = None;
32    let mut schema_flag = false;
33    if !attr.is_empty() {
34        let parser = syn::punctuated::Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated;
35        if let Ok(metas) = syn::parse::Parser::parse2(parser, attr.into()) {
36            for meta in metas {
37                match meta {
38                    syn::Meta::Path(path) if path.is_ident("schema") => {
39                        schema_flag = true;
40                    }
41                    syn::Meta::NameValue(nv) => {
42                        if nv.path.is_ident("res") {
43                            res_override = Some(nv.value);
44                        } else if nv.path.is_ident("bus_allow") {
45                            bus_allow_specified = true;
46                            match parse_type_array_expr(&nv.value) {
47                                Ok(types) => bus_allow_types = types,
48                                Err(err) => return err.to_compile_error().into(),
49                            }
50                        } else if nv.path.is_ident("bus_deny") {
51                            bus_deny_specified = true;
52                            match parse_type_array_expr(&nv.value) {
53                                Ok(types) => bus_deny_types = types,
54                                Err(err) => return err.to_compile_error().into(),
55                            }
56                        } else if nv.path.is_ident("x") {
57                            x_pos = Some(nv.value);
58                        } else if nv.path.is_ident("y") {
59                            y_pos = Some(nv.value);
60                        }
61                    }
62                    _ => {}
63                }
64            }
65        }
66    }
67
68    if let Err(err) = validate_bus_policy_types(&bus_allow_types, &bus_deny_types) {
69        return err.to_compile_error().into();
70    }
71
72    // 1. Extract Input Type (From)
73    let input_type = if let Some(FnArg::Typed(pat_type)) = inputs.first() {
74        let ty = &pat_type.ty;
75        quote! { #ty }
76    } else {
77        quote! { () }
78    };
79
80    // 2. Extract Resources Type
81    let second_is_bus = inputs.get(1).map(is_bus_argument).unwrap_or(false);
82    let res_type = if let Some(res) = res_override {
83        quote! { #res }
84    } else if second_is_bus {
85        quote! { () }
86    } else if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
87        let ty = &pat_type.ty;
88        if let Type::Reference(type_ref) = &**ty {
89            let elem = &type_ref.elem;
90            quote! { #elem }
91        } else {
92            quote! { #ty }
93        }
94    } else {
95        quote! { () }
96    };
97
98    // 3. Extract Outcome Types
99    let (output_type, error_type) = if let ReturnType::Type(_, ty) = &input_fn.sig.output {
100        extract_outcome_types(ty).unwrap_or((quote! { () }, quote! { anyhow::Error }))
101    } else {
102        (quote! { () }, quote! { anyhow::Error })
103    };
104
105    // 4. Handle Arguments
106    let arg_count = inputs.len();
107    let run_body = match arg_count {
108        1 => {
109            if let Some(FnArg::Typed(pat_type)) = inputs.first() {
110                let pat = &pat_type.pat;
111                quote! { let #pat = input; #block }
112            } else {
113                quote! { #block }
114            }
115        }
116        2 => {
117            let mut bindings = quote! {};
118            if let Some(FnArg::Typed(pat_type)) = inputs.first() {
119                let pat = &pat_type.pat;
120                bindings.extend(quote! { let #pat = input; });
121            }
122            if second_is_bus {
123                if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
124                    let pat = &pat_type.pat;
125                    bindings.extend(quote! { let #pat = bus; });
126                }
127            } else if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
128                let pat = &pat_type.pat;
129                bindings.extend(quote! { let #pat = resources; });
130            }
131            quote! { #bindings #block }
132        }
133        3 => {
134            let mut bindings = quote! {};
135            if let Some(FnArg::Typed(pat_type)) = inputs.first() {
136                let pat = &pat_type.pat;
137                bindings.extend(quote! { let #pat = input; });
138            }
139            if let Some(FnArg::Typed(pat_type)) = inputs.get(1) {
140                let pat = &pat_type.pat;
141                bindings.extend(quote! { let #pat = resources; });
142            }
143            if let Some(FnArg::Typed(pat_type)) = inputs.get(2) {
144                let pat = &pat_type.pat;
145                bindings.extend(quote! { let #pat = bus; });
146            }
147            quote! { #bindings #block }
148        }
149        _ => quote! { #block },
150    };
151
152    let bus_policy_method = if bus_allow_specified || bus_deny_specified {
153        let allow_expr = if bus_allow_specified {
154            quote! {
155                Some(vec![#(ranvier_core::bus::BusTypeRef::of::<#bus_allow_types>()),*])
156            }
157        } else {
158            quote! { None }
159        };
160        let deny_expr = if bus_deny_specified {
161            quote! {
162                vec![#(ranvier_core::bus::BusTypeRef::of::<#bus_deny_types>()),*]
163            }
164        } else {
165            quote! { Vec::new() }
166        };
167        quote! {
168            fn bus_access_policy(&self) -> Option<ranvier_core::bus::BusAccessPolicy> {
169                Some(ranvier_core::bus::BusAccessPolicy {
170                    allow: #allow_expr,
171                    deny: #deny_expr,
172                })
173            }
174        }
175    } else {
176        quote! {}
177    };
178
179    let position_method = if let (Some(x), Some(y)) = (x_pos, y_pos) {
180        quote! {
181            fn position(&self) -> Option<(f32, f32)> {
182                Some((#x as f32, #y as f32))
183            }
184        }
185    } else {
186        quote! {}
187    };
188
189    let schema_method = if schema_flag {
190        quote! {
191            fn input_schema(&self) -> Option<serde_json::Value> {
192                let schema = schemars::schema_for!(#input_type);
193                serde_json::to_value(schema).ok()
194            }
195        }
196    } else {
197        quote! {}
198    };
199
200    let expanded = quote! {
201        #[derive(Clone, Default)]
202        #[allow(non_camel_case_types)]
203        #vis struct #original_ident;
204
205        #[::async_trait::async_trait]
206        impl ranvier_core::transition::Transition<#input_type, #output_type> for #original_ident {
207            type Error = #error_type;
208            type Resources = #res_type;
209
210            #bus_policy_method
211            #position_method
212            #schema_method
213
214            async fn run(
215                &self,
216                input: #input_type,
217                resources: &Self::Resources,
218                bus: &mut ranvier_core::bus::Bus,
219            ) -> ranvier_core::outcome::Outcome<#output_type, Self::Error> {
220                #run_body
221            }
222        }
223
224        #input_fn
225    };
226
227    TokenStream::from(expanded)
228}
229
230/// Attribute macro for HTTP route registration.
231#[proc_macro_attribute]
232pub fn route(attr: TokenStream, item: TokenStream) -> TokenStream {
233    let input_fn = parse_macro_input!(item as ItemFn);
234    let original_ident = input_fn.sig.ident.clone();
235    let vis = &input_fn.vis;
236
237    let parser = syn::punctuated::Punctuated::<syn::Expr, syn::Token![,]>::parse_terminated;
238    let attr_args = parse_macro_input!(attr with parser);
239
240    if attr_args.len() < 2 {
241        return TokenStream::from(
242            quote! { compile_error!("route attribute requires method and path"); },
243        );
244    }
245
246    let method = &attr_args[0];
247    let path = &attr_args[1];
248
249    // For routes, we keep the function name for the function, and use a prefix for the metadata struct.
250    let struct_name = quote::format_ident!("Route_{}", original_ident);
251
252    let expanded = quote! {
253        #input_fn
254
255        #[allow(non_camel_case_types)]
256        #vis struct #struct_name;
257
258        impl #struct_name {
259            pub const METHOD: &'static str = stringify!(#method);
260            pub const PATH: &'static str = #path;
261        }
262    };
263
264    TokenStream::from(expanded)
265}
266
267/// Macro to build a router from a list of circuit functions annotated with `#[route]`.
268#[proc_macro]
269pub fn ranvier_router(input: TokenStream) -> TokenStream {
270    let parser = syn::punctuated::Punctuated::<syn::Ident, syn::Token![,]>::parse_terminated;
271    let idents = parse_macro_input!(input with parser);
272
273    let mut registrations = quote! {};
274
275    for ident in idents {
276        let route_struct = quote::format_ident!("Route_{}", ident);
277        registrations.extend(quote! {
278            let method_str = #route_struct::METHOD;
279            let method = match method_str {
280                "GET" => http::Method::GET,
281                "POST" => http::Method::POST,
282                "PUT" => http::Method::PUT,
283                "DELETE" => http::Method::DELETE,
284                _ => http::Method::GET,
285            };
286            ingress = ingress.route_method(method, #route_struct::PATH, #ident().await);
287        });
288    }
289
290    let expanded = quote! {
291        {
292            let mut ingress = ranvier_http::HttpIngress::new();
293            #registrations
294            ingress
295        }
296    };
297
298    TokenStream::from(expanded)
299}
300
301fn extract_outcome_types(
302    ty: &Type,
303) -> Option<(quote::__private::TokenStream, quote::__private::TokenStream)> {
304    if let Type::Path(type_path) = ty
305        && let Some(segment) = type_path.path.segments.last()
306        && segment.ident == "Outcome"
307        && let PathArguments::AngleBracketed(args) = &segment.arguments
308    {
309        let mut type_args = args.args.iter();
310        if let (Some(GenericArgument::Type(to)), Some(GenericArgument::Type(err))) =
311            (type_args.next(), type_args.next())
312        {
313            return Some((quote! { #to }, quote! { #err }));
314        }
315    }
316    None
317}
318
319fn is_bus_argument(arg: &FnArg) -> bool {
320    let FnArg::Typed(pat_type) = arg else {
321        return false;
322    };
323    let Type::Reference(type_ref) = &*pat_type.ty else {
324        return false;
325    };
326    let Type::Path(type_path) = &*type_ref.elem else {
327        return false;
328    };
329    type_path
330        .path
331        .segments
332        .last()
333        .map(|segment| segment.ident == "Bus")
334        .unwrap_or(false)
335}
336
337fn parse_type_array_expr(expr: &syn::Expr) -> syn::Result<Vec<Type>> {
338    let syn::Expr::Array(array) = expr else {
339        return Err(syn::Error::new_spanned(
340            expr,
341            "expected array syntax: [TypeA, TypeB]",
342        ));
343    };
344
345    array
346        .elems
347        .iter()
348        .map(|elem| syn::parse2::<Type>(elem.to_token_stream()))
349        .collect()
350}
351
352/// Derive macro for the `ResourceRequirement` marker trait.
353///
354/// Generates a blanket `impl ResourceRequirement for YourType {}`.
355/// The type must also implement `Clone` (required by the Axon execution engine).
356///
357/// # Example
358///
359/// ```rust,ignore
360/// use ranvier::prelude::*;
361///
362/// #[derive(Clone, ResourceRequirement)]
363/// struct AppResources {
364///     pool: sqlx::PgPool,
365///     redis: redis::Client,
366/// }
367/// ```
368#[proc_macro_derive(ResourceRequirement)]
369pub fn derive_resource_requirement(input: TokenStream) -> TokenStream {
370    let input = parse_macro_input!(input as DeriveInput);
371    let name = &input.ident;
372    let (impl_generics, ty_generics, where_clause) = input.generics.split_for_impl();
373
374    let expanded = quote! {
375        impl #impl_generics ranvier_core::transition::ResourceRequirement for #name #ty_generics #where_clause {}
376    };
377
378    TokenStream::from(expanded)
379}
380
381fn validate_bus_policy_types(allow: &[Type], deny: &[Type]) -> syn::Result<()> {
382    let mut allow_keys = HashSet::new();
383    for ty in allow {
384        let key = ty.to_token_stream().to_string();
385        if !allow_keys.insert(key) {
386            return Err(syn::Error::new_spanned(
387                ty,
388                "duplicate type in bus_allow list",
389            ));
390        }
391    }
392
393    let mut deny_keys = HashSet::new();
394    for ty in deny {
395        let key = ty.to_token_stream().to_string();
396        if !deny_keys.insert(key) {
397            return Err(syn::Error::new_spanned(
398                ty,
399                "duplicate type in bus_deny list",
400            ));
401        }
402    }
403
404    for ty in allow {
405        let key = ty.to_token_stream().to_string();
406        if deny_keys.contains(&key) {
407            return Err(syn::Error::new_spanned(
408                ty,
409                "same type cannot be present in both bus_allow and bus_deny",
410            ));
411        }
412    }
413
414    Ok(())
415}
416
417#[cfg(test)]
418mod tests {
419    use super::{is_bus_argument, parse_type_array_expr, validate_bus_policy_types};
420    use syn::{Expr, FnArg, parse_quote};
421
422    #[test]
423    fn detects_mut_bus_reference_argument() {
424        let arg: FnArg = parse_quote!(bus: &mut Bus);
425        assert!(is_bus_argument(&arg));
426    }
427
428    #[test]
429    fn detects_fully_qualified_bus_reference_argument() {
430        let arg: FnArg = parse_quote!(bus: &mut ranvier_core::bus::Bus);
431        assert!(is_bus_argument(&arg));
432    }
433
434    #[test]
435    fn rejects_non_bus_argument() {
436        let arg: FnArg = parse_quote!(res: &MyResources);
437        assert!(!is_bus_argument(&arg));
438    }
439
440    #[test]
441    fn parses_type_array_expr_for_bus_policy() {
442        let expr: Expr = parse_quote!([i32, alloc::string::String]);
443        let parsed = parse_type_array_expr(&expr).expect("type array should parse");
444        assert_eq!(parsed.len(), 2);
445    }
446
447    #[test]
448    fn validates_bus_policy_rejects_duplicate_allow() {
449        let allow = vec![parse_quote!(i32), parse_quote!(i32)];
450        let deny = Vec::new();
451        let err = validate_bus_policy_types(&allow, &deny).expect_err("should fail");
452        assert!(err.to_string().contains("duplicate type in bus_allow"));
453    }
454
455    #[test]
456    fn validates_bus_policy_rejects_allow_deny_conflict() {
457        let allow = vec![parse_quote!(i32)];
458        let deny = vec![parse_quote!(i32)];
459        let err = validate_bus_policy_types(&allow, &deny).expect_err("should fail");
460        assert!(
461            err.to_string()
462                .contains("same type cannot be present in both bus_allow and bus_deny")
463        );
464    }
465}