use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::{format_ident, quote};
use syn::{
Block, FnArg, GenericParam, Generics, Ident, ItemFn, Pat, Path, Result, Token, Type,
parenthesized,
parse::{Parse, ParseStream},
parse_macro_input,
};
enum AttrFuncArg {
Hole,
Named(Ident, Type),
}
struct AttrLayer {
func: Ident,
func_args: Vec<AttrFuncArg>,
handler: Path,
method: Ident,
call_generics: Vec<GenericParam>,
param: Ident,
param_ty: Type,
}
struct CartesianAttrInput {
layers: Vec<AttrLayer>,
}
struct EnvCapture {
pat: Pat,
ty: Type,
}
fn parse_attr_func_arg(input: ParseStream) -> Result<AttrFuncArg> {
if input.peek(Token![_]) {
input.parse::<Token![_]>()?;
Ok(AttrFuncArg::Hole)
} else {
let name: Ident = input.parse()?;
input.parse::<Token![:]>()?;
let ty: Type = input.parse()?;
Ok(AttrFuncArg::Named(name, ty))
}
}
impl Parse for CartesianAttrInput {
fn parse(input: ParseStream) -> Result<Self> {
let mut layers = vec![];
while !input.is_empty() {
let func: Ident = input.parse()?;
let args_buf;
parenthesized!(args_buf in input);
let mut func_args = vec![];
loop {
if args_buf.is_empty() {
break;
}
func_args.push(parse_attr_func_arg(&args_buf)?);
if args_buf.peek(Token![,]) {
args_buf.parse::<Token![,]>()?;
} else {
break;
}
}
input.parse::<Token![=>]>()?;
let mut handler: Path = input.call(Path::parse_mod_style)?;
let method: Ident = if handler.segments.len() > 1 {
let seg = handler.segments.pop().unwrap().into_value();
handler.segments.pop_punct();
seg.ident
} else {
format_ident!("call")
};
let call_generics = if input.peek(Token![<]) {
let generics: Generics = input.parse()?;
generics.params.into_iter().collect()
} else {
vec![]
};
let param_buf;
parenthesized!(param_buf in input);
let param: Ident = param_buf.parse()?;
param_buf.parse::<Token![:]>()?;
let param_ty: Type = param_buf.parse()?;
input.parse::<Token![;]>()?;
layers.push(AttrLayer {
func,
func_args,
handler,
method,
call_generics,
param,
param_ty,
});
}
Ok(CartesianAttrInput { layers })
}
}
fn params_to_args(params: &[&GenericParam]) -> Vec<TokenStream2> {
params
.iter()
.map(|p| match p {
GenericParam::Type(t) => {
let id = &t.ident;
quote! { #id }
}
GenericParam::Const(c) => {
let id = &c.ident;
quote! { #id }
}
GenericParam::Lifetime(l) => {
let lt = &l.lifetime;
quote! { #lt }
}
})
.collect()
}
fn phantom_type(outer_generics: &[GenericParam]) -> TokenStream2 {
let tys: Vec<TokenStream2> = outer_generics
.iter()
.filter_map(|p| match p {
GenericParam::Type(t) => {
let id = &t.ident;
Some(quote! { #id })
}
GenericParam::Lifetime(l) => {
let lt = &l.lifetime;
Some(quote! { &#lt () })
}
GenericParam::Const(_) => None,
})
.collect();
quote! { (#(#tys,)*) }
}
fn pat_idents(pat: &Pat) -> Vec<Ident> {
match pat {
Pat::Ident(p) if p.ident != "_" => vec![p.ident.clone()],
Pat::Tuple(p) => p.elems.iter().flat_map(pat_idents).collect(),
Pat::Wild(_) => vec![],
Pat::Reference(r) => pat_idents(&r.pat),
_ => vec![],
}
}
fn shadow_env_traits() -> TokenStream2 {
quote! {
#[allow(dead_code)]
struct __CartesianWrap<T>(T);
#[allow(dead_code)]
trait __ShadowMutMut { type Out; fn shadow_env(self) -> Self::Out; }
impl<'__a, '__b, T: ?Sized> __ShadowMutMut for __CartesianWrap<&'__a mut &'__b mut T> {
type Out = &'__a mut T;
#[inline(always)] fn shadow_env(self) -> Self::Out { self.0 }
}
#[allow(dead_code)]
trait __ShadowMutRef { type Out; fn shadow_env(self) -> Self::Out; }
impl<'__a, '__b, T: ?Sized> __ShadowMutRef for __CartesianWrap<&'__a mut &'__b T> {
type Out = &'__b T;
#[inline(always)] fn shadow_env(self) -> Self::Out { *self.0 }
}
#[allow(dead_code)]
trait __ShadowVal { type Out; fn shadow_env(self) -> Self::Out; }
impl<'__a, T: ::core::clone::Clone> __ShadowVal for &__CartesianWrap<&'__a mut T> {
type Out = T;
#[inline(always)] fn shadow_env(self) -> Self::Out { self.0.clone() }
}
}
}
enum ArgCaptureTyped<'a> {
MutRef(&'a Ident, &'a Type),
SharedRef(&'a Ident, &'a Type),
Value(&'a Ident, &'a Type),
}
fn capturable_args_fn(layer: &AttrLayer) -> Vec<(usize, ArgCaptureTyped<'_>)> {
let mut result = Vec::new();
let mut nh = 0usize;
for arg in &layer.func_args {
match arg {
AttrFuncArg::Hole => {}
AttrFuncArg::Named(name, ty) => {
let cap = match ty {
Type::Reference(r) if r.mutability.is_some() => {
ArgCaptureTyped::MutRef(name, &*r.elem)
}
Type::Reference(r) => ArgCaptureTyped::SharedRef(name, &*r.elem),
_ => ArgCaptureTyped::Value(name, ty),
};
result.push((nh, cap));
nh += 1;
}
}
}
result
}
fn gen_body_with_env(env: Option<&EnvCapture>, body: &Block) -> TokenStream2 {
let Some(env) = env else {
return quote! { #body };
};
let env_ty = &env.ty;
let env_pat = &env.pat;
let vars = pat_idents(env_pat);
let traits = shadow_env_traits();
let unpack = if vars.is_empty() {
quote! {}
} else if vars.len() == 1 {
quote! {
let __cartesian_env_ref = self.__env as *mut #env_ty;
#[allow(unused_variables)]
let #env_pat = unsafe { &mut *__cartesian_env_ref };
#[allow(unused_variables)]
let #env_pat = __CartesianWrap(#env_pat).shadow_env();
}
} else {
let shadow_calls: Vec<_> = vars
.iter()
.map(|v| quote! { __CartesianWrap(#v).shadow_env() })
.collect();
quote! {
let __cartesian_env_ref = self.__env as *mut #env_ty;
#[allow(unused_variables)]
let #env_pat = unsafe { &mut *__cartesian_env_ref };
#[allow(unused_variables)]
let (#(#vars,)*) = (#(#shadow_calls,)*);
}
};
quote! { #traits #unpack #body }
}
struct CtxFn<'a> {
layers: &'a [AttrLayer],
outer_generics: &'a [GenericParam],
env_capture: Option<&'a EnvCapture>,
fn_body: &'a Block,
depth: usize,
acc_call_generics: Vec<GenericParam>,
captured: Vec<(Ident, Ident, Type)>,
env_ptr: TokenStream2,
}
fn gen_layer_fn(ctx: &CtxFn) -> TokenStream2 {
let depth = ctx.depth;
let layer = &ctx.layers[depth];
let struct_name = format_ident!("__CartesianL{}", depth);
let outer_g = ctx.outer_generics;
let all_g: Vec<&GenericParam> = outer_g.iter().chain(ctx.acc_call_generics.iter()).collect();
let all_g_args = params_to_args(&all_g);
let phantom = phantom_type(outer_g);
let mut field_defs: Vec<TokenStream2> = ctx
.captured
.iter()
.map(|(f, _, ty)| quote! { #f: #ty })
.collect();
for l in (depth + 1)..ctx.layers.len() {
for (i, cap) in capturable_args_fn(&ctx.layers[l]) {
match cap {
ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
let f = format_ident!("__l{}_a{}", l, i);
field_defs.push(quote! { #f: *mut () });
}
ArgCaptureTyped::Value(_, ty) => {
let f = format_ident!("__l{}_v{}", l, i);
field_defs.push(quote! { #f: #ty });
}
}
}
}
let struct_def = if all_g.is_empty() {
quote! {
#[allow(non_local_definitions)]
struct #struct_name {
__env: *mut (),
__marker: ::core::marker::PhantomData<#phantom>,
#(#field_defs,)*
}
}
} else {
quote! {
#[allow(non_local_definitions)]
struct #struct_name<#(#all_g),*> {
__env: *mut (),
__marker: ::core::marker::PhantomData<#phantom>,
#(#field_defs,)*
}
}
};
let handler = &layer.handler;
let method = &layer.method;
let call_generics = &layer.call_generics;
let param = &layer.param;
let param_ty = &layer.param_ty;
let field_name = format_ident!("__cartesian_p{}", depth);
let mut new_captured = ctx.captured.clone();
new_captured.push((field_name.clone(), param.clone(), param_ty.clone()));
let mut new_acc_generics = ctx.acc_call_generics.clone();
new_acc_generics.extend(call_generics.iter().cloned());
let clone_stmts: Vec<_> = ctx
.captured
.iter()
.map(|(f, name, _)| quote! { let #name = self.#f.clone(); })
.collect();
let call_body = if depth + 1 == ctx.layers.len() {
let body_code = gen_body_with_env(ctx.env_capture, ctx.fn_body);
quote! { #(#clone_stmts)* #body_code }
} else {
let next_l = depth + 1;
let recovery_stmts: Vec<TokenStream2> = capturable_args_fn(&ctx.layers[next_l])
.into_iter()
.map(|(i, cap)| match cap {
ArgCaptureTyped::MutRef(_, inner_ty) => {
let field = format_ident!("__l{}_a{}", next_l, i);
let local = format_ident!("__l{}_a{}_local", next_l, i);
quote! { let #local = unsafe { &mut *(self.#field as *mut #inner_ty) }; }
}
ArgCaptureTyped::SharedRef(_, inner_ty) => {
let field = format_ident!("__l{}_a{}", next_l, i);
let local = format_ident!("__l{}_a{}_local", next_l, i);
quote! { let #local = unsafe { *(self.#field as *const &#inner_ty) }; }
}
ArgCaptureTyped::Value(_, _) => {
let field = format_ident!("__l{}_v{}", next_l, i);
let local = format_ident!("__l{}_v{}_local", next_l, i);
quote! { let #local = self.#field.clone(); }
}
})
.collect();
let next = gen_layer_fn(&CtxFn {
layers: ctx.layers,
outer_generics: ctx.outer_generics,
env_capture: ctx.env_capture,
fn_body: ctx.fn_body,
depth: depth + 1,
acc_call_generics: new_acc_generics,
captured: new_captured.clone(),
env_ptr: quote! { self.__env },
});
quote! { #(#clone_stmts)* #(#recovery_stmts)* #next }
};
let call_generic_decl = if call_generics.is_empty() {
quote! {}
} else {
quote! { <#(#call_generics),*> }
};
let impl_block = if all_g.is_empty() {
quote! {
#[allow(non_local_definitions)]
impl #handler for #struct_name {
fn #method #call_generic_decl (&mut self, #param: #param_ty) {
#call_body
}
}
}
} else {
quote! {
#[allow(non_local_definitions)]
impl<#(#all_g),*> #handler for #struct_name<#(#all_g_args),*> {
fn #method #call_generic_decl (&mut self, #param: #param_ty) {
#call_body
}
}
}
};
let env_ptr = &ctx.env_ptr;
let captured_init: Vec<_> = ctx
.captured
.iter()
.map(|(f, name, _)| quote! { #f: #name })
.collect();
let handler_binding = format_ident!("__cartesian_handler_{}", depth);
let mut ptr_field_inits: Vec<TokenStream2> = Vec::new();
for l in (depth + 1)..ctx.layers.len() {
for (i, cap) in capturable_args_fn(&ctx.layers[l]) {
match cap {
ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
let f = format_ident!("__l{}_a{}", l, i);
if depth == 0 {
ptr_field_inits.push(quote! { #f: #f });
} else {
ptr_field_inits.push(quote! { #f: self.#f });
}
}
ArgCaptureTyped::Value(name, _) => {
let f = format_ident!("__l{}_v{}", l, i);
if depth == 0 {
ptr_field_inits.push(quote! { #f: #name });
} else {
ptr_field_inits.push(quote! { #f: self.#f.clone() });
}
}
}
}
}
let handler_init = if all_g.is_empty() {
quote! {
let mut #handler_binding = #struct_name {
__env: #env_ptr,
__marker: ::core::marker::PhantomData,
#(#captured_init,)*
#(#ptr_field_inits,)*
};
}
} else {
quote! {
let mut #handler_binding: #struct_name<#(#all_g_args),*> = #struct_name {
__env: #env_ptr,
__marker: ::core::marker::PhantomData,
#(#captured_init,)*
#(#ptr_field_inits,)*
};
}
};
let func = &layer.func;
let caps = capturable_args_fn(layer);
let func_args: Vec<_> = {
let mut cap_iter = caps.iter();
layer
.func_args
.iter()
.map(|arg| match arg {
AttrFuncArg::Hole => quote! { &mut #handler_binding },
AttrFuncArg::Named(name, _) => {
let (nh, cap) = cap_iter.next().unwrap();
if depth > 0 {
match cap {
ArgCaptureTyped::MutRef(_, _) | ArgCaptureTyped::SharedRef(_, _) => {
let local = format_ident!("__l{}_a{}_local", depth, nh);
quote! { #local }
}
ArgCaptureTyped::Value(_, _) => {
let local = format_ident!("__l{}_v{}_local", depth, nh);
quote! { #local }
}
}
} else {
quote! { #name }
}
}
})
.collect()
};
quote! {
#struct_def
#impl_block
#handler_init
#func(#(#func_args),*)
}
}
#[proc_macro_attribute]
pub fn cartesian_fn(attr: TokenStream, item: TokenStream) -> TokenStream {
let parsed_attr = parse_macro_input!(attr as CartesianAttrInput);
let mut parsed_fn = parse_macro_input!(item as ItemFn);
if parsed_attr.layers.is_empty() {
return quote! { compile_error!("cartesian_fn requires at least one layer") }.into();
}
let outer_generics: Vec<GenericParam> = parsed_fn.sig.generics.params.iter().cloned().collect();
let env_params: Vec<(Ident, Type)> = parsed_fn
.sig
.inputs
.iter()
.filter_map(|arg| {
if let FnArg::Typed(pt) = arg {
if let Pat::Ident(pi) = &*pt.pat {
Some((pi.ident.clone(), (*pt.ty).clone()))
} else {
None
}
} else {
None
}
})
.collect();
for layer in &parsed_attr.layers {
for arg in &layer.func_args {
if let AttrFuncArg::Named(name, ty) = arg {
parsed_fn.sig.inputs.push(syn::parse_quote! { #name: #ty });
}
}
}
let env_capture: Option<EnvCapture> = match env_params.len() {
0 => None,
1 => {
let (name, ty) = &env_params[0];
let pat: Pat = syn::parse_quote! { #name };
Some(EnvCapture {
pat,
ty: ty.clone(),
})
}
_ => {
let names: Vec<_> = env_params.iter().map(|(n, _)| n).collect();
let tys: Vec<_> = env_params.iter().map(|(_, t)| t).collect();
let pat: Pat = syn::parse_quote! { (#(#names),*) };
let ty: Type = syn::parse_quote! { (#(#tys),*) };
Some(EnvCapture { pat, ty })
}
};
let env_setup: TokenStream2 = match env_params.len() {
0 => quote! {
let __cartesian_env_ptr: *mut () = ::core::ptr::null_mut();
},
1 => {
let (name, ty) = &env_params[0];
quote! {
let mut __cartesian_env_val: #ty = #name;
let __cartesian_env_ptr: *mut () =
&mut __cartesian_env_val as *mut _ as *mut ();
}
}
_ => {
let names: Vec<_> = env_params.iter().map(|(n, _)| n).collect();
let tys: Vec<_> = env_params.iter().map(|(_, t)| t).collect();
quote! {
let mut __cartesian_env_val: (#(#tys),*) = (#(#names),*);
let __cartesian_env_ptr: *mut () =
&mut __cartesian_env_val as *mut _ as *mut ();
}
}
};
let mut arg_preamble = TokenStream2::new();
for l in 1..parsed_attr.layers.len() {
for (i, cap) in capturable_args_fn(&parsed_attr.layers[l]) {
match cap {
ArgCaptureTyped::MutRef(name, inner_ty) => {
let binding = format_ident!("__l{}_a{}", l, i);
arg_preamble.extend(quote! {
let #binding: *mut () = #name as *mut #inner_ty as *mut ();
});
}
ArgCaptureTyped::SharedRef(name, _) => {
let binding = format_ident!("__l{}_a{}", l, i);
arg_preamble.extend(quote! {
let #binding: *mut () = (&#name) as *const _ as *mut ();
});
}
ArgCaptureTyped::Value(_, _) => {} }
}
}
let fn_body = &parsed_fn.block;
let code = gen_layer_fn(&CtxFn {
layers: &parsed_attr.layers,
outer_generics: &outer_generics,
env_capture: env_capture.as_ref(),
fn_body,
depth: 0,
acc_call_generics: vec![],
captured: vec![],
env_ptr: quote! { __cartesian_env_ptr },
});
*parsed_fn.block = syn::parse_quote! {{
#env_setup
#arg_preamble
#code
}};
quote! { #parsed_fn }.into()
}