Skip to main content

hetero_cartesian/
lib.rs

1// hetero-cartesian/src/lib.rs
2//
3// Proc-macro implementation of `#[cartesian_fn]`.
4//
5// Syntax:
6//
7//   #[cartesian_fn(
8//       func(args) => HandlerTrait::method<CallGenerics>(param: Type);
9//       func(args) => HandlerTrait::method              (param: Type);  // no call-generics
10//       func(args) => HandlerTrait                      (param: Type);  // method defaults to `call`
11//       ...
12//   )]
13//   fn name<OuterGenerics>(env_param1: T1, env_param2: T2, ...) {
14//       // body — executed for each cartesian-product tuple
15//   }
16//
17// Non-hole args in the attribute (`_` is the hole):
18//   name: &mut T   — passed as extra fn param; captured as raw ptr, typed-cast on recovery
19//   name: &T       — same but stored via binding address (fat-pointer safe)
20//   name: T        — passed as extra fn param; stored as typed struct field, recovered via .clone()
21//
22// The macro appends one extra function parameter per named non-hole arg, in layer-declaration
23// order. Callers supply all these arguments at call time — no special macros needed at the
24// call site.
25
26use proc_macro::TokenStream;
27use proc_macro2::TokenStream as TokenStream2;
28use quote::{format_ident, quote};
29use syn::{
30    Block, FnArg, GenericParam, Generics, Ident, ItemFn, Pat, Path, Result, Token, Type,
31    parenthesized,
32    parse::{Parse, ParseStream},
33    parse_macro_input,
34};
35
36// ─── AST types ───────────────────────────────────────────────────────────────
37
38/// One argument in `#[cartesian_fn]`: either `_` (hole) or `name: Type` (named, typed).
39enum AttrFuncArg {
40    Hole,
41    Named(Ident, Type),
42}
43
44/// One layer: `func(args) => HandlerTrait::method<call_generics>(param: Type)`
45struct AttrLayer {
46    func: Ident,
47    func_args: Vec<AttrFuncArg>,
48    handler: Path,
49    method: Ident,
50    call_generics: Vec<GenericParam>,
51    param: Ident,
52    param_ty: Type,
53}
54
55/// The full `#[cartesian_fn(...)]` attribute body.
56struct CartesianAttrInput {
57    layers: Vec<AttrLayer>,
58}
59
60/// Synthetic env capture derived from the fn's original parameter list.
61struct EnvCapture {
62    pat: Pat,
63    ty: Type,
64}
65
66// ─── Parsing ─────────────────────────────────────────────────────────────────
67
68fn parse_attr_func_arg(input: ParseStream) -> Result<AttrFuncArg> {
69    if input.peek(Token![_]) {
70        input.parse::<Token![_]>()?;
71        Ok(AttrFuncArg::Hole)
72    } else {
73        let name: Ident = input.parse()?;
74        input.parse::<Token![:]>()?;
75        let ty: Type = input.parse()?;
76        Ok(AttrFuncArg::Named(name, ty))
77    }
78}
79
80impl Parse for CartesianAttrInput {
81    fn parse(input: ParseStream) -> Result<Self> {
82        let mut layers = vec![];
83        while !input.is_empty() {
84            let func: Ident = input.parse()?;
85
86            let args_buf;
87            parenthesized!(args_buf in input);
88            let mut func_args = vec![];
89            loop {
90                if args_buf.is_empty() {
91                    break;
92                }
93                func_args.push(parse_attr_func_arg(&args_buf)?);
94                if args_buf.peek(Token![,]) {
95                    args_buf.parse::<Token![,]>()?;
96                } else {
97                    break;
98                }
99            }
100
101            input.parse::<Token![=>]>()?;
102
103            let mut handler: Path = input.call(Path::parse_mod_style)?;
104            let method: Ident = if handler.segments.len() > 1 {
105                let seg = handler.segments.pop().unwrap().into_value();
106                handler.segments.pop_punct();
107                seg.ident
108            } else {
109                format_ident!("call")
110            };
111
112            let call_generics = if input.peek(Token![<]) {
113                let generics: Generics = input.parse()?;
114                generics.params.into_iter().collect()
115            } else {
116                vec![]
117            };
118
119            let param_buf;
120            parenthesized!(param_buf in input);
121            let param: Ident = param_buf.parse()?;
122            param_buf.parse::<Token![:]>()?;
123            let param_ty: Type = param_buf.parse()?;
124
125            input.parse::<Token![;]>()?;
126
127            layers.push(AttrLayer {
128                func,
129                func_args,
130                handler,
131                method,
132                call_generics,
133                param,
134                param_ty,
135            });
136        }
137        Ok(CartesianAttrInput { layers })
138    }
139}
140
141// ─── Code generation helpers ─────────────────────────────────────────────────
142
143/// Turn generic *declarations* into the corresponding *arguments* (just the idents).
144fn params_to_args(params: &[&GenericParam]) -> Vec<TokenStream2> {
145    params
146        .iter()
147        .map(|p| match p {
148            GenericParam::Type(t) => {
149                let id = &t.ident;
150                quote! { #id }
151            }
152            GenericParam::Const(c) => {
153                let id = &c.ident;
154                quote! { #id }
155            }
156            GenericParam::Lifetime(l) => {
157                let lt = &l.lifetime;
158                quote! { #lt }
159            }
160        })
161        .collect()
162}
163
164/// Build the PhantomData type that references the *outer* generics only.
165fn phantom_type(outer_generics: &[GenericParam]) -> TokenStream2 {
166    let tys: Vec<TokenStream2> = outer_generics
167        .iter()
168        .filter_map(|p| match p {
169            GenericParam::Type(t) => {
170                let id = &t.ident;
171                Some(quote! { #id })
172            }
173            GenericParam::Lifetime(l) => {
174                let lt = &l.lifetime;
175                Some(quote! { &#lt () })
176            }
177            GenericParam::Const(_) => None,
178        })
179        .collect();
180    quote! { (#(#tys,)*) }
181}
182
183/// Walk a pattern and collect every bound identifier (skips wildcards).
184fn pat_idents(pat: &Pat) -> Vec<Ident> {
185    match pat {
186        Pat::Ident(p) if p.ident != "_" => vec![p.ident.clone()],
187        Pat::Tuple(p) => p.elems.iter().flat_map(pat_idents).collect(),
188        Pat::Wild(_) => vec![],
189        Pat::Reference(r) => pat_idents(&r.pat),
190        _ => vec![],
191    }
192}
193
194/// The three shadow_env traits used to coerce `&mut FieldType` → the right reference kind.
195fn shadow_env_traits() -> TokenStream2 {
196    quote! {
197        #[allow(dead_code)]
198        struct __CartesianWrap<T>(T);
199
200        #[allow(dead_code)]
201        trait __ShadowMutMut { type Out; fn shadow_env(self) -> Self::Out; }
202        impl<'__a, '__b, T: ?Sized> __ShadowMutMut for __CartesianWrap<&'__a mut &'__b mut T> {
203            type Out = &'__a mut T;
204            #[inline(always)] fn shadow_env(self) -> Self::Out { self.0 }
205        }
206
207        #[allow(dead_code)]
208        trait __ShadowMutRef { type Out; fn shadow_env(self) -> Self::Out; }
209        impl<'__a, '__b, T: ?Sized> __ShadowMutRef for __CartesianWrap<&'__a mut &'__b T> {
210            type Out = &'__b T;
211            #[inline(always)] fn shadow_env(self) -> Self::Out { *self.0 }
212        }
213
214        #[allow(dead_code)]
215        trait __ShadowVal { type Out; fn shadow_env(self) -> Self::Out; }
216        impl<'__a, T: ::core::clone::Clone> __ShadowVal for &__CartesianWrap<&'__a mut T> {
217            type Out = T;
218            #[inline(always)] fn shadow_env(self) -> Self::Out { self.0.clone() }
219        }
220    }
221}
222
223// ─── Func-arg capture helpers ─────────────────────────────────────────────────
224
225/// Classification of a non-hole `AttrFuncArg` for the capture strategy.
226enum ArgCaptureTyped<'a> {
227    /// `name: &mut T` — store `name as *mut T as *mut ()`, recover as `&mut T`.
228    MutRef(&'a Ident, &'a Type),
229    /// `name: &T` — store `(&name) as *const _ as *mut ()` (binding addr), recover as `&T`.
230    SharedRef(&'a Ident, &'a Type),
231    /// `name: T` — store as typed struct field, recover via `.clone()`.
232    Value(&'a Ident, &'a Type),
233}
234
235/// Returns (non_hole_idx, ArgCaptureTyped) for every Named arg of `layer`.
236fn capturable_args_fn(layer: &AttrLayer) -> Vec<(usize, ArgCaptureTyped<'_>)> {
237    let mut result = Vec::new();
238    let mut nh = 0usize;
239    for arg in &layer.func_args {
240        match arg {
241            AttrFuncArg::Hole => {}
242            AttrFuncArg::Named(name, ty) => {
243                let cap = match ty {
244                    Type::Reference(r) if r.mutability.is_some() => {
245                        ArgCaptureTyped::MutRef(name, &*r.elem)
246                    }
247                    Type::Reference(r) => ArgCaptureTyped::SharedRef(name, &*r.elem),
248                    _ => ArgCaptureTyped::Value(name, ty),
249                };
250                result.push((nh, cap));
251                nh += 1;
252            }
253        }
254    }
255    result
256}
257
258// ─── Innermost body generation ────────────────────────────────────────────────
259
260/// Unpack the env pointer back into typed references, then run the user's body.
261fn gen_body_with_env(env: Option<&EnvCapture>, body: &Block) -> TokenStream2 {
262    let Some(env) = env else {
263        return quote! { #body };
264    };
265
266    let env_ty = &env.ty;
267    let env_pat = &env.pat;
268    let vars = pat_idents(env_pat);
269
270    let traits = shadow_env_traits();
271
272    let unpack = if vars.is_empty() {
273        quote! {}
274    } else if vars.len() == 1 {
275        quote! {
276            let __cartesian_env_ref = self.__env as *mut #env_ty;
277            #[allow(unused_variables)]
278            let #env_pat = unsafe { &mut *__cartesian_env_ref };
279            #[allow(unused_variables)]
280            let #env_pat = __CartesianWrap(#env_pat).shadow_env();
281        }
282    } else {
283        let shadow_calls: Vec<_> = vars
284            .iter()
285            .map(|v| quote! { __CartesianWrap(#v).shadow_env() })
286            .collect();
287        quote! {
288            let __cartesian_env_ref = self.__env as *mut #env_ty;
289            #[allow(unused_variables)]
290            let #env_pat = unsafe { &mut *__cartesian_env_ref };
291            #[allow(unused_variables)]
292            let (#(#vars,)*) = (#(#shadow_calls,)*);
293        }
294    };
295
296    quote! { #traits #unpack #body }
297}
298
299// ─── Core recursive generator ─────────────────────────────────────────────────
300
301struct CtxFn<'a> {
302    layers: &'a [AttrLayer],
303    outer_generics: &'a [GenericParam],
304    env_capture: Option<&'a EnvCapture>,
305    fn_body: &'a Block,
306    depth: usize,
307    acc_call_generics: Vec<GenericParam>,
308    /// For each outer layer: (struct_field_ident, user_param_ident, param_type).
309    captured: Vec<(Ident, Ident, Type)>,
310    /// Expression yielding the `*mut ()` env pointer for the current struct init.
311    env_ptr: TokenStream2,
312}
313
314fn gen_layer_fn(ctx: &CtxFn) -> TokenStream2 {
315    let depth = ctx.depth;
316    let layer = &ctx.layers[depth];
317    let struct_name = format_ident!("__CartesianL{}", depth);
318
319    // ── Generics ─────────────────────────────────────────────────────────────
320    let outer_g = ctx.outer_generics;
321    let all_g: Vec<&GenericParam> = outer_g.iter().chain(ctx.acc_call_generics.iter()).collect();
322    let all_g_args = params_to_args(&all_g);
323    let phantom = phantom_type(outer_g);
324
325    // ── Struct field definitions ──────────────────────────────────────────────
326    let mut field_defs: Vec<TokenStream2> = ctx
327        .captured
328        .iter()
329        .map(|(f, _, ty)| quote! { #f: #ty })
330        .collect();
331
332    for l in (depth + 1)..ctx.layers.len() {
333        for (i, cap) in capturable_args_fn(&ctx.layers[l]) {
334            match cap {
335                ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
336                    let f = format_ident!("__l{}_a{}", l, i);
337                    field_defs.push(quote! { #f: *mut () });
338                }
339                ArgCaptureTyped::Value(_, ty) => {
340                    let f = format_ident!("__l{}_v{}", l, i);
341                    field_defs.push(quote! { #f: #ty });
342                }
343            }
344        }
345    }
346
347    let struct_def = if all_g.is_empty() {
348        quote! {
349            #[allow(non_local_definitions)]
350            struct #struct_name {
351                __env:    *mut (),
352                __marker: ::core::marker::PhantomData<#phantom>,
353                #(#field_defs,)*
354            }
355        }
356    } else {
357        quote! {
358            #[allow(non_local_definitions)]
359            struct #struct_name<#(#all_g),*> {
360                __env:    *mut (),
361                __marker: ::core::marker::PhantomData<#phantom>,
362                #(#field_defs,)*
363            }
364        }
365    };
366
367    // ── Next-layer state ──────────────────────────────────────────────────────
368    let handler = &layer.handler;
369    let method = &layer.method;
370    let call_generics = &layer.call_generics;
371    let param = &layer.param;
372    let param_ty = &layer.param_ty;
373
374    let field_name = format_ident!("__cartesian_p{}", depth);
375
376    let mut new_captured = ctx.captured.clone();
377    new_captured.push((field_name.clone(), param.clone(), param_ty.clone()));
378
379    let mut new_acc_generics = ctx.acc_call_generics.clone();
380    new_acc_generics.extend(call_generics.iter().cloned());
381
382    // ── call() body ───────────────────────────────────────────────────────────
383    let clone_stmts: Vec<_> = ctx
384        .captured
385        .iter()
386        .map(|(f, name, _)| quote! { let #name = self.#f.clone(); })
387        .collect();
388
389    let call_body = if depth + 1 == ctx.layers.len() {
390        let body_code = gen_body_with_env(ctx.env_capture, ctx.fn_body);
391        quote! { #(#clone_stmts)* #body_code }
392    } else {
393        let next_l = depth + 1;
394        let recovery_stmts: Vec<TokenStream2> = capturable_args_fn(&ctx.layers[next_l])
395            .into_iter()
396            .map(|(i, cap)| match cap {
397                ArgCaptureTyped::MutRef(_, inner_ty) => {
398                    let field = format_ident!("__l{}_a{}", next_l, i);
399                    let local = format_ident!("__l{}_a{}_local", next_l, i);
400                    quote! { let #local = unsafe { &mut *(self.#field as *mut #inner_ty) }; }
401                }
402                ArgCaptureTyped::SharedRef(_, inner_ty) => {
403                    let field = format_ident!("__l{}_a{}", next_l, i);
404                    let local = format_ident!("__l{}_a{}_local", next_l, i);
405                    // Dereferences *const &T → copies the &T value.
406                    quote! { let #local = unsafe { *(self.#field as *const &#inner_ty) }; }
407                }
408                ArgCaptureTyped::Value(_, _) => {
409                    let field = format_ident!("__l{}_v{}", next_l, i);
410                    let local = format_ident!("__l{}_v{}_local", next_l, i);
411                    quote! { let #local = self.#field.clone(); }
412                }
413            })
414            .collect();
415
416        let next = gen_layer_fn(&CtxFn {
417            layers: ctx.layers,
418            outer_generics: ctx.outer_generics,
419            env_capture: ctx.env_capture,
420            fn_body: ctx.fn_body,
421            depth: depth + 1,
422            acc_call_generics: new_acc_generics,
423            captured: new_captured.clone(),
424            env_ptr: quote! { self.__env },
425        });
426        quote! { #(#clone_stmts)* #(#recovery_stmts)* #next }
427    };
428
429    // ── impl block ────────────────────────────────────────────────────────────
430    let call_generic_decl = if call_generics.is_empty() {
431        quote! {}
432    } else {
433        quote! { <#(#call_generics),*> }
434    };
435
436    let impl_block = if all_g.is_empty() {
437        quote! {
438            #[allow(non_local_definitions)]
439            impl #handler for #struct_name {
440                fn #method #call_generic_decl (&mut self, #param: #param_ty) {
441                    #call_body
442                }
443            }
444        }
445    } else {
446        quote! {
447            #[allow(non_local_definitions)]
448            impl<#(#all_g),*> #handler for #struct_name<#(#all_g_args),*> {
449                fn #method #call_generic_decl (&mut self, #param: #param_ty) {
450                    #call_body
451                }
452            }
453        }
454    };
455
456    // ── Struct initializer + function call ────────────────────────────────────
457    let env_ptr = &ctx.env_ptr;
458    let captured_init: Vec<_> = ctx
459        .captured
460        .iter()
461        .map(|(f, name, _)| quote! { #f: #name })
462        .collect();
463
464    let handler_binding = format_ident!("__cartesian_handler_{}", depth);
465
466    let mut ptr_field_inits: Vec<TokenStream2> = Vec::new();
467    for l in (depth + 1)..ctx.layers.len() {
468        for (i, cap) in capturable_args_fn(&ctx.layers[l]) {
469            match cap {
470                ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
471                    let f = format_ident!("__l{}_a{}", l, i);
472                    if depth == 0 {
473                        ptr_field_inits.push(quote! { #f: #f });
474                    } else {
475                        ptr_field_inits.push(quote! { #f: self.#f });
476                    }
477                }
478                ArgCaptureTyped::Value(name, _) => {
479                    let f = format_ident!("__l{}_v{}", l, i);
480                    if depth == 0 {
481                        ptr_field_inits.push(quote! { #f: #name });
482                    } else {
483                        ptr_field_inits.push(quote! { #f: self.#f.clone() });
484                    }
485                }
486            }
487        }
488    }
489
490    let handler_init = if all_g.is_empty() {
491        quote! {
492            let mut #handler_binding = #struct_name {
493                __env:    #env_ptr,
494                __marker: ::core::marker::PhantomData,
495                #(#captured_init,)*
496                #(#ptr_field_inits,)*
497            };
498        }
499    } else {
500        quote! {
501            let mut #handler_binding: #struct_name<#(#all_g_args),*> = #struct_name {
502                __env:    #env_ptr,
503                __marker: ::core::marker::PhantomData,
504                #(#captured_init,)*
505                #(#ptr_field_inits,)*
506            };
507        }
508    };
509
510    let func = &layer.func;
511    let caps = capturable_args_fn(layer);
512    let func_args: Vec<_> = {
513        let mut cap_iter = caps.iter();
514        layer
515            .func_args
516            .iter()
517            .map(|arg| match arg {
518                AttrFuncArg::Hole => quote! { &mut #handler_binding },
519                AttrFuncArg::Named(name, _) => {
520                    let (nh, cap) = cap_iter.next().unwrap();
521                    if depth > 0 {
522                        match cap {
523                            ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
524                                let local = format_ident!("__l{}_a{}_local", depth, nh);
525                                quote! { #local }
526                            }
527                            ArgCaptureTyped::Value(_, _) => {
528                                let local = format_ident!("__l{}_v{}_local", depth, nh);
529                                quote! { #local }
530                            }
531                        }
532                    } else {
533                        quote! { #name }
534                    }
535                }
536            })
537            .collect()
538    };
539
540    quote! {
541        #struct_def
542        #impl_block
543        #handler_init
544        #func(#(#func_args),*)
545    }
546}
547
548// ─── Entry point ──────────────────────────────────────────────────────────────
549
550#[proc_macro_attribute]
551pub fn cartesian_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
552    let parsed_attr = parse_macro_input!(attr as CartesianAttrInput);
553    let mut parsed_fn = parse_macro_input!(item as ItemFn);
554
555    if parsed_attr.layers.is_empty() {
556        return quote! { compile_error!("cartesian_fn requires at least one layer") }.into();
557    }
558
559    // Outer generics from the fn signature.
560    let outer_generics: Vec<GenericParam> = parsed_fn.sig.generics.params.iter().cloned().collect();
561
562    // Original fn params → the "env" accessible in the body.
563    let env_params: Vec<(Ident, Type)> = parsed_fn
564        .sig
565        .inputs
566        .iter()
567        .filter_map(|arg| {
568            if let FnArg::Typed(pt) = arg {
569                if let Pat::Ident(pi) = &*pt.pat {
570                    Some((pi.ident.clone(), (*pt.ty).clone()))
571                } else {
572                    None
573                }
574            } else {
575                None
576            }
577        })
578        .collect();
579
580    // Append one extra fn param per non-hole named arg, across all layers in declaration order.
581    for layer in &parsed_attr.layers {
582        for arg in &layer.func_args {
583            if let AttrFuncArg::Named(name, ty) = arg {
584                parsed_fn.sig.inputs.push(syn::parse_quote! { #name: #ty });
585            }
586        }
587    }
588
589    // Build synthetic EnvCapture from the original fn params.
590    let env_capture: Option<EnvCapture> = match env_params.len() {
591        0 => None,
592        1 => {
593            let (name, ty) = &env_params[0];
594            let pat: Pat = syn::parse_quote! { #name };
595            Some(EnvCapture {
596                pat,
597                ty: ty.clone(),
598            })
599        }
600        _ => {
601            let names: Vec<_> = env_params.iter().map(|(n, _)| n).collect();
602            let tys: Vec<_> = env_params.iter().map(|(_, t)| t).collect();
603            let pat: Pat = syn::parse_quote! { (#(#names),*) };
604            let ty: Type = syn::parse_quote! { (#(#tys),*) };
605            Some(EnvCapture { pat, ty })
606        }
607    };
608
609    // Env setup code: move original params into the type-erased env tuple.
610    let env_setup: TokenStream2 = match env_params.len() {
611        0 => quote! {
612            let __cartesian_env_ptr: *mut () = ::core::ptr::null_mut();
613        },
614        1 => {
615            let (name, ty) = &env_params[0];
616            quote! {
617                let mut __cartesian_env_val: #ty = #name;
618                let __cartesian_env_ptr: *mut () =
619                    &mut __cartesian_env_val as *mut _ as *mut ();
620            }
621        }
622        _ => {
623            let names: Vec<_> = env_params.iter().map(|(n, _)| n).collect();
624            let tys: Vec<_> = env_params.iter().map(|(_, t)| t).collect();
625            quote! {
626                let mut __cartesian_env_val: (#(#tys),*) = (#(#names),*);
627                let __cartesian_env_ptr: *mut () =
628                    &mut __cartesian_env_val as *mut _ as *mut ();
629            }
630        }
631    };
632
633    // Preamble: raw-pointer bindings for MutRef / SharedRef args of layers 1..N.
634    // (Value args are stored directly as struct fields — no preamble needed.)
635    let mut arg_preamble = TokenStream2::new();
636    for l in 1..parsed_attr.layers.len() {
637        for (i, cap) in capturable_args_fn(&parsed_attr.layers[l]) {
638            match cap {
639                ArgCaptureTyped::MutRef(name, inner_ty) => {
640                    let binding = format_ident!("__l{}_a{}", l, i);
641                    arg_preamble.extend(quote! {
642                        let #binding: *mut () = #name as *mut #inner_ty as *mut ();
643                    });
644                }
645                ArgCaptureTyped::SharedRef(name, _) => {
646                    let binding = format_ident!("__l{}_a{}", l, i);
647                    // Store the address of the &T binding so fat-pointer types work.
648                    arg_preamble.extend(quote! {
649                        let #binding: *mut () = (&#name) as *const _ as *mut ();
650                    });
651                }
652                ArgCaptureTyped::Value(_, _) => {} // stored in struct field, no preamble
653            }
654        }
655    }
656
657    let fn_body = &parsed_fn.block;
658    let code = gen_layer_fn(&CtxFn {
659        layers: &parsed_attr.layers,
660        outer_generics: &outer_generics,
661        env_capture: env_capture.as_ref(),
662        fn_body,
663        depth: 0,
664        acc_call_generics: vec![],
665        captured: vec![],
666        env_ptr: quote! { __cartesian_env_ptr },
667    });
668
669    *parsed_fn.block = syn::parse_quote! {{
670        #env_setup
671        #arg_preamble
672        #code
673    }};
674
675    quote! { #parsed_fn }.into()
676}