1use proc_macro::TokenStream;
27use proc_macro2::TokenStream as TokenStream2;
28use quote::{format_ident, quote};
29use syn::{
30 Block, FnArg, GenericParam, Generics, Ident, ItemFn, Pat, Path, Result, Token, Type,
31 parenthesized,
32 parse::{Parse, ParseStream},
33 parse_macro_input,
34};
35
36enum AttrFuncArg {
40 Hole,
41 Named(Ident, Type),
42}
43
44struct AttrLayer {
46 func: Ident,
47 func_args: Vec<AttrFuncArg>,
48 handler: Path,
49 method: Ident,
50 call_generics: Vec<GenericParam>,
51 param: Ident,
52 param_ty: Type,
53}
54
55struct CartesianAttrInput {
57 layers: Vec<AttrLayer>,
58}
59
60struct EnvCapture {
62 pat: Pat,
63 ty: Type,
64}
65
66fn parse_attr_func_arg(input: ParseStream) -> Result<AttrFuncArg> {
69 if input.peek(Token![_]) {
70 input.parse::<Token![_]>()?;
71 Ok(AttrFuncArg::Hole)
72 } else {
73 let name: Ident = input.parse()?;
74 input.parse::<Token![:]>()?;
75 let ty: Type = input.parse()?;
76 Ok(AttrFuncArg::Named(name, ty))
77 }
78}
79
80impl Parse for CartesianAttrInput {
81 fn parse(input: ParseStream) -> Result<Self> {
82 let mut layers = vec![];
83 while !input.is_empty() {
84 let func: Ident = input.parse()?;
85
86 let args_buf;
87 parenthesized!(args_buf in input);
88 let mut func_args = vec![];
89 loop {
90 if args_buf.is_empty() {
91 break;
92 }
93 func_args.push(parse_attr_func_arg(&args_buf)?);
94 if args_buf.peek(Token![,]) {
95 args_buf.parse::<Token![,]>()?;
96 } else {
97 break;
98 }
99 }
100
101 input.parse::<Token![=>]>()?;
102
103 let mut handler: Path = input.call(Path::parse_mod_style)?;
104 let method: Ident = if handler.segments.len() > 1 {
105 let seg = handler.segments.pop().unwrap().into_value();
106 handler.segments.pop_punct();
107 seg.ident
108 } else {
109 format_ident!("call")
110 };
111
112 let call_generics = if input.peek(Token![<]) {
113 let generics: Generics = input.parse()?;
114 generics.params.into_iter().collect()
115 } else {
116 vec![]
117 };
118
119 let param_buf;
120 parenthesized!(param_buf in input);
121 let param: Ident = param_buf.parse()?;
122 param_buf.parse::<Token![:]>()?;
123 let param_ty: Type = param_buf.parse()?;
124
125 input.parse::<Token![;]>()?;
126
127 layers.push(AttrLayer {
128 func,
129 func_args,
130 handler,
131 method,
132 call_generics,
133 param,
134 param_ty,
135 });
136 }
137 Ok(CartesianAttrInput { layers })
138 }
139}
140
141fn params_to_args(params: &[&GenericParam]) -> Vec<TokenStream2> {
145 params
146 .iter()
147 .map(|p| match p {
148 GenericParam::Type(t) => {
149 let id = &t.ident;
150 quote! { #id }
151 }
152 GenericParam::Const(c) => {
153 let id = &c.ident;
154 quote! { #id }
155 }
156 GenericParam::Lifetime(l) => {
157 let lt = &l.lifetime;
158 quote! { #lt }
159 }
160 })
161 .collect()
162}
163
164fn phantom_type(outer_generics: &[GenericParam]) -> TokenStream2 {
166 let tys: Vec<TokenStream2> = outer_generics
167 .iter()
168 .filter_map(|p| match p {
169 GenericParam::Type(t) => {
170 let id = &t.ident;
171 Some(quote! { #id })
172 }
173 GenericParam::Lifetime(l) => {
174 let lt = &l.lifetime;
175 Some(quote! { &#lt () })
176 }
177 GenericParam::Const(_) => None,
178 })
179 .collect();
180 quote! { (#(#tys,)*) }
181}
182
183fn pat_idents(pat: &Pat) -> Vec<Ident> {
185 match pat {
186 Pat::Ident(p) if p.ident != "_" => vec![p.ident.clone()],
187 Pat::Tuple(p) => p.elems.iter().flat_map(pat_idents).collect(),
188 Pat::Wild(_) => vec![],
189 Pat::Reference(r) => pat_idents(&r.pat),
190 _ => vec![],
191 }
192}
193
194fn shadow_env_traits() -> TokenStream2 {
196 quote! {
197 #[allow(dead_code)]
198 struct __CartesianWrap<T>(T);
199
200 #[allow(dead_code)]
201 trait __ShadowMutMut { type Out; fn shadow_env(self) -> Self::Out; }
202 impl<'__a, '__b, T: ?Sized> __ShadowMutMut for __CartesianWrap<&'__a mut &'__b mut T> {
203 type Out = &'__a mut T;
204 #[inline(always)] fn shadow_env(self) -> Self::Out { self.0 }
205 }
206
207 #[allow(dead_code)]
208 trait __ShadowMutRef { type Out; fn shadow_env(self) -> Self::Out; }
209 impl<'__a, '__b, T: ?Sized> __ShadowMutRef for __CartesianWrap<&'__a mut &'__b T> {
210 type Out = &'__b T;
211 #[inline(always)] fn shadow_env(self) -> Self::Out { *self.0 }
212 }
213
214 #[allow(dead_code)]
215 trait __ShadowVal { type Out; fn shadow_env(self) -> Self::Out; }
216 impl<'__a, T: ::core::clone::Clone> __ShadowVal for &__CartesianWrap<&'__a mut T> {
217 type Out = T;
218 #[inline(always)] fn shadow_env(self) -> Self::Out { self.0.clone() }
219 }
220 }
221}
222
223enum ArgCaptureTyped<'a> {
227 MutRef(&'a Ident, &'a Type),
229 SharedRef(&'a Ident, &'a Type),
231 Value(&'a Ident, &'a Type),
233}
234
235fn capturable_args_fn(layer: &AttrLayer) -> Vec<(usize, ArgCaptureTyped<'_>)> {
237 let mut result = Vec::new();
238 let mut nh = 0usize;
239 for arg in &layer.func_args {
240 match arg {
241 AttrFuncArg::Hole => {}
242 AttrFuncArg::Named(name, ty) => {
243 let cap = match ty {
244 Type::Reference(r) if r.mutability.is_some() => {
245 ArgCaptureTyped::MutRef(name, &*r.elem)
246 }
247 Type::Reference(r) => ArgCaptureTyped::SharedRef(name, &*r.elem),
248 _ => ArgCaptureTyped::Value(name, ty),
249 };
250 result.push((nh, cap));
251 nh += 1;
252 }
253 }
254 }
255 result
256}
257
258fn gen_body_with_env(env: Option<&EnvCapture>, body: &Block) -> TokenStream2 {
262 let Some(env) = env else {
263 return quote! { #body };
264 };
265
266 let env_ty = &env.ty;
267 let env_pat = &env.pat;
268 let vars = pat_idents(env_pat);
269
270 let traits = shadow_env_traits();
271
272 let unpack = if vars.is_empty() {
273 quote! {}
274 } else if vars.len() == 1 {
275 quote! {
276 let __cartesian_env_ref = self.__env as *mut #env_ty;
277 #[allow(unused_variables)]
278 let #env_pat = unsafe { &mut *__cartesian_env_ref };
279 #[allow(unused_variables)]
280 let #env_pat = __CartesianWrap(#env_pat).shadow_env();
281 }
282 } else {
283 let shadow_calls: Vec<_> = vars
284 .iter()
285 .map(|v| quote! { __CartesianWrap(#v).shadow_env() })
286 .collect();
287 quote! {
288 let __cartesian_env_ref = self.__env as *mut #env_ty;
289 #[allow(unused_variables)]
290 let #env_pat = unsafe { &mut *__cartesian_env_ref };
291 #[allow(unused_variables)]
292 let (#(#vars,)*) = (#(#shadow_calls,)*);
293 }
294 };
295
296 quote! { #traits #unpack #body }
297}
298
299struct CtxFn<'a> {
302 layers: &'a [AttrLayer],
303 outer_generics: &'a [GenericParam],
304 env_capture: Option<&'a EnvCapture>,
305 fn_body: &'a Block,
306 depth: usize,
307 acc_call_generics: Vec<GenericParam>,
308 captured: Vec<(Ident, Ident, Type)>,
310 env_ptr: TokenStream2,
312}
313
314fn gen_layer_fn(ctx: &CtxFn) -> TokenStream2 {
315 let depth = ctx.depth;
316 let layer = &ctx.layers[depth];
317 let struct_name = format_ident!("__CartesianL{}", depth);
318
319 let outer_g = ctx.outer_generics;
321 let all_g: Vec<&GenericParam> = outer_g.iter().chain(ctx.acc_call_generics.iter()).collect();
322 let all_g_args = params_to_args(&all_g);
323 let phantom = phantom_type(outer_g);
324
325 let mut field_defs: Vec<TokenStream2> = ctx
327 .captured
328 .iter()
329 .map(|(f, _, ty)| quote! { #f: #ty })
330 .collect();
331
332 for l in (depth + 1)..ctx.layers.len() {
333 for (i, cap) in capturable_args_fn(&ctx.layers[l]) {
334 match cap {
335 ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
336 let f = format_ident!("__l{}_a{}", l, i);
337 field_defs.push(quote! { #f: *mut () });
338 }
339 ArgCaptureTyped::Value(_, ty) => {
340 let f = format_ident!("__l{}_v{}", l, i);
341 field_defs.push(quote! { #f: #ty });
342 }
343 }
344 }
345 }
346
347 let struct_def = if all_g.is_empty() {
348 quote! {
349 #[allow(non_local_definitions)]
350 struct #struct_name {
351 __env: *mut (),
352 __marker: ::core::marker::PhantomData<#phantom>,
353 #(#field_defs,)*
354 }
355 }
356 } else {
357 quote! {
358 #[allow(non_local_definitions)]
359 struct #struct_name<#(#all_g),*> {
360 __env: *mut (),
361 __marker: ::core::marker::PhantomData<#phantom>,
362 #(#field_defs,)*
363 }
364 }
365 };
366
367 let handler = &layer.handler;
369 let method = &layer.method;
370 let call_generics = &layer.call_generics;
371 let param = &layer.param;
372 let param_ty = &layer.param_ty;
373
374 let field_name = format_ident!("__cartesian_p{}", depth);
375
376 let mut new_captured = ctx.captured.clone();
377 new_captured.push((field_name.clone(), param.clone(), param_ty.clone()));
378
379 let mut new_acc_generics = ctx.acc_call_generics.clone();
380 new_acc_generics.extend(call_generics.iter().cloned());
381
382 let clone_stmts: Vec<_> = ctx
384 .captured
385 .iter()
386 .map(|(f, name, _)| quote! { let #name = self.#f.clone(); })
387 .collect();
388
389 let call_body = if depth + 1 == ctx.layers.len() {
390 let body_code = gen_body_with_env(ctx.env_capture, ctx.fn_body);
391 quote! { #(#clone_stmts)* #body_code }
392 } else {
393 let next_l = depth + 1;
394 let recovery_stmts: Vec<TokenStream2> = capturable_args_fn(&ctx.layers[next_l])
395 .into_iter()
396 .map(|(i, cap)| match cap {
397 ArgCaptureTyped::MutRef(_, inner_ty) => {
398 let field = format_ident!("__l{}_a{}", next_l, i);
399 let local = format_ident!("__l{}_a{}_local", next_l, i);
400 quote! { let #local = unsafe { &mut *(self.#field as *mut #inner_ty) }; }
401 }
402 ArgCaptureTyped::SharedRef(_, inner_ty) => {
403 let field = format_ident!("__l{}_a{}", next_l, i);
404 let local = format_ident!("__l{}_a{}_local", next_l, i);
405 quote! { let #local = unsafe { *(self.#field as *const &#inner_ty) }; }
407 }
408 ArgCaptureTyped::Value(_, _) => {
409 let field = format_ident!("__l{}_v{}", next_l, i);
410 let local = format_ident!("__l{}_v{}_local", next_l, i);
411 quote! { let #local = self.#field.clone(); }
412 }
413 })
414 .collect();
415
416 let next = gen_layer_fn(&CtxFn {
417 layers: ctx.layers,
418 outer_generics: ctx.outer_generics,
419 env_capture: ctx.env_capture,
420 fn_body: ctx.fn_body,
421 depth: depth + 1,
422 acc_call_generics: new_acc_generics,
423 captured: new_captured.clone(),
424 env_ptr: quote! { self.__env },
425 });
426 quote! { #(#clone_stmts)* #(#recovery_stmts)* #next }
427 };
428
429 let call_generic_decl = if call_generics.is_empty() {
431 quote! {}
432 } else {
433 quote! { <#(#call_generics),*> }
434 };
435
436 let impl_block = if all_g.is_empty() {
437 quote! {
438 #[allow(non_local_definitions)]
439 impl #handler for #struct_name {
440 fn #method #call_generic_decl (&mut self, #param: #param_ty) {
441 #call_body
442 }
443 }
444 }
445 } else {
446 quote! {
447 #[allow(non_local_definitions)]
448 impl<#(#all_g),*> #handler for #struct_name<#(#all_g_args),*> {
449 fn #method #call_generic_decl (&mut self, #param: #param_ty) {
450 #call_body
451 }
452 }
453 }
454 };
455
456 let env_ptr = &ctx.env_ptr;
458 let captured_init: Vec<_> = ctx
459 .captured
460 .iter()
461 .map(|(f, name, _)| quote! { #f: #name })
462 .collect();
463
464 let handler_binding = format_ident!("__cartesian_handler_{}", depth);
465
466 let mut ptr_field_inits: Vec<TokenStream2> = Vec::new();
467 for l in (depth + 1)..ctx.layers.len() {
468 for (i, cap) in capturable_args_fn(&ctx.layers[l]) {
469 match cap {
470 ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
471 let f = format_ident!("__l{}_a{}", l, i);
472 if depth == 0 {
473 ptr_field_inits.push(quote! { #f: #f });
474 } else {
475 ptr_field_inits.push(quote! { #f: self.#f });
476 }
477 }
478 ArgCaptureTyped::Value(name, _) => {
479 let f = format_ident!("__l{}_v{}", l, i);
480 if depth == 0 {
481 ptr_field_inits.push(quote! { #f: #name });
482 } else {
483 ptr_field_inits.push(quote! { #f: self.#f.clone() });
484 }
485 }
486 }
487 }
488 }
489
490 let handler_init = if all_g.is_empty() {
491 quote! {
492 let mut #handler_binding = #struct_name {
493 __env: #env_ptr,
494 __marker: ::core::marker::PhantomData,
495 #(#captured_init,)*
496 #(#ptr_field_inits,)*
497 };
498 }
499 } else {
500 quote! {
501 let mut #handler_binding: #struct_name<#(#all_g_args),*> = #struct_name {
502 __env: #env_ptr,
503 __marker: ::core::marker::PhantomData,
504 #(#captured_init,)*
505 #(#ptr_field_inits,)*
506 };
507 }
508 };
509
510 let func = &layer.func;
511 let caps = capturable_args_fn(layer);
512 let func_args: Vec<_> = {
513 let mut cap_iter = caps.iter();
514 layer
515 .func_args
516 .iter()
517 .map(|arg| match arg {
518 AttrFuncArg::Hole => quote! { &mut #handler_binding },
519 AttrFuncArg::Named(name, _) => {
520 let (nh, cap) = cap_iter.next().unwrap();
521 if depth > 0 {
522 match cap {
523 ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
524 let local = format_ident!("__l{}_a{}_local", depth, nh);
525 quote! { #local }
526 }
527 ArgCaptureTyped::Value(_, _) => {
528 let local = format_ident!("__l{}_v{}_local", depth, nh);
529 quote! { #local }
530 }
531 }
532 } else {
533 quote! { #name }
534 }
535 }
536 })
537 .collect()
538 };
539
540 quote! {
541 #struct_def
542 #impl_block
543 #handler_init
544 #func(#(#func_args),*)
545 }
546}
547
548#[proc_macro_attribute]
551pub fn cartesian_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
552 let parsed_attr = parse_macro_input!(attr as CartesianAttrInput);
553 let mut parsed_fn = parse_macro_input!(item as ItemFn);
554
555 if parsed_attr.layers.is_empty() {
556 return quote! { compile_error!("cartesian_fn requires at least one layer") }.into();
557 }
558
559 let outer_generics: Vec<GenericParam> = parsed_fn.sig.generics.params.iter().cloned().collect();
561
562 let env_params: Vec<(Ident, Type)> = parsed_fn
564 .sig
565 .inputs
566 .iter()
567 .filter_map(|arg| {
568 if let FnArg::Typed(pt) = arg {
569 if let Pat::Ident(pi) = &*pt.pat {
570 Some((pi.ident.clone(), (*pt.ty).clone()))
571 } else {
572 None
573 }
574 } else {
575 None
576 }
577 })
578 .collect();
579
580 for layer in &parsed_attr.layers {
582 for arg in &layer.func_args {
583 if let AttrFuncArg::Named(name, ty) = arg {
584 parsed_fn.sig.inputs.push(syn::parse_quote! { #name: #ty });
585 }
586 }
587 }
588
589 let env_capture: Option<EnvCapture> = match env_params.len() {
591 0 => None,
592 1 => {
593 let (name, ty) = &env_params[0];
594 let pat: Pat = syn::parse_quote! { #name };
595 Some(EnvCapture {
596 pat,
597 ty: ty.clone(),
598 })
599 }
600 _ => {
601 let names: Vec<_> = env_params.iter().map(|(n, _)| n).collect();
602 let tys: Vec<_> = env_params.iter().map(|(_, t)| t).collect();
603 let pat: Pat = syn::parse_quote! { (#(#names),*) };
604 let ty: Type = syn::parse_quote! { (#(#tys),*) };
605 Some(EnvCapture { pat, ty })
606 }
607 };
608
609 let env_setup: TokenStream2 = match env_params.len() {
611 0 => quote! {
612 let __cartesian_env_ptr: *mut () = ::core::ptr::null_mut();
613 },
614 1 => {
615 let (name, ty) = &env_params[0];
616 quote! {
617 let mut __cartesian_env_val: #ty = #name;
618 let __cartesian_env_ptr: *mut () =
619 &mut __cartesian_env_val as *mut _ as *mut ();
620 }
621 }
622 _ => {
623 let names: Vec<_> = env_params.iter().map(|(n, _)| n).collect();
624 let tys: Vec<_> = env_params.iter().map(|(_, t)| t).collect();
625 quote! {
626 let mut __cartesian_env_val: (#(#tys),*) = (#(#names),*);
627 let __cartesian_env_ptr: *mut () =
628 &mut __cartesian_env_val as *mut _ as *mut ();
629 }
630 }
631 };
632
633 let mut arg_preamble = TokenStream2::new();
636 for l in 1..parsed_attr.layers.len() {
637 for (i, cap) in capturable_args_fn(&parsed_attr.layers[l]) {
638 match cap {
639 ArgCaptureTyped::MutRef(name, inner_ty) => {
640 let binding = format_ident!("__l{}_a{}", l, i);
641 arg_preamble.extend(quote! {
642 let #binding: *mut () = #name as *mut #inner_ty as *mut ();
643 });
644 }
645 ArgCaptureTyped::SharedRef(name, _) => {
646 let binding = format_ident!("__l{}_a{}", l, i);
647 arg_preamble.extend(quote! {
649 let #binding: *mut () = (&#name) as *const _ as *mut ();
650 });
651 }
652 ArgCaptureTyped::Value(_, _) => {} }
654 }
655 }
656
657 let fn_body = &parsed_fn.block;
658 let code = gen_layer_fn(&CtxFn {
659 layers: &parsed_attr.layers,
660 outer_generics: &outer_generics,
661 env_capture: env_capture.as_ref(),
662 fn_body,
663 depth: 0,
664 acc_call_generics: vec![],
665 captured: vec![],
666 env_ptr: quote! { __cartesian_env_ptr },
667 });
668
669 *parsed_fn.block = syn::parse_quote! {{
670 #env_setup
671 #arg_preamble
672 #code
673 }};
674
675 quote! { #parsed_fn }.into()
676}