monarch_derive/
lib.rs

1use proc_macro::{Span, TokenStream};
2use quote::quote;
3use syn::Ident;
4
5const SIZES: [usize; 11] = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048];
6
7#[proc_macro]
8pub fn generate_powers_of_two(_input: TokenStream) -> TokenStream {
9    let ss = SIZES.clone().into_iter().map(|s| {
10        let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
11        let half = s / 2;
12        let half_butterfly = Ident::new(&format!("fft{}", half), Span::call_site().into());
13        let half_butterfly_even_idx = (0..s).step_by(2).map(|f|{
14            quote! {
15                x[#f],
16            }
17        });
18        let half_butterfly_odd_idx = (1..s).step_by(2).map(|f|{
19            quote! {
20                x[#f],
21            }
22        });
23
24        let t_s = (0..half).map(|tt|
25            quote! {
26                Complex::exp(Complex::<T>::i() * T::from(-2.0).unwrap() * T::PI() * T::from(#tt).unwrap() / T::from(n).unwrap()) * odd[#tt]
27            }
28        );
29
30        let sum_halves = (0..half).map(|t_e| quote! {
31            even[#t_e] + t[#t_e],
32        }
33        );
34        let sub_halves = (0..half).map(|t_o| quote! {
35            even[#t_o] - t[#t_o],
36        });
37
38        quote! {
39            #[inline]
40            pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
41                let n = #s;
42                let x = input.as_ref();
43                assert_eq!(n, x.len());
44
45                let even: [Complex<T>; #half] = #half_butterfly([
46                    #(#half_butterfly_even_idx)*
47                ]);
48                let odd: [Complex<T>; #half] = #half_butterfly([
49                    #(#half_butterfly_odd_idx)*
50                ]);
51
52                let t: [Complex<T>; #half] = [
53                    #(#t_s),*
54                ];
55
56                [
57                    #(#sum_halves)*
58                    #(#sub_halves)*
59                ]
60            }
61        }
62    });
63
64    let expanded = quote! {
65        #[inline]
66        pub fn fft1<T: Float>(x: [Complex<T>; 1]) -> [Complex<T>; 1] {
67            x
68        }
69
70        #(#ss)*
71    };
72    proc_macro::TokenStream::from(expanded)
73}