use rustfft::{num_complex::Complex, FftPlanner};
use super::config::PaddingMode;
pub fn apply_padding(samples: &[f32], pad_size: usize, mode: PaddingMode) -> Vec<f32> {
match mode {
PaddingMode::Reflect => pad_reflect(samples, pad_size, pad_size),
PaddingMode::Zero => pad_zero(samples, pad_size, pad_size),
PaddingMode::None => samples.to_vec(),
}
}
fn calculate_reflect_offset(i: i32, w: i32) -> usize {
((i + w) % (2 * w) - w).unsigned_abs() as usize
}
fn pad_reflect(samples: &[f32], pad_left: usize, pad_right: usize) -> Vec<f32> {
let n = samples.len();
if n == 0 {
return vec![0.0f32; pad_left + pad_right];
}
let total_len = pad_left + n + pad_right;
let mut padded = vec![0.0f32; total_len];
let w = (n - 1) as i32;
for i in 0..n {
padded[pad_left + i] = samples[i];
}
if n > 1 {
for i in 1..=pad_left {
padded[pad_left - i] = samples[calculate_reflect_offset(i as i32, w)];
}
for i in 1..=pad_right {
padded[w as usize + pad_left + i] = samples[calculate_reflect_offset(w - i as i32, w)];
}
} else {
let val = samples[0];
for i in 0..pad_left {
padded[i] = val;
}
for i in 0..pad_right {
padded[pad_left + 1 + i] = val;
}
}
padded
}
fn pad_zero(samples: &[f32], pad_left: usize, pad_right: usize) -> Vec<f32> {
let mut padded = vec![0.0f32; pad_left + samples.len() + pad_right];
padded[pad_left..pad_left + samples.len()].copy_from_slice(samples);
padded
}
pub fn create_hann_window(size: usize) -> Vec<f64> {
let factor = 2.0 * std::f64::consts::PI / size as f64;
(0..size)
.map(|i| 0.5 - 0.5 * (i as f64 * factor).cos())
.collect()
}
pub fn compute_stft_power(samples: &[f32], n_fft: usize, hop_length: usize) -> Vec<Vec<f64>> {
let window = create_hann_window(n_fft);
let n_freqs = n_fft / 2 + 1;
let n_frames = (samples.len().saturating_sub(n_fft)) / hop_length + 1;
if n_frames == 0 {
return vec![];
}
let mut planner = FftPlanner::<f64>::new();
let fft = planner.plan_fft_forward(n_fft);
let mut power_spec = Vec::with_capacity(n_frames);
for frame_idx in 0..n_frames {
let start = frame_idx * hop_length;
let end = start + n_fft;
if end > samples.len() {
break;
}
let mut fft_input: Vec<Complex<f64>> = (0..n_fft)
.map(|i| {
let sample = samples[start + i] as f64;
Complex::new(sample * window[i], 0.0)
})
.collect();
fft.process(&mut fft_input);
let frame_power: Vec<f64> = fft_input[..n_freqs].iter().map(|c| c.norm_sqr()).collect();
power_spec.push(frame_power);
}
power_spec
}
pub fn apply_log_normalization(mel_spec: &mut [Vec<f64>], max_frames: Option<usize>) -> Vec<f32> {
const EPSILON: f64 = 1e-10;
for mel_row in mel_spec.iter_mut() {
for val in mel_row.iter_mut() {
*val = (*val).max(EPSILON).log10();
}
}
let max_val = mel_spec
.iter()
.flat_map(|row| row.iter())
.fold(f64::NEG_INFINITY, |a, &b| a.max(b));
let min_clamp = max_val - 8.0;
let n_mels = mel_spec.len();
let n_frames = mel_spec.first().map(|r| r.len()).unwrap_or(0);
let output_frames = max_frames.unwrap_or(n_frames);
let mut mel_data = Vec::with_capacity(n_mels * output_frames);
for mel_idx in 0..n_mels {
for frame_idx in 0..output_frames {
let val = if frame_idx < n_frames {
mel_spec[mel_idx][frame_idx]
} else {
min_clamp };
let clamped = val.max(min_clamp);
let normalized = ((clamped + 4.0) / 4.0) as f32;
mel_data.push(normalized);
}
}
mel_data
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hann_window() {
let window = create_hann_window(400);
assert_eq!(window.len(), 400);
assert!((window[0] - 0.0).abs() < 0.001);
assert!((window[200] - 1.0).abs() < 0.001);
}
#[test]
fn test_pad_reflect_basic() {
let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let padded = pad_reflect(&samples, 2, 2);
assert_eq!(padded.len(), 9);
assert_eq!(padded[2], 1.0); assert_eq!(padded[6], 5.0); }
#[test]
fn test_pad_zero() {
let samples = vec![1.0, 2.0, 3.0];
let padded = pad_zero(&samples, 2, 2);
assert_eq!(padded, vec![0.0, 0.0, 1.0, 2.0, 3.0, 0.0, 0.0]);
}
#[test]
fn test_stft_power_shape() {
let samples: Vec<f32> = (0..16000)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 16000.0).sin())
.collect();
let power = compute_stft_power(&samples, 400, 160);
assert!(!power.is_empty());
assert_eq!(power[0].len(), 201); }
}