use proc_macro::TokenStream;
use quote::quote;
use syn::{parse_macro_input, LitInt};
pub fn generate(input: TokenStream) -> TokenStream {
let radix = parse_macro_input!(input as LitInt);
let r: usize = radix.base10_parse().expect("Invalid radix literal");
match r {
2 => gen_twiddle_2(),
4 => gen_twiddle_4(),
8 => gen_twiddle_8(),
_ => panic!("Unsupported twiddle radix: {r}"),
}
}
fn gen_twiddle_2() -> TokenStream {
let expanded = quote! {
#[inline(always)]
pub fn codelet_twiddle_2<T: crate::kernel::Float>(
x: &mut [crate::kernel::Complex<T>],
twiddle: crate::kernel::Complex<T>,
) {
debug_assert!(x.len() >= 2);
let a = x[0];
let b = x[1] * twiddle;
x[0] = a + b;
x[1] = a - b;
}
};
TokenStream::from(expanded)
}
fn gen_twiddle_4() -> TokenStream {
let expanded = quote! {
#[inline(always)]
pub fn codelet_twiddle_4<T: crate::kernel::Float>(
x: &mut [crate::kernel::Complex<T>],
tw1: crate::kernel::Complex<T>,
tw2: crate::kernel::Complex<T>,
tw3: crate::kernel::Complex<T>,
sign: i32,
) {
debug_assert!(x.len() >= 4);
let x0 = x[0];
let x1 = x[1] * tw1;
let x2 = x[2] * tw2;
let x3 = x[3] * tw3;
let t0 = x0 + x2;
let t1 = x0 - x2;
let t2 = x1 + x3;
let t3 = x1 - x3;
let t3_rot = if sign < 0 {
crate::kernel::Complex::new(t3.im, -t3.re)
} else {
crate::kernel::Complex::new(-t3.im, t3.re)
};
x[0] = t0 + t2;
x[1] = t1 + t3_rot;
x[2] = t0 - t2;
x[3] = t1 - t3_rot;
}
};
TokenStream::from(expanded)
}
fn gen_twiddle_8() -> TokenStream {
let expanded = quote! {
#[inline(always)]
pub fn codelet_twiddle_8<T: crate::kernel::Float>(
x: &mut [crate::kernel::Complex<T>],
twiddles: &[crate::kernel::Complex<T>; 7],
sign: i32,
) {
debug_assert!(x.len() >= 8);
let x0 = x[0];
let x1 = x[1] * twiddles[0];
let x2 = x[2] * twiddles[1];
let x3 = x[3] * twiddles[2];
let x4 = x[4] * twiddles[3];
let x5 = x[5] * twiddles[4];
let x6 = x[6] * twiddles[5];
let x7 = x[7] * twiddles[6];
let t0 = x0 + x4; let t1 = x0 - x4;
let t2 = x2 + x6; let t3 = x2 - x6;
let t4 = x1 + x5; let t5 = x1 - x5;
let t6 = x3 + x7; let t7 = x3 - x7;
let t3_rot = if sign < 0 {
crate::kernel::Complex::new(t3.im, -t3.re)
} else {
crate::kernel::Complex::new(-t3.im, t3.re)
};
let t7_rot = if sign < 0 {
crate::kernel::Complex::new(t7.im, -t7.re)
} else {
crate::kernel::Complex::new(-t7.im, t7.re)
};
let u0 = t0 + t2; let u1 = t0 - t2;
let u2 = t4 + t6; let u3 = t4 - t6;
let u4 = t1 + t3_rot; let u5 = t1 - t3_rot;
let u6 = t5 + t7_rot; let u7 = t5 - t7_rot;
let sqrt2_inv = T::ONE / T::TWO.sqrt();
let w8_re = sqrt2_inv;
let w8_im = if sign < 0 { -sqrt2_inv } else { sqrt2_inv };
let u3_rot = if sign < 0 {
crate::kernel::Complex::new(u3.im, -u3.re)
} else {
crate::kernel::Complex::new(-u3.im, u3.re)
};
let u6_tw = crate::kernel::Complex::new(
u6.re * w8_re - u6.im * w8_im,
u6.re * w8_im + u6.im * w8_re,
);
let u7_tw = crate::kernel::Complex::new(
u7.re * (-w8_im) - u7.im * w8_re,
u7.re * w8_re + u7.im * (-w8_im),
);
x[0] = u0 + u2;
x[4] = u0 - u2;
x[2] = u1 + u3_rot;
x[6] = u1 - u3_rot;
x[1] = u4 + u6_tw;
x[5] = u4 - u6_tw;
x[3] = u5 + u7_tw;
x[7] = u5 - u7_tw;
}
#[inline(always)]
pub fn codelet_twiddle_8_inline<T: crate::kernel::Float>(
x: &mut [crate::kernel::Complex<T>],
angle_step: T,
sign: i32,
) {
debug_assert!(x.len() >= 8);
let tw1 = crate::kernel::Complex::new((angle_step).cos(), (angle_step).sin());
let tw2 = crate::kernel::Complex::new((angle_step * T::TWO).cos(), (angle_step * T::TWO).sin());
let tw3 = crate::kernel::Complex::new((angle_step * T::from_usize(3)).cos(), (angle_step * T::from_usize(3)).sin());
let tw4 = crate::kernel::Complex::new((angle_step * T::from_usize(4)).cos(), (angle_step * T::from_usize(4)).sin());
let tw5 = crate::kernel::Complex::new((angle_step * T::from_usize(5)).cos(), (angle_step * T::from_usize(5)).sin());
let tw6 = crate::kernel::Complex::new((angle_step * T::from_usize(6)).cos(), (angle_step * T::from_usize(6)).sin());
let tw7 = crate::kernel::Complex::new((angle_step * T::from_usize(7)).cos(), (angle_step * T::from_usize(7)).sin());
let twiddles = [tw1, tw2, tw3, tw4, tw5, tw6, tw7];
codelet_twiddle_8(x, &twiddles, sign);
}
};
TokenStream::from(expanded)
}