use super::{Complex, Float};
use crate::prelude::*;
pub struct TwiddleCache<T: Float> {
cache: HashMap<(usize, usize), Vec<Complex<T>>>,
}
impl<T: Float> Default for TwiddleCache<T> {
fn default() -> Self {
Self::new()
}
}
impl<T: Float> TwiddleCache<T> {
#[must_use]
pub fn new() -> Self {
Self {
cache: HashMap::new(),
}
}
pub fn get(&mut self, n: usize, k: usize) -> &[Complex<T>] {
self.cache
.entry((n, k))
.or_insert_with(|| compute_twiddles(n, k))
}
pub fn clear(&mut self) {
self.cache.clear();
}
}
#[must_use]
pub fn compute_twiddles<T: Float>(n: usize, k: usize) -> Vec<Complex<T>> {
let mut result = Vec::with_capacity(k);
let theta_base = -T::TWO_PI / T::from_usize(n);
for j in 0..k {
let theta = theta_base * T::from_usize(j);
result.push(Complex::cis(theta));
}
result
}
#[allow(dead_code)]
#[inline]
#[must_use]
pub fn twiddle<T: Float>(n: usize, k: usize) -> Complex<T> {
let theta = -T::TWO_PI * T::from_usize(k) / T::from_usize(n);
Complex::cis(theta)
}
#[allow(dead_code)]
#[inline]
#[must_use]
pub fn twiddle_inverse<T: Float>(n: usize, k: usize) -> Complex<T> {
let theta = T::TWO_PI * T::from_usize(k) / T::from_usize(n);
Complex::cis(theta)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_twiddle_w4() {
let w0: Complex<f64> = twiddle(4, 0);
assert!((w0.re - 1.0).abs() < 1e-10);
assert!(w0.im.abs() < 1e-10);
let w1: Complex<f64> = twiddle(4, 1);
assert!(w1.re.abs() < 1e-10);
assert!((w1.im - (-1.0)).abs() < 1e-10);
let w2: Complex<f64> = twiddle(4, 2);
assert!((w2.re - (-1.0)).abs() < 1e-10);
assert!(w2.im.abs() < 1e-10);
}
#[test]
fn test_compute_twiddles() {
let tw: Vec<Complex<f64>> = compute_twiddles(8, 4);
assert_eq!(tw.len(), 4);
assert!((tw[0].re - 1.0).abs() < 1e-10);
assert!(tw[0].im.abs() < 1e-10);
}
}