use alloc::vec;
use burn_backend::Backend;
use crate::Tensor;
use crate::ops::PadMode;
use super::{irfft, rfft};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct StftOptions {
pub n_fft: usize,
pub hop_length: usize,
pub win_length: Option<usize>,
pub center: bool,
pub onesided: bool,
}
impl StftOptions {
pub fn new(n_fft: usize) -> Self {
Self {
n_fft,
hop_length: (n_fft / 4).max(1),
win_length: None,
center: true,
onesided: true,
}
}
fn effective_win_length(&self) -> usize {
self.win_length.unwrap_or(self.n_fft)
}
fn assert_valid(&self, op: &'static str) {
let n_fft = self.n_fft;
let hop_length = self.hop_length;
assert!(n_fft >= 1, "{op}: n_fft must be >= 1, got {n_fft}");
assert!(
n_fft.is_power_of_two(),
"{op}: n_fft must be a power of two, got {n_fft}. True non-power-of-two \
DFT support is tracked as a follow-up (Bluestein's algorithm)."
);
assert!(
hop_length >= 1,
"{op}: hop_length must be >= 1, got {hop_length}"
);
let win_len = self.effective_win_length();
assert!(
win_len >= 1,
"{op}: effective win_length must be >= 1, got {win_len}"
);
assert!(
win_len <= n_fft,
"{op}: win_length ({win_len}) must be <= n_fft ({n_fft})"
);
assert!(
hop_length <= win_len,
"{op}: hop_length ({hop_length}) must be <= effective win_length ({win_len}) \
so the window/hop combination satisfies the COLA/NOLA condition required \
for invertibility"
);
}
}
impl Default for StftOptions {
fn default() -> Self {
Self::new(400)
}
}
pub fn stft<B: Backend>(
signal: Tensor<B, 2>,
window: Option<Tensor<B, 1>>,
options: StftOptions,
) -> Tensor<B, 4> {
options.assert_valid("stft");
let n_fft = options.n_fft;
let hop_length = options.hop_length;
let center = options.center;
let onesided = options.onesided;
let win_len = options.effective_win_length();
let device = signal.device();
let window = match window {
Some(w) => {
assert_eq!(
w.dims()[0],
win_len,
"stft: window length ({}) must match effective win_length ({win_len})",
w.dims()[0],
);
w
}
None => Tensor::ones([win_len], &device),
};
let window = pad_window_to_n_fft(window, win_len, n_fft);
let [_, raw_sig_len] = signal.dims();
if center {
assert!(
raw_sig_len > n_fft / 2,
"stft: signal length ({raw_sig_len}) must be > n_fft/2 ({}) for reflect pad with center=true",
n_fft / 2
);
}
let signal = if center {
let pad_amount = n_fft / 2;
signal.pad([(0, 0), (pad_amount, pad_amount)], PadMode::Reflect)
} else {
signal
};
let [batch, sig_len] = signal.dims();
assert!(
sig_len >= n_fft,
"stft: signal length ({sig_len}) must be >= n_fft ({n_fft}) after padding"
);
let n_frames = 1 + (sig_len - n_fft) / hop_length;
let frames: Tensor<B, 3> = signal.unfold(1, n_fft, hop_length);
let window_3d: Tensor<B, 3> = window.reshape([1, 1, n_fft]);
let windowed = frames.mul(window_3d);
let flat: Tensor<B, 2> = windowed.reshape([batch * n_frames, n_fft]);
let (re, im) = rfft(flat, 1, Some(n_fft));
let (re, im, n_freqs) = if onesided {
(re, im, n_fft / 2 + 1)
} else {
let (re_full, im_full) = reconstruct_full_spectrum(re, im, n_fft);
(re_full, im_full, n_fft)
};
let re: Tensor<B, 3> = re.reshape([batch, n_frames, n_freqs]);
let im: Tensor<B, 3> = im.reshape([batch, n_frames, n_freqs]);
Tensor::stack::<4>(vec![re, im], 3)
}
fn reconstruct_full_spectrum<B: Backend>(
re: Tensor<B, 2>,
im: Tensor<B, 2>,
n_fft: usize,
) -> (Tensor<B, 2>, Tensor<B, 2>) {
if n_fft <= 1 {
return (re, im);
}
let n_freqs_onesided = n_fft / 2 + 1;
let n_neg = n_fft - n_freqs_onesided;
if n_neg == 0 {
return (re, im);
}
let re_neg = re.clone().narrow(1, 1, n_neg).flip([1]);
let im_neg = im.clone().narrow(1, 1, n_neg).flip([1]).neg();
(
Tensor::cat(vec![re, re_neg], 1),
Tensor::cat(vec![im, im_neg], 1),
)
}
fn pad_window_to_n_fft<B: Backend>(
window: Tensor<B, 1>,
win_len: usize,
n_fft: usize,
) -> Tensor<B, 1> {
if win_len < n_fft {
let pad_left = (n_fft - win_len) / 2;
let pad_right = n_fft - win_len - pad_left;
window.pad([(pad_left, pad_right)], PadMode::Constant(0.0))
} else {
window
}
}
pub fn istft<B: Backend>(
stft_matrix: Tensor<B, 4>,
window: Option<Tensor<B, 1>>,
length: Option<usize>,
options: StftOptions,
) -> Tensor<B, 2> {
options.assert_valid("istft");
let n_fft = options.n_fft;
let hop_length = options.hop_length;
let center = options.center;
let onesided = options.onesided;
let [batch, n_frames, n_freqs_in, two] = stft_matrix.dims();
assert_eq!(
two, 2,
"istft: last dimension of stft_matrix must be 2 (real, imag), got {two}"
);
assert!(
n_frames >= 1,
"istft: stft_matrix must contain at least one frame, got n_frames=0"
);
let expected_n_freqs = if onesided { n_fft / 2 + 1 } else { n_fft };
assert_eq!(
n_freqs_in, expected_n_freqs,
"istft: n_freqs dimension ({n_freqs_in}) does not match expected for \
n_fft={n_fft}, onesided={onesided} (expected {expected_n_freqs})"
);
let win_len = options.effective_win_length();
let device = stft_matrix.device();
let window = match window {
Some(w) => {
assert_eq!(
w.dims()[0],
win_len,
"istft: window length ({}) must match effective win_length ({win_len})",
w.dims()[0],
);
w
}
None => Tensor::ones([win_len], &device),
};
let window = pad_window_to_n_fft(window, win_len, n_fft);
let re: Tensor<B, 3> = stft_matrix.clone().narrow(3, 0, 1).squeeze_dim(3);
let im: Tensor<B, 3> = stft_matrix.narrow(3, 1, 1).squeeze_dim(3);
let re: Tensor<B, 2> = re.reshape([batch * n_frames, n_freqs_in]);
let im: Tensor<B, 2> = im.reshape([batch * n_frames, n_freqs_in]);
let (re_half, im_half) = if onesided {
(re, im)
} else {
let half = n_fft / 2 + 1;
(re.narrow(1, 0, half), im.narrow(1, 0, half))
};
let frames = irfft(re_half, im_half, 1, Some(n_fft));
let frames: Tensor<B, 3> = frames.reshape([batch, n_frames, n_fft]);
let window_3d: Tensor<B, 3> = window.reshape([1, 1, n_fft]);
let windowed = frames.mul(window_3d.clone());
let expected_len = n_fft + (n_frames - 1) * hop_length;
let mut output = Tensor::<B, 2>::zeros([batch, expected_len], &device);
let mut window_sum = Tensor::<B, 2>::zeros([batch, expected_len], &device);
let window_1d: Tensor<B, 1> = window_3d.clone().reshape([n_fft]);
let window_sq_1d: Tensor<B, 1> = window_1d.clone().mul(window_1d);
for f in 0..n_frames {
let start = f * hop_length;
let right_pad = expected_len - n_fft - start;
let frame: Tensor<B, 2> = windowed.clone().narrow(1, f, 1).squeeze_dim(1);
let frame_full = frame.pad([(0, 0), (start, right_pad)], PadMode::Constant(0.0));
output = output.add(frame_full);
let win_full: Tensor<B, 1> = window_sq_1d
.clone()
.pad([(start, right_pad)], PadMode::Constant(0.0));
let win_full: Tensor<B, 2> = win_full.unsqueeze_dim(0);
window_sum = window_sum.add(win_full);
}
let window_sum = window_sum.clamp_min(1e-10);
output = output.div(window_sum);
let output = if center {
let pad_amount = n_fft / 2;
let trimmed_len = expected_len - 2 * pad_amount;
output.narrow(1, pad_amount, trimmed_len)
} else {
output
};
match length {
Some(len) => {
let current_len = output.dims()[1];
if len <= current_len {
output.narrow(1, 0, len)
} else {
output.pad([(0, 0), (0, len - current_len)], PadMode::Constant(0.0))
}
}
None => output,
}
}