#![allow(clippy::too_many_arguments)]
use crate::signal::validate_stft_params;
use crate::window::WindowFunctions;
use numr::algorithm::fft::{FftAlgorithms, FftNormalization};
use numr::dtype::DType;
use numr::error::{Error, Result};
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
pub fn istft_impl<R, C>(
client: &C,
stft_matrix: &Tensor<R>,
hop_length: Option<usize>,
window: Option<&Tensor<R>>,
center: bool,
length: Option<usize>,
normalized: bool,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R> + WindowFunctions<R> + TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let dtype = stft_matrix.dtype();
if !dtype.is_complex() {
return Err(Error::UnsupportedDType { dtype, op: "istft" });
}
let real_dtype = match dtype {
DType::Complex64 => DType::F32,
DType::Complex128 => DType::F64,
_ => unreachable!(),
};
let stft_contig = stft_matrix.contiguous()?;
let ndim = stft_contig.ndim();
if ndim < 2 {
return Err(Error::InvalidArgument {
arg: "stft_matrix",
reason: "istft requires at least 2D input [n_frames, freq_bins]".to_string(),
});
}
let n_frames = stft_contig.shape()[ndim - 2];
let freq_bins = stft_contig.shape()[ndim - 1];
let n_fft = (freq_bins - 1) * 2;
let hop = hop_length.unwrap_or(n_fft / 4);
validate_stft_params(n_fft, hop, "istft")?;
let default_window;
let win = if let Some(w) = window {
if w.shape() != [n_fft] {
return Err(Error::InvalidArgument {
arg: "window",
reason: format!("window must have shape [{n_fft}], got {:?}", w.shape()),
});
}
w
} else {
default_window = client.hann_window(n_fft, real_dtype, client.device())?;
&default_window
};
let batch_size: usize = if ndim > 2 {
stft_contig.shape()[..ndim - 2].iter().product()
} else {
1
};
let full_len = n_fft + (n_frames - 1) * hop;
let pad_left = if center { n_fft / 2 } else { 0 };
let output_len = if center {
full_len - n_fft } else {
full_len
};
let final_len = length.unwrap_or(output_len);
let norm = if normalized {
FftNormalization::Ortho
} else {
FftNormalization::Backward
};
let window_sq = client.mul(win, win)?;
if batch_size == 1 {
istft_single(
client,
&stft_contig,
win,
&window_sq,
real_dtype,
n_fft,
hop,
n_frames,
freq_bins,
full_len,
pad_left,
final_len,
norm,
)
} else {
istft_batched(
client,
&stft_contig,
win,
&window_sq,
real_dtype,
n_fft,
hop,
n_frames,
batch_size,
freq_bins,
full_len,
pad_left,
final_len,
norm,
)
}
}
fn istft_single<R, C>(
client: &C,
stft_matrix: &Tensor<R>,
window: &Tensor<R>,
window_sq: &Tensor<R>,
real_dtype: DType,
n_fft: usize,
hop: usize,
n_frames: usize,
freq_bins: usize,
full_len: usize,
pad_left: usize,
final_len: usize,
norm: FftNormalization,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R> + TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let device = client.device();
let mut output = Tensor::<R>::zeros(&[full_len], real_dtype, device);
let mut window_sum = Tensor::<R>::zeros(&[full_len], real_dtype, device);
for f in 0..n_frames {
let spectrum = stft_matrix
.narrow(0, f, 1)?
.contiguous()?
.reshape(&[freq_bins])?;
let frame = client.irfft(&spectrum, Some(n_fft), norm)?;
let windowed_frame = client.mul(&frame, window)?;
let frame_start = f * hop;
let right_pad = full_len.saturating_sub(frame_start + n_fft);
let padded_frame = client.pad(&windowed_frame, &[frame_start, right_pad], 0.0)?;
output = client.add(&output, &padded_frame)?;
let padded_window_sq = client.pad(window_sq, &[frame_start, right_pad], 0.0)?;
window_sum = client.add(&window_sum, &padded_window_sq)?;
}
let eps = Tensor::<R>::full_scalar(&[full_len], real_dtype, 1e-8, device);
let safe_window_sum = client.maximum(&window_sum, &eps)?;
let normalized_output = client.div(&output, &safe_window_sum)?;
if pad_left == 0 && final_len == full_len {
Ok(normalized_output)
} else {
let extracted =
normalized_output.narrow(0, pad_left, final_len.min(full_len - pad_left))?;
Ok(extracted.contiguous()?)
}
}
fn istft_batched<R, C>(
client: &C,
stft_matrix: &Tensor<R>,
window: &Tensor<R>,
window_sq: &Tensor<R>,
real_dtype: DType,
n_fft: usize,
hop: usize,
n_frames: usize,
batch_size: usize,
freq_bins: usize,
full_len: usize,
pad_left: usize,
final_len: usize,
norm: FftNormalization,
) -> Result<Tensor<R>>
where
R: Runtime<DType = DType>,
C: FftAlgorithms<R> + TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let device = client.device();
let mut output = Tensor::<R>::zeros(&[batch_size, full_len], real_dtype, device);
let mut window_sum = Tensor::<R>::zeros(&[batch_size, full_len], real_dtype, device);
let stft_batched = stft_matrix.reshape(&[batch_size, n_frames, freq_bins])?;
for f in 0..n_frames {
let spectrum = stft_batched
.narrow(1, f, 1)?
.contiguous()?
.reshape(&[batch_size, freq_bins])?;
let frames = client.irfft(&spectrum, Some(n_fft), norm)?;
let window_broadcast = window.reshape(&[1, n_fft])?;
let window_sq_broadcast = window_sq.reshape(&[1, n_fft])?;
let windowed_frames = client.mul(&frames, &window_broadcast)?;
let frame_start = f * hop;
let right_pad = full_len.saturating_sub(frame_start + n_fft);
let padded_frames = client.pad(&windowed_frames, &[frame_start, right_pad], 0.0)?;
output = client.add(&output, &padded_frames)?;
let padded_window_sq = client.pad(&window_sq_broadcast, &[frame_start, right_pad], 0.0)?;
window_sum = client.add(&window_sum, &padded_window_sq)?;
}
let eps = Tensor::<R>::full_scalar(&[1, full_len], real_dtype, 1e-8, device);
let safe_window_sum = client.maximum(&window_sum, &eps)?;
let normalized_output = client.div(&output, &safe_window_sum)?;
if pad_left == 0 && final_len == full_len {
Ok(normalized_output)
} else {
let extracted =
normalized_output.narrow(1, pad_left, final_len.min(full_len - pad_left))?;
Ok(extracted.contiguous()?)
}
}