use crate::DType;
use numr::error::Result;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
pub trait StftAlgorithms<R: Runtime<DType = DType>> {
fn stft(
&self,
signal: &Tensor<R>,
n_fft: usize,
hop_length: Option<usize>,
window: Option<&Tensor<R>>,
center: bool,
normalized: bool,
) -> Result<Tensor<R>>;
fn istft(
&self,
stft_matrix: &Tensor<R>,
hop_length: Option<usize>,
window: Option<&Tensor<R>>,
center: bool,
length: Option<usize>,
normalized: bool,
) -> Result<Tensor<R>>;
}
#[cfg(test)]
mod tests {
#[test]
fn test_stft_num_frames() {
let signal_len = 256;
let n_fft = 256;
let hop_length = 64;
let center = true;
let padded_len = if center {
signal_len + n_fft
} else {
signal_len
};
let n_frames = if padded_len < n_fft {
0
} else {
(padded_len - n_fft) / hop_length + 1
};
let expected = (256 + 256 - 256) / 64 + 1;
assert_eq!(n_frames, expected);
}
}