use crate::error::{FFTError, FFTResult};
use crate::fft::{fft, ifft};
use scirs2_core::numeric::{Complex64, Zero};
use std::f64::consts::PI;
fn next_pow2(n: usize) -> usize {
if n <= 1 { 1 } else { n.next_power_of_two() }
}
fn is_prime(n: usize) -> bool {
if n < 2 { return false; }
if n == 2 { return true; }
if n % 2 == 0 { return false; }
let mut i = 3usize;
while i * i <= n {
if n % i == 0 { return false; }
i += 2;
}
true
}
pub fn primitive_root(p: usize) -> FFTResult<usize> {
if !is_prime(p) {
return Err(FFTError::ValueError(format!("{p} is not prime")));
}
if p == 2 {
return Ok(1);
}
let phi = p - 1;
let mut factors: Vec<usize> = Vec::new();
let mut rem = phi;
let mut f = 2usize;
while f * f <= rem {
if rem % f == 0 {
factors.push(f);
while rem % f == 0 { rem /= f; }
}
f += 1;
}
if rem > 1 { factors.push(rem); }
'outer: for g in 2..p {
for &q in &factors {
let exp = phi / q;
if modpow_usize(g, exp, p) == 1 {
continue 'outer;
}
}
return Ok(g);
}
Err(FFTError::ValueError(format!("No primitive root found for {p}")))
}
fn modpow_usize(mut base: usize, mut exp: usize, modulus: usize) -> usize {
let mut result = 1usize;
base %= modulus;
while exp > 0 {
if exp & 1 == 1 { result = result * base % modulus; }
exp >>= 1;
base = base * base % modulus;
}
result
}
fn powers_of_g(g: usize, n: usize, p: usize) -> Vec<usize> {
let mut seq = Vec::with_capacity(n);
let mut cur = 1usize;
for _ in 0..n {
seq.push(cur);
cur = cur * g % p;
}
seq
}
fn cyclic_convolve(a: &[Complex64], b: &[Complex64]) -> FFTResult<Vec<Complex64>> {
let n = a.len();
if n != b.len() {
return Err(FFTError::ValueError("cyclic_convolve: length mismatch".into()));
}
let fft_a = fft(a, None)?;
let fft_b = fft(b, None)?;
let prod: Vec<Complex64> = fft_a.iter().zip(fft_b.iter()).map(|(&x, &y)| x * y).collect();
let result = ifft(&prod, None)?;
Ok(result)
}
pub fn rader_fft(signal: &[Complex64]) -> FFTResult<Vec<Complex64>> {
let p = signal.len();
if p == 0 {
return Err(FFTError::ValueError("rader_fft: empty input".into()));
}
if p == 1 {
return Ok(signal.to_vec());
}
if p == 2 {
let s0 = signal[0];
let s1 = signal[1];
return Ok(vec![s0 + s1, s0 - s1]);
}
if !is_prime(p) {
return Err(FFTError::ValueError(format!(
"rader_fft: length {p} is not prime. Use bluestein_fft for arbitrary lengths."
)));
}
let g = primitive_root(p)?;
let phi = p - 1;
let g_pos = powers_of_g(g, phi, p);
let g_inv = modpow_usize(g, phi - 1, p);
let g_neg = powers_of_g(g_inv, phi, p);
let x0: Complex64 = signal.iter().fold(Complex64::zero(), |acc, &v| acc + v);
let a: Vec<Complex64> = (0..phi).map(|n| signal[g_pos[n]]).collect();
let b: Vec<Complex64> = (0..phi)
.map(|q| {
let phase = -2.0 * PI * g_neg[q] as f64 / p as f64;
Complex64::new(phase.cos(), phase.sin())
})
.collect();
let c = cyclic_convolve(&a, &b)?;
let mut out = vec![Complex64::zero(); p];
out[0] = x0;
for q in 0..phi {
let k = g_neg[q]; out[k] = signal[0] + c[q];
}
Ok(out)
}
pub fn bluestein_chirp_z(
x: &[Complex64],
m: usize,
a: Complex64,
w: Complex64,
) -> FFTResult<Vec<Complex64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("bluestein_chirp_z: input is empty".into()));
}
if m == 0 {
return Err(FFTError::ValueError("bluestein_chirp_z: m must be > 0".into()));
}
let l = next_pow2(n + m - 1);
let mut yn = vec![Complex64::zero(); l];
let mut a_pow = Complex64::new(1.0, 0.0); let mut a_inv = Complex64::new(1.0, 0.0); let a_norm_sq = a.norm_sqr();
if a_norm_sq == 0.0 {
return Err(FFTError::ValueError("bluestein_chirp_z: |A| = 0".into()));
}
let a_inv_scalar = Complex64::new(a.re / a_norm_sq, -a.im / a_norm_sq);
let w_angle = w.im.atan2(w.re);
for n_idx in 0..n {
let n_f = n_idx as f64;
let phase_w = w_angle * n_f * n_f * 0.5;
let w_chirp = Complex64::new(phase_w.cos(), phase_w.sin());
if n_idx > 0 {
a_inv = a_inv * a_inv_scalar;
}
yn[n_idx] = x[n_idx] * a_inv * w_chirp;
let _ = a_pow; let _ = a_pow.re;
}
let _ = a_pow;
let mut hn = vec![Complex64::zero(); l];
for k in 0..m {
let k_f = k as f64;
let phase = -w_angle * k_f * k_f * 0.5;
hn[k] = Complex64::new(phase.cos(), phase.sin());
}
for k in 1..n {
let k_f = k as f64;
let phase = -w_angle * k_f * k_f * 0.5;
let idx = l - k;
if idx < l {
hn[idx] = Complex64::new(phase.cos(), phase.sin());
}
}
let yn_fft = fft(&yn, None)?;
let hn_fft = fft(&hn, None)?;
let prod: Vec<Complex64> = yn_fft
.iter()
.zip(hn_fft.iter())
.map(|(&a, &b)| a * b)
.collect();
let g_conv = ifft(&prod, None)?;
let mut out = Vec::with_capacity(m);
for k in 0..m {
let k_f = k as f64;
let phase = w_angle * k_f * k_f * 0.5;
let twiddle = Complex64::new(phase.cos(), phase.sin());
out.push(g_conv[k] * twiddle);
}
Ok(out)
}
pub fn chirp_z_transform(
x: &[Complex64],
m: usize,
a_mag: f64,
a_angle: f64,
w_mag: f64,
w_angle: f64,
) -> FFTResult<Vec<Complex64>> {
if x.is_empty() {
return Err(FFTError::ValueError("chirp_z_transform: input is empty".into()));
}
if m == 0 {
return Err(FFTError::ValueError("chirp_z_transform: m must be > 0".into()));
}
let a = Complex64::new(a_mag * a_angle.cos(), a_mag * a_angle.sin());
let w = Complex64::new(w_mag * w_angle.cos(), w_mag * w_angle.sin());
bluestein_chirp_z(x, m, a, w)
}
pub fn dft_via_czt(x: &[Complex64]) -> FFTResult<Vec<Complex64>> {
let n = x.len();
if n == 0 {
return Err(FFTError::ValueError("dft_via_czt: empty input".into()));
}
chirp_z_transform(x, n, 1.0, 0.0, 1.0, 2.0 * PI / n as f64)
}
pub fn zoom_fft_band(
x: &[Complex64],
m: usize,
f_low: f64,
f_high: f64,
) -> FFTResult<Vec<Complex64>> {
if x.is_empty() {
return Err(FFTError::ValueError("zoom_fft_band: empty input".into()));
}
if m == 0 {
return Err(FFTError::ValueError("zoom_fft_band: m must be > 0".into()));
}
if !(0.0..=0.5).contains(&f_low) {
return Err(FFTError::ValueError(format!("zoom_fft_band: f_low={f_low} out of [0, 0.5]")));
}
if !(f_low..=0.5).contains(&f_high) {
return Err(FFTError::ValueError(format!(
"zoom_fft_band: f_high={f_high} must be in [{f_low}, 0.5]"
)));
}
let a_angle = 2.0 * PI * f_low;
let step = (f_high - f_low) / m as f64;
let w_angle = 2.0 * PI * step;
chirp_z_transform(x, m, 1.0, a_angle, 1.0, w_angle)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn brute_dft(x: &[Complex64]) -> Vec<Complex64> {
let n = x.len();
(0..n)
.map(|k| {
x.iter().enumerate().fold(Complex64::zero(), |acc, (m, &xm)| {
let phase = -2.0 * PI * k as f64 * m as f64 / n as f64;
acc + xm * Complex64::new(phase.cos(), phase.sin())
})
})
.collect()
}
fn assert_complex_close(a: &[Complex64], b: &[Complex64], tol: f64) {
assert_eq!(a.len(), b.len(), "length mismatch");
for (i, (ai, bi)) in a.iter().zip(b.iter()).enumerate() {
assert_relative_eq!(ai.re, bi.re, epsilon = tol, var_name = format!("bin {i} re"));
assert_relative_eq!(ai.im, bi.im, epsilon = tol, var_name = format!("bin {i} im"));
}
}
#[test]
fn test_primitive_root_5() {
let g = primitive_root(5).expect("primitive root");
assert!(g == 2 || g == 3, "unexpected primitive root {g} for p=5");
}
#[test]
fn test_primitive_root_7() {
let g = primitive_root(7).expect("primitive root");
assert!(g == 3 || g == 5, "unexpected primitive root {g} for p=7");
}
#[test]
fn test_primitive_root_nonprime_error() {
assert!(primitive_root(4).is_err());
assert!(primitive_root(9).is_err());
}
#[test]
fn test_rader_prime_7() {
let p = 7;
let signal: Vec<Complex64> = (0..p)
.map(|k| Complex64::new(k as f64, (k as f64) * 0.5))
.collect();
let rader = rader_fft(&signal).expect("rader_fft");
let brute = brute_dft(&signal);
assert_complex_close(&rader, &brute, 1e-9);
}
#[test]
fn test_rader_prime_11() {
let p = 11;
let signal: Vec<Complex64> = (0..p)
.map(|k| Complex64::new((k as f64 / p as f64).sin(), 0.0))
.collect();
let rader = rader_fft(&signal).expect("rader_fft");
let brute = brute_dft(&signal);
assert_complex_close(&rader, &brute, 1e-8);
}
#[test]
fn test_rader_prime_13() {
let p = 13;
let signal: Vec<Complex64> = (0..p)
.map(|k| Complex64::new(1.0, -(k as f64)))
.collect();
let rader = rader_fft(&signal).expect("rader_fft");
let brute = brute_dft(&signal);
assert_complex_close(&rader, &brute, 1e-8);
}
#[test]
fn test_rader_length_2() {
let signal = vec![Complex64::new(1.0, 0.0), Complex64::new(-1.0, 0.0)];
let out = rader_fft(&signal).expect("rader_fft length 2");
assert_relative_eq!(out[0].re, 0.0, epsilon = 1e-12);
assert_relative_eq!(out[1].re, 2.0, epsilon = 1e-12);
}
#[test]
fn test_rader_nonprime_error() {
let signal: Vec<Complex64> = (0..6).map(|k| Complex64::new(k as f64, 0.0)).collect();
assert!(rader_fft(&signal).is_err());
}
#[test]
fn test_rader_empty_error() {
assert!(rader_fft(&[]).is_err());
}
#[test]
fn test_bluestein_chirp_z_equals_dft_prime_7() {
let p = 7;
let signal: Vec<Complex64> = (0..p)
.map(|k| Complex64::new(k as f64, 0.0))
.collect();
let w_angle = 2.0 * PI / p as f64;
let a = Complex64::new(1.0, 0.0);
let w = Complex64::new(w_angle.cos(), w_angle.sin());
let czt = bluestein_chirp_z(&signal, p, a, w).expect("czt");
let brute = brute_dft(&signal);
assert_complex_close(&czt, &brute, 1e-9);
}
#[test]
fn test_bluestein_chirp_z_more_output_bins() {
let n = 8;
let signal: Vec<Complex64> = (0..n).map(|k| Complex64::new(k as f64, 0.0)).collect();
let m = 12;
let w_angle = 2.0 * PI / n as f64;
let a = Complex64::new(1.0, 0.0);
let w = Complex64::new(w_angle.cos(), w_angle.sin());
let out = bluestein_chirp_z(&signal, m, a, w).expect("czt m>n");
assert_eq!(out.len(), m);
}
#[test]
fn test_bluestein_chirp_z_empty_error() {
let a = Complex64::new(1.0, 0.0);
let w = Complex64::new(1.0, 0.0);
assert!(bluestein_chirp_z(&[], 8, a, w).is_err());
}
#[test]
fn test_chirp_z_transform_equals_dft() {
let n = 8;
let signal: Vec<Complex64> = (0..n)
.map(|k| Complex64::new(k as f64, 0.0))
.collect();
let spec = chirp_z_transform(&signal, n, 1.0, 0.0, 1.0, 2.0 * PI / n as f64)
.expect("czt");
let brute = brute_dft(&signal);
assert_complex_close(&spec, &brute, 1e-9);
}
#[test]
fn test_dft_via_czt_all_ones() {
let n = 7;
let signal = vec![Complex64::new(1.0, 0.0); n];
let spec = dft_via_czt(&signal).expect("dft_via_czt");
assert_relative_eq!(spec[0].re, n as f64, epsilon = 1e-9);
for k in 1..n {
assert_relative_eq!(spec[k].re.abs(), 0.0, epsilon = 1e-9);
assert_relative_eq!(spec[k].im.abs(), 0.0, epsilon = 1e-9);
}
}
#[test]
fn test_zoom_fft_band_length() {
let n = 64;
let signal = vec![Complex64::new(1.0, 0.0); n];
let spec = zoom_fft_band(&signal, 32, 0.1, 0.4).expect("zoom_fft");
assert_eq!(spec.len(), 32);
}
#[test]
fn test_zoom_fft_band_dc_when_full_range() {
let n = 16;
let signal: Vec<Complex64> = (0..n).map(|k| Complex64::new(k as f64, 0.0)).collect();
let spec = zoom_fft_band(&signal, n, 0.0, 0.5).expect("zoom_fft full");
assert_eq!(spec.len(), n);
}
#[test]
fn test_zoom_fft_invalid_args() {
let sig = vec![Complex64::new(1.0, 0.0); 8];
assert!(zoom_fft_band(&sig, 0, 0.1, 0.4).is_err());
assert!(zoom_fft_band(&sig, 8, 0.6, 0.4).is_err());
assert!(zoom_fft_band(&sig, 8, 0.1, 0.6).is_err());
}
}