monarch_derive/
lib.rs

1//! This crate is not meant to be used on its own. This crate contains
2//! the under the hood procedural macros to autogenerate
3//! FFTs for [monarch-butterfly](https://crates.io/crates/monarch-butterfly)
4
5use num_complex::Complex;
6use num_integer::Integer;
7use num_traits::{Float, FloatConst};
8use proc_macro::{Span, TokenStream};
9use quote::quote;
10use syn::Ident;
11
12const SIZES: [usize; 194] = [
13    2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30,
14    31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
15    55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
16    79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101,
17    102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
18    121, 122, 123, 124, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
19    141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
20    160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
21    179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
22    198, 199, 200,
23];
24
25const HAND_GEN: [usize; 5] = [3, 9, 18, 27, 125];
26
27#[derive(PartialEq, Debug)]
28enum FFTType {
29    PowerOfTwo,
30    Prime,
31    Coprime,
32    Mixed,
33}
34
35impl FFTType {
36    fn compute_type(n: usize) -> FFTType {
37        if n.is_power_of_two() {
38            FFTType::PowerOfTwo
39        } else {
40            // Check if it has an integer square root
41            let sqr = (n as f64).sqrt().round() as usize;
42            if sqr * sqr == n && sqr % 2 == 1 {
43                FFTType::Mixed
44            } else {
45                let (v1, _v2) = compute_coprimes(n);
46                if v1 == 1 {
47                    FFTType::Prime
48                } else {
49                    FFTType::Coprime
50                }
51            }
52        }
53    }
54}
55
56fn compute_coprimes(n: usize) -> (usize, usize) {
57    let sqr = (n as f32).sqrt().ceil() as usize;
58    for v1 in (1..sqr).rev() {
59        let v2 = n / v1;
60        if v2 * v1 == n {
61            if v1.gcd(&v2) == 1 {
62                return (usize::min(v1, v2), usize::max(v1, v2));
63            }
64        }
65    }
66    todo!()
67}
68
69fn compute_twiddle_forward<T: Float + FloatConst>(index: usize, fft_len: usize) -> Complex<T> {
70    let constant = T::from(-2.0).unwrap() * T::PI() / T::from(fft_len).unwrap();
71    // index * -2PI / fft_len
72    let angle = constant * T::from(index).unwrap();
73
74    Complex::new(angle.cos(), angle.sin())
75}
76
77#[proc_macro]
78pub fn generate_switch(_input: TokenStream) -> TokenStream {
79    let mut all_sizes: Vec<_> = SIZES
80        .clone()
81        .into_iter()
82        .chain(HAND_GEN.clone().into_iter())
83        .collect();
84    all_sizes.sort();
85
86    let ss_forward = all_sizes.clone().into_iter().map(|s| {
87        let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
88
89        quote! {
90            #s => {
91                let x = #func(x_in);
92                core::array::from_fn(|i| x[i])
93             },
94        }
95    });
96    let ss_inverse = all_sizes.into_iter().map(|s| {
97        let func = Ident::new(&format!("ifft{}", s), Span::call_site().into());
98
99        quote! {
100            #s => {
101                let x = #func(x_in);
102                core::array::from_fn(|i| x[i])
103             },
104        }
105    });
106
107    let expanded = quote! {
108        /// Top level FFT function
109        /// 
110        /// ```
111        /// use monarch_butterfly::*;
112        /// use num_complex::Complex;
113        /// 
114        /// let input: Vec<_> = (0..8).map(|i| Complex::new(i as f32, 0.0)).collect();
115        /// let output = fft::<8, _, _>(input);
116        /// ```
117        #[inline(always)]
118        pub fn fft<const N: usize, T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; N] {
119            let x_in = input.as_ref();
120            assert_eq!(x_in.len(), N);
121
122            match N {
123                1 => { core::array::from_fn(|i| x_in[i]) },
124                #(#ss_forward)*
125                _ => unimplemented!(),
126            }
127        }
128
129         /// Top level iFFT function
130        /// 
131        /// ```
132        /// use monarch_butterfly::*;
133        /// use num_complex::Complex;
134        /// 
135        /// let input: Vec<_> = (0..8).map(|i| Complex::new(i as f32, 0.0)).collect();
136        /// let output = ifft::<8, _, _>(input);
137        /// ```
138        #[inline(always)]
139        pub fn ifft<const N: usize, T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; N] {
140            let x_in = input.as_ref();
141            assert_eq!(x_in.len(), N);
142
143            match N {
144                1 => { core::array::from_fn(|i| x_in[i]) },
145                #(#ss_inverse)*
146                _ => unimplemented!(),
147            }
148        }
149    };
150    proc_macro::TokenStream::from(expanded)
151}
152
153#[proc_macro]
154pub fn generate_powers_of_two(_input: TokenStream) -> TokenStream {
155    let sizes = SIZES
156        .clone()
157        .into_iter()
158        .filter(|n| FFTType::compute_type(*n) == FFTType::PowerOfTwo);
159    let ss = sizes.map(|s| {
160        let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
161        let half = s / 2;
162        let half_butterfly = Ident::new(&format!("fft{}", half), Span::call_site().into());
163        let half_butterfly_even_idx = (0..s).step_by(2).map(|f|{
164            quote! {
165                x[#f],
166            }
167        });
168        let half_butterfly_odd_idx = (1..s).step_by(2).map(|f|{
169            quote! {
170                x[#f],
171            }
172        });
173
174        let t_s = (0..half).map(|tt|
175            quote! {
176                Complex::exp(Complex::<T>::i() * T::from(-2.0).unwrap() * T::PI() * T::from(#tt).unwrap() / T::from(n).unwrap()) * odd[#tt]
177            }
178        );
179
180        let sum_halves = (0..half).map(|t_e| quote! {
181            even[#t_e] + t[#t_e],
182        }
183        );
184        let sub_halves = (0..half).map(|t_o| quote! {
185            even[#t_o] - t[#t_o],
186        });
187
188        quote! {
189            #[doc = concat!("Inner FFT")]
190            #[inline(always)]
191            pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
192                let n = #s;
193                let x = input.as_ref();
194                assert_eq!(n, x.len());
195
196                let even: [Complex<T>; #half] = #half_butterfly([
197                    #(#half_butterfly_even_idx)*
198                ]);
199                let odd: [Complex<T>; #half] = #half_butterfly([
200                    #(#half_butterfly_odd_idx)*
201                ]);
202
203                let t: [Complex<T>; #half] = [
204                    #(#t_s),*
205                ];
206
207                [
208                    #(#sum_halves)*
209                    #(#sub_halves)*
210                ]
211            }
212        }
213    });
214
215    let expanded = quote! {
216        
217        #[inline(always)]
218        pub fn fft1<T: Float, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; 1] {
219            let n = 1;
220            let x = input.as_ref();
221            assert_eq!(n, x.len());
222
223            [x[0]]
224        }
225
226        #(#ss)*
227    };
228    proc_macro::TokenStream::from(expanded)
229}
230
231#[proc_macro]
232pub fn generate_coprimes(_input: TokenStream) -> TokenStream {
233    let sizes = SIZES
234        .clone()
235        .into_iter()
236        .filter(|n| FFTType::compute_type(*n) == FFTType::Coprime);
237    let ss = sizes.map(|s| {
238        let (c1, c2) = compute_coprimes(s);
239        let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
240        let func1 = Ident::new(&format!("fft{}", c1), Span::call_site().into());
241        let func2 = Ident::new(&format!("fft{}", c2), Span::call_site().into());
242
243        let rows = (0..c2).map(|i|  {
244            let mut start = c1 * i;
245            let idx = (0..c1).map(|_| {
246                let index = start;
247                start = (start + c2) % s;
248                quote! {
249                    x[#index],
250                }}
251            );
252            let row_call = Ident::new(&format!("row{}", i), Span::call_site().into());
253
254            quote! {
255                let #row_call = #func1([ #(#idx)* ]);
256        }});
257
258        let cols = (0..c1).map(|i| {
259            let idx = (0..c2).map(|ii| {
260                let row_call = Ident::new(&format!("row{}", ii), Span::call_site().into());
261                quote! {
262                    #row_call[#i]
263                }
264            });
265
266            let col_call = Ident::new(&format!("col{}", i), Span::call_site().into());
267
268            quote! {
269                let #col_call = #func2([ #(#idx),*]);
270            }
271        });
272
273        let combine = (0..s).map(|i| {
274            let col = i % c1;
275            let idx = i % c2;
276            let f = Ident::new(&format!("col{}", col), Span::call_site().into());
277            quote! {
278                #f[#idx],
279            }
280        });
281
282        quote! {
283            #[doc = concat!("Inner FFT")]
284            #[inline(always)]
285            pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
286                let n = #s;
287                let x = input.as_ref();
288                assert_eq!(n, x.len());
289
290                #(#rows)*
291                #(#cols)*
292
293
294                [#(#combine)*]
295
296            }
297        }
298    });
299
300    let expanded = quote! {
301        #(#ss)*
302    };
303    proc_macro::TokenStream::from(expanded)
304}
305
306#[proc_macro]
307pub fn generate_mixed_radix(_input: TokenStream) -> TokenStream {
308    let sizes = SIZES
309        .clone()
310        .into_iter()
311        .filter(|n| FFTType::compute_type(*n) == FFTType::Mixed);
312    let ss = sizes.map(|s| {
313        let c1 = (s as f64).sqrt().round() as usize;
314        let c2 = c1;
315        let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
316        let func1 = Ident::new(&format!("fft{}", c1), Span::call_site().into());
317        let func2 = Ident::new(&format!("fft{}", c2), Span::call_site().into());
318
319        let rows = (0..c1).map(|i|  {
320            let idx = (i..s).step_by(c1).map(|xx| {
321                let index = xx % s;
322                quote! {
323                    x[#index],
324                }}
325            );
326            let row_call = Ident::new(&format!("row{}", i), Span::call_site().into());
327
328            quote! {
329                let #row_call = #func2([ #(#idx)* ]);
330        }});
331
332        let mut twiddles = vec![Complex::<f64>::new(0.0, 0.0); s];
333        for (x, twiddle_chunk) in twiddles.chunks_exact_mut(c2).enumerate() {
334            for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
335                *twiddle_element = compute_twiddle_forward(x * y, s);
336            }
337        }
338
339        let cols = (0..c2).map(|i| {
340            let mut start_idx = i;
341            let idx = (0..c1).map(|ii| {
342                let row_call = Ident::new(&format!("row{}", ii), Span::call_site().into());
343                let re = twiddles[start_idx].re;
344                let im = twiddles[start_idx].im;
345                start_idx += c2;
346                quote! {
347                    #row_call[#i] * Complex::new(T::from(#re).unwrap(), T::from(#im).unwrap())
348                }
349            });
350
351            let col_call = Ident::new(&format!("col{}", i), Span::call_site().into());
352
353            quote! {
354                let #col_call = #func1([ #(#idx),*]);
355            }
356        });
357
358        let combine = (0..s).map(|i| {
359            let col = i % c2;
360            let idx = i / c2;
361            let f = Ident::new(&format!("col{}", col), Span::call_site().into());
362            quote! {
363                #f[#idx],
364            }
365        });
366
367        quote! {
368            #[doc = concat!("Inner FFT")]
369            #[inline(always)]
370            pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
371                let n = #s;
372                let x = input.as_ref();
373                assert_eq!(n, x.len());
374
375                #(#rows)*
376                #(#cols)*
377
378
379                [#(#combine)*]
380
381            }
382        }
383    });
384
385    let expanded = quote! {
386        #(#ss)*
387    };
388    proc_macro::TokenStream::from(expanded)
389}
390
391#[proc_macro]
392pub fn generate_primes(_input: TokenStream) -> TokenStream {
393    let sizes = SIZES
394        .clone()
395        .into_iter()
396        .filter(|n| FFTType::compute_type(*n) == FFTType::Prime);
397    let ss = sizes.map(|s| {
398        let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
399        let halflen = (s + 1) / 2;
400        let twiddles = (1..halflen).map(|n| {
401            let var = Ident::new(&format!("twiddle{}", n), Span::call_site().into());
402            let val: Complex<f64> = compute_twiddle_forward(n, s);
403            let re = val.re;
404            let im = val.im;
405            quote! {
406                let #var = Complex::new(T::from(#re).unwrap(), T::from(#im).unwrap());
407            }
408        });
409        let first_codegen = (1..halflen).map(|n| {
410            let var1 = Ident::new(&format!("x{}{}p", n, s - n), Span::call_site().into());
411            let var2 = Ident::new(&format!("x{}{}n", n, s - n), Span::call_site().into());
412            quote! {
413                let #var1 = x[#n] + x[#s - #n];
414                let #var2 = x[#n] - x[#s - #n];
415            }
416        });
417        let second_codegen = (1..halflen).map(|n| {
418            let var = Ident::new(&format!("x{}{}p", n, s - n), Span::call_site().into());
419            quote! {
420                + #var
421            }
422        });
423        let third_codegen = (1..halflen).map(|n| {
424            let var1 = Ident::new(&format!("b{}{}re_a", n, s - n), Span::call_site().into());
425            let sub1 = (1..halflen).map(|m| {
426
427                let mut mn = (m * n) % s;
428                if mn > s / 2 {
429                    mn = s - mn;
430                }
431                let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
432                let var3 = Ident::new(&format!("x{}{}p", m, s - m), Span::call_site().into());
433                quote! {
434                    + #var2.re * #var3.re
435                }
436            });
437
438            quote! {
439                let #var1 = x[0].re #(#sub1)* ;
440            }
441        });
442
443        let fourth_codegen = (1..halflen).map(|n| {
444            let var1 = Ident::new(&format!("b{}{}re_b", n, s - n), Span::call_site().into());
445            let sub1 = (1..halflen).map(|m| {
446
447                let mut mn = (m * n) % s;
448                if mn > s / 2 {
449                    mn = s - mn;
450                    let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
451                    let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
452                    quote! {
453                        - #var2.im * #var3.im
454                    }
455                } else {
456                    let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
457                    let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
458                    quote! {
459                        + #var2.im * #var3.im
460                    }
461                }
462            });
463
464            quote! {
465                let #var1 = T::zero() #(#sub1)* ;
466            }
467        });
468        let fifth_codegen = (1..halflen).map(|n| {
469            let var1 = Ident::new(&format!("b{}{}im_a", n, s - n), Span::call_site().into());
470            let sub1 = (1..halflen).map(|m| {
471
472                let mut mn = (m * n) % s;
473                if mn > s / 2 {
474                    mn = s - mn;
475                }
476                let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
477                let var3 = Ident::new(&format!("x{}{}p", m, s - m), Span::call_site().into());
478                quote! {
479                    + #var2.re * #var3.im
480                }
481            });
482
483            quote! {
484                let #var1 = x[0].im #(#sub1)* ;
485            }
486        });
487
488        let sixth_codegen = (1..halflen).map(|n| {
489            let var1 = Ident::new(&format!("b{}{}im_b", n, s - n), Span::call_site().into());
490            let sub1 = (1..halflen).map(|m| {
491
492                let mut mn = (m * n) % s;
493                if mn > s / 2 {
494                    mn = s - mn;
495                    let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
496                    let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
497                    quote! {
498                        - #var2.im * #var3.re
499                    }
500                } else {
501                    let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
502                    let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
503                    quote! {
504                        + #var2.im * #var3.re
505                    }
506                }
507            });
508
509            quote! {
510                let #var1 = T::zero() #(#sub1)* ;
511            }
512        });
513
514        let seventh_codegen = (1..s).map(|n| {
515            let mut nfold = n;
516            if n > s / 2 {
517                nfold = s - n;
518                let var1 = Ident::new(&format!("out{}re", n), Span::call_site().into());
519                let var2 = Ident::new(&format!("out{}im", n), Span::call_site().into());
520                let var3 = Ident::new(&format!("b{}{}re_a", nfold, s-nfold), Span::call_site().into());
521                let var4 = Ident::new(&format!("b{}{}re_b", nfold, s-nfold), Span::call_site().into());
522                let var5 = Ident::new(&format!("b{}{}im_a", nfold, s-nfold), Span::call_site().into());
523                let var6 = Ident::new(&format!("b{}{}im_b", nfold, s-nfold), Span::call_site().into());
524                quote! {
525                    let #var1 = #var3 + #var4;
526                    let #var2 = #var5 - #var6;
527                }
528            } else {
529                let var1 = Ident::new(&format!("out{}re", n), Span::call_site().into());
530                let var2 = Ident::new(&format!("out{}im", n), Span::call_site().into());
531                let var3 = Ident::new(&format!("b{}{}re_a", nfold, s-nfold), Span::call_site().into());
532                let var4 = Ident::new(&format!("b{}{}re_b", nfold, s-nfold), Span::call_site().into());
533                let var5 = Ident::new(&format!("b{}{}im_a", nfold, s-nfold), Span::call_site().into());
534                let var6 = Ident::new(&format!("b{}{}im_b", nfold, s-nfold), Span::call_site().into());
535                quote!{
536                    let #var1 = #var3 - #var4;
537                    let #var2 = #var5 + #var6;
538                }
539            }
540        });
541        let eigth_codegen = (1..s).map(|n| {
542            let var_re: Ident = Ident::new(&format!("out{}re", n), Span::call_site().into());
543            let var_im: Ident = Ident::new(&format!("out{}im", n), Span::call_site().into());
544            quote! {
545                Complex::new(#var_re, #var_im),
546            }
547        });
548
549        quote! {
550            #[doc = concat!("Inner FFT")]
551            #[inline(always)]
552            pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
553                let n = #s;
554                let x = input.as_ref();
555                assert_eq!(n, x.len());
556
557                #(#twiddles)*
558
559                #(#first_codegen)*
560                let sum = x[0] #(#second_codegen)* ;
561                #(#third_codegen)*
562                #(#fourth_codegen)*
563                #(#fifth_codegen)*
564                #(#sixth_codegen)*
565                #(#seventh_codegen)*
566
567                [
568                    sum,
569                    #(#eigth_codegen)*
570                ]
571            }
572
573    }});
574
575    let expanded = quote! {
576        #(#ss)*
577    };
578    proc_macro::TokenStream::from(expanded)
579}
580
581#[proc_macro]
582pub fn generate_iffts(_input: TokenStream) -> TokenStream {
583    let mut all_sizes: Vec<_> = SIZES
584        .clone()
585        .into_iter()
586        .chain(HAND_GEN.clone().into_iter())
587        .collect();
588    all_sizes.sort();
589    let iffts = all_sizes.into_iter().map(|n| {
590        let func = Ident::new(&format!("ifft{}", n), Span::call_site().into());
591        let input_args = (0..n).map(|i| {
592
593            quote! {
594                x[#i].conj(),
595            }
596        });
597        let output_args = (0..n).map(|i| {
598
599            quote! {
600                out[#i].conj(),
601            }
602        });
603
604        quote! {
605            #[doc = concat!("Inner iFFT")]
606            #[inline(always)]
607            pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #n] {
608                let x = input.as_ref();
609                assert_eq!(x.len(), #n);
610
611                let out: [Complex<T>; #n] = fft::<#n, _, _>([
612                    #(#input_args)*
613                ]);
614                [
615                    #(#output_args)*
616                ]
617
618            }
619        }
620    });
621
622    let expanded = quote! {
623        #[doc = concat!("Inner iFFT")]
624        #[inline(always)]
625        pub fn ifft1<T: Float, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; 1] {
626            let n = 1;
627            let x = input.as_ref();
628            assert_eq!(n, x.len());
629
630            [x[0]]
631        }
632
633        #(#iffts)*
634    };
635    proc_macro::TokenStream::from(expanded)
636}
637
638#[cfg(test)]
639mod tests {
640    use crate::*;
641
642    #[test]
643    fn test_coprimes() {
644        let coprimes = vec![
645            (2, 3),
646            (2, 5),
647            (3, 4),
648            (2, 7),
649            (3, 5),
650            (4, 5),
651            (3, 7),
652            (2, 11),
653            (3, 8),
654            (2, 13),
655            (4, 7),
656            (5, 6),
657            (3, 11),
658            (2, 17),
659            (5, 7),
660            (2, 19),
661            (3, 13),
662            (5, 8),
663            (6, 7),
664            (4, 11),
665            (5, 9),
666            (2, 23),
667            (2, 25),
668            (3, 17),
669            (4, 13),
670            (2, 27),
671            (5, 11),
672            (7, 8),
673            (3, 19),
674            (2, 29),
675            (5, 12),
676            (2, 31),
677            (7, 9),
678            (5, 13),
679            (6, 11),
680            (4, 17),
681            (3, 23),
682            (7, 10),
683            (8, 9),
684            (2, 37),
685            (3, 25),
686            (4, 19),
687            (7, 11),
688            (6, 13),
689            (5, 16),
690            (2, 41),
691            (7, 12),
692            (5, 17),
693            (2, 43),
694            (3, 29),
695        ];
696        for (v1, v2) in coprimes {
697            let n = v1 * v2;
698            let (computed_v1, computed_v2) = compute_coprimes(n);
699            dbg!(n, v1, v2, computed_v1, computed_v2);
700            assert_eq!(v1, computed_v1);
701            assert_eq!(v2, computed_v2);
702        }
703    }
704
705    #[test]
706    fn test_fft_type() {
707        assert_eq!(FFTType::compute_type(25), FFTType::Mixed);
708        assert_eq!(FFTType::compute_type(36), FFTType::Coprime);
709    }
710}