use proc_macro2::TokenStream;
use quote::quote;
pub(super) fn gen_scalar_size_2() -> TokenStream {
quote! {
#[inline(always)]
fn codelet_simd_2_scalar<T: crate::kernel::Float>(
data: &mut [crate::kernel::Complex<T>],
_sign: i32,
) {
let a = data[0];
let b = data[1];
data[0] = a + b;
data[1] = a - b;
}
}
}
pub(super) fn gen_scalar_size_4() -> TokenStream {
quote! {
#[inline(always)]
fn codelet_simd_4_scalar<T: crate::kernel::Float>(
data: &mut [crate::kernel::Complex<T>],
sign: i32,
) {
let x0 = data[0];
let x1 = data[1];
let x2 = data[2];
let x3 = data[3];
let t0 = x0 + x2;
let t1 = x0 - x2;
let t2 = x1 + x3;
let t3 = x1 - x3;
let t3_rot = if sign < 0 {
crate::kernel::Complex::new(t3.im, -t3.re)
} else {
crate::kernel::Complex::new(-t3.im, t3.re)
};
data[0] = t0 + t2;
data[1] = t1 + t3_rot;
data[2] = t0 - t2;
data[3] = t1 - t3_rot;
}
}
}
pub(super) fn gen_scalar_size_8() -> TokenStream {
quote! {
#[inline(always)]
fn codelet_simd_8_scalar<T: crate::kernel::Float>(
data: &mut [crate::kernel::Complex<T>],
sign: i32,
) {
let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
let mut a = [crate::kernel::Complex::<T>::zero(); 8];
a[0] = data[0]; a[1] = data[4];
a[2] = data[2]; a[3] = data[6];
a[4] = data[1]; a[5] = data[5];
a[6] = data[3]; a[7] = data[7];
for i in (0..8usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = a[i] - t;
a[i] = a[i] + t;
}
for group in (0..8usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = a[group] - t;
a[group] = a[group] + t;
let t = a[group + 3];
let t_tw = if sign < 0 {
crate::kernel::Complex::new(t.im, -t.re)
} else {
crate::kernel::Complex::new(-t.im, t.re)
};
a[group + 3] = a[group + 1] - t_tw;
a[group + 1] = a[group + 1] + t_tw;
}
let t = a[4];
a[4] = a[0] - t;
a[0] = a[0] + t;
let t = a[5];
let t_tw = if sign < 0 {
crate::kernel::Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
} else {
crate::kernel::Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
};
a[5] = a[1] - t_tw;
a[1] = a[1] + t_tw;
let t = a[6];
let t_tw = if sign < 0 {
crate::kernel::Complex::new(t.im, -t.re)
} else {
crate::kernel::Complex::new(-t.im, t.re)
};
a[6] = a[2] - t_tw;
a[2] = a[2] + t_tw;
let t = a[7];
let t_tw = if sign < 0 {
crate::kernel::Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
} else {
crate::kernel::Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
};
a[7] = a[3] - t_tw;
a[3] = a[3] + t_tw;
for i in 0..8usize {
data[i] = a[i];
}
}
}
}
#[allow(clippy::too_many_lines)]
pub(super) fn gen_scalar_size_16() -> TokenStream {
quote! {
#[inline(always)]
#[allow(clippy::too_many_lines)]
fn codelet_simd_16_scalar<T: crate::kernel::Float>(
data: &mut [crate::kernel::Complex<T>],
sign: i32,
) {
use crate::kernel::Complex;
let sign_f = if sign < 0 { T::from_f64(-1.0) } else { T::from_f64(1.0) };
let two_pi = T::from_f64(2.0 * core::f64::consts::PI);
let mut a = [Complex::<T>::zero(); 16];
a[0] = data[0]; a[1] = data[8];
a[2] = data[4]; a[3] = data[12];
a[4] = data[2]; a[5] = data[10];
a[6] = data[6]; a[7] = data[14];
a[8] = data[1]; a[9] = data[9];
a[10] = data[5]; a[11] = data[13];
a[12] = data[3]; a[13] = data[11];
a[14] = data[7]; a[15] = data[15];
for i in (0..16usize).step_by(2) {
let t = a[i + 1];
a[i + 1] = a[i] - t;
a[i] = a[i] + t;
}
for group in (0..16usize).step_by(4) {
let t = a[group + 2];
a[group + 2] = a[group] - t;
a[group] = a[group] + t;
let t = a[group + 3];
let t_tw = if sign < 0 {
Complex::new(t.im, -t.re)
} else {
Complex::new(-t.im, t.re)
};
a[group + 3] = a[group + 1] - t_tw;
a[group + 1] = a[group + 1] + t_tw;
}
let c2 = T::from_f64(0.707_106_781_186_547_6_f64);
for group in (0..16usize).step_by(8) {
let t = a[group + 4];
a[group + 4] = a[group] - t;
a[group] = a[group] + t;
let t = a[group + 5];
let t_tw = if sign < 0 {
Complex::new((t.re + t.im) * c2, (t.im - t.re) * c2)
} else {
Complex::new((t.re - t.im) * c2, (t.im + t.re) * c2)
};
a[group + 5] = a[group + 1] - t_tw;
a[group + 1] = a[group + 1] + t_tw;
let t = a[group + 6];
let t_tw = if sign < 0 {
Complex::new(t.im, -t.re)
} else {
Complex::new(-t.im, t.re)
};
a[group + 6] = a[group + 2] - t_tw;
a[group + 2] = a[group + 2] + t_tw;
let t = a[group + 7];
let t_tw = if sign < 0 {
Complex::new((-t.re + t.im) * c2, (-t.im - t.re) * c2)
} else {
Complex::new((-t.re - t.im) * c2, (-t.im + t.re) * c2)
};
a[group + 7] = a[group + 3] - t_tw;
a[group + 3] = a[group + 3] + t_tw;
}
for k in 0..8usize {
let angle = sign_f * two_pi * T::from_f64(k as f64) / T::from_f64(16.0);
let tw = Complex::new(
crate::kernel::Float::cos(angle),
crate::kernel::Float::sin(angle),
);
let t_tw = a[k + 8] * tw;
a[k + 8] = a[k] - t_tw;
a[k] = a[k] + t_tw;
}
for i in 0..16usize {
data[i] = a[i];
}
}
}
}