use crate::complex::Complex;
use crate::fft_core::fft;
use crate::utils::{window_func, Window};
use numra_core::Scalar;
pub fn psd<S: Scalar>(x: &[S], fs: S, window: &Window) -> (Vec<S>, Vec<S>) {
let n = x.len();
if n == 0 {
return (vec![], vec![]);
}
let w: Vec<S> = window_func(window, n);
let win_sum_sq: S = w.iter().map(|&v| v * v).sum();
let windowed: Vec<Complex<S>> = x
.iter()
.zip(w.iter())
.map(|(&xi, &wi)| Complex::new(xi * wi, S::ZERO))
.collect();
let spectrum = fft(&windowed);
let n_freq = n / 2 + 1;
let scale = S::ONE / (fs * win_sum_sq);
let psd_vals: Vec<S> = spectrum[..n_freq]
.iter()
.enumerate()
.map(|(k, c)| {
let power = c.norm_sqr() * scale;
if k > 0 && k < n / 2 {
S::TWO * power
} else {
power
}
})
.collect();
let freq_step = fs / S::from_usize(n);
let frequencies: Vec<S> = (0..n_freq).map(|k| S::from_usize(k) * freq_step).collect();
(frequencies, psd_vals)
}
pub fn welch<S: Scalar>(
x: &[S],
fs: S,
nperseg: usize,
noverlap: usize,
window: &Window,
) -> (Vec<S>, Vec<S>) {
let n = x.len();
if n == 0 || nperseg == 0 || nperseg > n {
return (vec![], vec![]);
}
let step = nperseg - noverlap;
if step == 0 {
return (vec![], vec![]);
}
let w: Vec<S> = window_func(window, nperseg);
let win_sum_sq: S = w.iter().map(|&v| v * v).sum();
let n_freq = nperseg / 2 + 1;
let scale = S::ONE / (fs * win_sum_sq);
let mut psd_avg = vec![S::ZERO; n_freq];
let mut n_segments: usize = 0;
let mut start = 0;
while start + nperseg <= n {
let segment = &x[start..start + nperseg];
let windowed: Vec<Complex<S>> = segment
.iter()
.zip(w.iter())
.map(|(&xi, &wi)| Complex::new(xi * wi, S::ZERO))
.collect();
let spectrum = fft(&windowed);
for k in 0..n_freq {
let power = spectrum[k].norm_sqr() * scale;
if k > 0 && k < nperseg / 2 {
psd_avg[k] += S::TWO * power;
} else {
psd_avg[k] += power;
}
}
n_segments += 1;
start += step;
}
if n_segments > 0 {
let inv_seg = S::ONE / S::from_usize(n_segments);
for v in &mut psd_avg {
*v *= inv_seg;
}
}
let freq_step = fs / S::from_usize(nperseg);
let frequencies: Vec<S> = (0..n_freq).map(|k| S::from_usize(k) * freq_step).collect();
(frequencies, psd_avg)
}
pub struct StftResult<S: Scalar> {
pub times: Vec<S>,
pub frequencies: Vec<S>,
pub magnitude: Vec<Vec<S>>,
}
pub fn stft<S: Scalar>(
x: &[S],
fs: S,
nperseg: usize,
noverlap: usize,
window: &Window,
) -> StftResult<S> {
let n = x.len();
if n == 0 || nperseg == 0 || nperseg > n {
return StftResult {
times: vec![],
frequencies: vec![],
magnitude: vec![],
};
}
let step = nperseg - noverlap;
if step == 0 {
return StftResult {
times: vec![],
frequencies: vec![],
magnitude: vec![],
};
}
let w: Vec<S> = window_func(window, nperseg);
let n_freq = nperseg / 2 + 1;
let freq_step = fs / S::from_usize(nperseg);
let frequencies: Vec<S> = (0..n_freq).map(|k| S::from_usize(k) * freq_step).collect();
let mut times = Vec::new();
let mut magnitude = Vec::new();
let mut start = 0;
while start + nperseg <= n {
let segment = &x[start..start + nperseg];
let windowed: Vec<Complex<S>> = segment
.iter()
.zip(w.iter())
.map(|(&xi, &wi)| Complex::new(xi * wi, S::ZERO))
.collect();
let spectrum = fft(&windowed);
let mag: Vec<S> = spectrum[..n_freq].iter().map(|c| c.abs()).collect();
magnitude.push(mag);
let center = S::from_usize(start) + (S::from_usize(nperseg) - S::ONE) / S::TWO;
times.push(center / fs);
start += step;
}
StftResult {
times,
frequencies,
magnitude,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_psd_dc_signal() {
let x = vec![1.0_f64; 64];
let (freqs, psd_vals) = psd(&x, 1.0, &Window::Rectangular);
assert_eq!(freqs.len(), 33); assert!(psd_vals[0] > 0.0);
for k in 1..psd_vals.len() {
assert!(psd_vals[k] < 1e-20, "bin {} = {}", k, psd_vals[k]);
}
}
#[test]
fn test_psd_single_tone() {
let n = 256;
let fs = 256.0;
let freq = 10.0; let pi2 = 2.0 * core::f64::consts::PI;
let x: Vec<f64> = (0..n).map(|k| (pi2 * freq * k as f64 / fs).sin()).collect();
let (freqs, psd_vals) = psd(&x, fs, &Window::Rectangular);
let peak_idx = psd_vals
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0;
assert!((freqs[peak_idx] - freq).abs() < fs / n as f64 + 0.01);
}
#[test]
fn test_welch_reduces_variance() {
let n = 512;
let fs = 100.0;
let pi2 = 2.0 * core::f64::consts::PI;
let x: Vec<f64> = (0..n).map(|k| (pi2 * 10.0 * k as f64 / fs).sin()).collect();
let (freqs, psd_vals) = welch(&x, fs, 128, 64, &Window::Hann);
assert!(!freqs.is_empty());
assert_eq!(freqs.len(), psd_vals.len());
assert!(psd_vals.iter().all(|&v| v >= 0.0));
}
#[test]
fn test_stft_basic() {
let n = 256;
let fs = 100.0;
let pi2 = 2.0 * core::f64::consts::PI;
let x: Vec<f64> = (0..n).map(|k| (pi2 * 10.0 * k as f64 / fs).sin()).collect();
let result = stft(&x, fs, 64, 32, &Window::Hann);
assert!(!result.times.is_empty());
assert!(!result.frequencies.is_empty());
assert_eq!(result.magnitude.len(), result.times.len());
for mag in &result.magnitude {
assert_eq!(mag.len(), result.frequencies.len());
}
}
#[test]
fn test_psd_empty() {
let (f, p) = psd::<f64>(&[], 1.0, &Window::Rectangular);
assert!(f.is_empty());
assert!(p.is_empty());
}
}