use super::fpr::Fpr;
use alloc::vec;
use alloc::vec::Vec;
#[derive(Clone, Copy)]
pub(crate) struct Cplx {
pub(crate) re: Fpr,
pub(crate) im: Fpr,
}
impl Cplx {
#[inline]
pub(crate) const fn new(re: Fpr, im: Fpr) -> Cplx {
Cplx { re, im }
}
#[inline]
pub(crate) fn zero() -> Cplx {
Cplx::new(Fpr::from_f64(0.0), Fpr::from_f64(0.0))
}
#[inline]
pub(crate) fn add(self, o: Cplx) -> Cplx {
Cplx::new(self.re.add(o.re), self.im.add(o.im))
}
#[inline]
pub(crate) fn sub(self, o: Cplx) -> Cplx {
Cplx::new(self.re.sub(o.re), self.im.sub(o.im))
}
#[inline]
pub(crate) fn mul(self, o: Cplx) -> Cplx {
Cplx::new(
self.re.mul(o.re).sub(self.im.mul(o.im)),
self.re.mul(o.im).add(self.im.mul(o.re)),
)
}
#[inline]
pub(crate) fn conj(self) -> Cplx {
Cplx::new(self.re, self.im.neg())
}
#[inline]
pub(crate) fn scale(self, s: Fpr) -> Cplx {
Cplx::new(self.re.mul(s), self.im.mul(s))
}
#[inline]
pub(crate) fn div(self, o: Cplx) -> Cplx {
let d = o.re.mul(o.re).add(o.im.mul(o.im));
Cplx::new(
self.re.mul(o.re).add(self.im.mul(o.im)).div(d),
self.im.mul(o.re).sub(self.re.mul(o.im)).div(d),
)
}
fn unit_sqrt(self) -> Cplx {
let one = Fpr::from_f64(1.0);
let re = one.add(self.re).half().sqrt();
let mut im = one.sub(self.re).half().sqrt();
if self.im.lt(Fpr::from_f64(0.0)) {
im = im.neg();
}
Cplx::new(re, im)
}
}
pub(crate) struct Fft {
pub(crate) n: usize,
rho: Vec<Vec<Cplx>>,
}
impl Fft {
pub(crate) fn new(n: usize) -> Fft {
debug_assert!(n.is_power_of_two() && n >= 2);
let mut eta: Vec<Cplx> = vec![Cplx::new(Fpr::from_f64(-1.0), Fpr::from_f64(0.0))];
let mut rho: Vec<Vec<Cplx>> = vec![Vec::new()]; let mut m = 2;
while m <= n {
let half = m / 2;
let mut rho_m = Vec::with_capacity(half);
let mut next_eta = vec![Cplx::zero(); m];
for i in 0..half {
let r = eta[i].unit_sqrt();
rho_m.push(r);
next_eta[2 * i] = r;
next_eta[2 * i + 1] = r.neg_c();
}
rho.push(rho_m);
eta = next_eta;
m *= 2;
}
Fft { n, rho }
}
pub(crate) fn fft(&self, f: &[Fpr]) -> Vec<Cplx> {
debug_assert_eq!(f.len(), self.n);
let cplx: Vec<Cplx> = f
.iter()
.map(|&c| Cplx::new(c, Fpr::from_f64(0.0)))
.collect();
self.fft_rec(&cplx)
}
fn fft_rec(&self, f: &[Cplx]) -> Vec<Cplx> {
let m = f.len();
if m == 2 {
let f0 = f[0];
let f1 = f[1];
return vec![
Cplx::new(f0.re.sub(f1.im), f0.im.add(f1.re)),
Cplx::new(f0.re.add(f1.im), f0.im.sub(f1.re)),
];
}
let half = m / 2;
let mut f0 = Vec::with_capacity(half);
let mut f1 = Vec::with_capacity(half);
for i in 0..half {
f0.push(f[2 * i]);
f1.push(f[2 * i + 1]);
}
let f0h = self.fft_rec(&f0);
let f1h = self.fft_rec(&f1);
self.merge_fft(&f0h, &f1h)
}
pub(crate) fn ifft(&self, fh: &[Cplx]) -> Vec<Fpr> {
debug_assert_eq!(fh.len(), self.n);
let c = self.ifft_rec(fh);
c.iter().map(|z| z.re).collect()
}
fn ifft_rec(&self, fh: &[Cplx]) -> Vec<Cplx> {
let m = fh.len();
if m == 2 {
return vec![
Cplx::new(fh[0].re, Fpr::from_f64(0.0)),
Cplx::new(fh[0].im, Fpr::from_f64(0.0)),
];
}
let (f0h, f1h) = self.split_fft(fh);
let f0 = self.ifft_rec(&f0h);
let f1 = self.ifft_rec(&f1h);
let mut out = vec![Cplx::zero(); m];
for i in 0..m / 2 {
out[2 * i] = f0[i];
out[2 * i + 1] = f1[i];
}
out
}
pub(crate) fn merge_fft(&self, f0h: &[Cplx], f1h: &[Cplx]) -> Vec<Cplx> {
let half = f0h.len();
let m = 2 * half;
let level = m.trailing_zeros() as usize;
let rho = &self.rho[level];
let mut out = vec![Cplx::zero(); m];
for i in 0..half {
let t = rho[i].mul(f1h[i]);
out[2 * i] = f0h[i].add(t);
out[2 * i + 1] = f0h[i].sub(t);
}
out
}
pub(crate) fn split_fft(&self, fh: &[Cplx]) -> (Vec<Cplx>, Vec<Cplx>) {
let m = fh.len();
let half = m / 2;
let level = m.trailing_zeros() as usize;
let rho = &self.rho[level];
let mut f0 = Vec::with_capacity(half);
let mut f1 = Vec::with_capacity(half);
for i in 0..half {
let a = fh[2 * i];
let b = fh[2 * i + 1];
f0.push(a.add(b).scale(Fpr::from_f64(0.5)));
f1.push(a.sub(b).scale(Fpr::from_f64(0.5)).mul(rho[i].conj()));
}
(f0, f1)
}
}
impl Cplx {
#[inline]
fn neg_c(self) -> Cplx {
Cplx::new(self.re.neg(), self.im.neg())
}
}
pub(crate) fn mul_fft(a: &[Cplx], b: &[Cplx]) -> Vec<Cplx> {
a.iter().zip(b).map(|(&x, &y)| x.mul(y)).collect()
}
pub(crate) fn add_fft(a: &[Cplx], b: &[Cplx]) -> Vec<Cplx> {
a.iter().zip(b).map(|(&x, &y)| x.add(y)).collect()
}
pub(crate) fn sub_fft(a: &[Cplx], b: &[Cplx]) -> Vec<Cplx> {
a.iter().zip(b).map(|(&x, &y)| x.sub(y)).collect()
}
pub(crate) fn adj_fft(a: &[Cplx]) -> Vec<Cplx> {
a.iter().map(|&x| x.conj()).collect()
}
pub(crate) fn div_fft(a: &[Cplx], b: &[Cplx]) -> Vec<Cplx> {
a.iter().zip(b).map(|(&x, &y)| x.div(y)).collect()
}
#[cfg(test)]
#[path = "fft_tests.rs"]
mod fft_tests;