1use proc_macro2::{Span, TokenStream};
2use quote::quote;
3use syn::{FnArg, ItemFn, Pat, Result, ReturnType, Type, parse_quote};
4
5use super::parse::{ModuleAttrs, OutputSpec};
6
7pub fn expand(attrs: ModuleAttrs, input_fn: ItemFn) -> Result<TokenStream> {
8 let fn_name = &input_fn.sig.ident;
9 let fn_vis = &input_fn.vis;
10 let fn_block = &input_fn.block;
11 let fn_attrs = &input_fn.attrs;
12
13 let params: Vec<_> = input_fn
15 .sig
16 .inputs
17 .iter()
18 .filter_map(|arg| {
19 if let FnArg::Typed(pat_type) = arg
20 && let Pat::Ident(pat_ident) = &*pat_type.pat
21 {
22 let name = pat_ident.ident.clone();
23 let ty = (*pat_type.ty).clone();
24 let attrs = pat_type.attrs.clone();
25 let pat = pat_type.pat.clone();
26 return Some((name, ty, attrs, pat));
27 }
28 None
29 })
30 .collect();
31
32 let return_type = extract_result_ok_type(&input_fn.sig.output)?;
34
35 let (output_struct, output_mapping, output_name) =
37 generate_output(&attrs.output, &return_type)?;
38
39 let call_args: Vec<_> = params
41 .iter()
42 .map(|(name, ty, _, _)| {
43 let name_str = name.to_string();
44 quote! { input.get_value::<#ty>(#name_str).ok_or_else(|| ::pyroduct::CapturedError::new(format!("Missing {}", #name_str)))? }
45 })
46 .collect();
47
48 let original_fn_params: Vec<_> = params
50 .iter()
51 .map(|(_, ty, attrs, pat)| quote! { #(#attrs)* #pat: #ty })
52 .collect();
53
54 let expanded = quote! {
55 #[unsafe(no_mangle)]
56 pub extern "C" fn call_extern(input_ptr: *mut u8) -> *const u8 {
57 #output_struct
58
59 let call = |input: ::pyroduct::PyroRow<'_>| {
60 #fn_name(#(#call_args),*).map(|result| {
61 #output_mapping
62 })
63 };
64
65 ::pyroduct::wasm::wasm_row_main::<#output_name, _>(input_ptr, call)
66 }
67
68 #(#fn_attrs)*
69 #fn_vis fn #fn_name(#(#original_fn_params),*) -> ::pyroduct::wasm::ModuleResult<#return_type>
70 #fn_block
71 };
72
73 Ok(expanded)
74}
75
76fn wrap_in_vec(ty: &Type) -> Type {
77 parse_quote!(Vec<#ty>)
78}
79
80pub fn expand_session(attrs: ModuleAttrs, input_fn: ItemFn) -> Result<TokenStream> {
81 let fn_name = &input_fn.sig.ident;
82 let fn_vis = &input_fn.vis;
83 let fn_block = &input_fn.block;
84 let fn_attrs = &input_fn.attrs;
85
86 let params: Vec<_> = input_fn
88 .sig
89 .inputs
90 .iter()
91 .filter_map(|arg| {
92 if let FnArg::Typed(pat_type) = arg
93 && let Pat::Ident(pat_ident) = &*pat_type.pat
94 {
95 let name = pat_ident.ident.clone();
96 let ty = (*pat_type.ty).clone();
97 let attrs = pat_type.attrs.clone();
98 let pat = pat_type.pat.clone();
99 return Some((name, ty, attrs, pat));
100 }
101 None
102 })
103 .collect();
104
105 let output_type = extract_session_inner_type(&input_fn.sig.output)?;
107 let output_vec = wrap_in_vec(&output_type);
108 let output_spec = attrs.output;
109 let (output_struct, output_mapping, output_name) = generate_output(&output_spec, &output_type)?;
110
111 let original_fn_params: Vec<_> = params
113 .iter()
114 .map(|(_, ty, attrs, pat)| quote! { #(#attrs)* #pat: #ty })
115 .collect();
116
117 let expanded = match params.len() {
119 2 => {
120 let input_vec = wrap_in_vec(¶ms[1].1);
121 if !(params[0].0 == "prior" || params[0].0 == "_prior") {
122 return Err(syn::Error::new(
123 params[0].0.span(),
124 "If 2 inputs, then the first parameter of session module must be named `prior`",
125 ));
126 }
127
128 if params[1].1 != output_type {
129 return Err(syn::Error::new(
130 params[1].0.span(),
131 "The type of the output must be the same as input and prior",
132 ));
133 }
134
135 if params[0].1 != input_vec || params[0].1 != output_vec {
136 return Err(syn::Error::new(
137 params[0].0.span(),
138 format!("The type of the prior must be type: {:?}", input_vec),
139 ));
140 }
141
142 quote! {
143 #[unsafe(no_mangle)]
144 pub extern "C" fn call_session_extern(session_id: u32) -> *const u8 {
145 #output_struct
146
147 let call = |prior: &[::pyroduct::PyroRow<'_>], input: ::pyroduct::PyroRow<'_>| {
148 let prior = prior.iter().map(|p| p.clone().try_into()).collect::<Result<Vec<#output_type>, _>>().map_err(|e| {
149 ::pyroduct::CapturedError::new("Unable to extract prior data")
150 .with_source(e)
151 })?;
152 let input = input.try_into().map_err(|e| {
153 ::pyroduct::CapturedError::new("Unable to extract input data")
154 .with_source(e)
155 })?;
156 #fn_name(prior, input).map(|result| {
157 match result {
158 ::pyroduct::session::SessionResponse::Continue(result) => {
159 ::pyroduct::session::SessionResponse::Continue(#output_mapping)
160 }
161 ::pyroduct::session::SessionResponse::End(result) => {
162 ::pyroduct::session::SessionResponse::End(#output_mapping)
163 }
164 ::pyroduct::session::SessionResponse::Terminate => {
165 ::pyroduct::session::SessionResponse::Terminate
166 }
167 }
168 })
169 };
170
171 ::pyroduct::wasm::wasm_row_main_session::<#output_name, _>(session_id, call)
172 }
173
174 #(#fn_attrs)*
175 #fn_vis fn #fn_name(#(#original_fn_params),*) -> ::pyroduct::wasm::ModuleResult<::pyroduct::session::SessionResponse<#output_type>>
176 #fn_block
177 }
178 }
179 3 => {
180 if !(params[0].0 == "prior_input" || params[0].0 == "_prior_input") {
181 return Err(syn::Error::new(
182 params[0].0.span(),
183 "If 3 inputs, then the first parameter of session module must be named `prior_input`",
184 ));
185 }
186 if !(params[1].0 == "prior_output" || params[1].0 == "_prior_output") {
187 return Err(syn::Error::new(
188 params[1].0.span(),
189 "If 3 inputs, then the second parameter of session module must be named `prior_output`",
190 ));
191 }
192 let input_type = ¶ms[2].1;
193 let input_vec = wrap_in_vec(¶ms[2].1);
194 if params[0].1 != input_vec {
195 return Err(syn::Error::new(
196 params[0].0.span(),
197 format!(
198 "First parameter of session module must have the type: {:?}",
199 input_vec
200 ),
201 ));
202 }
203 if params[1].1 != output_vec {
204 return Err(syn::Error::new(
205 params[1].0.span(),
206 format!(
207 "Second parameter of session module must have the type: {:?}",
208 output_type
209 ),
210 ));
211 }
212
213 quote! {
214 #[unsafe(no_mangle)]
215 pub extern "C" fn call_session_extern(session_id: u32) -> *const u8 {
216 #output_struct
217
218 let call = |prior_inputs: &[::pyroduct::PyroRow<'_>], prior_outputs: &[::pyroduct::PyroRow<'_>], input: ::pyroduct::PyroRow<'_>| {
219 let prior_inputs = prior_inputs.iter().map(|p| p.clone().try_into()).collect::<Result<Vec<#input_type>, _>>().map_err(|e| {
220 ::pyroduct::CapturedError::new("Unable to extract prior input data")
221 .with_source(e)
222 })?;
223 let prior_outputs = prior_outputs.iter().map(|p| p.clone().try_into()).collect::<Result<Vec<#output_type>, _>>().map_err(|e| {
224 ::pyroduct::CapturedError::new("Unable to extract prior output data")
225 .with_source(e)
226 })?;
227 let input = input.try_into().map_err(|e| {
228 ::pyroduct::CapturedError::new("Unable to extract input data")
229 .with_source(e)
230 })?;
231 #fn_name(prior_inputs, prior_outputs, input).map(|result| {
232 match result {
233 ::pyroduct::session::SessionResponse::Continue(result) => {
234 ::pyroduct::session::SessionResponse::Continue(#output_mapping)
235 }
236 ::pyroduct::session::SessionResponse::End(result) => {
237 ::pyroduct::session::SessionResponse::End(#output_mapping)
238 }
239 ::pyroduct::session::SessionResponse::Terminate => {
240 ::pyroduct::session::SessionResponse::Terminate
241 }
242 }
243 })
244 };
245
246 ::pyroduct::wasm::wasm_row_main_session_diff::<#output_name, _>(session_id, call)
247 }
248
249 #(#fn_attrs)*
250 #fn_vis fn #fn_name(#(#original_fn_params),*) -> ::pyroduct::wasm::ModuleResult<::pyroduct::session::SessionResponse<#output_type>>
251 #fn_block
252 }
253 }
254 _ => {
255 return Err(syn::Error::new(
256 Span::call_site(),
257 "Session module functions must have either 3 parameters (prior_input, prior_output, and input), or 2 parameters (prior, and input) with the same type for input and output",
258 ));
259 }
260 };
261
262 Ok(expanded)
263}
264
265fn extract_result_ok_type(ret: &ReturnType) -> Result<Type> {
267 match ret {
268 ReturnType::Default => Err(syn::Error::new(
269 Span::call_site(),
270 "Module function must return Result<T>",
271 )),
272 ReturnType::Type(_, ty) => {
273 if let Type::Path(type_path) = &**ty
274 && let Some(segment) = type_path.path.segments.last()
275 && segment.ident == "Result"
276 && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
277 && let Some(syn::GenericArgument::Type(ok_ty)) = args.args.first()
278 {
279 return Ok(ok_ty.clone());
280 }
281 Err(syn::Error::new(
282 Span::call_site(),
283 "Module function must return Result<T>",
284 ))
285 }
286 }
287}
288
289fn generate_output(
291 spec: &OutputSpec,
292 return_type: &Type,
293) -> Result<(TokenStream, TokenStream, Type)> {
294 match spec {
295 OutputSpec::SingleField(field_name) => {
297 let struct_def = quote! {
298 #[derive(::pyroduct::format::ToRow, ::pyroduct::format::Document)]
299 struct __Output {
300 #field_name: #return_type,
301 }
302 };
303
304 let mapping = quote! {
305 __Output {
306 #field_name: result,
307 }
308 };
309
310 let output_name = parse_quote!(__Output);
311
312 Ok((struct_def, mapping, output_name))
313 }
314
315 OutputSpec::TupleFields(field_names) => {
317 let tuple_types = extract_tuple_types(return_type)?;
319
320 if tuple_types.len() != field_names.len() {
321 return Err(syn::Error::new(
322 Span::call_site(),
323 format!(
324 "Output field count ({}) doesn't match tuple element count ({})",
325 field_names.len(),
326 tuple_types.len()
327 ),
328 ));
329 }
330
331 let field_defs: Vec<_> = field_names
332 .iter()
333 .zip(tuple_types.iter())
334 .map(|(name, ty)| quote! { #name: #ty })
335 .collect();
336
337 let struct_def = quote! {
338 #[derive(::pyroduct::format::ToRow, ::pyroduct::format::Document)]
339 struct __Output {
340 #(#field_defs,)*
341 }
342 };
343
344 let field_mappings: Vec<_> = field_names
345 .iter()
346 .enumerate()
347 .map(|(i, name)| {
348 let idx = syn::Index::from(i);
349 quote! { #name: result.#idx }
350 })
351 .collect();
352
353 let mapping = quote! {
354 __Output {
355 #(#field_mappings,)*
356 }
357 };
358
359 Ok((struct_def, mapping, parse_quote!(__Output)))
360 }
361
362 OutputSpec::Struct => Ok((quote! {}, quote! { result }, return_type.clone())),
364 }
365}
366
367fn extract_tuple_types(ty: &Type) -> Result<Vec<&Type>> {
369 if let Type::Tuple(tuple) = ty {
370 Ok(tuple.elems.iter().collect())
371 } else {
372 Err(syn::Error::new(
373 Span::call_site(),
374 "Expected tuple return type for multi-field output",
375 ))
376 }
377}
378
379fn extract_session_inner_type(ret: &ReturnType) -> Result<Type> {
381 match ret {
382 ReturnType::Default => Err(syn::Error::new(
383 Span::call_site(),
384 "Session module function must return Result<T>",
385 )),
386 ReturnType::Type(_, ty) => {
387 if let Type::Path(type_path) = &**ty
388 && let Some(segment) = type_path.path.segments.last()
389 && segment.ident == "Result"
390 && let syn::PathArguments::AngleBracketed(args) = &segment.arguments
391 && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first()
392 {
393 if let Type::Path(inner_path) = inner_ty
395 && let Some(seg) = inner_path.path.segments.last()
396 && seg.ident == "SessionResponse"
397 && let syn::PathArguments::AngleBracketed(inner_args) = &seg.arguments
398 && let Some(syn::GenericArgument::Type(output_ty)) = inner_args.args.first()
399 {
400 return Ok(output_ty.clone());
401 }
402 }
403 Err(syn::Error::new(
404 Span::call_site(),
405 "Session module must return Result<SessionResponse<T>>",
406 ))
407 }
408 }
409}