use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
use crate::hilbert::analytic_signal;
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
use super::types::{WvdConfig, WvdResult};
pub struct WignerVille;
impl WignerVille {
pub fn new() -> Self {
Self
}
pub fn compute_wvd(&self, signal: &[f64], fs: f64, config: &WvdConfig) -> FFTResult<WvdResult> {
compute_wvd_impl(signal, fs, config, false)
}
pub fn compute_pwvd(
&self,
signal: &[f64],
fs: f64,
config: &WvdConfig,
) -> FFTResult<WvdResult> {
compute_wvd_impl(signal, fs, config, true)
}
}
impl Default for WignerVille {
fn default() -> Self {
Self::new()
}
}
fn compute_wvd_impl(
signal: &[f64],
fs: f64,
config: &WvdConfig,
apply_smooth: bool,
) -> FFTResult<WvdResult> {
let n = signal.len();
if n == 0 {
return Err(FFTError::ValueError("Signal must not be empty".to_string()));
}
if fs <= 0.0 {
return Err(FFTError::ValueError(
"Sampling frequency must be positive".to_string(),
));
}
let n_freqs = config.n_freqs.unwrap_or(n);
if n_freqs == 0 {
return Err(FFTError::ValueError("n_freqs must be > 0".to_string()));
}
let z: Vec<Complex64> = if config.analytic {
analytic_signal(signal)?
} else {
signal.iter().map(|&s| Complex64::new(s, 0.0)).collect()
};
let smooth_len = if apply_smooth {
config.smooth_window
} else {
0
};
let gaussian_window: Vec<f64> = build_gaussian_window(smooth_len);
let mut wvd_matrix: Vec<Vec<f64>> = Vec::with_capacity(n);
for t in 0..n {
let tau_max = t.min(n - 1 - t);
let mut acf_buf: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); n_freqs];
for tau in 0..=tau_max.min(n_freqs / 2) {
let r_pos = z[t + tau] * z[t - tau].conj();
let r_neg = if tau > 0 {
z[t - tau] * z[t + tau].conj()
} else {
r_pos
};
let win_weight = if smooth_len > 0 {
let w_idx = tau.min(gaussian_window.len().saturating_sub(1));
gaussian_window[w_idx]
} else {
1.0
};
let buf_idx_pos = tau % n_freqs;
let buf_idx_neg = if tau > 0 {
(n_freqs - tau % n_freqs) % n_freqs
} else {
0
};
acf_buf[buf_idx_pos] = acf_buf[buf_idx_pos] + r_pos * win_weight;
if tau > 0 && buf_idx_neg < n_freqs {
acf_buf[buf_idx_neg] = acf_buf[buf_idx_neg] + r_neg * win_weight;
}
}
let wvd_row_complex = fft_complex(&acf_buf)?;
let row: Vec<f64> = wvd_row_complex.iter().map(|c| c.re).collect();
wvd_matrix.push(row);
}
let times: Vec<f64> = (0..n).map(|i| i as f64 / fs).collect();
let frequencies: Vec<f64> = (0..n_freqs)
.map(|k| k as f64 * fs / n_freqs as f64)
.collect();
Ok(WvdResult {
wvd: wvd_matrix,
times,
frequencies,
})
}
fn fft_complex(signal: &[Complex64]) -> FFTResult<Vec<Complex64>> {
let n = signal.len();
if n == 0 {
return Ok(Vec::new());
}
let re_part: Vec<f64> = signal.iter().map(|c| c.re).collect();
let im_part: Vec<f64> = signal.iter().map(|c| c.im).collect();
let fft_re = crate::fft::fft(&re_part, None)?;
let fft_im = crate::fft::fft(&im_part, None)?;
let result: Vec<Complex64> = fft_re
.iter()
.zip(fft_im.iter())
.map(|(r, i)| {
Complex64::new(r.re - i.im, r.im + i.re)
})
.collect();
Ok(result)
}
fn build_gaussian_window(half_len: usize) -> Vec<f64> {
if half_len == 0 {
return vec![1.0];
}
let sigma = half_len as f64 / 2.0;
let sigma2 = 2.0 * sigma * sigma;
(0..=half_len)
.map(|tau| {
let t = tau as f64;
(-t * t / sigma2).exp()
})
.collect()
}
pub fn compute_wvd(signal: &[f64], fs: f64) -> FFTResult<WvdResult> {
let config = WvdConfig::default();
let wv = WignerVille::new();
wv.compute_wvd(signal, fs, &config)
}
pub fn compute_pwvd(signal: &[f64], fs: f64, smooth_window: usize) -> FFTResult<WvdResult> {
let config = WvdConfig {
smooth_window,
..WvdConfig::default()
};
let wv = WignerVille::new();
wv.compute_pwvd(signal, fs, &config)
}