use crate::error::{Error, Result};
use numr::runtime::cpu::{CpuClient, CpuRuntime};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct IStftOptions {
pub hop_length: usize,
pub center: bool,
pub eps: f32,
}
impl Default for IStftOptions {
fn default() -> Self {
Self {
hop_length: 256,
center: true,
eps: 1e-8,
}
}
}
pub fn istft(
client: &CpuClient,
mag: &Tensor<CpuRuntime>,
phase: &Tensor<CpuRuntime>,
window: &Tensor<CpuRuntime>,
opts: IStftOptions,
) -> Result<Tensor<CpuRuntime>> {
if mag.shape() != phase.shape() {
return Err(Error::InvalidArgument {
arg: "phase",
reason: format!(
"shape must match mag ({:?}), got {:?}",
mag.shape(),
phase.shape()
),
});
}
let m_shape = mag.shape();
if m_shape.len() != 3 {
return Err(Error::InvalidArgument {
arg: "mag",
reason: format!("expected [B, F, T], got {m_shape:?}"),
});
}
let (b, f, t_frames) = (m_shape[0], m_shape[1], m_shape[2]);
let w_shape = window.shape();
if w_shape.len() != 1 {
return Err(Error::InvalidArgument {
arg: "window",
reason: format!("expected 1D window [n_fft], got {w_shape:?}"),
});
}
let n_fft = w_shape[0];
if f != n_fft / 2 + 1 {
return Err(Error::InvalidArgument {
arg: "mag",
reason: format!("F ({f}) must equal n_fft/2+1 ({})", n_fft / 2 + 1),
});
}
if opts.hop_length == 0 {
return Err(Error::InvalidArgument {
arg: "hop_length",
reason: "must be > 0".into(),
});
}
if t_frames == 0 {
return Err(Error::InvalidArgument {
arg: "mag",
reason: "T_frames must be > 0".into(),
});
}
let _ = client; let mag_flat: Vec<f32> = mag.contiguous()?.to_vec();
let phase_flat: Vec<f32> = phase.contiguous()?.to_vec();
let window_samples: Vec<f32> = window.contiguous()?.to_vec();
let n_fft_f = n_fft as f32;
let f_bins = n_fft / 2 + 1;
let mut cos_table = vec![0.0f32; f_bins * n_fft];
let mut sin_table = vec![0.0f32; f_bins * n_fft];
for k in 0..f_bins {
for n in 0..n_fft {
let theta = 2.0 * std::f32::consts::PI * (k as f32) * (n as f32) / n_fft_f;
cos_table[k * n_fft + n] = theta.cos();
sin_table[k * n_fft + n] = theta.sin();
}
}
let mut windowed_flat = vec![0.0f32; b * t_frames * n_fft];
let inv_n = 1.0f32 / n_fft_f;
for b_idx in 0..b {
for t_idx in 0..t_frames {
for (n, &w) in window_samples.iter().take(n_fft).enumerate() {
let mut acc = 0.0f32;
for k in 0..f_bins {
let src = (b_idx * f_bins + k) * t_frames + t_idx;
let mag_k = mag_flat[src];
let ph_k = phase_flat[src];
let theta =
ph_k + 2.0 * std::f32::consts::PI * (k as f32) * (n as f32) / n_fft_f;
let term = mag_k * theta.cos();
let mirror = k != 0 && !(n_fft.is_multiple_of(2) && k == n_fft / 2);
acc += term * if mirror { 2.0 } else { 1.0 };
}
let dst = (b_idx * t_frames + t_idx) * n_fft + n;
windowed_flat[dst] = acc * inv_n * w;
}
}
}
drop(cos_table);
drop(sin_table);
let raw_len = (t_frames - 1) * opts.hop_length + n_fft;
let mut waveform = vec![0.0f32; b * raw_len];
let mut norm = vec![0.0f32; raw_len];
let window_sq: Vec<f32> = window_samples.iter().map(|w| w * w).collect();
for b_idx in 0..b {
for t_idx in 0..t_frames {
let frame_base = (b_idx * t_frames + t_idx) * n_fft;
let wave_base = b_idx * raw_len + t_idx * opts.hop_length;
for n in 0..n_fft {
waveform[wave_base + n] += windowed_flat[frame_base + n];
}
}
}
for t_idx in 0..t_frames {
let base = t_idx * opts.hop_length;
for n in 0..n_fft {
norm[base + n] += window_sq[n];
}
}
for b_idx in 0..b {
for n in 0..raw_len {
let nrm = norm[n];
if nrm > opts.eps {
waveform[b_idx * raw_len + n] /= nrm;
} else {
waveform[b_idx * raw_len + n] = 0.0;
}
}
}
let (out_len, output) = if opts.center {
let half = n_fft / 2;
if raw_len < 2 * half {
return Err(Error::InvalidArgument {
arg: "mag",
reason: "signal too short to remove center padding".into(),
});
}
let out_len = raw_len - 2 * half;
let mut cropped = vec![0.0f32; b * out_len];
for b_idx in 0..b {
let src = &waveform[b_idx * raw_len + half..b_idx * raw_len + half + out_len];
cropped[b_idx * out_len..(b_idx + 1) * out_len].copy_from_slice(src);
}
(out_len, cropped)
} else {
(raw_len, waveform)
};
let device = mag.device();
Ok(Tensor::<CpuRuntime>::from_slice(
&output,
&[b, out_len],
device,
))
}
#[cfg(test)]
#[allow(clippy::useless_vec)]
mod tests {
use super::*;
use crate::test_utils::cpu_setup;
use numr::runtime::Runtime;
fn make_tensor(
data: &[f32],
shape: &[usize],
device: &<CpuRuntime as Runtime>::Device,
) -> Tensor<CpuRuntime> {
Tensor::<CpuRuntime>::from_slice(data, shape, device)
}
#[test]
fn output_shape_matches_overlap_add_formula() {
let (client, device) = cpu_setup();
let n_fft = 8;
let hop = 4;
let t_frames = 5;
let f = n_fft / 2 + 1;
let mag = make_tensor(&vec![0.0f32; f * t_frames], &[1, f, t_frames], &device);
let phase = make_tensor(&vec![0.0f32; f * t_frames], &[1, f, t_frames], &device);
let window = make_tensor(&vec![1.0f32; n_fft], &[n_fft], &device);
let opts = IStftOptions {
hop_length: hop,
center: false,
eps: 1e-8,
};
let out = istft(&client, &mag, &phase, &window, opts).unwrap();
let expected_len = (t_frames - 1) * hop + n_fft;
assert_eq!(out.shape(), &[1, expected_len]);
}
#[test]
fn zero_spectrogram_yields_zero_waveform() {
let (client, device) = cpu_setup();
let n_fft = 8;
let t_frames = 4;
let f = n_fft / 2 + 1;
let mag = make_tensor(&vec![0.0f32; f * t_frames], &[1, f, t_frames], &device);
let phase = make_tensor(&vec![0.0f32; f * t_frames], &[1, f, t_frames], &device);
let window = make_tensor(&vec![0.5f32; n_fft], &[n_fft], &device);
let out = istft(&client, &mag, &phase, &window, IStftOptions::default()).unwrap();
for v in out.to_vec::<f32>() {
assert_eq!(v, 0.0);
}
}
#[test]
fn dc_only_spectrogram_reconstructs_constant() {
let (client, device) = cpu_setup();
let n_fft = 4;
let hop = 2;
let t_frames = 3;
let f = n_fft / 2 + 1;
let c = 4.0f32;
let mut mag_data = vec![0.0f32; f * t_frames];
mag_data.iter_mut().take(t_frames).for_each(|v| *v = c);
let mag = make_tensor(&mag_data, &[1, f, t_frames], &device);
let phase = make_tensor(&vec![0.0f32; f * t_frames], &[1, f, t_frames], &device);
let window = make_tensor(&vec![1.0f32; n_fft], &[n_fft], &device);
let opts = IStftOptions {
hop_length: hop,
center: false,
eps: 1e-8,
};
let out = istft(&client, &mag, &phase, &window, opts).unwrap();
let samples: Vec<f32> = out.to_vec();
let expected = c / n_fft as f32;
let mid_start = n_fft;
let mid_end = samples.len().saturating_sub(n_fft);
for (i, sample) in samples.iter().enumerate().take(mid_end).skip(mid_start) {
assert!(
(sample - expected).abs() < 1e-4,
"sample {i}: {sample} vs expected {expected}"
);
}
}
#[test]
fn center_trim_removes_n_fft_over_2_from_each_end() {
let (client, device) = cpu_setup();
let n_fft = 8;
let hop = 4;
let t_frames = 6;
let f = n_fft / 2 + 1;
let mag = make_tensor(&vec![0.0f32; f * t_frames], &[1, f, t_frames], &device);
let phase = make_tensor(&vec![0.0f32; f * t_frames], &[1, f, t_frames], &device);
let window = make_tensor(&vec![1.0f32; n_fft], &[n_fft], &device);
let with_center = istft(
&client,
&mag,
&phase,
&window,
IStftOptions {
hop_length: hop,
center: true,
eps: 1e-8,
},
)
.unwrap();
let without_center = istft(
&client,
&mag,
&phase,
&window,
IStftOptions {
hop_length: hop,
center: false,
eps: 1e-8,
},
)
.unwrap();
assert_eq!(without_center.shape()[1] - with_center.shape()[1], n_fft);
}
#[test]
fn rejects_mismatched_mag_phase_shapes() {
let (client, device) = cpu_setup();
let mag = make_tensor(&vec![0.0f32; 9], &[1, 3, 3], &device);
let phase = make_tensor(&vec![0.0f32; 6], &[1, 3, 2], &device);
let window = make_tensor(&vec![1.0f32; 4], &[4], &device);
assert!(istft(&client, &mag, &phase, &window, IStftOptions::default()).is_err());
}
#[test]
fn rejects_bad_window_size() {
let (client, device) = cpu_setup();
let mag = make_tensor(&vec![0.0f32; 15], &[1, 5, 3], &device);
let phase = make_tensor(&vec![0.0f32; 15], &[1, 5, 3], &device);
let window = make_tensor(&vec![1.0f32; 4], &[4], &device);
assert!(istft(&client, &mag, &phase, &window, IStftOptions::default()).is_err());
}
}