use crate::error::Result;
use pmetal_bridge::compat::{Array, fft, ops};
#[derive(Debug, Clone)]
pub struct StftConfig {
pub n_fft: i32,
pub hop_length: i32,
pub win_length: Option<i32>,
pub center: bool,
pub pad_mode: PadMode,
}
#[derive(Debug, Clone, Copy, Default)]
pub enum PadMode {
#[default]
Reflect,
Zeros,
Replicate,
}
impl Default for StftConfig {
fn default() -> Self {
Self {
n_fft: 1024,
hop_length: 256,
win_length: None,
center: true,
pad_mode: PadMode::Reflect,
}
}
}
pub fn hann_window(size: i32) -> Result<Array> {
let n = Array::arange(size, 10); let pi = std::f32::consts::PI;
let scale = Array::from_f32(2.0 * pi / (size - 1) as f32);
let cos_term = n.multiply(&scale).cos();
let half = Array::from_f32(0.5);
let one = Array::from_f32(1.0);
Ok(half.multiply(&one.subtract(&cos_term)))
}
pub fn stft(signal: &Array, config: &StftConfig) -> Result<Array> {
let win_length = config.win_length.unwrap_or(config.n_fft);
let window = hann_window(win_length)?;
let window = if win_length < config.n_fft {
let pad_left = (config.n_fft - win_length) / 2;
let pad_right = config.n_fft - win_length - pad_left;
let zeros_left = Array::zeros(&[pad_left], 10);
let zeros_right = Array::zeros(&[pad_right], 10);
ops::concatenate_axis(&[&zeros_left, &window, &zeros_right], 0)
} else {
window
};
let (signal, was_1d) = if signal.ndim() == 1 {
(signal.reshape(&[1, -1]), true)
} else {
(signal.clone(), false)
};
let _batch_size = signal.dim(0);
let _signal_length = signal.dim(1);
let signal = if config.center {
let pad_amount = config.n_fft / 2;
pad_signal(&signal, pad_amount, config.pad_mode)?
} else {
signal
};
let padded_length = signal.dim(1);
let num_frames = (padded_length - config.n_fft) / config.hop_length + 1;
let batch = signal.dim(0);
let mut frames = Vec::with_capacity(num_frames as usize);
for i in 0..num_frames {
let start = i * config.hop_length;
let end = start + config.n_fft;
let frame = signal.slice(&[0, start], &[batch, end]);
frames.push(frame);
}
let framed = ops::stack_axis(&frames, 1);
let windowed = framed.multiply(&window);
let spectrum = fft::rfft(&windowed, Some(config.n_fft), -1);
let spectrum = spectrum.transpose_axes(&[0, 2, 1]);
if was_1d {
Ok(spectrum.squeeze_all())
} else {
Ok(spectrum)
}
}
pub fn istft(stft_matrix: &Array, config: &StftConfig) -> Result<Array> {
let win_length = config.win_length.unwrap_or(config.n_fft);
let window = hann_window(win_length)?;
let window = if win_length < config.n_fft {
let pad_left = (config.n_fft - win_length) / 2;
let pad_right = config.n_fft - win_length - pad_left;
let zeros_left = Array::zeros(&[pad_left], 10);
let zeros_right = Array::zeros(&[pad_right], 10);
ops::concatenate_axis(&[&zeros_left, &window, &zeros_right], 0)
} else {
window
};
let (stft_matrix, was_2d) = if stft_matrix.ndim() == 2 {
(
stft_matrix.reshape(&[1, stft_matrix.dim(0), stft_matrix.dim(1)]),
true,
)
} else {
(stft_matrix.clone(), false)
};
let batch_size = stft_matrix.dim(0);
let num_frames = stft_matrix.dim(2);
let n_fft = config.n_fft;
let hop_length = config.hop_length;
let stft_transposed = stft_matrix.transpose_axes(&[0, 2, 1]);
let ifft_frames = fft::irfft(&stft_transposed, Some(n_fft), -1);
let windowed_frames = ifft_frames.multiply(&window);
let output_length = n_fft + (num_frames - 1) * hop_length;
let window_sq = window.multiply(&window);
let mut output_sum = Array::zeros(&[batch_size, output_length], 10);
let mut norm_sum = Array::zeros(&[output_length], 10);
for i in 0..num_frames {
let offset = i * hop_length;
let pad_before = offset;
let pad_after = output_length - offset - n_fft;
let frame = windowed_frames
.slice(&[0, i, 0], &[batch_size, i + 1, n_fft])
.reshape(&[batch_size, n_fft]);
let padded_frame = frame.pad_constant(&[0, 0, pad_before, pad_after], 0.0);
output_sum = output_sum.add(&padded_frame);
let padded_wsq = window_sq.pad_constant(&[pad_before, pad_after], 0.0);
norm_sum = norm_sum.add(&padded_wsq);
}
let eps = Array::from_f32(1e-8_f32);
let norm_safe = ops::maximum(&norm_sum, &eps);
let norm_broadcast = ops::broadcast_to(&norm_safe, &[batch_size, output_length]);
let output = output_sum.divide(&norm_broadcast);
let output = if config.center {
let trim = n_fft / 2;
let trimmed_length = output_length - 2 * trim;
output.slice(&[0, trim], &[batch_size, trim + trimmed_length])
} else {
output
};
if was_2d {
Ok(output.squeeze_all())
} else {
Ok(output)
}
}
fn pad_signal(signal: &Array, pad_amount: i32, mode: PadMode) -> Result<Array> {
let batch_size = signal.dim(0);
let length = signal.dim(1);
match mode {
PadMode::Zeros => {
let left_pad = Array::zeros(&[batch_size, pad_amount], 10);
let right_pad = Array::zeros(&[batch_size, pad_amount], 10);
Ok(ops::concatenate_axis(&[&left_pad, signal, &right_pad], 1))
}
PadMode::Reflect => {
let left_indices: Vec<i32> = (1..=pad_amount).rev().collect();
let left_pad = if !left_indices.is_empty() {
let indices = Array::from_i32_slice(&left_indices);
signal.take_axis(&indices, 1)
} else {
Array::zeros(&[batch_size, 0], 10)
};
let right_indices: Vec<i32> = ((length - pad_amount - 1)..(length - 1)).rev().collect();
let right_pad = if !right_indices.is_empty() {
let indices = Array::from_i32_slice(&right_indices);
signal.take_axis(&indices, 1)
} else {
Array::zeros(&[batch_size, 0], 10)
};
Ok(ops::concatenate_axis(&[&left_pad, signal, &right_pad], 1))
}
PadMode::Replicate => {
let left_val = signal.slice(&[0, 0], &[batch_size, 1]);
let right_start = length - 1;
let right_val = signal.slice(&[0, right_start], &[batch_size, length]);
let left_pad = ops::broadcast_to(&left_val, &[batch_size, pad_amount]);
let right_pad = ops::broadcast_to(&right_val, &[batch_size, pad_amount]);
Ok(ops::concatenate_axis(&[&left_pad, signal, &right_pad], 1))
}
}
}
pub fn stft_magnitude(stft_matrix: &Array) -> Result<Array> {
Ok(stft_matrix.abs_val())
}
pub fn stft_power(stft_matrix: &Array) -> Result<Array> {
let mag = stft_matrix.abs_val();
Ok(mag.multiply(&mag))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hann_window() {
let window = hann_window(4).unwrap();
let w2 = window.clone();
w2.eval();
assert_eq!(w2.shape(), &[4]);
}
#[test]
fn test_stft_config() {
let config = StftConfig::default();
assert_eq!(config.n_fft, 1024);
assert_eq!(config.hop_length, 256);
}
#[test]
fn test_stft_istft_roundtrip() {
let n_fft = 64;
let hop_length = n_fft / 4;
let num_samples = 512;
let config = StftConfig {
n_fft,
hop_length,
win_length: None,
center: true,
pad_mode: PadMode::Reflect,
};
let pi = std::f32::consts::PI;
let samples: Vec<f32> = (0..num_samples)
.map(|n| (2.0 * pi * 440.0 * n as f32 / 16000.0).sin())
.collect();
let signal = Array::from_f32_slice(&samples, &[num_samples]);
let spectrum = stft(&signal, &config).unwrap();
let reconstructed = istft(&spectrum, &config).unwrap();
let r2 = reconstructed.clone();
r2.eval();
assert_eq!(r2.shape(), &[num_samples]);
let diff = r2.subtract(&signal);
let abs_diff = diff.abs_val();
let max_err_arr = abs_diff.max(None);
max_err_arr.eval();
let max_err: f32 = max_err_arr.item_f32();
assert!(
max_err < 1e-3,
"STFT round-trip error too large: max |error| = {max_err}"
);
}
#[test]
fn test_istft_batched_shape() {
let n_fft = 32;
let hop_length = 8;
let num_samples = 128;
let batch_size = 3;
let config = StftConfig {
n_fft,
hop_length,
win_length: None,
center: false,
pad_mode: PadMode::Zeros,
};
let samples: Vec<f32> = (0..(batch_size * num_samples))
.map(|i| (i as f32 / num_samples as f32).sin())
.collect();
let signal = Array::from_f32_slice(&samples, &[batch_size, num_samples]);
let spectrum = stft(&signal, &config).unwrap();
let reconstructed = istft(&spectrum, &config).unwrap();
let r2 = reconstructed.clone();
r2.eval();
assert_eq!(r2.dim(0), batch_size);
}
}