extensor_macros/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::Span;
3use quote::quote;
4use syn::{parse_macro_input, DeriveInput, GenericParam, Ident};
5
6#[proc_macro_derive(MultilinearMap)]
7pub fn multilinear_map_derive(input: TokenStream) -> TokenStream {
8    let ast = parse_macro_input!(input as DeriveInput);
9    let struct_name = &ast.ident;
10
11    let generics = &ast.generics;
12    let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();
13
14    // Type?
15    let const_generics: Vec<_> = generics
16        .params
17        .iter()
18        .filter_map(|param| match param {
19            GenericParam::Const(const_param) => Some(&const_param.ident),
20            _ => None,
21        })
22        .collect();
23
24    // use proc_macro2::{Ident, Span};
25    //
26    // Span::call_site()
27    //
28    // pub fn multilinear_map(
29    //     &self,
30    //     v_0: V<M, F>,
31    //     v_1: V<N, F>,
32    //     v_2: V<P, F>,
33    // )
34    let input_params = const_generics.iter().enumerate().map(|(i, ident)| {
35        // Is this identifier actually scoped correctly to the function? Or, by using
36        // ::new raw, have we just made a global or something? span => scope?
37        //
38        // let param_name = Ident::new(&format!("v_{}", i), ident.span());
39        let param_name = Ident::new(&format!("v_{}", i), Span::call_site());
40        quote! { #param_name: V<#ident, F> }
41    });
42
43    let loop_indices: Vec<_> = (0..const_generics.len())
44        .map(|i| Ident::new(&format!("i_{}", i), proc_macro2::Span::call_site()))
45        .collect();
46
47    let component_product = loop_indices.iter().zip(0..).map(|(index, i)| {
48        let param_name = Ident::new(&format!("v_{}", i), index.span());
49        quote! { * #param_name.0[#index] }
50    });
51
52    // Add the calculation to the innermost loop
53    let coefficient_access =
54        loop_indices
55            .iter()
56            .fold(quote! { self.coefficients }, |acc, index| {
57                quote! { #acc.0[#index] }
58            });
59
60    let mut loop_nest = quote! {
61        sum += #coefficient_access #(#component_product)*;
62    };
63
64    for (index, ident) in loop_indices.iter().rev().zip(const_generics.iter().rev()) {
65        loop_nest = quote! {
66            for #index in 0..#ident {
67                #loop_nest
68            }
69        };
70    }
71
72    loop_nest = quote! {
73        #loop_nest
74
75    };
76
77    let expanded = quote! {
78        impl #impl_generics #struct_name #ty_generics #where_clause {
79            pub fn multilinear_map(&self, #(#input_params),*) -> F {
80                let mut sum = F::default();
81                #loop_nest
82                sum
83            }
84        }
85    };
86
87    TokenStream::from(expanded)
88}