burn_tensor/tensor/signal/
hamming_window.rs1use burn_backend::{
2 Backend,
3 tensor::{Float, Int},
4};
5
6use crate::{Tensor, TensorCreationOptions, check, check::TensorCheck};
7
8#[cfg_attr(
11 doc,
12 doc = r#"
13$$w_n = \alpha - \beta \cos\left(\frac{2\pi n}{N}\right)$$
14
15where $\alpha = 25/46$, $\beta = 1 - \alpha$, and $N$ = `size` when `periodic` is `true`, or $N$ = `size - 1` when `periodic` is `false`.
16"#
17)]
18#[cfg_attr(
19 not(doc),
20 doc = "`w_n = 25/46 - 21/46 * cos(2πn/N)` where N = size (periodic) or N = size-1 (symmetric)"
21)]
22pub fn hamming_window<B: Backend>(
41 size: usize,
42 periodic: bool,
43 options: impl Into<TensorCreationOptions<B>>,
44) -> Tensor<B, 1> {
45 let opt = options.into();
46 let dtype = opt.resolve_dtype::<Float>();
47 let shape = [size];
48 check!(TensorCheck::creation_ops::<1>("HammingWindow", &shape));
49
50 if size == 0 {
51 return Tensor::<B, 1>::empty(shape, opt).cast(dtype);
52 }
53
54 if size == 1 {
55 return Tensor::<B, 1>::ones(shape, opt).cast(dtype);
56 }
57
58 let size_i64 = i64::try_from(size).expect("HammingWindow size doesn't fit in i64 range.");
59 let denominator = if periodic { size } else { size - 1 };
60 let angular_increment = (2.0 * core::f64::consts::PI) / denominator as f64;
61
62 let alpha = 25.0_f64 / 46.0_f64;
63 let beta = 1.0 - alpha;
64
65 Tensor::<B, 1, Int>::arange(0..size_i64, &opt.device)
66 .float()
67 .mul_scalar(angular_increment)
68 .cos()
69 .mul_scalar(-beta)
70 .add_scalar(alpha)
71 .cast(dtype)
72}