Skip to main content

burn_tensor/tensor/signal/
stft.rs

1use alloc::vec;
2
3use burn_backend::Backend;
4
5use crate::Tensor;
6use crate::ops::PadMode;
7
8use super::{hermitian_extend, irfft, rfft};
9
10/// Configuration shared by [`stft`] and [`istft`].
11#[derive(Debug, Clone, Copy, PartialEq)]
12pub struct StftOptions {
13    /// Size of each FFT frame (must be >= 1).
14    pub n_fft: usize,
15    /// Stride between successive frames (must be >= 1 and <= effective window length
16    /// so overlap-add can reconstruct the signal).
17    pub hop_length: usize,
18    /// Window length. If `Some(w)`, the window is center-padded to `n_fft`. Defaults to `n_fft`.
19    pub win_length: Option<usize>,
20    /// If `true`, the signal is reflect-padded by `n_fft / 2` on both sides before framing.
21    pub center: bool,
22    /// If `true` (typical for real input), output has `n_fft/2 + 1` frequency bins; otherwise
23    /// the full `n_fft` bins are returned.
24    pub onesided: bool,
25}
26
27impl StftOptions {
28    /// Construct default options for the given FFT size, matching PyTorch defaults
29    /// (`hop_length = n_fft / 4`, `win_length = None`, `center = true`, `onesided = true`).
30    pub fn new(n_fft: usize) -> Self {
31        Self {
32            n_fft,
33            hop_length: (n_fft / 4).max(1),
34            win_length: None,
35            center: true,
36            onesided: true,
37        }
38    }
39
40    fn effective_win_length(&self) -> usize {
41        self.win_length.unwrap_or(self.n_fft)
42    }
43
44    fn assert_valid(&self, op: &'static str) {
45        let n_fft = self.n_fft;
46        let hop_length = self.hop_length;
47        assert!(n_fft >= 1, "{op}: n_fft must be >= 1, got {n_fft}");
48        assert!(
49            n_fft.is_power_of_two(),
50            "{op}: n_fft must be a power of two, got {n_fft}. True non-power-of-two \
51             DFT support is tracked as a follow-up (Bluestein's algorithm)."
52        );
53        assert!(
54            hop_length >= 1,
55            "{op}: hop_length must be >= 1, got {hop_length}"
56        );
57        let win_len = self.effective_win_length();
58        assert!(
59            win_len >= 1,
60            "{op}: effective win_length must be >= 1, got {win_len}"
61        );
62        assert!(
63            win_len <= n_fft,
64            "{op}: win_length ({win_len}) must be <= n_fft ({n_fft})"
65        );
66        assert!(
67            hop_length <= win_len,
68            "{op}: hop_length ({hop_length}) must be <= effective win_length ({win_len}) \
69             so the window/hop combination satisfies the COLA/NOLA condition required \
70             for invertibility"
71        );
72    }
73}
74
75impl Default for StftOptions {
76    fn default() -> Self {
77        Self::new(400)
78    }
79}
80
81/// Computes the Short-Time Fourier Transform (STFT).
82///
83/// Splits the signal into overlapping windowed frames and computes the FFT on each.
84///
85/// # Returns
86///
87/// A tensor of shape `[batch, n_frames, n_freqs, 2]` where the last dimension holds
88/// `[real, imaginary]`.
89pub fn stft<B: Backend>(
90    signal: Tensor<B, 2>,
91    window: Option<Tensor<B, 1>>,
92    options: StftOptions,
93) -> Tensor<B, 4> {
94    options.assert_valid("stft");
95    let n_fft = options.n_fft;
96    let hop_length = options.hop_length;
97    let center = options.center;
98    let onesided = options.onesided;
99    let win_len = options.effective_win_length();
100
101    let device = signal.device();
102
103    let window = match window {
104        Some(w) => {
105            assert_eq!(
106                w.dims()[0],
107                win_len,
108                "stft: window length ({}) must match effective win_length ({win_len})",
109                w.dims()[0],
110            );
111            w
112        }
113        None => Tensor::ones([win_len], &device),
114    };
115
116    let window = pad_window_to_n_fft(window, win_len, n_fft);
117
118    let [_, raw_sig_len] = signal.dims();
119    if center {
120        assert!(
121            raw_sig_len > n_fft / 2,
122            "stft: signal length ({raw_sig_len}) must be > n_fft/2 ({}) for reflect pad with center=true",
123            n_fft / 2
124        );
125    }
126
127    let signal = if center {
128        let pad_amount = n_fft / 2;
129        signal.pad([(0, 0), (pad_amount, pad_amount)], PadMode::Reflect)
130    } else {
131        signal
132    };
133
134    let [batch, sig_len] = signal.dims();
135    assert!(
136        sig_len >= n_fft,
137        "stft: signal length ({sig_len}) must be >= n_fft ({n_fft}) after padding"
138    );
139
140    let n_frames = 1 + (sig_len - n_fft) / hop_length;
141
142    // Extract overlapping frames: [batch, n_frames, n_fft]
143    let frames: Tensor<B, 3> = signal.unfold(1, n_fft, hop_length);
144
145    // Apply window (broadcast over batch and n_frames)
146    let window_3d: Tensor<B, 3> = window.reshape([1, 1, n_fft]);
147    let windowed = frames.mul(window_3d);
148
149    // Flatten to [batch * n_frames, n_fft] for rfft
150    let flat: Tensor<B, 2> = windowed.reshape([batch * n_frames, n_fft]);
151
152    // rfft returns n_fft/2 + 1 bins along dim=1 (n_fft is pow2).
153    let (re, im) = rfft(flat, 1, Some(n_fft));
154
155    let (re, im, n_freqs) = if onesided {
156        (re, im, n_fft / 2 + 1)
157    } else {
158        let (re_full, im_full) = hermitian_extend(re, im, 1, n_fft);
159        (re_full, im_full, n_fft)
160    };
161
162    // Reshape to [batch, n_frames, n_freqs] then stack real/imag
163    let re: Tensor<B, 3> = re.reshape([batch, n_frames, n_freqs]);
164    let im: Tensor<B, 3> = im.reshape([batch, n_frames, n_freqs]);
165
166    Tensor::stack::<4>(vec![re, im], 3)
167}
168
169/// Center-pad window from `win_len` to `n_fft` with zeros.
170fn pad_window_to_n_fft<B: Backend>(
171    window: Tensor<B, 1>,
172    win_len: usize,
173    n_fft: usize,
174) -> Tensor<B, 1> {
175    if win_len < n_fft {
176        let pad_left = (n_fft - win_len) / 2;
177        let pad_right = n_fft - win_len - pad_left;
178        window.pad([(pad_left, pad_right)], PadMode::Constant(0.0))
179    } else {
180        window
181    }
182}
183
184/// Computes the inverse Short-Time Fourier Transform (ISTFT).
185///
186/// Reconstructs a time-domain signal from its STFT representation using overlap-add.
187///
188/// # Arguments
189///
190/// * `stft_matrix` - Complex STFT tensor of shape `[batch, n_frames, n_freqs, 2]`.
191/// * `window` - Window tensor used in the forward STFT. Defaults to rectangular.
192/// * `length` - Optional output signal length. If `None`, the length is inferred.
193/// * `options` - STFT configuration (must match the forward STFT).
194///
195/// # Returns
196///
197/// A real-valued tensor of shape `[batch, signal_length]`.
198pub fn istft<B: Backend>(
199    stft_matrix: Tensor<B, 4>,
200    window: Option<Tensor<B, 1>>,
201    length: Option<usize>,
202    options: StftOptions,
203) -> Tensor<B, 2> {
204    options.assert_valid("istft");
205    let n_fft = options.n_fft;
206    let hop_length = options.hop_length;
207    let center = options.center;
208    let onesided = options.onesided;
209    let [batch, n_frames, n_freqs_in, two] = stft_matrix.dims();
210    assert_eq!(
211        two, 2,
212        "istft: last dimension of stft_matrix must be 2 (real, imag), got {two}"
213    );
214    assert!(
215        n_frames >= 1,
216        "istft: stft_matrix must contain at least one frame, got n_frames=0"
217    );
218    let expected_n_freqs = if onesided { n_fft / 2 + 1 } else { n_fft };
219    assert_eq!(
220        n_freqs_in, expected_n_freqs,
221        "istft: n_freqs dimension ({n_freqs_in}) does not match expected for \
222         n_fft={n_fft}, onesided={onesided} (expected {expected_n_freqs})"
223    );
224
225    let win_len = options.effective_win_length();
226    let device = stft_matrix.device();
227
228    let window = match window {
229        Some(w) => {
230            assert_eq!(
231                w.dims()[0],
232                win_len,
233                "istft: window length ({}) must match effective win_length ({win_len})",
234                w.dims()[0],
235            );
236            w
237        }
238        None => Tensor::ones([win_len], &device),
239    };
240    let window = pad_window_to_n_fft(window, win_len, n_fft);
241
242    // Split real and imaginary: each [batch, n_frames, n_freqs]
243    let re: Tensor<B, 3> = stft_matrix.clone().narrow(3, 0, 1).squeeze_dim(3);
244    let im: Tensor<B, 3> = stft_matrix.narrow(3, 1, 1).squeeze_dim(3);
245
246    let re: Tensor<B, 2> = re.reshape([batch * n_frames, n_freqs_in]);
247    let im: Tensor<B, 2> = im.reshape([batch * n_frames, n_freqs_in]);
248
249    // Extract onesided spectrum for irfft
250    let (re_half, im_half) = if onesided {
251        (re, im)
252    } else {
253        let half = n_fft / 2 + 1;
254        (re.narrow(1, 0, half), im.narrow(1, 0, half))
255    };
256
257    let frames = irfft(re_half, im_half, 1, Some(n_fft));
258
259    // Reshape to [batch, n_frames, n_fft]
260    let frames: Tensor<B, 3> = frames.reshape([batch, n_frames, n_fft]);
261
262    // Apply window to each frame
263    let window_3d: Tensor<B, 3> = window.reshape([1, 1, n_fft]);
264    let windowed = frames.mul(window_3d.clone());
265
266    let expected_len = n_fft + (n_frames - 1) * hop_length;
267    let mut output = Tensor::<B, 2>::zeros([batch, expected_len], &device);
268    let mut window_sum = Tensor::<B, 2>::zeros([batch, expected_len], &device);
269
270    let window_1d: Tensor<B, 1> = window_3d.clone().reshape([n_fft]);
271    let window_sq_1d: Tensor<B, 1> = window_1d.clone().mul(window_1d);
272
273    for f in 0..n_frames {
274        let start = f * hop_length;
275        let right_pad = expected_len - n_fft - start;
276        let frame: Tensor<B, 2> = windowed.clone().narrow(1, f, 1).squeeze_dim(1);
277        let frame_full = frame.pad([(0, 0), (start, right_pad)], PadMode::Constant(0.0));
278        output = output.add(frame_full);
279
280        let win_full: Tensor<B, 1> = window_sq_1d
281            .clone()
282            .pad([(start, right_pad)], PadMode::Constant(0.0));
283        let win_full: Tensor<B, 2> = win_full.unsqueeze_dim(0);
284        window_sum = window_sum.add(win_full);
285    }
286
287    let window_sum = window_sum.clamp_min(1e-10);
288    output = output.div(window_sum);
289
290    // Remove center padding
291    let output = if center {
292        let pad_amount = n_fft / 2;
293        let trimmed_len = expected_len - 2 * pad_amount;
294        output.narrow(1, pad_amount, trimmed_len)
295    } else {
296        output
297    };
298
299    match length {
300        Some(len) => {
301            let current_len = output.dims()[1];
302            if len <= current_len {
303                output.narrow(1, 0, len)
304            } else {
305                output.pad([(0, 0), (0, len - current_len)], PadMode::Constant(0.0))
306            }
307        }
308        None => output,
309    }
310}