use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn bartlett(m: usize) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::bartlett(m).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn blackman(m: usize) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::blackman(m).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn hamming(m: usize) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::hamming(m).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn hann(m: usize) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::hanning(m).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
#[inline]
pub fn hanning(m: usize) -> FerrotorchResult<Tensor<f64>> {
hann(m)
}
pub fn kaiser(m: usize, beta: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::kaiser(m, beta).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn cosine(m: usize) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::cosine(m).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn exponential(m: usize, center: Option<f64>, tau: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::exponential(m, center, tau).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn gaussian(m: usize, std: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::gaussian(m, std).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn general_cosine(m: usize, coeffs: &[f64]) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::general_cosine(m, coeffs).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn general_hamming(m: usize, alpha: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::general_hamming(m, alpha).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn nuttall(m: usize) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::nuttall(m).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn parzen(m: usize) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::parzen(m).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn taylor(m: usize, nbar: usize, sll: f64, norm: bool) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::taylor(m, nbar, sll, norm).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
pub fn tukey(m: usize, alpha: f64) -> FerrotorchResult<Tensor<f64>> {
let arr = ferray_window::tukey(m, alpha).map_err(FerrotorchError::Ferray)?;
array_to_tensor(arr, m)
}
fn array_to_tensor(
arr: ferray_core::Array<f64, ferray_core::Ix1>,
m: usize,
) -> FerrotorchResult<Tensor<f64>> {
let data: Vec<f64> = arr.into_iter().collect();
Tensor::from_storage(TensorStorage::cpu(data), vec![m], false)
}
#[cfg(test)]
mod tests {
use super::*;
fn close(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn bartlett_length() {
let w = bartlett(8).unwrap();
assert_eq!(w.shape(), &[8]);
assert_eq!(w.device(), crate::Device::Cpu);
}
#[test]
fn blackman_length() {
let w = blackman(16).unwrap();
assert_eq!(w.shape(), &[16]);
}
#[test]
fn hann_length() {
let w = hann(32).unwrap();
assert_eq!(w.shape(), &[32]);
}
#[test]
fn hamming_length() {
let w = hamming(64).unwrap();
assert_eq!(w.shape(), &[64]);
}
#[test]
fn kaiser_length() {
let w = kaiser(128, 14.0).unwrap();
assert_eq!(w.shape(), &[128]);
}
#[test]
fn bartlett_endpoints_are_zero() {
let w = bartlett(8).unwrap();
let d = w.data().unwrap();
assert!(close(d[0], 0.0, 1e-12));
assert!(close(d[d.len() - 1], 0.0, 1e-12));
}
#[test]
fn bartlett_is_symmetric() {
let w = bartlett(11).unwrap();
let d = w.data().unwrap();
let n = d.len();
for i in 0..n {
assert!(
close(d[i], d[n - 1 - i], 1e-12),
"bartlett not symmetric at {i}: {} vs {}",
d[i],
d[n - 1 - i]
);
}
}
#[test]
fn hann_endpoints_are_zero() {
let w = hann(16).unwrap();
let d = w.data().unwrap();
assert!(close(d[0], 0.0, 1e-12));
assert!(close(d[d.len() - 1], 0.0, 1e-12));
}
#[test]
fn hann_peak_is_one() {
let w = hann(11).unwrap();
let d = w.data().unwrap();
assert!(close(d[d.len() / 2], 1.0, 1e-12));
}
#[test]
fn hamming_endpoints_match_alpha_minus_beta() {
let w = hamming(8).unwrap();
let d = w.data().unwrap();
assert!((d[0] - 0.08).abs() < 0.02);
assert!((d[d.len() - 1] - 0.08).abs() < 0.02);
}
#[test]
fn blackman_is_symmetric() {
let w = blackman(15).unwrap();
let d = w.data().unwrap();
let n = d.len();
for i in 0..n {
assert!(close(d[i], d[n - 1 - i], 1e-12));
}
}
#[test]
fn kaiser_beta_zero_is_rectangular() {
let w = kaiser(16, 0.0).unwrap();
let d = w.data().unwrap();
for &v in d {
assert!(close(v, 1.0, 1e-12), "expected 1.0, got {v}");
}
}
#[test]
fn kaiser_peak_centre() {
let w = kaiser(11, 8.6).unwrap();
let d = w.data().unwrap();
let mid = d[d.len() / 2];
for (i, &v) in d.iter().enumerate() {
assert!(
v <= mid + 1e-12,
"kaiser sample {i}={v} exceeds centre {mid}",
);
}
}
#[test]
#[allow(clippy::float_cmp)]
fn hanning_is_alias_for_hann() {
let a = hann(13).unwrap();
let b = hanning(13).unwrap();
let ad = a.data().unwrap();
let bd = b.data().unwrap();
assert_eq!(ad.len(), bd.len());
for i in 0..ad.len() {
assert_eq!(ad[i], bd[i]);
}
}
#[test]
fn output_lives_on_cpu() {
for w in [
bartlett(4).unwrap(),
blackman(4).unwrap(),
hamming(4).unwrap(),
hann(4).unwrap(),
hanning(4).unwrap(),
kaiser(4, 5.0).unwrap(),
cosine(4).unwrap(),
exponential(4, None, 1.0).unwrap(),
gaussian(4, 1.0).unwrap(),
general_cosine(4, &[0.5, 0.5]).unwrap(),
general_hamming(4, 0.54).unwrap(),
nuttall(4).unwrap(),
parzen(4).unwrap(),
taylor(8, 4, 30.0, true).unwrap(),
tukey(4, 0.5).unwrap(),
] {
assert_eq!(w.device(), crate::Device::Cpu);
}
}
#[test]
fn cosine_length_and_symmetry() {
let w = cosine(8).unwrap();
assert_eq!(w.shape(), &[8]);
let d = w.data().unwrap();
for i in 0..4 {
assert!(close(d[i], d[7 - i], 1e-14));
}
}
#[test]
fn exponential_default_centre_is_symmetric() {
let w = exponential(8, None, 1.0).unwrap();
let d = w.data().unwrap();
for i in 0..4 {
assert!(close(d[i], d[7 - i], 1e-14));
}
}
#[test]
fn exponential_rejects_invalid_tau() {
assert!(exponential(8, None, 0.0).is_err());
assert!(exponential(8, None, -1.0).is_err());
}
#[test]
fn gaussian_centre_is_one_for_odd_m() {
let w = gaussian(11, 2.0).unwrap();
let d = w.data().unwrap();
assert!(close(d[5], 1.0, 1e-14));
}
#[test]
fn gaussian_known_value() {
let w = gaussian(7, 1.0).unwrap();
assert!(close(w.data().unwrap()[4], (-0.5_f64).exp(), 1e-14));
}
#[test]
fn gaussian_rejects_nonpositive_std() {
assert!(gaussian(8, 0.0).is_err());
assert!(gaussian(8, -1.0).is_err());
}
#[test]
fn general_cosine_with_hann_coeffs_matches_hann() {
let m = 9;
let gc = general_cosine(m, &[0.5, 0.5]).unwrap();
let hn = hann(m).unwrap();
for (a, b) in gc.data().unwrap().iter().zip(hn.data().unwrap().iter()) {
assert!(close(*a, *b, 1e-14));
}
}
#[test]
fn general_cosine_with_blackman_coeffs_matches_blackman() {
let m = 9;
let gc = general_cosine(m, &[0.42, 0.5, 0.08]).unwrap();
let bk = blackman(m).unwrap();
for (a, b) in gc.data().unwrap().iter().zip(bk.data().unwrap().iter()) {
assert!(close(*a, *b, 1e-12));
}
}
#[test]
fn general_cosine_rejects_empty_coeffs() {
assert!(general_cosine(8, &[]).is_err());
}
#[test]
fn general_hamming_alpha_half_matches_hann() {
let m = 9;
let gh = general_hamming(m, 0.5).unwrap();
let hn = hann(m).unwrap();
for (a, b) in gh.data().unwrap().iter().zip(hn.data().unwrap().iter()) {
assert!(close(*a, *b, 1e-14));
}
}
#[test]
fn general_hamming_alpha_054_matches_hamming() {
let m = 9;
let gh = general_hamming(m, 0.54).unwrap();
let hm = hamming(m).unwrap();
for (a, b) in gh.data().unwrap().iter().zip(hm.data().unwrap().iter()) {
assert!(close(*a, *b, 1e-14));
}
}
#[test]
fn nuttall_length_and_symmetry() {
let m = 33;
let w = nuttall(m).unwrap();
let d = w.data().unwrap();
for i in 0..m / 2 {
assert!(close(d[i], d[m - 1 - i], 1e-14));
}
}
#[test]
fn nuttall_endpoints_are_small() {
let w = nuttall(64).unwrap();
let d = w.data().unwrap();
assert!(d[0].abs() < 1e-2);
assert!(d[d.len() - 1].abs() < 1e-2);
}
#[test]
fn parzen_centre_is_one() {
let w = parzen(13).unwrap();
assert!(close(w.data().unwrap()[6], 1.0, 1e-14));
}
#[test]
fn parzen_is_symmetric() {
let m = 21;
let w = parzen(m).unwrap();
let d = w.data().unwrap();
for i in 0..m / 2 {
assert!(close(d[i], d[m - 1 - i], 1e-14));
}
}
#[test]
fn taylor_normalised_centre_is_one() {
let w = taylor(33, 4, 30.0, true).unwrap();
assert!(close(w.data().unwrap()[16], 1.0, 1e-12));
}
#[test]
fn taylor_is_symmetric() {
let m = 33;
let w = taylor(m, 4, 30.0, true).unwrap();
let d = w.data().unwrap();
for i in 0..m / 2 {
assert!(close(d[i], d[m - 1 - i], 1e-12));
}
}
#[test]
fn taylor_rejects_invalid_args() {
assert!(taylor(8, 0, 30.0, true).is_err());
assert!(taylor(8, 4, f64::NAN, true).is_err());
}
#[test]
fn tukey_alpha_zero_is_rectangular() {
let w = tukey(8, 0.0).unwrap();
for &v in w.data().unwrap() {
assert!(close(v, 1.0, 1e-14));
}
}
#[test]
fn tukey_alpha_one_matches_hann() {
let m = 9;
let tk = tukey(m, 1.0).unwrap();
let hn = hann(m).unwrap();
for (a, b) in tk.data().unwrap().iter().zip(hn.data().unwrap().iter()) {
assert!(close(*a, *b, 1e-12));
}
}
#[test]
fn tukey_centre_is_one() {
let m = 21;
let w = tukey(m, 0.5).unwrap();
assert!(close(w.data().unwrap()[m / 2], 1.0, 1e-14));
}
#[test]
fn tukey_rejects_invalid_alpha() {
assert!(tukey(8, -0.1).is_err());
assert!(tukey(8, 1.1).is_err());
assert!(tukey(8, f64::NAN).is_err());
}
}