Skip to main content

resonate_sdk_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use proc_macro_crate::{crate_name, FoundCrate};
4use quote::{format_ident, quote};
5use syn::{parse_macro_input, FnArg, ItemFn, Pat, Type};
6
7/// Returns the token path to the `resonate-sdk` crate root.
8/// - When compiled inside the `resonate-sdk` crate itself → `crate`
9/// - When compiled by an external consumer               → `resonate_sdk` (or the renamed alias)
10fn resonate_crate() -> TokenStream2 {
11    match crate_name("resonate-sdk") {
12        Ok(FoundCrate::Itself) => quote! { crate },
13        Ok(FoundCrate::Name(name)) => {
14            let ident = format_ident!("{}", name);
15            quote! { #ident }
16        }
17        // Fallback: assume external usage
18        Err(_) => quote! { resonate_sdk },
19    }
20}
21
22/// Attribute macro for registering durable functions.
23///
24/// Detects the function kind from the first parameter's type:
25/// - `&Context` → Workflow
26/// - `&Info` → Leaf with metadata
27/// - anything else → Pure leaf
28///
29/// Generates a PascalCase unit struct implementing `Durable`, plus a
30/// lowercase const alias matching the original function name. The original
31/// function is consumed — its body is inlined into the `Durable::execute` impl.
32///
33/// # Examples
34///
35/// ```ignore
36/// #[resonate_sdk::function]
37/// async fn my_leaf(x: i32) -> Result<i32> { Ok(x + 1) }
38///
39/// #[resonate_sdk::function]
40/// async fn my_workflow(ctx: &Context, x: i32) -> Result<i32> {
41///     ctx.run(my_leaf, x).await
42/// }
43///
44/// // Both lowercase const and PascalCase struct work:
45/// ctx.run(my_leaf, 42).await     // preferred
46/// ctx.run(MyLeaf, 42).await      // also works
47/// resonate.register(my_leaf)     // preferred
48/// resonate.register(MyLeaf)      // also works
49/// ```
50#[proc_macro_attribute]
51pub fn function(attr: TokenStream, item: TokenStream) -> TokenStream {
52    let input = parse_macro_input!(item as ItemFn);
53    let attrs = if attr.is_empty() {
54        None
55    } else {
56        Some(parse_macro_input!(attr as MacroAttrs))
57    };
58
59    match generate_durable_impl(input, attrs) {
60        Ok(tokens) => tokens.into(),
61        Err(e) => e.to_compile_error().into(),
62    }
63}
64
65/// Parsed macro attributes.
66struct MacroAttrs {
67    name: Option<String>,
68}
69
70impl syn::parse::Parse for MacroAttrs {
71    fn parse(input: syn::parse::ParseStream) -> syn::Result<Self> {
72        let mut name = None;
73
74        while !input.is_empty() {
75            let ident: syn::Ident = input.parse()?;
76            let _eq: syn::Token![=] = input.parse()?;
77
78            if ident == "name" {
79                let lit: syn::LitStr = input.parse()?;
80                name = Some(lit.value());
81            } else {
82                return Err(syn::Error::new(ident.span(), "unknown attribute"));
83            }
84
85            if input.peek(syn::Token![,]) {
86                let _comma: syn::Token![,] = input.parse()?;
87            }
88        }
89
90        Ok(MacroAttrs { name })
91    }
92}
93
94/// The detected kind of function based on its first parameter.
95enum FunctionKind {
96    /// Pure leaf: no special first argument
97    PureLeaf,
98    /// Leaf with Info: first argument is `&Info`
99    LeafWithInfo,
100    /// Workflow: first argument is `&Context`
101    Workflow,
102}
103
104fn generate_durable_impl(
105    input: ItemFn,
106    attrs: Option<MacroAttrs>,
107) -> syn::Result<proc_macro2::TokenStream> {
108    let fn_name = &input.sig.ident;
109    let vis = &input.vis;
110
111    // Generate struct name: snake_case → PascalCase
112    let struct_name = format_ident!("{}", to_pascal_case(&fn_name.to_string()));
113
114    // Determine the registered name
115    let registered_name = attrs
116        .and_then(|a| a.name)
117        .unwrap_or_else(|| fn_name.to_string());
118
119    // Analyze the function signature to detect the kind
120    let params: Vec<_> = input.sig.inputs.iter().collect();
121    let (kind, user_params) = detect_kind(&params)?;
122
123    // Extract user parameter types and names
124    let mut arg_types = Vec::new();
125    let mut arg_names = Vec::new();
126
127    for param in &user_params {
128        if let FnArg::Typed(pat_type) = param {
129            arg_types.push(pat_type.ty.as_ref().clone());
130            if let Pat::Ident(pat_ident) = pat_type.pat.as_ref() {
131                arg_names.push(pat_ident.ident.clone());
132            } else {
133                return Err(syn::Error::new_spanned(
134                    pat_type,
135                    "expected a simple identifier pattern",
136                ));
137            }
138        }
139    }
140
141    // Determine the Args and T types
142    // For the Durable trait, Args is a tuple of all user params, T is extracted from Result<T>
143    let args_type = if arg_types.is_empty() {
144        quote! { () }
145    } else if arg_types.len() == 1 {
146        let t = &arg_types[0];
147        quote! { #t }
148    } else {
149        quote! { (#(#arg_types),*) }
150    };
151
152    // Extract return type T from Result<T>
153    let return_type = extract_return_type(&input.sig.output)?;
154
155    // Resolve the crate path (crate:: when inside resonate-sdk, resonate_sdk:: externally)
156    let krate = resonate_crate();
157
158    // Generate the kind constant
159    let kind_token = match kind {
160        FunctionKind::PureLeaf | FunctionKind::LeafWithInfo => {
161            quote! { #krate::types::DurableKind::Function }
162        }
163        FunctionKind::Workflow => quote! { #krate::types::DurableKind::Workflow },
164    };
165
166    // Extract the function body and any attributes (e.g. #[allow(...)])
167    let fn_body = &input.block;
168    let fn_attrs = &input.attrs;
169
170    // Generate the execute body based on kind.
171    // Instead of delegating to the original function, we inline the body directly.
172    // We destructure the `args` tuple and set up `ctx`/`info` bindings as needed.
173    let execute_body = match kind {
174        FunctionKind::PureLeaf => {
175            let ignore_env = quote! { let _ = env; };
176            if arg_names.is_empty() {
177                quote! {
178                    #ignore_env
179                    async move #fn_body .await
180                }
181            } else if arg_names.len() == 1 {
182                let name = &arg_names[0];
183                quote! {
184                    #ignore_env
185                    let #name = args;
186                    async move #fn_body .await
187                }
188            } else {
189                let destructure: Vec<_> = arg_names
190                    .iter()
191                    .enumerate()
192                    .map(|(i, name)| {
193                        let idx = syn::Index::from(i);
194                        quote! { let #name = args.#idx; }
195                    })
196                    .collect();
197                quote! {
198                    #ignore_env
199                    #(#destructure)*
200                    async move #fn_body .await
201                }
202            }
203        }
204        FunctionKind::LeafWithInfo => {
205            let info_unwrap = quote! {
206                let info = env.into_info();
207            };
208            if arg_names.is_empty() {
209                quote! {
210                    #info_unwrap
211                    async move #fn_body .await
212                }
213            } else if arg_names.len() == 1 {
214                let name = &arg_names[0];
215                quote! {
216                    #info_unwrap
217                    let #name = args;
218                    async move #fn_body .await
219                }
220            } else {
221                let destructure: Vec<_> = arg_names
222                    .iter()
223                    .enumerate()
224                    .map(|(i, name)| {
225                        let idx = syn::Index::from(i);
226                        quote! { let #name = args.#idx; }
227                    })
228                    .collect();
229                quote! {
230                    #info_unwrap
231                    #(#destructure)*
232                    async move #fn_body .await
233                }
234            }
235        }
236        FunctionKind::Workflow => {
237            let ctx_unwrap = quote! {
238                let ctx = env.into_context();
239            };
240            if arg_names.is_empty() {
241                quote! {
242                    #ctx_unwrap
243                    async move #fn_body .await
244                }
245            } else if arg_names.len() == 1 {
246                let name = &arg_names[0];
247                quote! {
248                    #ctx_unwrap
249                    let #name = args;
250                    async move #fn_body .await
251                }
252            } else {
253                let destructure: Vec<_> = arg_names
254                    .iter()
255                    .enumerate()
256                    .map(|(i, name)| {
257                        let idx = syn::Index::from(i);
258                        quote! { let #name = args.#idx; }
259                    })
260                    .collect();
261                quote! {
262                    #ctx_unwrap
263                    #(#destructure)*
264                    async move #fn_body .await
265                }
266            }
267        }
268    };
269
270    let internal_struct_name = format_ident!("__Durable_{}", struct_name);
271
272    let output = quote! {
273        /// Generated durable function struct (internal — use the const `#fn_name` instead).
274        #(#fn_attrs)*
275        #[derive(Debug, Clone, Copy)]
276        #[doc(hidden)]
277        #vis struct #internal_struct_name;
278
279        /// Durable function handle. Use with `ctx.run(#fn_name, args)` or `resonate.register(#fn_name)`.
280        #[allow(non_upper_case_globals)]
281        #vis const #fn_name: #internal_struct_name = #internal_struct_name;
282
283        impl #krate::durable::Durable<#args_type, #return_type> for #internal_struct_name {
284            const NAME: &'static str = #registered_name;
285            const KIND: #krate::types::DurableKind = #kind_token;
286
287            async fn execute(
288                &self,
289                env: #krate::durable::ExecutionEnv<'_>,
290                args: #args_type,
291            ) -> #krate::error::Result<#return_type> {
292                #execute_body
293            }
294        }
295    };
296
297    Ok(output)
298}
299
300/// Detect the function kind by inspecting the first parameter's type.
301fn detect_kind<'a>(params: &'a [&'a FnArg]) -> syn::Result<(FunctionKind, Vec<&'a FnArg>)> {
302    if params.is_empty() {
303        return Ok((FunctionKind::PureLeaf, vec![]));
304    }
305
306    // Check the first parameter's type
307    if let FnArg::Typed(pat_type) = params[0] {
308        if is_reference_to(&pat_type.ty, "Context") {
309            return Ok((FunctionKind::Workflow, params[1..].to_vec()));
310        }
311        if is_reference_to(&pat_type.ty, "Info") {
312            return Ok((FunctionKind::LeafWithInfo, params[1..].to_vec()));
313        }
314    }
315
316    // No special first param — pure leaf
317    Ok((FunctionKind::PureLeaf, params.to_vec()))
318}
319
320/// Check if a type is a reference to a type with the given name.
321fn is_reference_to(ty: &Type, name: &str) -> bool {
322    if let Type::Reference(type_ref) = ty {
323        if let Type::Path(type_path) = type_ref.elem.as_ref() {
324            if let Some(segment) = type_path.path.segments.last() {
325                return segment.ident == name;
326            }
327        }
328    }
329    false
330}
331
332/// Extract the T from `-> Result<T>` return type.
333fn extract_return_type(output: &syn::ReturnType) -> syn::Result<proc_macro2::TokenStream> {
334    match output {
335        syn::ReturnType::Default => Ok(quote! { () }),
336        syn::ReturnType::Type(_, ty) => {
337            // Try to extract T from Result<T>
338            if let Type::Path(type_path) = ty.as_ref() {
339                if let Some(segment) = type_path.path.segments.last() {
340                    if segment.ident == "Result" {
341                        if let syn::PathArguments::AngleBracketed(args) = &segment.arguments {
342                            if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() {
343                                return Ok(quote! { #inner_ty });
344                            }
345                        }
346                    }
347                }
348            }
349            // If we can't extract, use the full type
350            Ok(quote! { #ty })
351        }
352    }
353}
354
355/// Convert snake_case to PascalCase.
356fn to_pascal_case(s: &str) -> String {
357    s.split('_')
358        .map(|word| {
359            let mut chars = word.chars();
360            match chars.next() {
361                None => String::new(),
362                Some(c) => c.to_uppercase().collect::<String>() + chars.as_str(),
363            }
364        })
365        .collect()
366}