use crate::error::{FFTError, FFTResult};
use crate::fft::fft;
use crate::helper::next_fast_len;
use crate::window_functions::dpss;
use scirs2_core::numeric::Complex64;
use std::f64::consts::PI;
fn hann_window(n: usize) -> Vec<f64> {
if n == 1 {
return vec![1.0];
}
(0..n)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f64 / (n as f64 - 1.0)).cos()))
.collect()
}
fn windowed_fft(segment: &[f64], window: &[f64], fft_len: usize) -> FFTResult<Vec<Complex64>> {
let seg_len = segment.len().min(window.len());
let mut buf = vec![0.0_f64; fft_len];
for i in 0..seg_len {
buf[i] = segment[i] * window[i];
}
fft(&buf, None)
}
fn rfft_freqs(nfft: usize, fs: f64) -> Vec<f64> {
let n_freqs = nfft / 2 + 1;
(0..n_freqs)
.map(|k| k as f64 * fs / nfft as f64)
.collect()
}
pub fn welch_psd(
signal: &[f64],
window_size: usize,
overlap: usize,
fs: f64,
) -> FFTResult<(Vec<f64>, Vec<f64>)> {
if signal.is_empty() {
return Err(FFTError::ValueError("welch_psd: signal is empty".into()));
}
if window_size == 0 {
return Err(FFTError::ValueError(
"welch_psd: window_size must be positive".into(),
));
}
if overlap >= window_size {
return Err(FFTError::ValueError(
"welch_psd: overlap must be < window_size".into(),
));
}
let nfft = next_fast_len(window_size, true);
let win = hann_window(window_size);
let win_power = win.iter().map(|&w| w * w).sum::<f64>();
let hop = window_size - overlap;
let n_freqs = nfft / 2 + 1;
let mut psd_acc = vec![0.0_f64; n_freqs];
let mut n_segments = 0usize;
let mut start = 0usize;
while start + window_size <= signal.len() {
let seg = &signal[start..start + window_size];
let spectrum = windowed_fft(seg, &win, nfft)?;
for k in 0..n_freqs {
let power = spectrum[k].norm_sqr();
psd_acc[k] += if k == 0 || k == nfft / 2 {
power
} else {
2.0 * power };
}
n_segments += 1;
start += hop;
}
if n_segments == 0 {
return Err(FFTError::ValueError(
"welch_psd: signal is shorter than window_size".into(),
));
}
let scale = fs * win_power * n_segments as f64 * nfft as f64 * nfft as f64;
let psd: Vec<f64> = psd_acc.iter().map(|&p| p / scale).collect();
let freqs = rfft_freqs(nfft, fs);
Ok((freqs, psd))
}
pub fn bartlett_psd(signal: &[f64], segment_len: usize, fs: f64) -> FFTResult<(Vec<f64>, Vec<f64>)> {
welch_psd(signal, segment_len, 0, fs)
}
pub fn lomb_scargle(t: &[f64], y: &[f64], freqs: &[f64]) -> FFTResult<Vec<f64>> {
if t.len() != y.len() {
return Err(FFTError::DimensionError(format!(
"lomb_scargle: t.len()={} != y.len()={}",
t.len(),
y.len()
)));
}
if t.len() < 2 {
return Err(FFTError::ValueError(
"lomb_scargle: need at least 2 data points".into(),
));
}
if freqs.is_empty() {
return Err(FFTError::ValueError(
"lomb_scargle: freqs must be non-empty".into(),
));
}
let n = t.len();
let y_mean = y.iter().sum::<f64>() / n as f64;
let y_c: Vec<f64> = y.iter().map(|&yi| yi - y_mean).collect();
let y_var = y_c.iter().map(|&yi| yi * yi).sum::<f64>() / n as f64;
if y_var < 1e-30 {
return Ok(vec![0.0; freqs.len()]);
}
let mut power = Vec::with_capacity(freqs.len());
for &omega in freqs {
let sum_sin2 = t.iter().map(|&ti| (2.0 * omega * ti).sin()).sum::<f64>();
let sum_cos2 = t.iter().map(|&ti| (2.0 * omega * ti).cos()).sum::<f64>();
let tau = (0.5 / omega) * sum_sin2.atan2(sum_cos2);
let c_sq: f64 = t
.iter()
.zip(y_c.iter())
.map(|(&ti, &yi)| yi * (omega * (ti - tau)).cos())
.sum::<f64>()
.powi(2);
let s_sq: f64 = t
.iter()
.zip(y_c.iter())
.map(|(&ti, &yi)| yi * (omega * (ti - tau)).sin())
.sum::<f64>()
.powi(2);
let denom_cos: f64 = t
.iter()
.map(|&ti| (omega * (ti - tau)).cos().powi(2))
.sum::<f64>();
let denom_sin: f64 = t
.iter()
.map(|&ti| (omega * (ti - tau)).sin().powi(2))
.sum::<f64>();
let p = if denom_cos < 1e-30 || denom_sin < 1e-30 {
0.0
} else {
(c_sq / denom_cos + s_sq / denom_sin) / (2.0 * y_var)
};
power.push(p);
}
Ok(power)
}
pub fn multitaper_psd(
signal: &[f64],
nw: f64,
k: usize,
fs: f64,
) -> FFTResult<(Vec<f64>, Vec<f64>)> {
if signal.len() < 2 {
return Err(FFTError::ValueError(
"multitaper_psd: signal must have at least 2 samples".into(),
));
}
if nw <= 0.0 {
return Err(FFTError::ValueError(
"multitaper_psd: nw must be positive".into(),
));
}
if k == 0 {
return Err(FFTError::ValueError(
"multitaper_psd: k must be at least 1".into(),
));
}
let k_max = (2.0 * nw - 1.0).floor() as usize;
if k > k_max {
return Err(FFTError::ValueError(format!(
"multitaper_psd: k={k} exceeds 2*nw-1={k_max}; tapers would be ill-conditioned"
)));
}
let n = signal.len();
let nfft = next_fast_len(n, true);
let n_freqs = nfft / 2 + 1;
let tapers = dpss(n, nw, k)?;
let mut psd_acc = vec![0.0_f64; n_freqs];
for taper in &tapers {
let mut buf = vec![0.0_f64; nfft];
for (i, (&s, &w)) in signal.iter().zip(taper.iter()).enumerate() {
buf[i] = s * w;
}
let spectrum = fft(&buf, None)?;
for k_idx in 0..n_freqs {
let power = spectrum[k_idx].norm_sqr();
psd_acc[k_idx] += if k_idx == 0 || k_idx == nfft / 2 {
power
} else {
2.0 * power
};
}
}
let scale = fs * (n as f64) * (n as f64) * tapers.len() as f64;
let psd: Vec<f64> = psd_acc.iter().map(|&p| p / scale).collect();
let freqs = rfft_freqs(nfft, fs);
Ok((freqs, psd))
}
pub fn coherence(x: &[f64], y: &[f64], fs: f64) -> FFTResult<(Vec<f64>, Vec<f64>)> {
if x.is_empty() {
return Err(FFTError::ValueError("coherence: x is empty".into()));
}
if x.len() != y.len() {
return Err(FFTError::DimensionError(format!(
"coherence: x.len()={} != y.len()={}",
x.len(),
y.len()
)));
}
let window_size = 256.min(x.len());
let overlap = window_size / 2;
let nfft = next_fast_len(window_size, true);
let win = hann_window(window_size);
let hop = window_size - overlap;
let n_freqs = nfft / 2 + 1;
let mut pxx = vec![0.0_f64; n_freqs]; let mut pyy = vec![0.0_f64; n_freqs]; let mut pxy = vec![Complex64::new(0.0, 0.0); n_freqs]; let mut n_segments = 0usize;
let mut start = 0usize;
while start + window_size <= x.len() {
let sx = windowed_fft(&x[start..start + window_size], &win, nfft)?;
let sy = windowed_fft(&y[start..start + window_size], &win, nfft)?;
for k in 0..n_freqs {
let scale = if k == 0 || k == nfft / 2 { 1.0 } else { 2.0 };
pxx[k] += scale * sx[k].norm_sqr();
pyy[k] += scale * sy[k].norm_sqr();
let cross = sx[k].conj() * sy[k];
pxy[k] = Complex64::new(
pxy[k].re + scale * cross.re,
pxy[k].im + scale * cross.im,
);
}
n_segments += 1;
start += hop;
}
if n_segments == 0 {
return Err(FFTError::ValueError(
"coherence: signal is shorter than window_size".into(),
));
}
let coh: Vec<f64> = (0..n_freqs)
.map(|k| {
let denom = pxx[k] * pyy[k];
if denom < 1e-60 {
0.0
} else {
pxy[k].norm_sqr() / denom
}
})
.collect();
let freqs = rfft_freqs(nfft, fs);
Ok((freqs, coh))
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
fn sine_wave(freq_hz: f64, n: usize, fs: f64) -> Vec<f64> {
(0..n)
.map(|i| (2.0 * PI * freq_hz * i as f64 / fs).sin())
.collect()
}
#[test]
fn test_welch_psd_output_shape() {
let sig = sine_wave(100.0, 4096, 1000.0);
let (freqs, psd) = welch_psd(&sig, 256, 128, 1000.0).expect("welch");
assert_eq!(freqs.len(), psd.len());
assert_eq!(freqs.len(), 256 / 2 + 1);
}
#[test]
fn test_welch_psd_peak_near_signal_freq() {
let fs = 1000.0_f64;
let sig = sine_wave(100.0, 8192, fs);
let (freqs, psd) = welch_psd(&sig, 512, 256, fs).expect("welch");
let peak_idx = psd
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("comparison"))
.map(|(i, _)| i)
.expect("peak");
let peak_freq = freqs[peak_idx];
assert!(
(peak_freq - 100.0).abs() < 5.0,
"Peak at {peak_freq} Hz, expected ~100 Hz"
);
}
#[test]
fn test_bartlett_psd_no_overlap() {
let fs = 1000.0_f64;
let sig = sine_wave(200.0, 4096, fs);
let (f1, p1) = welch_psd(&sig, 256, 0, fs).expect("welch no-overlap");
let (f2, p2) = bartlett_psd(&sig, 256, fs).expect("bartlett");
assert_eq!(f1.len(), f2.len());
for (a, b) in p1.iter().zip(p2.iter()) {
assert!((a - b).abs() < 1e-20, "bartlett mismatch: {a} vs {b}");
}
}
#[test]
fn test_lomb_scargle_peak_at_signal_freq() {
let t: Vec<f64> = (0..50).map(|i| i as f64 * 0.1).collect(); let y: Vec<f64> = t.iter().map(|&ti| (2.0 * PI * 1.0 * ti).sin()).collect();
let freqs: Vec<f64> = (1..=20)
.map(|k| 2.0 * PI * k as f64 * 0.5)
.collect();
let power = lomb_scargle(&t, &y, &freqs).expect("lomb");
let peak_idx = power
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("cmp"))
.map(|(i, _)| i)
.expect("peak");
let peak_hz = freqs[peak_idx] / (2.0 * PI);
assert!(
(peak_hz - 1.0).abs() < 0.6,
"LS peak at {peak_hz} Hz, expected ~1 Hz"
);
}
#[test]
fn test_multitaper_psd_shape() {
let sig = sine_wave(100.0, 512, 1000.0);
let (freqs, psd) = multitaper_psd(&sig, 4.0, 7, 1000.0).expect("multitaper");
assert_eq!(freqs.len(), psd.len());
}
#[test]
fn test_multitaper_psd_peak() {
let fs = 1000.0_f64;
let sig = sine_wave(100.0, 512, fs);
let (freqs, psd) = multitaper_psd(&sig, 4.0, 7, fs).expect("multitaper");
let peak_idx = psd
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).expect("cmp"))
.map(|(i, _)| i)
.expect("peak");
let peak_freq = freqs[peak_idx];
assert!(
(peak_freq - 100.0).abs() < 15.0,
"Multitaper peak at {peak_freq}, expected ~100 Hz"
);
}
#[test]
fn test_coherence_identical_signals() {
let fs = 1000.0_f64;
let sig = sine_wave(100.0, 4096, fs);
let (freqs, coh) = coherence(&sig, &sig, fs).expect("coherence");
assert_eq!(freqs.len(), coh.len());
let peak = coh.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!(peak > 0.99, "coherence peak={peak}, expected near 1");
}
#[test]
fn test_coherence_unrelated_signals() {
let fs = 1000.0_f64;
let n = 4096;
let x = sine_wave(100.0, n, fs);
let y = sine_wave(300.0, n, fs);
let (freqs, coh) = coherence(&x, &y, fs).expect("coherence");
assert_eq!(freqs.len(), coh.len());
for &c in &coh {
assert!(c >= 0.0 - 1e-9 && c <= 1.0 + 1e-9, "coherence out of [0,1]: {c}");
}
}
#[test]
fn test_lomb_scargle_constant_signal_zero_power() {
let t: Vec<f64> = (0..20).map(|i| i as f64 * 0.1).collect();
let y = vec![3.14_f64; 20];
let freqs = vec![2.0 * PI * 1.0, 2.0 * PI * 2.0];
let power = lomb_scargle(&t, &y, &freqs).expect("constant");
for &p in &power {
assert!(p.abs() < 1e-10, "constant signal should have zero LS power: {p}");
}
}
}