use crate::error::{FFTError, FFTResult};
use std::f64::consts::PI;
pub fn hann(n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("hann: n must be positive".into()));
}
if n == 1 {
return Ok(vec![1.0]);
}
let w: Vec<f64> = (0..n)
.map(|k| 0.5 * (1.0 - (2.0 * PI * k as f64 / (n as f64 - 1.0)).cos()))
.collect();
Ok(w)
}
pub fn hamming(n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("hamming: n must be positive".into()));
}
if n == 1 {
return Ok(vec![1.0]);
}
let w: Vec<f64> = (0..n)
.map(|k| 0.54 - 0.46 * (2.0 * PI * k as f64 / (n as f64 - 1.0)).cos())
.collect();
Ok(w)
}
pub fn blackman(n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("blackman: n must be positive".into()));
}
if n == 1 {
return Ok(vec![1.0]);
}
let w: Vec<f64> = (0..n)
.map(|k| {
let x = 2.0 * PI * k as f64 / (n as f64 - 1.0);
0.42 - 0.5 * x.cos() + 0.08 * (2.0 * x).cos()
})
.collect();
Ok(w)
}
pub fn kaiser(n: usize, beta: f64) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("kaiser: n must be positive".into()));
}
if beta < 0.0 {
return Err(FFTError::ValueError("kaiser: beta must be non-negative".into()));
}
if n == 1 {
return Ok(vec![1.0]);
}
let alpha = (n as f64 - 1.0) / 2.0;
let i0_beta = bessel_i0(beta);
let w: Vec<f64> = (0..n)
.map(|k| {
let x = ((k as f64 - alpha) / alpha).powi(2);
let arg = beta * (1.0 - x).max(0.0).sqrt();
bessel_i0(arg) / i0_beta
})
.collect();
Ok(w)
}
pub fn flattop(n: usize) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("flattop: n must be positive".into()));
}
if n == 1 {
return Ok(vec![1.0]);
}
let a: [f64; 5] = [
0.215_578_95,
0.416_631_58,
0.277_263_158,
0.083_578_947,
0.006_947_368,
];
let w: Vec<f64> = (0..n)
.map(|k| {
let x = 2.0 * PI * k as f64 / (n as f64 - 1.0);
a[0] - a[1] * x.cos() + a[2] * (2.0 * x).cos()
- a[3] * (3.0 * x).cos()
+ a[4] * (4.0 * x).cos()
})
.collect();
Ok(w)
}
pub fn tukey(n: usize, alpha: f64) -> FFTResult<Vec<f64>> {
if n == 0 {
return Err(FFTError::ValueError("tukey: n must be positive".into()));
}
if !(0.0..=1.0).contains(&alpha) {
return Err(FFTError::ValueError(
"tukey: alpha must be in [0, 1]".into(),
));
}
if n == 1 {
return Ok(vec![1.0]);
}
if alpha == 0.0 {
return Ok(vec![1.0; n]);
}
if alpha == 1.0 {
return hann(n);
}
let width = (alpha * (n as f64 - 1.0) / 2.0).floor() as usize;
let mut w = vec![1.0_f64; n];
for i in 0..width {
let val = 0.5 * (1.0 + (PI * i as f64 / width as f64).cos());
w[i] = val;
w[n - 1 - i] = val;
}
Ok(w)
}
pub fn dpss(n: usize, nw: f64, k: usize) -> FFTResult<Vec<Vec<f64>>> {
if n == 0 {
return Err(FFTError::ValueError("dpss: n must be positive".into()));
}
if nw <= 0.0 {
return Err(FFTError::ValueError(
"dpss: time-bandwidth product nw must be positive".into(),
));
}
if k == 0 {
return Err(FFTError::ValueError(
"dpss: k must be at least 1".into(),
));
}
let k_max = (2.0 * nw - 1.0).floor() as usize;
if k > k_max {
return Err(FFTError::ValueError(format!(
"dpss: k={k} exceeds floor(2*nw-1)={k_max}"
)));
}
let w = nw / n as f64; let mut diag = vec![0.0_f64; n];
let mut off = vec![0.0_f64; n.saturating_sub(1)];
for i in 0..n {
let t = (n as f64 - 1.0) / 2.0 - i as f64;
diag[i] = t * t * (2.0 * PI * w).cos();
}
for i in 0..n.saturating_sub(1) {
off[i] = (i as f64 + 1.0) * (n as f64 - 1.0 - i as f64) / 2.0;
}
let mut tapers: Vec<Vec<f64>> = Vec::with_capacity(k);
for m in 0..k {
let mut v: Vec<f64> = (0..n)
.map(|i| {
let sign = if i % 2 == 0 { 1.0 } else { -1.0 };
sign / (n as f64).sqrt()
})
.collect();
let shift_approx = if m < k {
let frac = m as f64 / k_max as f64;
(1.0 - 2.0 * frac) * diag.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
} else {
0.0
};
let max_iter = 500;
let tol = 1e-13;
for _iter in 0..max_iter {
let mut tv = tridiag_mul(&diag, &off, &v);
for (tvi, &vi) in tv.iter_mut().zip(v.iter()) {
*tvi -= shift_approx * vi;
}
for prev in &tapers {
let dot: f64 = tv.iter().zip(prev.iter()).map(|(&a, &b)| a * b).sum();
for (tvi, &pi) in tv.iter_mut().zip(prev.iter()) {
*tvi -= dot * pi;
}
}
let norm: f64 = tv.iter().map(|x| x * x).sum::<f64>().sqrt();
if norm < 1e-30 {
break;
}
for val in &mut tv {
*val /= norm;
}
let diff: f64 = v
.iter()
.zip(tv.iter())
.map(|(&a, &b)| (a - b).abs())
.sum();
v = tv;
if diff < tol {
break;
}
}
let first_large = v.iter().find(|&&x| x.abs() > 1e-10).copied().unwrap_or(0.0);
if first_large < 0.0 {
for val in &mut v {
*val = -*val;
}
}
tapers.push(v);
}
Ok(tapers)
}
pub fn apply_window(signal: &[f64], window: &[f64]) -> FFTResult<Vec<f64>> {
if signal.is_empty() {
return Err(FFTError::ValueError("apply_window: signal is empty".into()));
}
if signal.len() != window.len() {
return Err(FFTError::DimensionError(format!(
"apply_window: signal.len()={} != window.len()={}",
signal.len(),
window.len()
)));
}
Ok(signal.iter().zip(window.iter()).map(|(&s, &w)| s * w).collect())
}
fn bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let y = (x / 3.75).powi(2);
y.mul_add(
3.515_622_9
+ y * (3.089_942_4
+ y * (1.206_749_2 + y * (0.265_973_2 + y * (0.036_076_8 + y * 0.004_581_3)))),
1.0,
)
} else {
let y = 3.75 / ax;
let exp_term = ax.exp() / ax.sqrt();
exp_term
* y.mul_add(
0.013_285_92
+ y * (0.002_253_19
+ y * (-0.001_575_65
+ y * (0.009_162_81
+ y * (-0.020_577_06
+ y * (0.026_355_37
+ y * (-0.016_476_33 + y * 0.003_923_77)))))),
0.398_942_28,
)
}
}
fn tridiag_mul(d: &[f64], o: &[f64], v: &[f64]) -> Vec<f64> {
let n = d.len();
let mut out = vec![0.0_f64; n];
if n == 0 {
return out;
}
out[0] = d[0] * v[0];
if n > 1 {
out[0] += o[0] * v[1];
}
for i in 1..n - 1 {
out[i] = o[i - 1] * v[i - 1] + d[i] * v[i] + o[i] * v[i + 1];
}
if n > 1 {
out[n - 1] = o[n - 2] * v[n - 2] + d[n - 1] * v[n - 1];
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn check_window(w: &[f64], n: usize) {
assert_eq!(w.len(), n);
for &v in w {
assert!(v >= -1e-10 && v <= 1.0 + 1e-10, "window value out of range: {v}");
}
}
#[test]
fn test_hann_zeros_at_endpoints() {
let w = hann(8).expect("hann");
check_window(&w, 8);
assert!(w[0].abs() < 1e-12, "hann[0]={}", w[0]);
assert!(w[7].abs() < 1e-12, "hann[7]={}", w[7]);
}
#[test]
fn test_hann_peak_at_centre() {
let n = 9;
let w = hann(n).expect("hann");
let centre = n / 2;
assert!((w[centre] - 1.0).abs() < 1e-10, "hann peak={}", w[centre]);
}
#[test]
fn test_hamming_non_zero_endpoints() {
let w = hamming(8).expect("hamming");
check_window(&w, 8);
assert!(w[0] > 0.05 && w[0] < 0.15, "hamming[0]={}", w[0]);
}
#[test]
fn test_blackman_length_and_range() {
let w = blackman(12).expect("blackman");
check_window(&w, 12);
}
#[test]
fn test_kaiser_length_and_range() {
let w = kaiser(10, 8.6).expect("kaiser");
assert_eq!(w.len(), 10);
for &v in &w {
assert!(v >= 0.0 && v <= 1.0 + 1e-10, "kaiser value out of range: {v}");
}
assert!(w[0] < 0.05, "kaiser beta=8.6 endpoint too large: {}", w[0]);
}
#[test]
fn test_kaiser_beta_zero_is_rectangular() {
let w = kaiser(8, 0.0).expect("kaiser_rect");
for &v in &w {
assert!((v - 1.0).abs() < 1e-10, "kaiser β=0 should be rectangular: {v}");
}
}
#[test]
fn test_flattop_length() {
let w = flattop(16).expect("flattop");
assert_eq!(w.len(), 16);
}
#[test]
fn test_tukey_alpha_zero_is_rectangular() {
let w = tukey(10, 0.0).expect("tukey_rect");
for &v in &w {
assert!((v - 1.0).abs() < 1e-10);
}
}
#[test]
fn test_tukey_alpha_one_is_hann() {
let n = 16;
let tw = tukey(n, 1.0).expect("tukey_hann");
let hw = hann(n).expect("hann");
for (t, h) in tw.iter().zip(hw.iter()) {
assert!((t - h).abs() < 1e-10, "tukey(1)≠hann: {t} vs {h}");
}
}
#[test]
fn test_tukey_flat_centre() {
let n = 20;
let w = tukey(n, 0.5).expect("tukey");
assert_eq!(w.len(), n);
assert!((w[10] - 1.0).abs() < 1e-10);
}
#[test]
fn test_apply_window_length() {
let signal = vec![1.0_f64; 8];
let win = hann(8).expect("hann");
let out = apply_window(&signal, &win).expect("apply");
assert_eq!(out.len(), 8);
}
#[test]
fn test_apply_window_values() {
let signal = vec![1.0_f64; 8];
let win = hann(8).expect("hann");
let out = apply_window(&signal, &win).expect("apply");
for (o, w) in out.iter().zip(win.iter()) {
assert!((o - w).abs() < 1e-14);
}
}
#[test]
fn test_dpss_count_and_length() {
let tapers = dpss(64, 4.0, 7).expect("dpss");
assert_eq!(tapers.len(), 7);
for t in &tapers {
assert_eq!(t.len(), 64);
}
}
#[test]
fn test_dpss_unit_energy() {
let tapers = dpss(64, 4.0, 7).expect("dpss");
for (i, taper) in tapers.iter().enumerate() {
let energy: f64 = taper.iter().map(|x| x * x).sum();
assert!(
(energy - 1.0).abs() < 1e-5,
"taper {i}: energy={energy}"
);
}
}
#[test]
fn test_dpss_orthogonality() {
let tapers = dpss(128, 4.0, 6).expect("dpss");
for i in 0..tapers.len() {
for j in (i + 1)..tapers.len() {
let dot: f64 = tapers[i]
.iter()
.zip(tapers[j].iter())
.map(|(&a, &b)| a * b)
.sum();
assert!(
dot.abs() < 0.1,
"tapers {i} and {j} are not orthogonal: dot={dot}"
);
}
}
}
#[test]
fn test_dpss_k_exceeds_limit_fails() {
let result = dpss(64, 2.0, 10);
assert!(result.is_err(), "expected error for k > 2*nw-1");
}
#[test]
fn test_apply_window_length_mismatch_error() {
let signal = vec![1.0_f64; 8];
let win = vec![1.0_f64; 5];
let result = apply_window(&signal, &win);
assert!(result.is_err());
}
}