use std::f64::consts::PI;
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::Complex64;
use crate::error::{FFTError, FFTResult};
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
pub enum WindowType {
Gaussian,
Hamming,
Hann,
Blackman,
Rectangular,
}
impl WindowType {
pub fn samples(&self, size: usize) -> Vec<f64> {
let l = size as f64;
match self {
WindowType::Gaussian => {
let sigma = l / 6.0;
let centre = (l - 1.0) / 2.0;
(0..size)
.map(|n| {
let x = (n as f64 - centre) / sigma;
(-0.5 * x * x).exp()
})
.collect()
}
WindowType::Hamming => (0..size)
.map(|n| 0.54 - 0.46 * (2.0 * PI * n as f64 / (l - 1.0)).cos())
.collect(),
WindowType::Hann => (0..size)
.map(|n| 0.5 * (1.0 - (2.0 * PI * n as f64 / (l - 1.0)).cos()))
.collect(),
WindowType::Blackman => (0..size)
.map(|n| {
let t = 2.0 * PI * n as f64 / (l - 1.0);
0.42 - 0.5 * t.cos() + 0.08 * (2.0 * t).cos()
})
.collect(),
WindowType::Rectangular => vec![1.0; size],
}
}
}
#[derive(Debug, Clone)]
pub struct StfrftConfig {
pub alpha: f64,
pub window_size: usize,
pub hop_size: usize,
pub window_type: WindowType,
pub oversample: bool,
}
impl Default for StfrftConfig {
fn default() -> Self {
Self {
alpha: 1.0,
window_size: 256,
hop_size: 64,
window_type: WindowType::Gaussian,
oversample: false,
}
}
}
#[derive(Debug, Clone)]
pub struct StfrftResult {
pub coefficients: Array2<Complex64>,
pub time_centers: Vec<f64>,
pub fractional_freqs: Vec<f64>,
pub alpha: f64,
}
fn grunbaum_eigendecomp(n: usize) -> FFTResult<(Vec<Vec<f64>>, Vec<i32>)> {
if n == 0 {
return Ok((vec![], vec![]));
}
if n == 1 {
return Ok((vec![vec![1.0]], vec![0]));
}
let mut mat = vec![0.0_f64; n * n];
for j in 0..n {
mat[j * n + j] = 2.0 * (2.0 * PI * j as f64 / n as f64).cos();
if j + 1 < n {
mat[j * n + j + 1] = 1.0;
mat[(j + 1) * n + j] = 1.0;
}
}
mat[n - 1] = 1.0;
mat[(n - 1) * n] = 1.0;
let (eigenvalues, eigenvectors) = symmetric_jacobi_eig(&mut mat, n);
let mut order_idx: Vec<usize> = (0..n).collect();
order_idx.sort_by(|&a, &b| {
eigenvalues[b]
.partial_cmp(&eigenvalues[a])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut ev_orders = vec![0_i32; n];
for (rank, &idx) in order_idx.iter().enumerate() {
ev_orders[idx] = (rank % 4) as i32;
}
Ok((eigenvectors, ev_orders))
}
fn symmetric_jacobi_eig(mat_flat: &mut [f64], n: usize) -> (Vec<f64>, Vec<Vec<f64>>) {
let mut z: Vec<f64> = (0..n * n)
.map(|k| if k / n == k % n { 1.0 } else { 0.0 })
.collect();
const MAX_SWEEP: usize = 100;
let eps = 1e-15_f64;
for _ in 0..MAX_SWEEP {
let mut max_val = 0.0_f64;
let mut p_idx = 0;
let mut q_idx = 1;
for i in 0..n {
for j in (i + 1)..n {
let v = mat_flat[i * n + j].abs();
if v > max_val {
max_val = v;
p_idx = i;
q_idx = j;
}
}
}
if max_val < eps {
break;
}
let p = p_idx;
let q = q_idx;
let app = mat_flat[p * n + p];
let aqq = mat_flat[q * n + q];
let apq = mat_flat[p * n + q];
let tau = (aqq - app) / (2.0 * apq);
let t = if tau >= 0.0 {
1.0 / (tau + (1.0 + tau * tau).sqrt())
} else {
-1.0 / (-tau + (1.0 + tau * tau).sqrt())
};
let c = 1.0 / (1.0 + t * t).sqrt();
let s = t * c;
mat_flat[p * n + p] = app - t * apq;
mat_flat[q * n + q] = aqq + t * apq;
mat_flat[p * n + q] = 0.0;
mat_flat[q * n + p] = 0.0;
for r in 0..n {
if r == p || r == q {
continue;
}
let arp = mat_flat[r * n + p];
let arq = mat_flat[r * n + q];
let new_rp = c * arp - s * arq;
let new_rq = s * arp + c * arq;
mat_flat[r * n + p] = new_rp;
mat_flat[p * n + r] = new_rp;
mat_flat[r * n + q] = new_rq;
mat_flat[q * n + r] = new_rq;
}
for r in 0..n {
let zrp = z[r * n + p];
let zrq = z[r * n + q];
z[r * n + p] = c * zrp - s * zrq;
z[r * n + q] = s * zrp + c * zrq;
}
}
let eigenvalues: Vec<f64> = (0..n).map(|i| mat_flat[i * n + i]).collect();
let eigenvectors: Vec<Vec<f64>> = (0..n)
.map(|j| (0..n).map(|i| z[i * n + j]).collect())
.collect();
(eigenvalues, eigenvectors)
}
pub fn dfrft(signal: &[Complex64], alpha: f64) -> FFTResult<Vec<Complex64>> {
let n = signal.len();
if n == 0 {
return Ok(vec![]);
}
let (eigvecs, ev_orders) = grunbaum_eigendecomp(n)?;
let mut vhx = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
let mut sum = Complex64::new(0.0, 0.0);
for j in 0..n {
sum += Complex64::new(eigvecs[k][j], 0.0) * signal[j];
}
vhx[k] = sum;
}
let mut dvhx = vec![Complex64::new(0.0, 0.0); n];
for k in 0..n {
let angle = -PI / 2.0 * ev_orders[k] as f64 * alpha;
let frac_eig = Complex64::new(angle.cos(), angle.sin());
dvhx[k] = frac_eig * vhx[k];
}
let mut result = vec![Complex64::new(0.0, 0.0); n];
for j in 0..n {
let mut sum = Complex64::new(0.0, 0.0);
for k in 0..n {
sum += Complex64::new(eigvecs[k][j], 0.0) * dvhx[k];
}
result[j] = sum;
}
Ok(result)
}
pub fn stfrft(signal: &[f64], config: &StfrftConfig) -> FFTResult<StfrftResult> {
let sig_len = signal.len();
if sig_len == 0 {
return Err(FFTError::ValueError("Signal must be non-empty".to_string()));
}
let win_size = config.window_size;
if win_size == 0 {
return Err(FFTError::ValueError("Window size must be > 0".to_string()));
}
let hop = config.hop_size;
if hop == 0 {
return Err(FFTError::ValueError("Hop size must be > 0".to_string()));
}
let window = config.window_type.samples(win_size);
let n_frames = if sig_len <= win_size {
1
} else {
(sig_len - win_size) / hop + 1
};
let mut coefficients = Array2::zeros((n_frames, win_size));
let mut time_centers = Vec::with_capacity(n_frames);
for frame_idx in 0..n_frames {
let start = frame_idx * hop;
let centre = start as f64 + win_size as f64 / 2.0;
time_centers.push(centre);
let frame_complex: Vec<Complex64> = (0..win_size)
.map(|i| {
let sig_idx = start + i;
let sample = if sig_idx < sig_len {
signal[sig_idx]
} else {
0.0
};
Complex64::new(sample * window[i], 0.0)
})
.collect();
let dfrft_out = dfrft(&frame_complex, config.alpha)?;
for (k, val) in dfrft_out.into_iter().enumerate() {
coefficients[[frame_idx, k]] = val;
}
}
let fractional_freqs: Vec<f64> = (0..win_size).map(|k| k as f64 / win_size as f64).collect();
Ok(StfrftResult {
coefficients,
time_centers,
fractional_freqs,
alpha: config.alpha,
})
}
pub fn istfrft(
result: &StfrftResult,
signal_length: usize,
hop_size: usize,
) -> FFTResult<Vec<f64>> {
let n_frames = result.coefficients.shape()[0];
let win_size = result.coefficients.shape()[1];
let hop = if hop_size > 0 { hop_size } else { 1 };
let mut output = vec![0.0_f64; signal_length];
let mut norm = vec![0.0_f64; signal_length];
let inv_alpha = -result.alpha;
for frame_idx in 0..n_frames {
let start = frame_idx * hop;
let frame: Vec<Complex64> = (0..win_size)
.map(|k| result.coefficients[[frame_idx, k]])
.collect();
let recon = dfrft(&frame, inv_alpha)?;
for i in 0..win_size {
let sig_idx = start + i;
if sig_idx < signal_length {
output[sig_idx] += recon[i].re;
norm[sig_idx] += 1.0;
}
}
}
for i in 0..signal_length {
if norm[i] > 0.0 {
output[i] /= norm[i];
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use std::f64::consts::PI;
fn naive_dft(signal: &[f64]) -> Vec<Complex64> {
let n = signal.len();
(0..n)
.map(|k| {
(0..n).fold(Complex64::new(0.0, 0.0), |acc, j| {
let angle = -2.0 * PI * (j * k) as f64 / n as f64;
acc + Complex64::new(signal[j] * angle.cos(), signal[j] * angle.sin())
})
})
.collect()
}
#[test]
fn test_dfrft_alpha_0_identity() {
let n = 16;
let signal: Vec<Complex64> = (0..n)
.map(|i| Complex64::new((2.0 * PI * i as f64 / n as f64).sin(), 0.0))
.collect();
let out = dfrft(&signal, 0.0).expect("dfrft failed");
for (a, b) in signal.iter().zip(out.iter()) {
assert!((a.re - b.re).abs() < 1e-6, "real mismatch alpha=0 at re");
assert!((a.im - b.im).abs() < 1e-6, "imag mismatch alpha=0 at im");
}
}
#[test]
fn test_dfrft_alpha_1_matches_dft() {
let n = 8;
let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * i as f64 / n as f64).cos())
.collect();
let signal_complex: Vec<Complex64> =
signal.iter().map(|&v| Complex64::new(v, 0.0)).collect();
let dfrft_out = dfrft(&signal_complex, 1.0).expect("dfrft failed");
let dft_ref = naive_dft(&signal);
let dfrft_mags: Vec<f64> = dfrft_out
.iter()
.map(|c| (c.re * c.re + c.im * c.im).sqrt())
.collect();
let dft_mags: Vec<f64> = dft_ref
.iter()
.map(|c| (c.re * c.re + c.im * c.im).sqrt())
.collect();
let dfrft_peak = dfrft_mags
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
let dft_peak = dft_mags
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
assert_eq!(
dfrft_peak, dft_peak,
"DFrFT and DFT should have same peak frequency bin"
);
}
#[test]
fn test_dfrft_alpha_2_time_reversal() {
let n = 8;
let signal: Vec<Complex64> = (0..n)
.map(|i| Complex64::new((PI * i as f64 / n as f64).sin(), 0.0))
.collect();
let step1 = dfrft(&signal, 1.0).expect("dfrft step1 failed");
let step2 = dfrft(&step1, 1.0).expect("dfrft step2 failed");
let step3 = dfrft(&step2, 1.0).expect("dfrft step3 failed");
let step4 = dfrft(&step3, 1.0).expect("dfrft step4 failed");
for i in 0..n {
let diff = (step4[i].re - signal[i].re).abs() + (step4[i].im - signal[i].im).abs();
assert!(
diff < 0.3,
"Group property DFrFT(4)≈I violated at index {i}: diff={diff}"
);
}
}
#[test]
fn test_stfrft_output_shape() {
let signal: Vec<f64> = (0..512).map(|i| (i as f64 * 0.1).sin()).collect();
let cfg = StfrftConfig {
alpha: 0.8,
window_size: 64,
hop_size: 16,
window_type: WindowType::Hann,
oversample: false,
};
let result = stfrft(&signal, &cfg).expect("stfrft failed");
let n_frames = result.coefficients.shape()[0];
let n_freqs = result.coefficients.shape()[1];
assert_eq!(
n_freqs, 64,
"Expected window_size={} frequency bins",
cfg.window_size
);
assert!(n_frames > 0, "Expected at least one frame");
assert_eq!(result.time_centers.len(), n_frames);
assert_eq!(result.fractional_freqs.len(), n_freqs);
assert_eq!(result.alpha, 0.8);
}
#[test]
fn test_stfrft_alpha_1_resembles_stft() {
let n = 256;
let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * 5.0 * i as f64 / n as f64).sin())
.collect();
let cfg = StfrftConfig {
alpha: 1.0,
window_size: 32,
hop_size: 8,
window_type: WindowType::Rectangular,
oversample: false,
};
let result = stfrft(&signal, &cfg).expect("stfrft failed");
assert_eq!(result.coefficients.shape()[1], 32);
assert!(result.coefficients.shape()[0] > 1);
}
#[test]
fn test_stfrft_alpha_0_recovers_windowed_signal() {
let n = 128;
let signal: Vec<f64> = (0..n).map(|i| i as f64 * 0.01).collect();
let win_size = 16;
let hop = 4;
let cfg = StfrftConfig {
alpha: 0.0,
window_size: win_size,
hop_size: hop,
window_type: WindowType::Rectangular,
oversample: false,
};
let result = stfrft(&signal, &cfg).expect("stfrft failed");
let frame_0: Vec<f64> = (0..win_size)
.map(|k| result.coefficients[[0, k]].re)
.collect();
for k in 0..win_size {
assert!(
(frame_0[k] - signal[k]).abs() < 1e-6,
"Frame 0 coefficient mismatch at k={k}: {} vs {}",
frame_0[k],
signal[k]
);
}
}
#[test]
fn test_istfrft_output_length() {
let n = 256;
let signal: Vec<f64> = (0..n).map(|i| (i as f64 * 0.05).cos()).collect();
let cfg = StfrftConfig {
alpha: 0.5,
window_size: 32,
hop_size: 16,
window_type: WindowType::Hamming,
oversample: false,
};
let result = stfrft(&signal, &cfg).expect("stfrft failed");
let recon = istfrft(&result, n, cfg.hop_size).expect("istfrft failed");
assert_eq!(recon.len(), n, "Reconstructed signal has wrong length");
}
#[test]
fn test_istfrft_roundtrip_rectangular_window() {
let n = 64;
let signal: Vec<f64> = (0..n)
.map(|i| (2.0 * PI * 3.0 * i as f64 / n as f64).sin())
.collect();
let win_size = 16;
let hop = 4;
let cfg = StfrftConfig {
alpha: 0.0,
window_size: win_size,
hop_size: hop,
window_type: WindowType::Rectangular,
oversample: false,
};
let result = stfrft(&signal, &cfg).expect("stfrft failed");
let recon = istfrft(&result, n, hop).expect("istfrft failed");
let start = win_size;
let end = n.saturating_sub(win_size);
if start < end {
for i in start..end {
assert!(
(recon[i] - signal[i]).abs() < 0.1,
"Roundtrip mismatch at index {i}: {} vs {}",
recon[i],
signal[i]
);
}
}
}
#[test]
fn test_window_type_samples_correct_length() {
for size in [8, 16, 64, 256] {
for wt in [
WindowType::Gaussian,
WindowType::Hamming,
WindowType::Hann,
WindowType::Blackman,
WindowType::Rectangular,
] {
let samples = wt.samples(size);
assert_eq!(
samples.len(),
size,
"Window {wt:?} has wrong sample count for size={size}"
);
}
}
}
#[test]
fn test_dfrft_energy_approximately_preserved() {
let n = 16;
let signal: Vec<Complex64> = (0..n)
.map(|i| Complex64::new((PI * i as f64 / n as f64).sin(), 0.0))
.collect();
let out = dfrft(&signal, 1.0).expect("dfrft failed");
let energy_in: f64 = signal.iter().map(|c| c.re * c.re + c.im * c.im).sum();
let energy_out: f64 = out.iter().map(|c| c.re * c.re + c.im * c.im).sum();
let ratio = energy_out / energy_in;
assert!(
ratio > 0.1,
"Energy ratio {ratio} too small — DFrFT destroyed energy"
);
}
}