Skip to main content

pyro_macro/module/
codegen.rs

1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{FnArg, ItemFn, Pat, Result, ReturnType, Type, parse_quote};
4
5use super::parse::{ModuleAttrs, OutputSpec};
6
7pub fn expand(attrs: ModuleAttrs, input_fn: ItemFn) -> Result<TokenStream> {
8    let fn_name = &input_fn.sig.ident;
9    let fn_vis = &input_fn.vis;
10    let fn_block = &input_fn.block;
11    let fn_attrs = &input_fn.attrs;
12
13    // Extract parameters
14    let params: Vec<_> = input_fn
15        .sig
16        .inputs
17        .iter()
18        .filter_map(|arg| {
19            if let FnArg::Typed(pat_type) = arg
20                && let Pat::Ident(pat_ident) = &*pat_type.pat
21            {
22                let name = pat_ident.ident.clone();
23                let ty = (*pat_type.ty).clone();
24                let attrs = pat_type.attrs.clone();
25                let pat = pat_type.pat.clone();
26                return Some((name, ty, attrs, pat));
27            }
28            None
29        })
30        .collect();
31
32    // Validate module requirements and extract types
33    let return_type = extract_result_ok_type(&input_fn.sig.output)?;
34
35    // Generate __Output struct and mapping based on output spec
36    let (output_struct, output_mapping, output_name) =
37        generate_output(&attrs.output, &return_type)?;
38
39    // Generate the call arguments (extract from input struct)
40    let call_args: Vec<_> = params
41        .iter()
42        .map(|(name, ty, _, _)| {
43            let name_str = name.to_string();
44            quote! { input.get_value::<#ty>(#name_str).ok_or_else(|| ::pyroduct::CapturedError::new(format!("Missing {}", #name_str)))? }
45        })
46        .collect();
47
48    // Generate the original function parameters
49    let original_fn_params: Vec<_> = params
50        .iter()
51        .map(|(_, ty, attrs, pat)| quote! { #(#attrs)* #pat: #ty })
52        .collect();
53
54    let expanded = quote! {
55        #[unsafe(no_mangle)]
56        pub extern "C" fn call_extern(input_ptr: *mut u8) -> *const u8 {
57            #output_struct
58
59            let call = |input: ::pyroduct::PyroRow<'_>| {
60                #fn_name(#(#call_args),*).map(|result| {
61                    #output_mapping
62                })
63            };
64
65            ::pyroduct::wasm::wasm_row_main::<#output_name, _>(input_ptr, call)
66        }
67
68        #(#fn_attrs)*
69        #fn_vis fn #fn_name(#(#original_fn_params),*) -> ::pyroduct::wasm::ModuleResult<#return_type>
70        #fn_block
71    };
72
73    Ok(expanded)
74}
75
76fn wrap_in_vec(ty: &Type) -> Type {
77    parse_quote!(Vec<#ty>)
78}
79
80pub fn expand_session(attrs: ModuleAttrs, input_fn: ItemFn) -> Result<TokenStream> {
81    let fn_name = &input_fn.sig.ident;
82    let fn_vis = &input_fn.vis;
83    let fn_block = &input_fn.block;
84    let fn_attrs = &input_fn.attrs;
85
86    // Extract parameters
87    let params: Vec<_> = input_fn
88        .sig
89        .inputs
90        .iter()
91        .filter_map(|arg| {
92            if let FnArg::Typed(pat_type) = arg
93                && let Pat::Ident(pat_ident) = &*pat_type.pat
94            {
95                let name = pat_ident.ident.clone();
96                let ty = (*pat_type.ty).clone();
97                let attrs = pat_type.attrs.clone();
98                let pat = pat_type.pat.clone();
99                return Some((name, ty, attrs, pat));
100            }
101            None
102        })
103        .collect();
104
105    // Generate __Output struct and mapping based on output spec
106    let output_type = extract_session_inner_type(&input_fn.sig.output)?;
107    let output_vec = wrap_in_vec(&output_type);
108    let output_spec = attrs.output;
109    let (output_struct, output_mapping, output_name) = generate_output(&output_spec, &output_type)?;
110
111    // Generate the original function parameters
112    let original_fn_params: Vec<_> = params
113        .iter()
114        .map(|(_, ty, attrs, pat)| quote! { #(#attrs)* #pat: #ty })
115        .collect();
116
117    // Validate session module requirements and extract types
118    let expanded = match params.len() {
119        2 => {
120            let input_vec = wrap_in_vec(&params[1].1);
121            if !(params[0].0 == "prior" || params[0].0 == "_prior") {
122                return Err(syn::Error::new(
123                    params[0].0.span(),
124                    "If 2 inputs, then the first parameter of session module must be named `prior`",
125                ));
126            }
127
128            if params[1].1 != output_type {
129                return Err(syn::Error::new(
130                    params[1].0.span(),
131                    "The type of the output must be the same as input and prior",
132                ));
133            }
134
135            if params[0].1 != input_vec || params[0].1 != output_vec {
136                return Err(syn::Error::new(
137                    params[0].0.span(),
138                    format!("The type of the prior must be type: {:?}", input_vec),
139                ));
140            }
141
142            quote! {
143                #[unsafe(no_mangle)]
144                pub extern "C" fn call_session_extern(session_id: u32) -> *const u8 {
145                    #output_struct
146
147                    let call = |prior: &[::pyroduct::PyroRow<'_>], input: ::pyroduct::PyroRow<'_>| {
148                        let prior = prior.iter().map(|p| p.clone().try_into()).collect::<Result<Vec<#output_type>, _>>().map_err(|e| {
149                            ::pyroduct::CapturedError::new("Unable to extract prior data")
150                                .with_source(e)
151                        })?;
152                        let input = input.try_into().map_err(|e| {
153                            ::pyroduct::CapturedError::new("Unable to extract input data")
154                                .with_source(e)
155                        })?;
156                        #fn_name(prior, input).map(|result| {
157                            match result {
158                                ::pyroduct::session::SessionResponse::Continue(result) => {
159                                    ::pyroduct::session::SessionResponse::Continue(#output_mapping)
160                                }
161                                ::pyroduct::session::SessionResponse::End(result) => {
162                                    ::pyroduct::session::SessionResponse::End(#output_mapping)
163                                }
164                                ::pyroduct::session::SessionResponse::Terminate => {
165                                    ::pyroduct::session::SessionResponse::Terminate
166                                }
167                            }
168                        })
169                    };
170
171                    ::pyroduct::wasm::wasm_row_main_session::<#output_name, _>(session_id, call)
172                }
173
174                #(#fn_attrs)*
175                #fn_vis fn #fn_name(#(#original_fn_params),*) -> ::pyroduct::wasm::ModuleResult<::pyroduct::session::SessionResponse<#output_type>>
176                #fn_block
177            }
178        }
179        3 => {
180            if !(params[0].0 == "prior_input" || params[0].0 == "_prior_input") {
181                return Err(syn::Error::new(
182                    params[0].0.span(),
183                    "If 3 inputs, then the first parameter of session module must be named `prior_input`",
184                ));
185            }
186            if !(params[1].0 == "prior_output" || params[1].0 == "_prior_output") {
187                return Err(syn::Error::new(
188                    params[1].0.span(),
189                    "If 3 inputs, then the second parameter of session module must be named `prior_output`",
190                ));
191            }
192            let input_type = &params[2].1;
193            let input_vec = wrap_in_vec(&params[2].1);
194            if params[0].1 != input_vec {
195                return Err(syn::Error::new(
196                    params[0].0.span(),
197                    format!(
198                        "First parameter of session module must have the type: {:?}",
199                        input_vec
200                    ),
201                ));
202            }
203            if params[1].1 != output_vec {
204                return Err(syn::Error::new(
205                    params[1].0.span(),
206                    format!(
207                        "Second parameter of session module must have the type: {:?}",
208                        output_type
209                    ),
210                ));
211            }
212
213            quote! {
214                #[unsafe(no_mangle)]
215                pub extern "C" fn call_session_extern(session_id: u32) -> *const u8 {
216                    #output_struct
217
218                    let call = |prior_inputs: &[::pyroduct::PyroRow<'_>], prior_outputs: &[::pyroduct::PyroRow<'_>], input: ::pyroduct::PyroRow<'_>| {
219                        let prior_inputs = prior_inputs.iter().map(|p| p.clone().try_into()).collect::<Result<Vec<#input_type>, _>>().map_err(|e| {
220                            ::pyroduct::CapturedError::new("Unable to extract prior input data")
221                                .with_source(e)
222                        })?;
223                        let prior_outputs = prior_outputs.iter().map(|p| p.clone().try_into()).collect::<Result<Vec<#output_type>, _>>().map_err(|e| {
224                            ::pyroduct::CapturedError::new("Unable to extract prior output data")
225                                .with_source(e)
226                        })?;
227                        let input = input.try_into().map_err(|e| {
228                            ::pyroduct::CapturedError::new("Unable to extract input data")
229                                .with_source(e)
230                        })?;
231                        #fn_name(prior_inputs, prior_outputs, input).map(|result| {
232                            match result {
233                                ::pyroduct::session::SessionResponse::Continue(result) => {
234                                    ::pyroduct::session::SessionResponse::Continue(#output_mapping)
235                                }
236                                ::pyroduct::session::SessionResponse::End(result) => {
237                                    ::pyroduct::session::SessionResponse::End(#output_mapping)
238                                }
239                                ::pyroduct::session::SessionResponse::Terminate => {
240                                    ::pyroduct::session::SessionResponse::Terminate
241                                }
242                            }
243                        })
244                    };
245
246                    ::pyroduct::wasm::wasm_row_main_session_diff::<#output_name, _>(session_id, call)
247                }
248
249                #(#fn_attrs)*
250                #fn_vis fn #fn_name(#(#original_fn_params),*) -> ::pyroduct::wasm::ModuleResult<::pyroduct::session::SessionResponse<#output_type>>
251                #fn_block
252            }
253        }
254        _ => {
255            return Err(syn::Error::new(
256                Span::call_site(),
257                "Session module functions must have either 3 parameters (prior_input, prior_output, and input), or 2 parameters (prior, and input) with the same type for input and output",
258            ));
259        }
260    };
261
262    Ok(expanded)
263}
264
265/// Extract the Ok type from Result<T, E>
266fn extract_result_ok_type(ret: &ReturnType) -> Result<Type> {
267    match ret {
268        ReturnType::Default => Err(syn::Error::new(
269            Span::call_site(),
270            "Module function must return Result<T>",
271        )),
272        ReturnType::Type(_, ty) => {
273            if let Type::Path(type_path) = &**ty
274                && let Some(segment) = type_path.path.segments.last()
275                && segment.ident == "Result"
276                && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
277                && let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first()
278            {
279                return Ok(ok_ty.clone());
280            }
281            Err(syn::Error::new(
282                Span::call_site(),
283                "Module function must return Result<T>",
284            ))
285        }
286    }
287}
288
289/// Generate the __Output struct and the mapping expression
290fn generate_output(
291    spec: &OutputSpec,
292    return_type: &Type,
293) -> Result<(TokenStream, TokenStream, Type)> {
294    match spec {
295        // Pattern 1: Single named field
296        OutputSpec::SingleField(field_name) => {
297            let struct_def = quote! {
298                #[derive(::pyroduct::format::ToRow, ::pyroduct::format::Document)]
299                struct __Output {
300                    #field_name: #return_type,
301                }
302            };
303
304            let mapping = quote! {
305                __Output {
306                    #field_name: result,
307                }
308            };
309
310            let output_name = parse_quote!(__Output);
311
312            Ok((struct_def, mapping, output_name))
313        }
314
315        // Pattern 2: Tuple with named fields
316        OutputSpec::TupleFields(field_names) => {
317            // Extract tuple element types from return_type
318            let tuple_types = extract_tuple_types(return_type)?;
319
320            if tuple_types.len() != field_names.len() {
321                return Err(syn::Error::new(
322                    Span::call_site(),
323                    format!(
324                        "Output field count ({}) doesn't match tuple element count ({})",
325                        field_names.len(),
326                        tuple_types.len()
327                    ),
328                ));
329            }
330
331            let field_defs: Vec<_> = field_names
332                .iter()
333                .zip(tuple_types.iter())
334                .map(|(name, ty)| quote! { #name: #ty })
335                .collect();
336
337            let struct_def = quote! {
338                #[derive(::pyroduct::format::ToRow, ::pyroduct::format::Document)]
339                struct __Output {
340                    #(#field_defs,)*
341                }
342            };
343
344            let field_mappings: Vec<_> = field_names
345                .iter()
346                .enumerate()
347                .map(|(i, name)| {
348                    let idx = syn::Index::from(i);
349                    quote! { #name: result.#idx }
350                })
351                .collect();
352
353            let mapping = quote! {
354                __Output {
355                    #(#field_mappings,)*
356                }
357            };
358
359            Ok((struct_def, mapping, parse_quote!(__Output)))
360        }
361
362        // Pattern 3: Existing struct that implements ToRow
363        OutputSpec::Struct => Ok((quote! {}, quote! { result }, return_type.clone())),
364    }
365}
366
367/// Extract element types from a tuple type
368fn extract_tuple_types(ty: &Type) -> Result<Vec<&Type>> {
369    if let Type::Tuple(tuple) = ty {
370        Ok(tuple.elems.iter().collect())
371    } else {
372        Err(syn::Error::new(
373            Span::call_site(),
374            "Expected tuple return type for multi-field output",
375        ))
376    }
377}
378
379/// Extract the inner type T from Result<SessionResponse<T>>
380fn extract_session_inner_type(ret: &ReturnType) -> Result<Type> {
381    match ret {
382        ReturnType::Default => Err(syn::Error::new(
383            Span::call_site(),
384            "Session module function must return Result<T>",
385        )),
386        ReturnType::Type(_, ty) => {
387            if let Type::Path(type_path) = &**ty
388                && let Some(segment) = type_path.path.segments.last()
389                && segment.ident == "Result"
390                && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
391                && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
392            {
393                // inner_ty should be SessionResponse<T>
394                if let Type::Path(inner_path) = inner_ty
395                    && let Some(seg) = inner_path.path.segments.last()
396                    && seg.ident == "SessionResponse"
397                    && let syn::PathArguments::AngleBracketed(inner_args) = &seg.arguments
398                    && let Some(syn::GenericArgument::Type(output_ty)) = inner_args.args.first()
399                {
400                    return Ok(output_ty.clone());
401                }
402            }
403            Err(syn::Error::new(
404                Span::call_site(),
405                "Session module must return Result<SessionResponse<T>>",
406            ))
407        }
408    }
409}