use core::mem::MaybeUninit;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::{
vdupq_n_f32, vld1q_f32, vmaxq_f32, vminq_f32, vmulq_f32, vst1q_f32, vsubq_f32,
};
use crate::error::{Error, Result};
#[inline]
fn mel_scale_kaldi(hz: f32) -> f32 {
1127.0_f32 * (1.0_f32 + hz / 700.0_f32).ln()
}
#[inline]
#[doc(hidden)]
pub fn get_mel_banks_kaldi_scalar(
out: &mut [MaybeUninit<f32>],
num_bins: usize,
num_fft_bins: usize,
fft_bin_width: f32,
mel_low: f32,
mel_delta: f32,
) {
let elements = num_bins.checked_mul(num_fft_bins).unwrap_or_else(|| {
panic!("get_mel_banks_kaldi_scalar: dimensions {num_bins}x{num_fft_bins} overflow usize")
});
assert_eq!(
out.len(),
elements,
"get_mel_banks_kaldi_scalar: out.len() ({}) must equal num_bins * num_fft_bins ({} * {} = {})",
out.len(),
num_bins,
num_fft_bins,
elements,
);
for m in 0..num_bins {
let left_mel = mel_low + (m as f32) * mel_delta;
let center_mel = mel_low + ((m + 1) as f32) * mel_delta;
let right_mel = mel_low + ((m + 2) as f32) * mel_delta;
let lc = center_mel - left_mel;
let cr = right_mel - center_mel;
let row_off = m * num_fft_bins;
if lc <= 0.0 || cr <= 0.0 {
for k in 0..num_fft_bins {
out[row_off + k].write(0.0);
}
continue;
}
for k in 0..num_fft_bins {
let mel = mel_scale_kaldi(fft_bin_width * k as f32);
let up = (mel - left_mel) / lc;
let down = (right_mel - mel) / cr;
let v = up.min(down).max(0.0);
out[row_off + k].write(v);
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[target_feature(enable = "neon")]
unsafe fn get_mel_banks_kaldi_neon(
out: &mut [MaybeUninit<f32>],
num_bins: usize,
num_fft_bins: usize,
fft_bin_width: f32,
mel_low: f32,
mel_delta: f32,
) -> Result<()> {
let elements = num_bins.checked_mul(num_fft_bins).unwrap_or_else(|| {
panic!("get_mel_banks_kaldi_neon: dimensions {num_bins}x{num_fft_bins} overflow usize")
});
assert_eq!(
out.len(),
elements,
"get_mel_banks_kaldi_neon: out.len() ({}) must equal num_bins * num_fft_bins ({} * {} = {})",
out.len(),
num_bins,
num_fft_bins,
elements,
);
let mut mel_values: Vec<f32> = Vec::new();
mel_values
.try_reserve_exact(num_fft_bins)
.map_err(|_| Error::OutOfMemory)?;
for k in 0..num_fft_bins {
mel_values.push(mel_scale_kaldi(fft_bin_width * k as f32));
}
let body_len = num_fft_bins - (num_fft_bins % 4);
let zero = vdupq_n_f32(0.0);
unsafe {
let dst_base = out.as_mut_ptr().cast::<f32>();
let mel_base = mel_values.as_ptr();
for m in 0..num_bins {
let left_mel = mel_low + (m as f32) * mel_delta;
let center_mel = mel_low + ((m + 1) as f32) * mel_delta;
let right_mel = mel_low + ((m + 2) as f32) * mel_delta;
let lc = center_mel - left_mel;
let cr = right_mel - center_mel;
let row_off = m * num_fft_bins;
if lc <= 0.0 || cr <= 0.0 {
let mut k = 0usize;
while k + 4 <= body_len {
vst1q_f32(dst_base.add(row_off + k), zero);
k += 4;
}
for kk in body_len..num_fft_bins {
out[row_off + kk].write(0.0);
}
continue;
}
let inv_lc = 1.0_f32 / lc;
let inv_cr = 1.0_f32 / cr;
let left_over_lc = left_mel * inv_lc;
let right_over_cr = right_mel * inv_cr;
let inv_lc_v = vdupq_n_f32(inv_lc);
let inv_cr_v = vdupq_n_f32(inv_cr);
let left_over_lc_v = vdupq_n_f32(left_over_lc);
let right_over_cr_v = vdupq_n_f32(right_over_cr);
let mut k = 0usize;
while k + 4 <= body_len {
let mel_v = vld1q_f32(mel_base.add(k));
let up = vsubq_f32(vmulq_f32(mel_v, inv_lc_v), left_over_lc_v);
let prod = vmulq_f32(mel_v, inv_cr_v);
let down = vsubq_f32(right_over_cr_v, prod);
let mn = vminq_f32(up, down);
let v = vmaxq_f32(mn, zero);
vst1q_f32(dst_base.add(row_off + k), v);
k += 4;
}
for kk in body_len..num_fft_bins {
let mel = mel_values[kk];
let up = (mel - left_mel) / lc;
let down = (right_mel - mel) / cr;
let v = up.min(down).max(0.0);
out[row_off + kk].write(v);
}
}
}
Ok(())
}
#[inline]
#[doc(hidden)]
pub fn get_mel_banks_kaldi_rows(
out: &mut [MaybeUninit<f32>],
num_bins: usize,
num_fft_bins: usize,
fft_bin_width: f32,
mel_low: f32,
mel_delta: f32,
) -> Result<()> {
let elements = num_bins.checked_mul(num_fft_bins).unwrap_or_else(|| {
panic!(
"simd::audio::get_mel_banks_kaldi_rows: dimensions {num_bins}x{num_fft_bins} overflow usize"
)
});
assert_eq!(
out.len(),
elements,
"simd::audio::get_mel_banks_kaldi_rows: out.len() ({}) must equal num_bins * num_fft_bins \
({} * {} = {})",
out.len(),
num_bins,
num_fft_bins,
elements,
);
#[cfg(target_arch = "aarch64")]
{
if crate::simd::is_neon_available() {
unsafe {
return get_mel_banks_kaldi_neon(
out,
num_bins,
num_fft_bins,
fft_bin_width,
mel_low,
mel_delta,
);
}
}
}
get_mel_banks_kaldi_scalar(
out,
num_bins,
num_fft_bins,
fft_bin_width,
mel_low,
mel_delta,
);
Ok(())
}
#[cfg(test)]
mod tests {
use super::{get_mel_banks_kaldi_rows, get_mel_banks_kaldi_scalar};
fn realistic_params(num_bins: usize) -> (usize, f32, f32, f32) {
let n_fft_padded = 512usize;
let num_fft_bins = n_fft_padded / 2;
let sample_freq = 16_000.0_f32;
let fft_bin_width = sample_freq / n_fft_padded as f32;
let mel_low = 1127.0_f32 * (1.0_f32 + 20.0_f32 / 700.0_f32).ln();
let mel_high = 1127.0_f32 * (1.0_f32 + 7800.0_f32 / 700.0_f32).ln();
let mel_delta = (mel_high - mel_low) / (num_bins as f32 + 1.0);
(num_fft_bins, fft_bin_width, mel_low, mel_delta)
}
fn bank_via_scalar(num_bins: usize) -> Vec<f32> {
let (num_fft_bins, w, mel_low, mel_delta) = realistic_params(num_bins);
let mut out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let spare = out.spare_capacity_mut();
get_mel_banks_kaldi_scalar(
&mut spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
w,
mel_low,
mel_delta,
);
unsafe { out.set_len(num_bins * num_fft_bins) };
out
}
fn bank_via_dispatch(num_bins: usize) -> Vec<f32> {
let (num_fft_bins, w, mel_low, mel_delta) = realistic_params(num_bins);
let mut out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let spare = out.spare_capacity_mut();
get_mel_banks_kaldi_rows(
&mut spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
w,
mel_low,
mel_delta,
)
.expect("realistic params should not OOM");
unsafe { out.set_len(num_bins * num_fft_bins) };
out
}
#[test]
fn kaldi_mel_scalar_matches_dispatcher_tolerance() {
for &num_bins in &[5usize, 10, 23, 40, 80] {
let s = bank_via_scalar(num_bins);
let d = bank_via_dispatch(num_bins);
assert_eq!(s.len(), d.len(), "shape parity at num_bins={num_bins}");
for (i, (a, b)) in s.iter().zip(d.iter()).enumerate() {
let diff = (a - b).abs();
let tol = 1e-5_f32.max(1e-5_f32 * a.abs());
assert!(
diff <= tol,
"Tolerance mismatch at num_bins={num_bins} i={i}: scalar={a} dispatcher={b} \
diff={diff} tol={tol}"
);
}
}
}
#[test]
fn kaldi_mel_triangle_shape() {
let num_bins = 23;
let bank = bank_via_dispatch(num_bins);
let (num_fft_bins, _, _, _) = realistic_params(num_bins);
for m in 0..num_bins {
let row = &bank[m * num_fft_bins..(m + 1) * num_fft_bins];
for (i, &v) in row.iter().enumerate() {
assert!(
(0.0..=1.0001).contains(&v),
"cell out of [0, 1]: m={m} i={i} v={v}"
);
}
}
}
#[test]
fn kaldi_mel_collapsed_row_is_zero() {
let num_bins = 4;
let num_fft_bins = 16;
let fft_bin_width = 30.0_f32;
let mel_low = 100.0_f32;
let mel_delta = 0.0_f32;
let mut out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let spare = out.spare_capacity_mut();
get_mel_banks_kaldi_rows(
&mut spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
fft_bin_width,
mel_low,
mel_delta,
)
.expect("small collapsed-row params should not OOM");
unsafe { out.set_len(num_bins * num_fft_bins) };
for (i, &v) in out.iter().enumerate() {
assert_eq!(v, 0.0, "collapsed row cell i={i} should be 0.0");
}
}
#[test]
#[should_panic(
expected = "simd::audio::get_mel_banks_kaldi_rows: out.len() (3) must equal num_bins"
)]
fn kaldi_mel_panics_on_size_mismatch() {
let mut out: Vec<f32> = Vec::with_capacity(3); let spare = out.spare_capacity_mut();
let _ = get_mel_banks_kaldi_rows(&mut spare[..3], 2, 4, 30.0, 100.0, 50.0);
}
#[test]
fn kaldi_mel_dispatcher_returns_result_for_fallible_allocation() {
let num_bins = 2usize;
let num_fft_bins = 4usize;
let mut out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let spare = out.spare_capacity_mut();
let r: Result<(), super::Error> = get_mel_banks_kaldi_rows(
&mut spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
30.0,
100.0,
50.0,
);
assert!(r.is_ok(), "small input should not OOM");
}
#[test]
#[should_panic(expected = "overflow usize")]
fn kaldi_mel_panics_on_dimension_overflow() {
let num_bins = usize::MAX / 4 + 1;
let num_fft_bins = 4usize;
let mut out: Vec<f32> = Vec::new();
let spare = out.spare_capacity_mut();
let _ = get_mel_banks_kaldi_rows(spare, num_bins, num_fft_bins, 30.0, 100.0, 50.0);
}
#[test]
#[should_panic(expected = "overflow usize")]
fn kaldi_mel_scalar_panics_on_dimension_overflow() {
let num_bins = usize::MAX / 4 + 1;
let num_fft_bins = 4usize;
let mut out: Vec<f32> = Vec::new();
let spare = out.spare_capacity_mut();
get_mel_banks_kaldi_scalar(spare, num_bins, num_fft_bins, 30.0, 100.0, 50.0);
}
#[test]
#[should_panic(expected = "get_mel_banks_kaldi_scalar: out.len() (3) must equal num_bins")]
fn kaldi_mel_scalar_panics_on_size_mismatch() {
let mut out: Vec<f32> = Vec::with_capacity(3); let spare = out.spare_capacity_mut();
get_mel_banks_kaldi_scalar(&mut spare[..3], 2, 4, 30.0, 100.0, 50.0);
}
#[test]
fn kaldi_mel_scalar_collapsed_row_is_zero() {
let num_bins = 4usize;
let num_fft_bins = 16usize;
let fft_bin_width = 30.0_f32;
let mel_low = 100.0_f32;
let mel_delta = 0.0_f32;
let mut out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let spare = out.spare_capacity_mut();
get_mel_banks_kaldi_scalar(
&mut spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
fft_bin_width,
mel_low,
mel_delta,
);
unsafe { out.set_len(num_bins * num_fft_bins) };
for (i, &v) in out.iter().enumerate() {
assert_eq!(v, 0.0, "scalar collapsed-row cell i={i} should be 0.0");
}
}
#[test]
fn kaldi_mel_scalar_collapsed_row_odd_width_is_zero() {
let num_bins = 3usize;
let num_fft_bins = 17usize;
let mut out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let spare = out.spare_capacity_mut();
get_mel_banks_kaldi_scalar(
&mut spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
30.0,
100.0,
0.0, );
unsafe { out.set_len(num_bins * num_fft_bins) };
assert!(
out.iter().all(|&v| v == 0.0),
"odd-width collapsed rows should be all 0.0"
);
}
#[cfg(target_arch = "aarch64")]
#[test]
#[should_panic(expected = "overflow usize")]
fn kaldi_mel_neon_panics_on_dimension_overflow() {
if !crate::simd::is_neon_available() {
panic!("overflow usize");
}
let num_bins = usize::MAX / 4 + 1;
let num_fft_bins = 4usize;
let mut out: Vec<f32> = Vec::new();
let spare = out.spare_capacity_mut();
let _ =
unsafe { super::get_mel_banks_kaldi_neon(spare, num_bins, num_fft_bins, 30.0, 100.0, 50.0) };
}
#[cfg(target_arch = "aarch64")]
#[test]
#[should_panic(expected = "get_mel_banks_kaldi_neon: out.len() (3) must equal num_bins")]
fn kaldi_mel_neon_panics_on_size_mismatch() {
if !crate::simd::is_neon_available() {
panic!("get_mel_banks_kaldi_neon: out.len() (3) must equal num_bins");
}
let mut out: Vec<f32> = Vec::with_capacity(3); let spare = out.spare_capacity_mut();
let _ = unsafe { super::get_mel_banks_kaldi_neon(&mut spare[..3], 2, 4, 30.0, 100.0, 50.0) };
}
#[cfg(target_arch = "aarch64")]
#[test]
fn kaldi_mel_neon_collapsed_row_tail_is_zero() {
if !crate::simd::is_neon_available() {
return;
}
let num_bins = 3usize;
let num_fft_bins = 17usize; let mut out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let spare = out.spare_capacity_mut();
let r = unsafe {
super::get_mel_banks_kaldi_neon(
&mut spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
30.0,
100.0,
0.0, )
};
r.expect("tiny collapsed-row params should not OOM");
unsafe { out.set_len(num_bins * num_fft_bins) };
assert!(
out.iter().all(|&v| v == 0.0),
"NEON collapsed-row body + tail should be all 0.0"
);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn kaldi_mel_neon_odd_width_matches_scalar_tolerance() {
if !crate::simd::is_neon_available() {
return;
}
let num_bins = 6usize;
let num_fft_bins = 17usize; let fft_bin_width = 31.25_f32;
let mel_low = 1127.0_f32 * (1.0_f32 + 20.0_f32 / 700.0_f32).ln();
let mel_high = 1127.0_f32 * (1.0_f32 + 7800.0_f32 / 700.0_f32).ln();
let mel_delta = (mel_high - mel_low) / (num_bins as f32 + 1.0);
let mut s_out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let s_spare = s_out.spare_capacity_mut();
get_mel_banks_kaldi_scalar(
&mut s_spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
fft_bin_width,
mel_low,
mel_delta,
);
unsafe { s_out.set_len(num_bins * num_fft_bins) };
let mut n_out: Vec<f32> = Vec::with_capacity(num_bins * num_fft_bins);
let n_spare = n_out.spare_capacity_mut();
let r = unsafe {
super::get_mel_banks_kaldi_neon(
&mut n_spare[..num_bins * num_fft_bins],
num_bins,
num_fft_bins,
fft_bin_width,
mel_low,
mel_delta,
)
};
r.expect("realistic odd-width params should not OOM");
unsafe { n_out.set_len(num_bins * num_fft_bins) };
assert_eq!(s_out.len(), n_out.len(), "odd-width shape parity");
for (i, (a, b)) in s_out.iter().zip(n_out.iter()).enumerate() {
let diff = (a - b).abs();
let tol = 1e-5_f32.max(1e-5_f32 * a.abs());
assert!(
diff <= tol,
"NEON/scalar Tolerance mismatch at i={i}: scalar={a} neon={b} diff={diff} tol={tol}"
);
}
}
}