1use num_complex::Complex;
6use num_integer::Integer;
7use num_traits::{Float, FloatConst};
8use proc_macro::{Span, TokenStream};
9use quote::quote;
10use syn::Ident;
11
12const SIZES: [usize; 194] = [
13 2, 4, 5, 6, 7, 8, 10, 11, 12, 13, 14, 15, 16, 17, 19, 20, 21, 22, 23, 24, 25, 26, 28, 29, 30,
14 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, 52, 53, 54,
15 55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78,
16 79, 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101,
17 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120,
18 121, 122, 123, 124, 126, 127, 128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140,
19 141, 142, 143, 144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
20 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178,
21 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191, 192, 193, 194, 195, 196, 197,
22 198, 199, 200,
23];
24
25const HAND_GEN: [usize; 5] = [3, 9, 18, 27, 125];
26
27#[derive(PartialEq, Debug)]
28enum FFTType {
29 PowerOfTwo,
30 Prime,
31 Coprime,
32 Mixed,
33}
34
35impl FFTType {
36 fn compute_type(n: usize) -> FFTType {
37 if n.is_power_of_two() {
38 FFTType::PowerOfTwo
39 } else {
40 let sqr = (n as f64).sqrt().round() as usize;
42 if sqr * sqr == n && sqr % 2 == 1 {
43 FFTType::Mixed
44 } else {
45 let (v1, _v2) = compute_coprimes(n);
46 if v1 == 1 {
47 FFTType::Prime
48 } else {
49 FFTType::Coprime
50 }
51 }
52 }
53 }
54}
55
56fn compute_coprimes(n: usize) -> (usize, usize) {
57 let sqr = (n as f32).sqrt().ceil() as usize;
58 for v1 in (1..sqr).rev() {
59 let v2 = n / v1;
60 if v2 * v1 == n {
61 if v1.gcd(&v2) == 1 {
62 return (usize::min(v1, v2), usize::max(v1, v2));
63 }
64 }
65 }
66 todo!()
67}
68
69fn compute_twiddle_forward<T: Float + FloatConst>(index: usize, fft_len: usize) -> Complex<T> {
70 let constant = T::from(-2.0).unwrap() * T::PI() / T::from(fft_len).unwrap();
71 let angle = constant * T::from(index).unwrap();
73
74 Complex::new(angle.cos(), angle.sin())
75}
76
77#[proc_macro]
78pub fn generate_switch(_input: TokenStream) -> TokenStream {
79 let mut all_sizes: Vec<_> = SIZES
80 .clone()
81 .into_iter()
82 .chain(HAND_GEN.clone().into_iter())
83 .collect();
84 all_sizes.sort();
85
86 let ss_forward = all_sizes.clone().into_iter().map(|s| {
87 let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
88
89 quote! {
90 #s => {
91 let x = #func(x_in);
92 core::array::from_fn(|i| x[i])
93 },
94 }
95 });
96 let ss_inverse = all_sizes.into_iter().map(|s| {
97 let func = Ident::new(&format!("ifft{}", s), Span::call_site().into());
98
99 quote! {
100 #s => {
101 let x = #func(x_in);
102 core::array::from_fn(|i| x[i])
103 },
104 }
105 });
106
107 let expanded = quote! {
108 #[inline(always)]
118 pub fn fft<const N: usize, T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; N] {
119 let x_in = input.as_ref();
120 assert_eq!(x_in.len(), N);
121
122 match N {
123 1 => { core::array::from_fn(|i| x_in[i]) },
124 #(#ss_forward)*
125 _ => unimplemented!(),
126 }
127 }
128
129 #[inline(always)]
139 pub fn ifft<const N: usize, T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; N] {
140 let x_in = input.as_ref();
141 assert_eq!(x_in.len(), N);
142
143 match N {
144 1 => { core::array::from_fn(|i| x_in[i]) },
145 #(#ss_inverse)*
146 _ => unimplemented!(),
147 }
148 }
149 };
150 proc_macro::TokenStream::from(expanded)
151}
152
153#[proc_macro]
154pub fn generate_powers_of_two(_input: TokenStream) -> TokenStream {
155 let sizes = SIZES
156 .clone()
157 .into_iter()
158 .filter(|n| FFTType::compute_type(*n) == FFTType::PowerOfTwo);
159 let ss = sizes.map(|s| {
160 let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
161 let half = s / 2;
162 let half_butterfly = Ident::new(&format!("fft{}", half), Span::call_site().into());
163 let half_butterfly_even_idx = (0..s).step_by(2).map(|f|{
164 quote! {
165 x[#f],
166 }
167 });
168 let half_butterfly_odd_idx = (1..s).step_by(2).map(|f|{
169 quote! {
170 x[#f],
171 }
172 });
173
174 let t_s = (0..half).map(|tt|
175 quote! {
176 Complex::exp(Complex::<T>::i() * T::from(-2.0).unwrap() * T::PI() * T::from(#tt).unwrap() / T::from(n).unwrap()) * odd[#tt]
177 }
178 );
179
180 let sum_halves = (0..half).map(|t_e| quote! {
181 even[#t_e] + t[#t_e],
182 }
183 );
184 let sub_halves = (0..half).map(|t_o| quote! {
185 even[#t_o] - t[#t_o],
186 });
187
188 quote! {
189 #[doc = concat!("Inner FFT")]
190 #[inline(always)]
191 pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
192 let n = #s;
193 let x = input.as_ref();
194 assert_eq!(n, x.len());
195
196 let even: [Complex<T>; #half] = #half_butterfly([
197 #(#half_butterfly_even_idx)*
198 ]);
199 let odd: [Complex<T>; #half] = #half_butterfly([
200 #(#half_butterfly_odd_idx)*
201 ]);
202
203 let t: [Complex<T>; #half] = [
204 #(#t_s),*
205 ];
206
207 [
208 #(#sum_halves)*
209 #(#sub_halves)*
210 ]
211 }
212 }
213 });
214
215 let expanded = quote! {
216
217 #[inline(always)]
218 pub fn fft1<T: Float, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; 1] {
219 let n = 1;
220 let x = input.as_ref();
221 assert_eq!(n, x.len());
222
223 [x[0]]
224 }
225
226 #(#ss)*
227 };
228 proc_macro::TokenStream::from(expanded)
229}
230
231#[proc_macro]
232pub fn generate_coprimes(_input: TokenStream) -> TokenStream {
233 let sizes = SIZES
234 .clone()
235 .into_iter()
236 .filter(|n| FFTType::compute_type(*n) == FFTType::Coprime);
237 let ss = sizes.map(|s| {
238 let (c1, c2) = compute_coprimes(s);
239 let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
240 let func1 = Ident::new(&format!("fft{}", c1), Span::call_site().into());
241 let func2 = Ident::new(&format!("fft{}", c2), Span::call_site().into());
242
243 let rows = (0..c2).map(|i| {
244 let mut start = c1 * i;
245 let idx = (0..c1).map(|_| {
246 let index = start;
247 start = (start + c2) % s;
248 quote! {
249 x[#index],
250 }}
251 );
252 let row_call = Ident::new(&format!("row{}", i), Span::call_site().into());
253
254 quote! {
255 let #row_call = #func1([ #(#idx)* ]);
256 }});
257
258 let cols = (0..c1).map(|i| {
259 let idx = (0..c2).map(|ii| {
260 let row_call = Ident::new(&format!("row{}", ii), Span::call_site().into());
261 quote! {
262 #row_call[#i]
263 }
264 });
265
266 let col_call = Ident::new(&format!("col{}", i), Span::call_site().into());
267
268 quote! {
269 let #col_call = #func2([ #(#idx),*]);
270 }
271 });
272
273 let combine = (0..s).map(|i| {
274 let col = i % c1;
275 let idx = i % c2;
276 let f = Ident::new(&format!("col{}", col), Span::call_site().into());
277 quote! {
278 #f[#idx],
279 }
280 });
281
282 quote! {
283 #[doc = concat!("Inner FFT")]
284 #[inline(always)]
285 pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
286 let n = #s;
287 let x = input.as_ref();
288 assert_eq!(n, x.len());
289
290 #(#rows)*
291 #(#cols)*
292
293
294 [#(#combine)*]
295
296 }
297 }
298 });
299
300 let expanded = quote! {
301 #(#ss)*
302 };
303 proc_macro::TokenStream::from(expanded)
304}
305
306#[proc_macro]
307pub fn generate_mixed_radix(_input: TokenStream) -> TokenStream {
308 let sizes = SIZES
309 .clone()
310 .into_iter()
311 .filter(|n| FFTType::compute_type(*n) == FFTType::Mixed);
312 let ss = sizes.map(|s| {
313 let c1 = (s as f64).sqrt().round() as usize;
314 let c2 = c1;
315 let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
316 let func1 = Ident::new(&format!("fft{}", c1), Span::call_site().into());
317 let func2 = Ident::new(&format!("fft{}", c2), Span::call_site().into());
318
319 let rows = (0..c1).map(|i| {
320 let idx = (i..s).step_by(c1).map(|xx| {
321 let index = xx % s;
322 quote! {
323 x[#index],
324 }}
325 );
326 let row_call = Ident::new(&format!("row{}", i), Span::call_site().into());
327
328 quote! {
329 let #row_call = #func2([ #(#idx)* ]);
330 }});
331
332 let mut twiddles = vec![Complex::<f64>::new(0.0, 0.0); s];
333 for (x, twiddle_chunk) in twiddles.chunks_exact_mut(c2).enumerate() {
334 for (y, twiddle_element) in twiddle_chunk.iter_mut().enumerate() {
335 *twiddle_element = compute_twiddle_forward(x * y, s);
336 }
337 }
338
339 let cols = (0..c2).map(|i| {
340 let mut start_idx = i;
341 let idx = (0..c1).map(|ii| {
342 let row_call = Ident::new(&format!("row{}", ii), Span::call_site().into());
343 let re = twiddles[start_idx].re;
344 let im = twiddles[start_idx].im;
345 start_idx += c2;
346 quote! {
347 #row_call[#i] * Complex::new(T::from(#re).unwrap(), T::from(#im).unwrap())
348 }
349 });
350
351 let col_call = Ident::new(&format!("col{}", i), Span::call_site().into());
352
353 quote! {
354 let #col_call = #func1([ #(#idx),*]);
355 }
356 });
357
358 let combine = (0..s).map(|i| {
359 let col = i % c2;
360 let idx = i / c2;
361 let f = Ident::new(&format!("col{}", col), Span::call_site().into());
362 quote! {
363 #f[#idx],
364 }
365 });
366
367 quote! {
368 #[doc = concat!("Inner FFT")]
369 #[inline(always)]
370 pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
371 let n = #s;
372 let x = input.as_ref();
373 assert_eq!(n, x.len());
374
375 #(#rows)*
376 #(#cols)*
377
378
379 [#(#combine)*]
380
381 }
382 }
383 });
384
385 let expanded = quote! {
386 #(#ss)*
387 };
388 proc_macro::TokenStream::from(expanded)
389}
390
391#[proc_macro]
392pub fn generate_primes(_input: TokenStream) -> TokenStream {
393 let sizes = SIZES
394 .clone()
395 .into_iter()
396 .filter(|n| FFTType::compute_type(*n) == FFTType::Prime);
397 let ss = sizes.map(|s| {
398 let func = Ident::new(&format!("fft{}", s), Span::call_site().into());
399 let halflen = (s + 1) / 2;
400 let twiddles = (1..halflen).map(|n| {
401 let var = Ident::new(&format!("twiddle{}", n), Span::call_site().into());
402 let val: Complex<f64> = compute_twiddle_forward(n, s);
403 let re = val.re;
404 let im = val.im;
405 quote! {
406 let #var = Complex::new(T::from(#re).unwrap(), T::from(#im).unwrap());
407 }
408 });
409 let first_codegen = (1..halflen).map(|n| {
410 let var1 = Ident::new(&format!("x{}{}p", n, s - n), Span::call_site().into());
411 let var2 = Ident::new(&format!("x{}{}n", n, s - n), Span::call_site().into());
412 quote! {
413 let #var1 = x[#n] + x[#s - #n];
414 let #var2 = x[#n] - x[#s - #n];
415 }
416 });
417 let second_codegen = (1..halflen).map(|n| {
418 let var = Ident::new(&format!("x{}{}p", n, s - n), Span::call_site().into());
419 quote! {
420 + #var
421 }
422 });
423 let third_codegen = (1..halflen).map(|n| {
424 let var1 = Ident::new(&format!("b{}{}re_a", n, s - n), Span::call_site().into());
425 let sub1 = (1..halflen).map(|m| {
426
427 let mut mn = (m * n) % s;
428 if mn > s / 2 {
429 mn = s - mn;
430 }
431 let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
432 let var3 = Ident::new(&format!("x{}{}p", m, s - m), Span::call_site().into());
433 quote! {
434 + #var2.re * #var3.re
435 }
436 });
437
438 quote! {
439 let #var1 = x[0].re #(#sub1)* ;
440 }
441 });
442
443 let fourth_codegen = (1..halflen).map(|n| {
444 let var1 = Ident::new(&format!("b{}{}re_b", n, s - n), Span::call_site().into());
445 let sub1 = (1..halflen).map(|m| {
446
447 let mut mn = (m * n) % s;
448 if mn > s / 2 {
449 mn = s - mn;
450 let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
451 let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
452 quote! {
453 - #var2.im * #var3.im
454 }
455 } else {
456 let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
457 let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
458 quote! {
459 + #var2.im * #var3.im
460 }
461 }
462 });
463
464 quote! {
465 let #var1 = T::zero() #(#sub1)* ;
466 }
467 });
468 let fifth_codegen = (1..halflen).map(|n| {
469 let var1 = Ident::new(&format!("b{}{}im_a", n, s - n), Span::call_site().into());
470 let sub1 = (1..halflen).map(|m| {
471
472 let mut mn = (m * n) % s;
473 if mn > s / 2 {
474 mn = s - mn;
475 }
476 let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
477 let var3 = Ident::new(&format!("x{}{}p", m, s - m), Span::call_site().into());
478 quote! {
479 + #var2.re * #var3.im
480 }
481 });
482
483 quote! {
484 let #var1 = x[0].im #(#sub1)* ;
485 }
486 });
487
488 let sixth_codegen = (1..halflen).map(|n| {
489 let var1 = Ident::new(&format!("b{}{}im_b", n, s - n), Span::call_site().into());
490 let sub1 = (1..halflen).map(|m| {
491
492 let mut mn = (m * n) % s;
493 if mn > s / 2 {
494 mn = s - mn;
495 let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
496 let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
497 quote! {
498 - #var2.im * #var3.re
499 }
500 } else {
501 let var2 = Ident::new(&format!("twiddle{}", mn), Span::call_site().into());
502 let var3 = Ident::new(&format!("x{}{}n", m, s - m), Span::call_site().into());
503 quote! {
504 + #var2.im * #var3.re
505 }
506 }
507 });
508
509 quote! {
510 let #var1 = T::zero() #(#sub1)* ;
511 }
512 });
513
514 let seventh_codegen = (1..s).map(|n| {
515 let mut nfold = n;
516 if n > s / 2 {
517 nfold = s - n;
518 let var1 = Ident::new(&format!("out{}re", n), Span::call_site().into());
519 let var2 = Ident::new(&format!("out{}im", n), Span::call_site().into());
520 let var3 = Ident::new(&format!("b{}{}re_a", nfold, s-nfold), Span::call_site().into());
521 let var4 = Ident::new(&format!("b{}{}re_b", nfold, s-nfold), Span::call_site().into());
522 let var5 = Ident::new(&format!("b{}{}im_a", nfold, s-nfold), Span::call_site().into());
523 let var6 = Ident::new(&format!("b{}{}im_b", nfold, s-nfold), Span::call_site().into());
524 quote! {
525 let #var1 = #var3 + #var4;
526 let #var2 = #var5 - #var6;
527 }
528 } else {
529 let var1 = Ident::new(&format!("out{}re", n), Span::call_site().into());
530 let var2 = Ident::new(&format!("out{}im", n), Span::call_site().into());
531 let var3 = Ident::new(&format!("b{}{}re_a", nfold, s-nfold), Span::call_site().into());
532 let var4 = Ident::new(&format!("b{}{}re_b", nfold, s-nfold), Span::call_site().into());
533 let var5 = Ident::new(&format!("b{}{}im_a", nfold, s-nfold), Span::call_site().into());
534 let var6 = Ident::new(&format!("b{}{}im_b", nfold, s-nfold), Span::call_site().into());
535 quote!{
536 let #var1 = #var3 - #var4;
537 let #var2 = #var5 + #var6;
538 }
539 }
540 });
541 let eigth_codegen = (1..s).map(|n| {
542 let var_re: Ident = Ident::new(&format!("out{}re", n), Span::call_site().into());
543 let var_im: Ident = Ident::new(&format!("out{}im", n), Span::call_site().into());
544 quote! {
545 Complex::new(#var_re, #var_im),
546 }
547 });
548
549 quote! {
550 #[doc = concat!("Inner FFT")]
551 #[inline(always)]
552 pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #s] {
553 let n = #s;
554 let x = input.as_ref();
555 assert_eq!(n, x.len());
556
557 #(#twiddles)*
558
559 #(#first_codegen)*
560 let sum = x[0] #(#second_codegen)* ;
561 #(#third_codegen)*
562 #(#fourth_codegen)*
563 #(#fifth_codegen)*
564 #(#sixth_codegen)*
565 #(#seventh_codegen)*
566
567 [
568 sum,
569 #(#eigth_codegen)*
570 ]
571 }
572
573 }});
574
575 let expanded = quote! {
576 #(#ss)*
577 };
578 proc_macro::TokenStream::from(expanded)
579}
580
581#[proc_macro]
582pub fn generate_iffts(_input: TokenStream) -> TokenStream {
583 let mut all_sizes: Vec<_> = SIZES
584 .clone()
585 .into_iter()
586 .chain(HAND_GEN.clone().into_iter())
587 .collect();
588 all_sizes.sort();
589 let iffts = all_sizes.into_iter().map(|n| {
590 let func = Ident::new(&format!("ifft{}", n), Span::call_site().into());
591 let input_args = (0..n).map(|i| {
592
593 quote! {
594 x[#i].conj(),
595 }
596 });
597 let output_args = (0..n).map(|i| {
598
599 quote! {
600 out[#i].conj(),
601 }
602 });
603
604 quote! {
605 #[doc = concat!("Inner iFFT")]
606 #[inline(always)]
607 pub fn #func<T: Float + FloatConst, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; #n] {
608 let x = input.as_ref();
609 assert_eq!(x.len(), #n);
610
611 let out: [Complex<T>; #n] = fft::<#n, _, _>([
612 #(#input_args)*
613 ]);
614 [
615 #(#output_args)*
616 ]
617
618 }
619 }
620 });
621
622 let expanded = quote! {
623 #[doc = concat!("Inner iFFT")]
624 #[inline(always)]
625 pub fn ifft1<T: Float, A: AsRef<[Complex<T>]>>(input: A) -> [Complex<T>; 1] {
626 let n = 1;
627 let x = input.as_ref();
628 assert_eq!(n, x.len());
629
630 [x[0]]
631 }
632
633 #(#iffts)*
634 };
635 proc_macro::TokenStream::from(expanded)
636}
637
638#[cfg(test)]
639mod tests {
640 use crate::*;
641
642 #[test]
643 fn test_coprimes() {
644 let coprimes = vec![
645 (2, 3),
646 (2, 5),
647 (3, 4),
648 (2, 7),
649 (3, 5),
650 (4, 5),
651 (3, 7),
652 (2, 11),
653 (3, 8),
654 (2, 13),
655 (4, 7),
656 (5, 6),
657 (3, 11),
658 (2, 17),
659 (5, 7),
660 (2, 19),
661 (3, 13),
662 (5, 8),
663 (6, 7),
664 (4, 11),
665 (5, 9),
666 (2, 23),
667 (2, 25),
668 (3, 17),
669 (4, 13),
670 (2, 27),
671 (5, 11),
672 (7, 8),
673 (3, 19),
674 (2, 29),
675 (5, 12),
676 (2, 31),
677 (7, 9),
678 (5, 13),
679 (6, 11),
680 (4, 17),
681 (3, 23),
682 (7, 10),
683 (8, 9),
684 (2, 37),
685 (3, 25),
686 (4, 19),
687 (7, 11),
688 (6, 13),
689 (5, 16),
690 (2, 41),
691 (7, 12),
692 (5, 17),
693 (2, 43),
694 (3, 29),
695 ];
696 for (v1, v2) in coprimes {
697 let n = v1 * v2;
698 let (computed_v1, computed_v2) = compute_coprimes(n);
699 dbg!(n, v1, v2, computed_v1, computed_v2);
700 assert_eq!(v1, computed_v1);
701 assert_eq!(v2, computed_v2);
702 }
703 }
704
705 #[test]
706 fn test_fft_type() {
707 assert_eq!(FFTType::compute_type(25), FFTType::Mixed);
708 assert_eq!(FFTType::compute_type(36), FFTType::Coprime);
709 }
710}