use burn_backend::{
Backend,
tensor::{Float, Int},
};
use crate::{Tensor, TensorCreationOptions, check, check::TensorCheck};
#[cfg_attr(
doc,
doc = r#"
$$w_n = 0.42 - 0.5 \cos\left(\frac{2\pi n}{N}\right) + 0.08 \cos\left(\frac{4\pi n}{N}\right)$$
where $N$ = `size` when `periodic` is `true`, or $N$ = `size - 1` when `periodic` is `false`.
"#
)]
#[cfg_attr(
not(doc),
doc = "`w_n = 0.42 - 0.5 * cos(2πn / N) + 0.08 * cos(4πn / N)` where N = size (periodic) or N = size-1 (symmetric)"
)]
pub fn blackman_window<B: Backend>(
size: usize,
periodic: bool,
options: impl Into<TensorCreationOptions<B>>,
) -> Tensor<B, 1> {
let opt = options.into();
let dtype = opt.resolve_dtype::<Float>();
let shape = [size];
check!(TensorCheck::creation_ops::<1>("BlackmanWindow", &shape));
if size == 0 {
return Tensor::<B, 1>::empty(shape, opt).cast(dtype);
}
if size == 1 {
return Tensor::<B, 1>::ones(shape, opt).cast(dtype);
}
let size_i64 = i64::try_from(size).expect("BlackmanWindow size doesn't fit in i64 range.");
let denominator = if periodic { size } else { size - 1 };
let angular_increment = (2.0 * core::f64::consts::PI) / denominator as f64;
let cos_val = Tensor::<B, 1, Int>::arange(0..size_i64, &opt.device)
.float()
.mul_scalar(angular_increment)
.cos();
let first_cos_term = cos_val.clone().mul_scalar(-0.5);
let second_cos_term = cos_val.powi_scalar(2).mul_scalar(0.16);
first_cos_term
.add(second_cos_term)
.add_scalar(0.34)
.cast(dtype)
}