use core::sync::atomic::{AtomicU64, Ordering};
use crate::dft::problem::Sign;
use crate::kernel::complex_mul::complex_mul_aos;
use crate::kernel::{is_prime, primitive_root};
use crate::kernel::{Complex, Float};
use crate::prelude::*;
use super::bluestein::BluesteinSolver;
static RADER_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
pub struct RaderSolver<T: Float> {
p: usize,
g: usize,
g_powers: Vec<usize>,
g_inv_powers: Vec<usize>,
twiddle_fft_fwd: Vec<Complex<T>>,
twiddle_fft_bwd: Vec<Complex<T>>,
conv_solver: BluesteinSolver<T>,
pub(crate) solver_id: u64,
#[cfg(feature = "std")]
work_a: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_a_fft: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_conv: Mutex<Vec<Complex<T>>>,
#[cfg(feature = "std")]
work_inplace: Mutex<Vec<Complex<T>>>,
}
impl<T: Float> RaderSolver<T> {
#[must_use]
pub fn new(p: usize) -> Option<Self> {
if p < 3 || !is_prime(p) {
return None;
}
let solver_id = RADER_ID_COUNTER.fetch_add(1, Ordering::Relaxed);
let g = primitive_root(p)?;
let n = p - 1;
let mut g_powers = Vec::with_capacity(n);
let mut g_inv_powers = Vec::with_capacity(n);
let mut power = 1usize;
for _ in 0..n {
g_powers.push(power);
power = (power * g) % p;
}
for k in 0..n {
g_inv_powers.push(g_powers[(n - k) % n]);
}
let conv_solver = BluesteinSolver::new(n);
let mut twiddles_fwd = Vec::with_capacity(n);
for k in 0..n {
let exp = g_powers[k];
let angle = -<T as Float>::TWO_PI * T::from_usize(exp) / T::from_usize(p);
twiddles_fwd.push(Complex::cis(angle));
}
let mut twiddle_fft_fwd = vec![Complex::zero(); n];
conv_solver.execute(&twiddles_fwd, &mut twiddle_fft_fwd, Sign::Forward);
let mut twiddles_bwd = Vec::with_capacity(n);
for k in 0..n {
let exp = g_powers[k];
let angle = <T as Float>::TWO_PI * T::from_usize(exp) / T::from_usize(p);
twiddles_bwd.push(Complex::cis(angle));
}
let mut twiddle_fft_bwd = vec![Complex::zero(); n];
conv_solver.execute(&twiddles_bwd, &mut twiddle_fft_bwd, Sign::Forward);
Some(Self {
p,
g,
g_powers,
g_inv_powers,
twiddle_fft_fwd,
twiddle_fft_bwd,
conv_solver,
solver_id,
#[cfg(feature = "std")]
work_a: Mutex::new(vec![Complex::zero(); n]),
#[cfg(feature = "std")]
work_a_fft: Mutex::new(vec![Complex::zero(); n]),
#[cfg(feature = "std")]
work_conv: Mutex::new(vec![Complex::zero(); n]),
#[cfg(feature = "std")]
work_inplace: Mutex::new(vec![Complex::zero(); p]),
})
}
#[must_use]
pub fn name(&self) -> &'static str {
"dft-rader"
}
#[must_use]
pub fn size(&self) -> usize {
self.p
}
#[must_use]
pub fn primitive_root(&self) -> usize {
self.g
}
#[must_use]
pub fn id(&self) -> u64 {
self.solver_id
}
#[must_use]
pub fn applicable(p: usize) -> bool {
p >= 3 && is_prime(p)
}
fn execute_with_buffers(
&self,
input: &[Complex<T>],
output: &mut [Complex<T>],
sign: Sign,
a: &mut [Complex<T>],
a_fft: &mut [Complex<T>],
conv: &mut [Complex<T>],
) {
let p = self.p;
let n = p - 1;
let mut sum = Complex::zero();
for x in input {
sum = sum + *x;
}
for j in 0..n {
a[j] = input[self.g_inv_powers[j]];
}
self.conv_solver.execute(a, a_fft, Sign::Forward);
let twiddle_fft = match sign {
Sign::Forward => &self.twiddle_fft_fwd,
Sign::Backward => &self.twiddle_fft_bwd,
};
complex_mul_aos(&mut conv[..n], a_fft, twiddle_fft);
a_fft[..n].copy_from_slice(&conv[..n]);
self.conv_solver.execute(a_fft, conv, Sign::Backward);
let n_inv = T::ONE / T::from_usize(n);
for x in conv.iter_mut().take(n) {
*x = *x * n_inv;
}
output[0] = sum;
for k in 0..n {
let idx = self.g_powers[k];
output[idx] = input[0] + conv[k];
}
}
#[cfg(feature = "std")]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let p = self.p;
let n = p - 1;
debug_assert_eq!(input.len(), p);
debug_assert_eq!(output.len(), p);
let a_guard = self.work_a.try_lock();
let a_fft_guard = self.work_a_fft.try_lock();
let conv_guard = self.work_conv.try_lock();
if let (Ok(mut a), Ok(mut a_fft), Ok(mut conv)) = (a_guard, a_fft_guard, conv_guard) {
self.execute_with_buffers(input, output, sign, &mut a, &mut a_fft, &mut conv);
} else {
let mut a = vec![Complex::zero(); n];
let mut a_fft = vec![Complex::zero(); n];
let mut conv = vec![Complex::zero(); n];
self.execute_with_buffers(input, output, sign, &mut a, &mut a_fft, &mut conv);
}
}
#[cfg(not(feature = "std"))]
pub fn execute(&self, input: &[Complex<T>], output: &mut [Complex<T>], sign: Sign) {
let p = self.p;
let n = p - 1;
debug_assert_eq!(input.len(), p);
debug_assert_eq!(output.len(), p);
let mut a = vec![Complex::zero(); n];
let mut a_fft = vec![Complex::zero(); n];
let mut conv = vec![Complex::zero(); n];
self.execute_with_buffers(input, output, sign, &mut a, &mut a_fft, &mut conv);
}
pub fn execute_inplace(&self, data: &mut [Complex<T>], sign: Sign) {
let p = self.p;
debug_assert_eq!(data.len(), p);
#[cfg(feature = "std")]
{
if let Ok(mut inplace_buf) = self.work_inplace.try_lock() {
if inplace_buf.len() < p {
inplace_buf.resize(p, Complex::zero());
}
inplace_buf[..p].copy_from_slice(data);
let input_ptr = inplace_buf[..p].as_ptr();
let input_slice = unsafe { core::slice::from_raw_parts(input_ptr, p) };
self.execute(input_slice, data, sign);
return;
}
}
let input: Vec<Complex<T>> = data.to_vec();
self.execute(&input, data, sign);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dft::solvers::direct::DirectSolver;
fn approx_eq(a: f64, b: f64, eps: f64) -> bool {
(a - b).abs() < eps
}
fn complex_approx_eq(a: Complex<f64>, b: Complex<f64>, eps: f64) -> bool {
approx_eq(a.re, b.re, eps) && approx_eq(a.im, b.im, eps)
}
#[test]
fn test_rader_applicable() {
assert!(!RaderSolver::<f64>::applicable(0));
assert!(!RaderSolver::<f64>::applicable(1));
assert!(!RaderSolver::<f64>::applicable(2));
assert!(RaderSolver::<f64>::applicable(3));
assert!(!RaderSolver::<f64>::applicable(4));
assert!(RaderSolver::<f64>::applicable(5));
assert!(!RaderSolver::<f64>::applicable(6));
assert!(RaderSolver::<f64>::applicable(7));
assert!(RaderSolver::<f64>::applicable(11));
assert!(RaderSolver::<f64>::applicable(13));
}
#[test]
fn test_rader_size_3() {
let input: Vec<Complex<f64>> = (0..3).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut output_rader = vec![Complex::zero(); 3];
let mut output_direct = vec![Complex::zero(); 3];
RaderSolver::new(3).expect("p=3 is prime").execute(
&input,
&mut output_rader,
Sign::Forward,
);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_size_5() {
let input: Vec<Complex<f64>> = (0..5)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_rader = vec![Complex::zero(); 5];
let mut output_direct = vec![Complex::zero(); 5];
RaderSolver::new(5).expect("p=5 is prime").execute(
&input,
&mut output_rader,
Sign::Forward,
);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_size_7() {
let input: Vec<Complex<f64>> = (0..7)
.map(|i| Complex::new(f64::from(i), f64::from(i) * 0.5))
.collect();
let mut output_rader = vec![Complex::zero(); 7];
let mut output_direct = vec![Complex::zero(); 7];
RaderSolver::new(7).expect("p=7 is prime").execute(
&input,
&mut output_rader,
Sign::Forward,
);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_size_13() {
let input: Vec<Complex<f64>> = (0..13)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut output_rader = vec![Complex::zero(); 13];
let mut output_direct = vec![Complex::zero(); 13];
RaderSolver::new(13).expect("p=13 is prime").execute(
&input,
&mut output_rader,
Sign::Forward,
);
DirectSolver::new().execute(&input, &mut output_direct, Sign::Forward);
for (a, b) in output_rader.iter().zip(output_direct.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-8));
}
}
#[test]
fn test_rader_inverse_recovers_input() {
let original: Vec<Complex<f64>> = (0..11)
.map(|i| Complex::new(f64::from(i).sin(), f64::from(i).cos()))
.collect();
let mut transformed = vec![Complex::zero(); 11];
let mut recovered = vec![Complex::zero(); 11];
let solver = RaderSolver::new(11).expect("p=11 is prime");
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n = original.len() as f64;
for x in &mut recovered {
*x = *x / n;
}
for (a, b) in original.iter().zip(recovered.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-9));
}
}
#[test]
fn test_rader_inplace() {
let original: Vec<Complex<f64>> = (0..7).map(|i| Complex::new(f64::from(i), 0.0)).collect();
let mut out_of_place = vec![Complex::zero(); 7];
let solver = RaderSolver::new(7).expect("p=7 is prime");
solver.execute(&original, &mut out_of_place, Sign::Forward);
let mut in_place = original;
solver.execute_inplace(&mut in_place, Sign::Forward);
for (a, b) in out_of_place.iter().zip(in_place.iter()) {
assert!(complex_approx_eq(*a, *b, 1e-10));
}
}
fn roundtrip_f64(n: usize) {
let original: Vec<Complex<f64>> = (0..n)
.map(|i| Complex::new((i as f64).sin(), (i as f64 * 0.7).cos()))
.collect();
let mut transformed = vec![Complex::zero(); n];
let mut recovered = vec![Complex::zero(); n];
let solver = RaderSolver::new(n).expect("n must be prime");
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n_f = n as f64;
let mut max_rel = 0.0_f64;
for (orig, rec) in original.iter().zip(recovered.iter()) {
let rec_scaled = *rec / n_f;
let re_err = (orig.re - rec_scaled.re).abs();
let im_err = (orig.im - rec_scaled.im).abs();
let norm = (orig.re * orig.re + orig.im * orig.im).sqrt().max(1e-30);
max_rel = max_rel.max((re_err + im_err) / norm);
}
assert!(
max_rel < 1e-12,
"rader f64 round-trip n={n}: max_rel={max_rel} (must be < 1e-12)"
);
}
fn roundtrip_f32(n: usize) {
let original: Vec<Complex<f32>> = (0..n)
.map(|i| Complex::new((i as f32).sin(), (i as f32 * 0.7).cos()))
.collect();
let mut transformed = vec![Complex::new(0.0_f32, 0.0); n];
let mut recovered = vec![Complex::new(0.0_f32, 0.0); n];
let solver = RaderSolver::<f32>::new(n).expect("n must be prime");
solver.execute(&original, &mut transformed, Sign::Forward);
solver.execute(&transformed, &mut recovered, Sign::Backward);
let n_f = n as f32;
let mut max_rel = 0.0_f32;
for (orig, rec) in original.iter().zip(recovered.iter()) {
let rec_scaled = *rec / n_f;
let re_err = (orig.re - rec_scaled.re).abs();
let im_err = (orig.im - rec_scaled.im).abs();
let norm = (orig.re * orig.re + orig.im * orig.im)
.sqrt()
.max(1e-10_f32);
max_rel = max_rel.max((re_err + im_err) / norm);
}
assert!(
max_rel < 1e-3,
"rader f32 round-trip n={n}: max_rel={max_rel} (must be < 1e-3)"
);
}
#[test]
fn roundtrip_prime_17_f64() {
roundtrip_f64(17);
}
#[test]
fn roundtrip_prime_61_f64() {
roundtrip_f64(61);
}
#[test]
fn roundtrip_prime_127_f64() {
roundtrip_f64(127);
}
#[test]
fn roundtrip_prime_257_f64() {
roundtrip_f64(257);
}
#[test]
fn roundtrip_prime_509_f64() {
roundtrip_f64(509);
}
#[test]
fn roundtrip_prime_1009_f64() {
roundtrip_f64(1009);
}
#[test]
fn roundtrip_prime_17_f32() {
roundtrip_f32(17);
}
#[test]
fn roundtrip_prime_61_f32() {
roundtrip_f32(61);
}
#[test]
fn roundtrip_prime_127_f32() {
roundtrip_f32(127);
}
#[test]
fn roundtrip_prime_257_f32() {
roundtrip_f32(257);
}
#[test]
fn roundtrip_prime_509_f32() {
roundtrip_f32(509);
}
#[test]
fn roundtrip_prime_1009_f32() {
roundtrip_f32(1009);
}
#[cfg(feature = "threading")]
#[test]
fn parallel_shared_rader_correctness() {
use rayon::prelude::*;
let p = 61_usize;
let solver = std::sync::Arc::new(RaderSolver::new(p).expect("p=61 is prime"));
let input: Vec<Complex<f64>> = (0..p)
.map(|i| Complex::new((i as f64).sin(), (i as f64).cos()))
.collect();
let mut reference = vec![Complex::zero(); p];
solver.execute(&input, &mut reference, Sign::Forward);
let results: Vec<Vec<Complex<f64>>> = (0..16_usize)
.into_par_iter()
.map(|_| {
let mut out = vec![Complex::zero(); p];
solver.execute(&input, &mut out, Sign::Forward);
out
})
.collect();
for (thread_idx, result) in results.iter().enumerate() {
for (k, (r, rr)) in result.iter().zip(reference.iter()).enumerate() {
let err = ((r.re - rr.re).abs() + (r.im - rr.im).abs())
/ (rr.re * rr.re + rr.im * rr.im).sqrt().max(1e-30);
assert!(
err < 1e-12,
"parallel thread {thread_idx} output[{k}] diverged: err={err}"
);
}
}
}
}