#![allow(clippy::too_many_arguments)]
use crate::DType;
mod istft;
pub use istft::istft_impl;
use super::helpers::complex_magnitude_pow_impl;
use super::padding::pad_1d_reflect_impl;
use crate::signal::{stft_num_frames, validate_signal_dtype, validate_stft_params};
use crate::window::WindowFunctions;
use numr::algorithm::fft::{FftAlgorithms, FftNormalization};
use numr::error::{Error, Result};
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn stft_impl<R, C>(
client: &C,
signal: &Tensor<R>,
n_fft: usize,
hop_length: Option<usize>,
window: Option<&Tensor<R>>,
center: bool,
_normalized: bool,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R> + WindowFunctions<R> + TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let dtype = signal.dtype();
validate_signal_dtype(dtype, "stft")?;
let hop = hop_length.unwrap_or(n_fft / 4);
validate_stft_params(n_fft, hop, "stft")?;
let signal_contig = signal.contiguous()?;
let ndim = signal_contig.ndim();
if ndim == 0 {
return Err(Error::InvalidArgument {
arg: "signal",
reason: "stft requires at least 1D signal".to_string(),
});
}
let signal_len = signal_contig.shape()[ndim - 1];
let default_window;
let win = if let Some(w) = window {
if w.shape() != [n_fft] {
return Err(Error::InvalidArgument {
arg: "window",
reason: format!("window must have shape [{n_fft}], got {:?}", w.shape()),
});
}
w
} else {
default_window = client.hann_window(n_fft, dtype, client.device())?;
&default_window
};
let n_frames = stft_num_frames(signal_len, n_fft, hop, center);
if n_frames == 0 {
return Err(Error::InvalidArgument {
arg: "signal",
reason: format!("signal too short for STFT: length={signal_len}, n_fft={n_fft}"),
});
}
let padded_signal = if center {
let pad_left = n_fft / 2;
let pad_right = n_fft / 2;
pad_1d_reflect_impl(client, &signal_contig, pad_left, pad_right)?
} else {
signal_contig.clone()
};
let batch_size: usize = if ndim > 1 {
signal_contig.shape()[..ndim - 1].iter().product()
} else {
1
};
let freq_bins = n_fft / 2 + 1;
if batch_size == 1 {
stft_single(client, &padded_signal, win, n_fft, hop, n_frames, freq_bins)
} else {
stft_batched(
client,
&padded_signal,
win,
n_fft,
hop,
n_frames,
batch_size,
freq_bins,
)
}
}
fn stft_single<R, C>(
client: &C,
signal: &Tensor<R>,
window: &Tensor<R>,
n_fft: usize,
hop: usize,
n_frames: usize,
freq_bins: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R> + TensorOps<R> + RuntimeClient<R>,
{
let signal_len = signal.shape()[0];
let mut frame_spectra: Vec<Tensor<R>> = Vec::with_capacity(n_frames);
for f in 0..n_frames {
let frame_start = f * hop;
let available = signal_len.saturating_sub(frame_start);
let frame_len = n_fft.min(available);
let frame = if frame_len == n_fft && frame_start + n_fft <= signal_len {
signal.narrow(0, frame_start, n_fft)?.contiguous()?
} else {
if frame_len > 0 {
let partial = signal.narrow(0, frame_start, frame_len)?.contiguous()?;
let pad_amount = n_fft - frame_len;
client.pad(&partial, &[0, pad_amount], 0.0)?
} else {
Tensor::<R>::zeros(&[n_fft], signal.dtype(), client.device())
}
};
let windowed = client.mul(&frame, window)?;
let spectrum = client.rfft(&windowed, FftNormalization::None)?;
let spectrum_2d = spectrum.reshape(&[1, freq_bins])?;
frame_spectra.push(spectrum_2d);
}
let refs: Vec<&Tensor<R>> = frame_spectra.iter().collect();
client.cat(&refs, 0)
}
fn stft_batched<R, C>(
client: &C,
signal: &Tensor<R>,
window: &Tensor<R>,
n_fft: usize,
hop: usize,
n_frames: usize,
batch_size: usize,
freq_bins: usize,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R> + TensorOps<R> + RuntimeClient<R>,
{
let signal_len = signal.numel() / batch_size;
let signal_2d = signal.reshape(&[batch_size, signal_len])?;
let window_2d = window.reshape(&[1, n_fft])?;
let mut frame_spectra: Vec<Tensor<R>> = Vec::with_capacity(n_frames);
for f in 0..n_frames {
let frame_start = f * hop;
let available = signal_len.saturating_sub(frame_start);
let frame_len = n_fft.min(available);
let frames = if frame_len == n_fft && frame_start + n_fft <= signal_len {
signal_2d.narrow(1, frame_start, n_fft)?.contiguous()?
} else {
if frame_len > 0 {
let partial = signal_2d.narrow(1, frame_start, frame_len)?.contiguous()?;
let pad_amount = n_fft - frame_len;
client.pad(&partial, &[0, pad_amount], 0.0)?
} else {
Tensor::<R>::zeros(&[batch_size, n_fft], signal.dtype(), client.device())
}
};
let windowed = client.mul(&frames, &window_2d)?;
let spectrum = client.rfft(&windowed, FftNormalization::None)?;
let spectrum_3d = spectrum.reshape(&[batch_size, 1, freq_bins])?;
frame_spectra.push(spectrum_3d);
}
let refs: Vec<&Tensor<R>> = frame_spectra.iter().collect();
client.cat(&refs, 1)
}
pub fn spectrogram_impl<R, C>(
client: &C,
signal: &Tensor<R>,
n_fft: usize,
hop_length: Option<usize>,
window: Option<&Tensor<R>>,
power: f64,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R> + WindowFunctions<R> + TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let stft_result = stft_impl(client, signal, n_fft, hop_length, window, true, false)?;
let dtype = signal.dtype();
complex_magnitude_pow_impl(client, &stft_result, power, dtype)
}