use crate::error::{FFTError, FFTResult};
use scirs2_core::ndarray::{Array1, ArrayBase, Data, Ix1};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::f64::consts::PI;
use std::fmt::Debug;
use std::str::FromStr;
#[derive(Debug, Clone, PartialEq)]
pub enum Window {
Rectangular,
Hann,
Hamming,
Blackman,
Bartlett,
FlatTop,
Parzen,
Bohman,
BlackmanHarris,
Nuttall,
Barthann,
Cosine,
Exponential,
Tukey(f64),
Kaiser(f64),
Gaussian(f64),
GeneralCosine(Vec<f64>),
}
impl FromStr for Window {
type Err = FFTError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"rectangular" | "boxcar" | "rect" => Ok(Window::Rectangular),
"hann" | "hanning" => Ok(Window::Hann),
"hamming" => Ok(Window::Hamming),
"blackman" => Ok(Window::Blackman),
"bartlett" | "triangular" | "triangle" => Ok(Window::Bartlett),
"flattop" | "flat" => Ok(Window::FlatTop),
"parzen" => Ok(Window::Parzen),
"bohman" => Ok(Window::Bohman),
"blackmanharris" | "blackman-harris" => Ok(Window::BlackmanHarris),
"nuttall" => Ok(Window::Nuttall),
"barthann" => Ok(Window::Barthann),
"cosine" | "cos" => Ok(Window::Cosine),
"exponential" | "exp" => Ok(Window::Exponential),
_ => Err(FFTError::ValueError(format!("Unknown window type: {s}"))),
}
}
}
#[allow(dead_code)]
pub fn get_window<T>(window: T, n: usize, sym: bool) -> FFTResult<Array1<f64>>
where
T: Into<WindowParam>,
{
if n == 0 {
return Err(FFTError::ValueError(
"Window length must be positive".to_string(),
));
}
let window_param = window.into();
let window_type = match window_param {
WindowParam::Type(wt) => wt,
WindowParam::Name(s) => Window::from_str(&s)?,
};
match window_type {
Window::Rectangular => rectangular(n),
Window::Hann => hann(n, sym),
Window::Hamming => hamming(n, sym),
Window::Blackman => blackman(n, sym),
Window::Bartlett => bartlett(n, sym),
Window::FlatTop => flattop(n, sym),
Window::Parzen => parzen(n, sym),
Window::Bohman => bohman(n),
Window::BlackmanHarris => blackmanharris(n, sym),
Window::Nuttall => nuttall(n, sym),
Window::Barthann => barthann(n, sym),
Window::Cosine => cosine(n, sym),
Window::Exponential => exponential(n, sym, 1.0),
Window::Tukey(alpha) => tukey(n, sym, alpha),
Window::Kaiser(beta) => kaiser(n, sym, beta),
Window::Gaussian(std) => gaussian(n, sym, std),
Window::GeneralCosine(coeffs) => general_cosine(n, sym, &coeffs),
}
}
#[derive(Debug)]
pub enum WindowParam {
Type(Window),
Name(String),
}
impl From<Window> for WindowParam {
fn from(window: Window) -> Self {
WindowParam::Type(window)
}
}
impl From<&str> for WindowParam {
fn from(s: &str) -> Self {
WindowParam::Name(s.to_string())
}
}
impl From<String> for WindowParam {
fn from(s: String) -> Self {
WindowParam::Name(s)
}
}
#[allow(dead_code)]
fn rectangular(n: usize) -> FFTResult<Array1<f64>> {
Ok(Array1::ones(n))
}
#[allow(dead_code)]
fn hann(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
general_cosine(n, sym, &[0.5, 0.5])
}
#[allow(dead_code)]
fn hamming(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
general_cosine(n, sym, &[0.54, 0.46])
}
#[allow(dead_code)]
fn blackman(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
general_cosine(n, sym, &[0.42, 0.5, 0.08])
}
#[allow(dead_code)]
fn bartlett(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
let mut n = n;
if !sym {
n += 1;
}
let mut w = Array1::zeros(n);
let range: Vec<f64> = (0..n).map(|i| i as f64).collect();
for (i, &x) in range.iter().enumerate() {
if x < (n as f64) / 2.0 {
w[i] = 2.0 * x / (n as f64 - 1.0);
} else {
w[i] = 2.0 - 2.0 * x / (n as f64 - 1.0);
}
}
if !sym {
let w_slice = w.slice(scirs2_core::ndarray::s![0..n - 1]).to_owned();
Ok(w_slice)
} else {
Ok(w)
}
}
#[allow(dead_code)]
fn flattop(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
general_cosine(
n,
sym,
&[
0.215_578_95,
0.416_631_58,
0.277_263_158,
0.083_578_947,
0.006_947_368,
],
)
}
#[allow(dead_code)]
fn parzen(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
let mut n = n;
if !sym {
n += 1;
}
let mut w = Array1::zeros(n);
let half_n = (n as f64) / 2.0;
for i in 0..n {
let x = (i as f64 - half_n + 0.5).abs() / half_n;
if x <= 0.5 {
w[i] = 1.0 - 6.0 * x.powi(2) * (1.0 - x);
} else if x <= 1.0 {
w[i] = 2.0 * (1.0 - x).powi(3);
}
}
if !sym {
let w_slice = w.slice(scirs2_core::ndarray::s![0..n - 1]).to_owned();
Ok(w_slice)
} else {
Ok(w)
}
}
#[allow(dead_code)]
fn bohman(n: usize) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
let mut w = Array1::zeros(n);
let half_n = (n as f64 - 1.0) / 2.0;
for i in 0..n {
let x = ((i as f64) - half_n).abs() / half_n;
if x <= 1.0 {
w[i] = (1.0 - x) * (PI * x).cos() + (PI * x).sin() / PI;
}
}
Ok(w)
}
#[allow(dead_code)]
fn blackmanharris(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
general_cosine(n, sym, &[0.35875, 0.48829, 0.14128, 0.01168])
}
#[allow(dead_code)]
fn nuttall(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
general_cosine(
n,
sym,
&[0.363_581_9, 0.489_177_5, 0.136_599_5, 0.010_641_1],
)
}
#[allow(dead_code)]
fn barthann(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
let mut n = n;
if !sym {
n += 1;
}
let mut w = Array1::zeros(n);
let fac = 1.0 / (n as f64 - 1.0);
for i in 0..n {
let x = i as f64 * fac;
w[i] = 0.62 - 0.48 * (2.0 * x - 1.0).abs() + 0.38 * (2.0 * PI * (2.0 * x - 1.0)).cos();
}
if !sym {
let w_slice = w.slice(scirs2_core::ndarray::s![0..n - 1]).to_owned();
Ok(w_slice)
} else {
Ok(w)
}
}
#[allow(dead_code)]
fn cosine(n: usize, sym: bool) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
let mut w = Array1::zeros(n);
let range: Vec<f64> = if sym {
(0..n).map(|i| i as f64 / (n as f64 - 1.0)).collect()
} else {
(0..n).map(|i| i as f64 / n as f64).collect()
};
for (i, &x) in range.iter().enumerate() {
w[i] = (PI * x).sin();
}
Ok(w)
}
#[allow(dead_code)]
fn exponential(n: usize, sym: bool, tau: f64) -> FFTResult<Array1<f64>> {
if tau <= 0.0 {
return Err(FFTError::ValueError("tau must be positive".to_string()));
}
if n == 1 {
return Ok(Array1::ones(1));
}
let center = if sym { (n as f64 - 1.0) / 2.0 } else { 0.0 };
let mut w = Array1::zeros(n);
for i in 0..n {
let x = (i as f64 - center).abs() / (tau * (n as f64));
w[i] = (-x).exp();
}
Ok(w)
}
#[allow(dead_code)]
fn tukey(n: usize, sym: bool, alpha: f64) -> FFTResult<Array1<f64>> {
if !(0.0..=1.0).contains(&alpha) {
return Err(FFTError::ValueError(
"alpha must be between 0 and 1".to_string(),
));
}
if n == 1 {
return Ok(Array1::ones(1));
}
if alpha == 0.0 {
return rectangular(n);
}
if alpha == 1.0 {
return hann(n, sym);
}
let mut w = Array1::ones(n);
let width = (alpha * (n as f64 - 1.0) / 2.0).floor() as usize;
for i in 0..width {
let x = 0.5 * (1.0 + ((PI * i as f64) / width as f64).cos());
w[i] = x;
}
for i in 0..width {
let idx = n - 1 - i;
let x = 0.5 * (1.0 + ((PI * i as f64) / width as f64).cos());
w[idx] = x;
}
Ok(w)
}
#[allow(dead_code)]
fn kaiser(n: usize, sym: bool, beta: f64) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
if beta < 0.0 {
return Err(FFTError::ValueError(
"beta must be non-negative".to_string(),
));
}
let mut n = n;
if !sym {
n += 1;
}
let mut w = Array1::zeros(n);
let alpha = 0.5 * (n as f64 - 1.0);
let i0_beta = bessel_i0(beta);
for i in 0..n {
let x = beta * (1.0 - ((i as f64 - alpha) / alpha).powi(2)).sqrt();
w[i] = bessel_i0(x) / i0_beta;
}
if !sym {
let w_slice = w.slice(scirs2_core::ndarray::s![0..n - 1]).to_owned();
Ok(w_slice)
} else {
Ok(w)
}
}
#[allow(dead_code)]
fn gaussian(n: usize, sym: bool, std: f64) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
if std <= 0.0 {
return Err(FFTError::ValueError("std must be positive".to_string()));
}
let mut w = Array1::zeros(n);
let center = if sym { (n as f64 - 1.0) / 2.0 } else { 0.0 };
for i in 0..n {
let x = (i as f64 - center) / (std * (n as f64 - 1.0) / 2.0);
w[i] = (-0.5 * x.powi(2)).exp();
}
Ok(w)
}
#[allow(dead_code)]
fn general_cosine(n: usize, sym: bool, a: &[f64]) -> FFTResult<Array1<f64>> {
if n == 1 {
return Ok(Array1::ones(1));
}
let mut w = Array1::zeros(n);
let fac = if sym {
2.0 * PI / (n as f64 - 1.0)
} else {
2.0 * PI / n as f64
};
for i in 0..n {
let mut win_val = a[0];
for (k, &coef) in a.iter().enumerate().skip(1) {
let sign = if k % 2 == 1 { -1.0 } else { 1.0 };
win_val += sign * coef * ((k as f64) * fac * (i as f64)).cos();
}
w[i] = win_val;
}
Ok(w)
}
#[allow(dead_code)]
fn bessel_i0(x: f64) -> f64 {
let ax = x.abs();
if ax < 3.75 {
let y = (x / 3.75).powi(2);
return y.mul_add(
3.515_622_9
+ y * (3.089_942_4
+ y * (1.206_749_2 + y * (0.265_973_2 + y * (0.036_076_8 + y * 0.004_581_3)))),
1.0,
);
}
let y = 3.75 / ax;
let exp_term = (ax).exp() / (ax).sqrt();
exp_term
* y.mul_add(
0.013_285_92
+ y * (0.002_253_19
+ y * (-0.001_575_65
+ y * (0.009_162_81
+ y * (-0.020_577_06
+ y * (0.026_355_37 + y * (-0.016_476_33 + y * 0.003_923_77)))))),
0.398_942_28,
)
}
#[allow(dead_code)]
pub fn apply_window<F, S>(x: &ArrayBase<S, Ix1>, window: Window) -> FFTResult<Array1<F>>
where
S: Data<Elem = F>,
F: Float + FromPrimitive + Debug,
{
let n = x.len();
let win = get_window(window, n, true)?;
let mut result = Array1::zeros(n);
for i in 0..n {
result[i] = x[i] * F::from_f64(win[i]).expect("Operation failed");
}
Ok(result)
}
#[allow(dead_code)]
pub fn enbw(window: Window, n: usize) -> FFTResult<f64> {
let w = get_window(window, n, true)?;
let sum_squared = w.iter().map(|&x| x.powi(2)).sum::<f64>();
let square_sum = w.iter().sum::<f64>().powi(2);
let n_f64 = n as f64;
Ok(n_f64 * sum_squared / square_sum)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_rectangular() {
let win = rectangular(5).expect("Operation failed");
let expected = [1.0, 1.0, 1.0, 1.0, 1.0];
for (a, &b) in win.iter().zip(expected.iter()) {
assert_relative_eq!(a, &b, epsilon = 1e-10);
}
}
#[test]
fn test_hann() {
let win = hann(5, true).expect("Operation failed");
let expected = [0.0, 0.5, 1.0, 0.5, 0.0];
for (a, &b) in win.iter().zip(expected.iter()) {
assert_relative_eq!(a, &b, epsilon = 1e-10);
}
}
#[test]
fn test_hamming() {
let win = hamming(5, true).expect("Operation failed");
let expected = [0.08, 0.54, 1.0, 0.54, 0.08];
for (a, &b) in win.iter().zip(expected.iter()) {
assert_relative_eq!(a, &b, epsilon = 1e-10);
}
}
#[test]
fn test_blackman() {
let win = blackman(5, true).expect("Operation failed");
let expected = [0.0, 0.34, 1.0, 0.34, 0.0];
for (a, &b) in win.iter().zip(expected.iter()) {
assert_relative_eq!(a, &b, epsilon = 0.01);
}
}
#[test]
fn test_from_str() {
assert_eq!(
Window::from_str("hann").expect("Operation failed"),
Window::Hann
);
assert_eq!(
Window::from_str("hamming").expect("Operation failed"),
Window::Hamming
);
assert_eq!(
Window::from_str("blackman").expect("Operation failed"),
Window::Blackman
);
assert_eq!(
Window::from_str("rectangular").expect("Operation failed"),
Window::Rectangular
);
assert!(Window::from_str("invalid").is_err());
}
#[test]
fn test_get_window() {
let win1 = get_window(Window::Hann, 5, true).expect("Operation failed");
let win2 = get_window("hann", 5, true).expect("Operation failed");
for (a, b) in win1.iter().zip(win2.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_apply_window() {
let signal = Array1::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let win = apply_window(&signal.view(), Window::Hann).expect("Operation failed");
let expected = Array1::from_vec(vec![0.0, 1.0, 3.0, 2.0, 0.0]);
for (a, b) in win.iter().zip(expected.iter()) {
assert_relative_eq!(a, b, epsilon = 1e-10);
}
}
#[test]
fn test_enbw() {
let rect_enbw = enbw(Window::Rectangular, 1024).expect("Operation failed");
assert_relative_eq!(rect_enbw, 1.0, epsilon = 1e-10);
let hann_enbw = enbw(Window::Hann, 1024).expect("Operation failed");
assert_relative_eq!(hann_enbw, 1.5, epsilon = 0.01);
let hamming_enbw = enbw(Window::Hamming, 1024).expect("Operation failed");
assert_relative_eq!(hamming_enbw, 1.36, epsilon = 0.01);
}
}