1use proc_macro::{Span, TokenStream};
2use quote::quote;
3use syn::Ident;
4
5const SIZES: [usize; 11] = [2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048];
6
7#[proc_macro]
8pub fn generate_powers_of_two(_input: TokenStream) -> TokenStream {
9 let ss = SIZES.clone().into_iter().map(|s| {
10 let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
11 let half = s / 2;
12 let half_butterfly = Ident::new(&format!("fft{}", half), Span::call_site().into());
13 let half_butterfly_even_idx = (0..s).step_by(2).map(|f|{
14 quote! {
15 x[#f],
16 }
17 });
18 let half_butterfly_odd_idx = (1..s).step_by(2).map(|f|{
19 quote! {
20 x[#f],
21 }
22 });
23
24 let t_s = (0..half).map(|tt|
25 quote! {
26 Complex::exp(Complex::<T>::i() * T::from(-2.0).unwrap() * T::PI() * T::from(#tt).unwrap() / T::from(n).unwrap()) * odd[#tt]
27 }
28 );
29
30 let sum_halves = (0..half).map(|t_e| quote! {
31 even[#t_e] + t[#t_e],
32 }
33 );
34 let sub_halves = (0..half).map(|t_o| quote! {
35 even[#t_o] - t[#t_o],
36 });
37
38 quote! {
39 #[inline]
40 pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
41 let n = #s;
42 let x = input.as_ref();
43 assert_eq!(n, x.len());
44
45 let even: [Complex<T>; #half] = #half_butterfly([
46 #(#half_butterfly_even_idx)*
47 ]);
48 let odd: [Complex<T>; #half] = #half_butterfly([
49 #(#half_butterfly_odd_idx)*
50 ]);
51
52 let t: [Complex<T>; #half] = [
53 #(#t_s),*
54 ];
55
56 [
57 #(#sum_halves)*
58 #(#sub_halves)*
59 ]
60 }
61 }
62 });
63
64 let expanded = quote! {
65 #[inline]
66 pub fn fft1<T: Float>(x: [Complex<T>; 1]) -> [Complex<T>; 1] {
67 x
68 }
69
70 #(#ss)*
71 };
72 proc_macro::TokenStream::from(expanded)
73}