use crate::dft::problem::Sign;
use crate::kernel::{is_prime, primitive_root};
use crate::kernel::{Complex, Float};
use crate::prelude::*;
use super::bluestein::BluesteinSolver;
pub struct RaderSolver<T: Float> {
p: usize,
g: usize,
g_powers: Vec<usize>,
g_inv_powers: Vec<usize>,
twiddle_fft_fwd: Vec<Complex<T>>,
twiddle_fft_bwd: Vec<Complex<T>>,
conv_solver: BluesteinSolver<T>,
#[cfg(feature = "std")]
work_a: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_a_fft: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_conv: Mutex<Vec<Complex<T>>>,
}
impl<T: Float> RaderSolver<T> {
#[must_use]
pub fn new(p: usize) -> Option<Self> {
if p < 3 || !is_prime(p) {
return None;
}
let g = primitive_root(p)?;
let n = p - 1;
let mut g_powers = Vec::with_capacity(n);
let mut g_inv_powers = Vec::with_capacity(n);
let mut power = 1usize;
for _ in 0..n {
g_powers.push(power);
power = (power * g) % p;
}
for k in 0..n {
g_inv_powers.push(g_powers[(n - k) % n]);
}
let conv_solver = BluesteinSolver::new(n);
let mut twiddles_fwd = Vec::with_capacity(n);
for k in 0..n {
let exp = g_powers[k];
let angle = -<T as Float>::TWO_PI * T::from_usize(exp) / T::from_usize(p);
twiddles_fwd.push(Complex::cis(angle));
}
let mut twiddle_fft_fwd = vec![Complex::zero(); n];
conv_solver.execute(&twiddles_fwd, &mut twiddle_fft_fwd, Sign::Forward);
let mut twiddles_bwd = Vec::with_capacity(n);
for k in 0..n {
let exp = g_powers[k];
let angle = <T as Float>::TWO_PI * T::from_usize(exp) / T::from_usize(p);
twiddles_bwd.push(Complex::cis(angle));
}
let mut twiddle_fft_bwd = vec![Complex::zero(); n];
conv_solver.execute(&twiddles_bwd, &mut twiddle_fft_bwd, Sign::Forward);
Some(Self {
p,
g,
g_powers,
g_inv_powers,
twiddle_fft_fwd,
twiddle_fft_bwd,
conv_solver,
#[cfg(feature = "std")]
work_a: Mutex::new(vec![Complex::zero(); n]),
#[cfg(feature = "std")]
work_a_fft: Mutex::new(vec![Complex::zero(); n]),
#[cfg(feature = "std")]
work_conv: Mutex::new(vec![Complex::zero(); n]),
})
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-rader"
}
#[must_use]
pub fn size(&self) -> usize {
self.p
}
#[must_use]
#[allow(dead_code)]
pub fn primitive_root(&self) -> usize {
self.g
}
#[must_use]
pub fn applicable(p: usize) -> bool {
p >= 3 && is_prime(p)
}
fn execute_with_buffers(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
sign: Sign,
a: &mut [Complex<T>],
a_fft: &mut [Complex<T>],
conv: &mut [Complex<T>],
) {
let p = self.p;
let n = p - 1;
let mut sum = Complex::zero();
for x in input {
sum = sum + *x;
}
for j in 0..n {
a[j] = input[self.g_inv_powers[j]];
}
self.conv_solver.execute(a, a_fft, Sign::Forward);
let twiddle_fft = match sign {
Sign::Forward => &self.twiddle_fft_fwd,
Sign::Backward => &self.twiddle_fft_bwd,
};
for i in 0..n {
a_fft[i] = a_fft[i] * twiddle_fft[i];
}
self.conv_solver.execute(a_fft, conv, Sign::Backward);
let n_inv = T::ONE / T::from_usize(n);
for x in conv.iter_mut().take(n) {
*x = *x * n_inv;
}
output[0] = sum;
for k in 0..n {
let idx = self.g_powers[k];
output[idx] = input[0] + conv[k];
}
}
#[cfg(feature = "std")]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let p = self.p;
let n = p - 1;
debug_assert_eq!(input.len(), p);
debug_assert_eq!(output.len(), p);
let a_guard = self.work_a.try_lock();
let a_fft_guard = self.work_a_fft.try_lock();
let conv_guard = self.work_conv.try_lock();
if let (Ok(mut a), Ok(mut a_fft), Ok(mut conv)) = (a_guard, a_fft_guard, conv_guard) {
self.execute_with_buffers(input, output, sign, &mut a, &mut a_fft, &mut conv);
} else {
let mut a = vec![Complex::zero(); n];
let mut a_fft = vec![Complex::zero(); n];
let mut conv = vec![Complex::zero(); n];
self.execute_with_buffers(input, output, sign, &mut a, &mut a_fft, &mut conv);
}
}
#[cfg(not(feature = "std"))]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let p = self.p;
let n = p - 1;
debug_assert_eq!(input.len(), p);
debug_assert_eq!(output.len(), p);
let mut a = vec![Complex::zero(); n];
let mut a_fft = vec![Complex::zero(); n];
let mut conv = vec![Complex::zero(); n];
self.execute_with_buffers(input, output, sign, &mut a, &mut a_fft, &mut conv);
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let p = self.p;
debug_assert_eq!(data.len(), p);
let input: Vec<Complex<T>> = data.to_vec();
self.execute(&input, data, sign);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dft::solvers::direct::DirectSolver;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_rader_applicable() {
assert!(!RaderSolver::<f64>::applicable(0));
assert!(!RaderSolver::<f64>::applicable(1));
assert!(!RaderSolver::<f64>::applicable(2));
assert!(RaderSolver::<f64>::applicable(3));
assert!(!RaderSolver::<f64>::applicable(4));
assert!(RaderSolver::<f64>::applicable(5));
assert!(!RaderSolver::<f64>::applicable(6));
assert!(RaderSolver::<f64>::applicable(7));
assert!(RaderSolver::<f64>::applicable(11));
assert!(RaderSolver::<f64>::applicable(13));
}
#[test]
fn test_rader_size_3() {
let input: Vec<Complex<f64>> = (0..3).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut output_rader = vec![Complex::zero(); 3];
let mut output_direct = vec![Complex::zero(); 3];
RaderSolver::new(3)
.unwrap()
.execute(&input, &mut output_rader, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_size_5() {
let input: Vec<Complex<f64>> = (0..5)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_rader = vec![Complex::zero(); 5];
let mut output_direct = vec![Complex::zero(); 5];
RaderSolver::new(5)
.unwrap()
.execute(&input, &mut output_rader, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_size_7() {
let input: Vec<Complex<f64>> = (0..7)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_rader = vec![Complex::zero(); 7];
let mut output_direct = vec![Complex::zero(); 7];
RaderSolver::new(7)
.unwrap()
.execute(&input, &mut output_rader, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_size_13() {
let input: Vec<Complex<f64>> = (0..13)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_rader = vec![Complex::zero(); 13];
let mut output_direct = vec![Complex::zero(); 13];
RaderSolver::new(13)
.unwrap()
.execute(&input, &mut output_rader, Sign::Forward);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-8));
}
}
#[test]
fn test_rader_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..11)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut transformed = vec![Complex::zero(); 11];
let mut recovered = vec![Complex::zero(); 11];
let solver = RaderSolver::new(11).unwrap();
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n = original.len() as f64;
for x in &mut recovered {
*x = *x / n;
}
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_inplace() {
let original: Vec<Complex<f64>> = (0..7).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 7];
let solver = RaderSolver::new(7).unwrap();
solver.execute(&original, &mut out_of_place, Sign::Forward);
let mut in_place = original;
solver.execute_inplace(&mut in_place, Sign::Forward);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
}