1use alloc::vec;
2
3use burn_backend::Backend;
4
5use crate::Tensor;
6use crate::ops::PadMode;
7
8use super::{hermitian_extend, irfft, rfft};
9
10#[derive(Debug, Clone, Copy, PartialEq)]
12pub struct StftOptions {
13 pub n_fft: usize,
15 pub hop_length: usize,
18 pub win_length: Option<usize>,
20 pub center: bool,
22 pub onesided: bool,
25}
26
27impl StftOptions {
28 pub fn new(n_fft: usize) -> Self {
31 Self {
32 n_fft,
33 hop_length: (n_fft / 4).max(1),
34 win_length: None,
35 center: true,
36 onesided: true,
37 }
38 }
39
40 fn effective_win_length(&self) -> usize {
41 self.win_length.unwrap_or(self.n_fft)
42 }
43
44 fn assert_valid(&self, op: &'static str) {
45 let n_fft = self.n_fft;
46 let hop_length = self.hop_length;
47 assert!(n_fft >= 1, "{op}: n_fft must be >= 1, got {n_fft}");
48 assert!(
49 n_fft.is_power_of_two(),
50 "{op}: n_fft must be a power of two, got {n_fft}. True non-power-of-two \
51 DFT support is tracked as a follow-up (Bluestein's algorithm)."
52 );
53 assert!(
54 hop_length >= 1,
55 "{op}: hop_length must be >= 1, got {hop_length}"
56 );
57 let win_len = self.effective_win_length();
58 assert!(
59 win_len >= 1,
60 "{op}: effective win_length must be >= 1, got {win_len}"
61 );
62 assert!(
63 win_len <= n_fft,
64 "{op}: win_length ({win_len}) must be <= n_fft ({n_fft})"
65 );
66 assert!(
67 hop_length <= win_len,
68 "{op}: hop_length ({hop_length}) must be <= effective win_length ({win_len}) \
69 so the window/hop combination satisfies the COLA/NOLA condition required \
70 for invertibility"
71 );
72 }
73}
74
75impl Default for StftOptions {
76 fn default() -> Self {
77 Self::new(400)
78 }
79}
80
81pub fn stft<B: Backend>(
90 signal: Tensor<B, 2>,
91 window: Option<Tensor<B, 1>>,
92 options: StftOptions,
93) -> Tensor<B, 4> {
94 options.assert_valid("stft");
95 let n_fft = options.n_fft;
96 let hop_length = options.hop_length;
97 let center = options.center;
98 let onesided = options.onesided;
99 let win_len = options.effective_win_length();
100
101 let device = signal.device();
102
103 let window = match window {
104 Some(w) => {
105 assert_eq!(
106 w.dims()[0],
107 win_len,
108 "stft: window length ({}) must match effective win_length ({win_len})",
109 w.dims()[0],
110 );
111 w
112 }
113 None => Tensor::ones([win_len], &device),
114 };
115
116 let window = pad_window_to_n_fft(window, win_len, n_fft);
117
118 let [_, raw_sig_len] = signal.dims();
119 if center {
120 assert!(
121 raw_sig_len > n_fft / 2,
122 "stft: signal length ({raw_sig_len}) must be > n_fft/2 ({}) for reflect pad with center=true",
123 n_fft / 2
124 );
125 }
126
127 let signal = if center {
128 let pad_amount = n_fft / 2;
129 signal.pad([(0, 0), (pad_amount, pad_amount)], PadMode::Reflect)
130 } else {
131 signal
132 };
133
134 let [batch, sig_len] = signal.dims();
135 assert!(
136 sig_len >= n_fft,
137 "stft: signal length ({sig_len}) must be >= n_fft ({n_fft}) after padding"
138 );
139
140 let n_frames = 1 + (sig_len - n_fft) / hop_length;
141
142 let frames: Tensor<B, 3> = signal.unfold(1, n_fft, hop_length);
144
145 let window_3d: Tensor<B, 3> = window.reshape([1, 1, n_fft]);
147 let windowed = frames.mul(window_3d);
148
149 let flat: Tensor<B, 2> = windowed.reshape([batch * n_frames, n_fft]);
151
152 let (re, im) = rfft(flat, 1, Some(n_fft));
154
155 let (re, im, n_freqs) = if onesided {
156 (re, im, n_fft / 2 + 1)
157 } else {
158 let (re_full, im_full) = hermitian_extend(re, im, 1, n_fft);
159 (re_full, im_full, n_fft)
160 };
161
162 let re: Tensor<B, 3> = re.reshape([batch, n_frames, n_freqs]);
164 let im: Tensor<B, 3> = im.reshape([batch, n_frames, n_freqs]);
165
166 Tensor::stack::<4>(vec![re, im], 3)
167}
168
169fn pad_window_to_n_fft<B: Backend>(
171 window: Tensor<B, 1>,
172 win_len: usize,
173 n_fft: usize,
174) -> Tensor<B, 1> {
175 if win_len < n_fft {
176 let pad_left = (n_fft - win_len) / 2;
177 let pad_right = n_fft - win_len - pad_left;
178 window.pad([(pad_left, pad_right)], PadMode::Constant(0.0))
179 } else {
180 window
181 }
182}
183
184pub fn istft<B: Backend>(
199 stft_matrix: Tensor<B, 4>,
200 window: Option<Tensor<B, 1>>,
201 length: Option<usize>,
202 options: StftOptions,
203) -> Tensor<B, 2> {
204 options.assert_valid("istft");
205 let n_fft = options.n_fft;
206 let hop_length = options.hop_length;
207 let center = options.center;
208 let onesided = options.onesided;
209 let [batch, n_frames, n_freqs_in, two] = stft_matrix.dims();
210 assert_eq!(
211 two, 2,
212 "istft: last dimension of stft_matrix must be 2 (real, imag), got {two}"
213 );
214 assert!(
215 n_frames >= 1,
216 "istft: stft_matrix must contain at least one frame, got n_frames=0"
217 );
218 let expected_n_freqs = if onesided { n_fft / 2 + 1 } else { n_fft };
219 assert_eq!(
220 n_freqs_in, expected_n_freqs,
221 "istft: n_freqs dimension ({n_freqs_in}) does not match expected for \
222 n_fft={n_fft}, onesided={onesided} (expected {expected_n_freqs})"
223 );
224
225 let win_len = options.effective_win_length();
226 let device = stft_matrix.device();
227
228 let window = match window {
229 Some(w) => {
230 assert_eq!(
231 w.dims()[0],
232 win_len,
233 "istft: window length ({}) must match effective win_length ({win_len})",
234 w.dims()[0],
235 );
236 w
237 }
238 None => Tensor::ones([win_len], &device),
239 };
240 let window = pad_window_to_n_fft(window, win_len, n_fft);
241
242 let re: Tensor<B, 3> = stft_matrix.clone().narrow(3, 0, 1).squeeze_dim(3);
244 let im: Tensor<B, 3> = stft_matrix.narrow(3, 1, 1).squeeze_dim(3);
245
246 let re: Tensor<B, 2> = re.reshape([batch * n_frames, n_freqs_in]);
247 let im: Tensor<B, 2> = im.reshape([batch * n_frames, n_freqs_in]);
248
249 let (re_half, im_half) = if onesided {
251 (re, im)
252 } else {
253 let half = n_fft / 2 + 1;
254 (re.narrow(1, 0, half), im.narrow(1, 0, half))
255 };
256
257 let frames = irfft(re_half, im_half, 1, Some(n_fft));
258
259 let frames: Tensor<B, 3> = frames.reshape([batch, n_frames, n_fft]);
261
262 let window_3d: Tensor<B, 3> = window.reshape([1, 1, n_fft]);
264 let windowed = frames.mul(window_3d.clone());
265
266 let expected_len = n_fft + (n_frames - 1) * hop_length;
267 let mut output = Tensor::<B, 2>::zeros([batch, expected_len], &device);
268 let mut window_sum = Tensor::<B, 2>::zeros([batch, expected_len], &device);
269
270 let window_1d: Tensor<B, 1> = window_3d.clone().reshape([n_fft]);
271 let window_sq_1d: Tensor<B, 1> = window_1d.clone().mul(window_1d);
272
273 for f in 0..n_frames {
274 let start = f * hop_length;
275 let right_pad = expected_len - n_fft - start;
276 let frame: Tensor<B, 2> = windowed.clone().narrow(1, f, 1).squeeze_dim(1);
277 let frame_full = frame.pad([(0, 0), (start, right_pad)], PadMode::Constant(0.0));
278 output = output.add(frame_full);
279
280 let win_full: Tensor<B, 1> = window_sq_1d
281 .clone()
282 .pad([(start, right_pad)], PadMode::Constant(0.0));
283 let win_full: Tensor<B, 2> = win_full.unsqueeze_dim(0);
284 window_sum = window_sum.add(win_full);
285 }
286
287 let window_sum = window_sum.clamp_min(1e-10);
288 output = output.div(window_sum);
289
290 let output = if center {
292 let pad_amount = n_fft / 2;
293 let trimmed_len = expected_len - 2 * pad_amount;
294 output.narrow(1, pad_amount, trimmed_len)
295 } else {
296 output
297 };
298
299 match length {
300 Some(len) => {
301 let current_len = output.dims()[1];
302 if len <= current_len {
303 output.narrow(1, 0, len)
304 } else {
305 output.pad([(0, 0), (0, len - current_len)], PadMode::Constant(0.0))
306 }
307 }
308 None => output,
309 }
310}