Skip to main content

burn_tensor/tensor/signal/
hamming_window.rs

1use burn_backend::{
2    Backend,
3    tensor::{Float, Int},
4};
5
6use crate::{Tensor, TensorCreationOptions, check, check::TensorCheck};
7
8/// Creates a 1D Hamming window.
9///
10#[cfg_attr(
11    doc,
12    doc = r#"
13$$w_n = \alpha - \beta \cos\left(\frac{2\pi n}{N}\right)$$
14
15where $\alpha = 25/46$, $\beta = 1 - \alpha$, and $N$ = `size` when `periodic` is `true`, or $N$ = `size - 1` when `periodic` is `false`.
16"#
17)]
18#[cfg_attr(
19    not(doc),
20    doc = "`w_n = 25/46 - 21/46 * cos(2πn/N)` where N = size (periodic) or N = size-1 (symmetric)"
21)]
22///
23/// # Notes
24///
25/// - `size == 0` returns an empty tensor.
26/// - `size == 1` returns `[1.0]` regardless of `periodic`.
27///
28/// # Example
29///
30/// ```rust
31/// use burn_tensor::backend::Backend;
32/// use burn_tensor::signal::hamming_window;
33///
34/// fn example<B: Backend>() {
35///     let device = B::Device::default();
36///     let window = hamming_window::<B>(8, true, &device);
37///     println!("{window}");
38/// }
39/// ```
40pub fn hamming_window<B: Backend>(
41    size: usize,
42    periodic: bool,
43    options: impl Into<TensorCreationOptions<B>>,
44) -> Tensor<B, 1> {
45    let opt = options.into();
46    let dtype = opt.resolve_dtype::<Float>();
47    let shape = [size];
48    check!(TensorCheck::creation_ops::<1>("HammingWindow", &shape));
49
50    if size == 0 {
51        return Tensor::<B, 1>::empty(shape, opt).cast(dtype);
52    }
53
54    if size == 1 {
55        return Tensor::<B, 1>::ones(shape, opt).cast(dtype);
56    }
57
58    let size_i64 = i64::try_from(size).expect("HammingWindow size doesn't fit in i64 range.");
59    let denominator = if periodic { size } else { size - 1 };
60    let angular_increment = (2.0 * core::f64::consts::PI) / denominator as f64;
61
62    let alpha = 25.0_f64 / 46.0_f64;
63    let beta = 1.0 - alpha;
64
65    Tensor::<B, 1, Int>::arange(0..size_i64, &opt.device)
66        .float()
67        .mul_scalar(angular_increment)
68        .cos()
69        .mul_scalar(-beta)
70        .add_scalar(alpha)
71        .cast(dtype)
72}