1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput, GenericParam, Ident};
5
6#[proc_macro_derive(MultilinearMap)]
7pub fn multilinear_map_derive(input: TokenStream) -> TokenStream {
8 let ast = parse_macro_input!(input as DeriveInput);
9 let struct_name = &ast.ident;
10
11 let generics = &ast.generics;
12 let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
13
14 let const_generics: Vec<_> = generics
16 .params
17 .iter()
18 .filter_map(|param| match param {
19 GenericParam::Const(const_param) => Some(&const_param.ident),
20 _ => None,
21 })
22 .collect();
23
24 let input_params = const_generics.iter().enumerate().map(|(i, ident)| {
35 let param_name = Ident::new(&format!("v_{}", i), Span::call_site());
40 quote! { #param_name: V<#ident, F> }
41 });
42
43 let loop_indices: Vec<_> = (0..const_generics.len())
44 .map(|i| Ident::new(&format!("i_{}", i), proc_macro2::Span::call_site()))
45 .collect();
46
47 let component_product = loop_indices.iter().zip(0..).map(|(index, i)| {
48 let param_name = Ident::new(&format!("v_{}", i), index.span());
49 quote! { * #param_name.0[#index] }
50 });
51
52 let coefficient_access =
54 loop_indices
55 .iter()
56 .fold(quote! { self.coefficients }, |acc, index| {
57 quote! { #acc.0[#index] }
58 });
59
60 let mut loop_nest = quote! {
61 sum += #coefficient_access #(#component_product)*;
62 };
63
64 for (index, ident) in loop_indices.iter().rev().zip(const_generics.iter().rev()) {
65 loop_nest = quote! {
66 for #index in 0..#ident {
67 #loop_nest
68 }
69 };
70 }
71
72 loop_nest = quote! {
73 #loop_nest
74
75 };
76
77 let expanded = quote! {
78 impl #impl_generics #struct_name #ty_generics #where_clause {
79 pub fn multilinear_map(&self, #(#input_params),*) -> F {
80 let mut sum = F::default();
81 #loop_nest
82 sum
83 }
84 }
85 };
86
87 TokenStream::from(expanded)
88}