use core::mem::MaybeUninit;
#[cfg(target_arch = "aarch64")]
use core::arch::aarch64::{
vdupq_n_f32, vld1q_f32, vmaxq_f32, vminq_f32, vmulq_f32, vst1q_f32, vsubq_f32,
};
#[inline]
#[doc(hidden)]
pub fn mel_filter_bank_rows_scalar(
out: &mut [MaybeUninit<f32>],
all_freqs: &[f32],
f_pts: &[f32],
n_mels: usize,
) {
let n_freqs = all_freqs.len();
let elements = n_mels.checked_mul(n_freqs).unwrap_or_else(|| {
panic!("mel_filter_bank_rows_scalar: dimensions {n_mels}x{n_freqs} overflow usize")
});
assert_eq!(
out.len(),
elements,
"mel_filter_bank_rows_scalar: out.len() ({}) must equal n_mels * n_freqs ({} * {} = {})",
out.len(),
n_mels,
n_freqs,
elements,
);
assert_eq!(
f_pts.len(),
n_mels + 2,
"mel_filter_bank_rows_scalar: f_pts.len() ({}) must equal n_mels + 2 ({})",
f_pts.len(),
n_mels + 2,
);
for m in 0..n_mels {
let left = f_pts[m];
let center = f_pts[m + 1];
let right = f_pts[m + 2];
let lc = center - left;
let cr = right - center;
let row_off = m * n_freqs;
if lc <= 0.0 || cr <= 0.0 {
for k in 0..n_freqs {
out[row_off + k].write(0.0);
}
continue;
}
for (f, &freq) in all_freqs.iter().enumerate() {
let up = (freq - left) / lc;
let down = (right - freq) / cr;
let v = up.min(down).max(0.0);
out[row_off + f].write(v);
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline]
#[target_feature(enable = "neon")]
unsafe fn mel_filter_bank_rows_neon(
out: &mut [MaybeUninit<f32>],
all_freqs: &[f32],
f_pts: &[f32],
n_mels: usize,
) {
let n_freqs = all_freqs.len();
let elements = n_mels.checked_mul(n_freqs).unwrap_or_else(|| {
panic!("mel_filter_bank_rows_neon: dimensions {n_mels}x{n_freqs} overflow usize")
});
assert_eq!(
out.len(),
elements,
"mel_filter_bank_rows_neon: out.len() ({}) must equal n_mels * n_freqs ({} * {} = {})",
out.len(),
n_mels,
n_freqs,
elements,
);
assert_eq!(
f_pts.len(),
n_mels + 2,
"mel_filter_bank_rows_neon: f_pts.len() ({}) must equal n_mels + 2 ({})",
f_pts.len(),
n_mels + 2,
);
let body_len = n_freqs - (n_freqs % 4);
let zero = vdupq_n_f32(0.0);
unsafe {
let dst_base = out.as_mut_ptr().cast::<f32>();
let freq_base = all_freqs.as_ptr();
for m in 0..n_mels {
let left = f_pts[m];
let center = f_pts[m + 1];
let right = f_pts[m + 2];
let lc = center - left;
let cr = right - center;
let row_off = m * n_freqs;
if lc <= 0.0 || cr <= 0.0 {
let mut k = 0usize;
while k + 4 <= body_len {
vst1q_f32(dst_base.add(row_off + k), zero);
k += 4;
}
for kk in body_len..n_freqs {
out[row_off + kk].write(0.0);
}
continue;
}
let inv_lc = 1.0_f32 / lc;
let inv_cr = 1.0_f32 / cr;
let left_over_lc = left * inv_lc;
let right_over_cr = right * inv_cr;
let inv_lc_v = vdupq_n_f32(inv_lc);
let inv_cr_v = vdupq_n_f32(inv_cr);
let left_over_lc_v = vdupq_n_f32(left_over_lc);
let right_over_cr_v = vdupq_n_f32(right_over_cr);
let mut k = 0usize;
while k + 4 <= body_len {
let f_v = vld1q_f32(freq_base.add(k));
let up = vsubq_f32(vmulq_f32(f_v, inv_lc_v), left_over_lc_v);
let prod = vmulq_f32(f_v, inv_cr_v);
let down = vsubq_f32(right_over_cr_v, prod);
let mn = vminq_f32(up, down);
let v = vmaxq_f32(mn, zero);
vst1q_f32(dst_base.add(row_off + k), v);
k += 4;
}
for kk in body_len..n_freqs {
let freq = all_freqs[kk];
let up = (freq - left) / lc;
let down = (right - freq) / cr;
let v = up.min(down).max(0.0);
out[row_off + kk].write(v);
}
}
}
}
#[inline]
#[doc(hidden)]
pub fn mel_filter_bank_rows(
out: &mut [MaybeUninit<f32>],
all_freqs: &[f32],
f_pts: &[f32],
n_mels: usize,
) {
let n_freqs = all_freqs.len();
let elements = n_mels.checked_mul(n_freqs).unwrap_or_else(|| {
panic!("simd::audio::mel_filter_bank_rows: dimensions {n_mels}x{n_freqs} overflow usize")
});
assert_eq!(
out.len(),
elements,
"simd::audio::mel_filter_bank_rows: out.len() ({}) must equal n_mels * n_freqs ({} * {} = {})",
out.len(),
n_mels,
n_freqs,
elements,
);
assert_eq!(
f_pts.len(),
n_mels + 2,
"simd::audio::mel_filter_bank_rows: f_pts.len() ({}) must equal n_mels + 2 ({})",
f_pts.len(),
n_mels + 2,
);
#[cfg(target_arch = "aarch64")]
{
if crate::simd::is_neon_available() {
unsafe { mel_filter_bank_rows_neon(out, all_freqs, f_pts, n_mels) };
return;
}
}
mel_filter_bank_rows_scalar(out, all_freqs, f_pts, n_mels);
}
#[cfg(test)]
mod tests {
use super::{mel_filter_bank_rows, mel_filter_bank_rows_scalar};
fn make_inputs(n_freqs: usize, n_mels: usize) -> (Vec<f32>, Vec<f32>) {
let mut all_freqs: Vec<f32> = Vec::with_capacity(n_freqs);
let denom = (n_freqs as f32 - 1.0).max(1.0);
for k in 0..n_freqs {
all_freqs.push(8000.0 * (k as f32) / denom);
}
let n_pts = n_mels + 2;
let mut f_pts: Vec<f32> = Vec::with_capacity(n_pts);
let pts_denom = (n_pts as f32 - 1.0).max(1.0);
for i in 0..n_pts {
f_pts.push(50.0 + (7500.0 - 50.0) * (i as f32) / pts_denom);
}
(all_freqs, f_pts)
}
fn bank_via_scalar(n_freqs: usize, n_mels: usize) -> Vec<f32> {
let (all_freqs, f_pts) = make_inputs(n_freqs, n_mels);
let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows_scalar(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
unsafe { out.set_len(n_mels * n_freqs) };
out
}
fn bank_via_dispatch(n_freqs: usize, n_mels: usize) -> Vec<f32> {
let (all_freqs, f_pts) = make_inputs(n_freqs, n_mels);
let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
unsafe { out.set_len(n_mels * n_freqs) };
out
}
#[test]
fn mel_filter_bank_scalar_matches_dispatcher_tolerance() {
let n_mels = 8usize;
for &n_freqs in &[5usize, 8, 16, 17, 64, 201, 257] {
let s = bank_via_scalar(n_freqs, n_mels);
let d = bank_via_dispatch(n_freqs, n_mels);
assert_eq!(s.len(), d.len(), "shape parity at n_freqs={n_freqs}");
for (i, (a, b)) in s.iter().zip(d.iter()).enumerate() {
let diff = (a - b).abs();
let tol = 1e-6_f32.max(1e-6_f32 * a.abs());
assert!(
diff <= tol,
"Tolerance mismatch at n_freqs={n_freqs} i={i}: scalar={a} dispatcher={b} \
diff={diff} tol={tol}"
);
}
}
}
#[test]
fn mel_filter_bank_triangle_shape() {
let n_mels = 4usize;
let n_freqs = 65;
let bank = bank_via_dispatch(n_freqs, n_mels);
for m in 0..n_mels {
let row = &bank[m * n_freqs..(m + 1) * n_freqs];
for (i, &v) in row.iter().enumerate() {
assert!(
(0.0..=1.0001).contains(&v),
"cell out of [0, 1]: m={m} i={i} v={v}"
);
}
}
}
#[test]
fn mel_filter_bank_collapsed_row_is_zero() {
let n_mels = 3;
let n_freqs = 16;
let all_freqs: Vec<f32> = (0..n_freqs).map(|k| 100.0 * k as f32).collect();
let f_pts = vec![0.0, 500.0, 500.0, 1500.0, 2000.0];
let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
unsafe { out.set_len(n_mels * n_freqs) };
for k in 0..n_freqs {
assert_eq!(
out[n_freqs + k],
0.0,
"collapsed row m=1 cell k={k} should be 0.0"
);
}
}
#[test]
#[should_panic(
expected = "simd::audio::mel_filter_bank_rows: out.len() (3) must equal n_mels * n_freqs"
)]
fn mel_filter_bank_panics_on_size_mismatch() {
let all_freqs = vec![100.0_f32, 200.0, 300.0, 400.0];
let f_pts = vec![0.0, 200.0, 400.0, 600.0]; let mut out: Vec<f32> = Vec::with_capacity(3); let spare = out.spare_capacity_mut();
mel_filter_bank_rows(&mut spare[..3], &all_freqs, &f_pts, 2);
}
#[test]
#[should_panic(expected = "f_pts.len() (5) must equal n_mels + 2 (4)")]
fn mel_filter_bank_panics_on_f_pts_size_mismatch() {
let all_freqs = vec![100.0_f32, 200.0, 300.0];
let f_pts = vec![0.0, 200.0, 400.0, 600.0, 800.0]; let mut out: Vec<f32> = Vec::with_capacity(6);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows(&mut spare[..6], &all_freqs, &f_pts, 2);
}
#[test]
#[should_panic(expected = "overflow usize")]
fn mel_filter_bank_panics_on_dimension_overflow() {
let n_mels = usize::MAX / 4 + 1;
let all_freqs = vec![0.0_f32; 4]; let f_pts = vec![0.0_f32; 4];
let mut out: Vec<f32> = Vec::new();
let spare = out.spare_capacity_mut();
mel_filter_bank_rows(spare, &all_freqs, &f_pts, n_mels);
}
#[test]
fn mel_filter_bank_scalar_matches_handcomputed_triangle() {
let n_mels = 1usize;
let all_freqs = vec![0.0_f32, 100.0, 200.0, 300.0, 400.0, 500.0];
let n_freqs = all_freqs.len();
let f_pts = vec![0.0_f32, 200.0, 400.0];
let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows_scalar(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
unsafe { out.set_len(n_mels * n_freqs) };
let expected = [0.0_f32, 0.5, 1.0, 0.5, 0.0, 0.0];
for (i, (&got, &want)) in out.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() <= 1e-6,
"hand-computed triangle mismatch i={i}: got={got} want={want}"
);
}
}
#[test]
fn mel_filter_bank_scalar_collapsed_row_is_zero() {
let n_mels = 3usize;
let n_freqs = 6usize; let all_freqs: Vec<f32> = (0..n_freqs).map(|k| 100.0 * k as f32).collect();
let f_pts = vec![0.0_f32, 500.0, 500.0, 1500.0, 2000.0];
let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows_scalar(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
unsafe { out.set_len(n_mels * n_freqs) };
for k in 0..n_freqs {
assert_eq!(
out[n_freqs + k],
0.0,
"scalar collapsed row m=1 cell k={k} should be 0.0"
);
}
}
#[test]
#[should_panic(
expected = "mel_filter_bank_rows_scalar: out.len() (3) must equal n_mels * n_freqs"
)]
fn mel_filter_bank_scalar_panics_on_size_mismatch() {
let all_freqs = vec![100.0_f32, 200.0, 300.0, 400.0]; let f_pts = vec![0.0_f32, 200.0, 400.0, 600.0]; let mut out: Vec<f32> = Vec::with_capacity(3); let spare = out.spare_capacity_mut();
mel_filter_bank_rows_scalar(&mut spare[..3], &all_freqs, &f_pts, 2);
}
#[test]
#[should_panic(
expected = "mel_filter_bank_rows_scalar: f_pts.len() (5) must equal n_mels + 2 (4)"
)]
fn mel_filter_bank_scalar_panics_on_f_pts_size_mismatch() {
let all_freqs = vec![100.0_f32, 200.0, 300.0]; let f_pts = vec![0.0_f32, 200.0, 400.0, 600.0, 800.0]; let mut out: Vec<f32> = Vec::with_capacity(6);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows_scalar(&mut spare[..6], &all_freqs, &f_pts, 2);
}
#[test]
#[should_panic(expected = "overflow usize")]
fn mel_filter_bank_scalar_panics_on_dimension_overflow() {
let n_mels = usize::MAX / 4 + 1;
let all_freqs = vec![0.0_f32; 4]; let f_pts = vec![0.0_f32; 4];
let mut out: Vec<f32> = Vec::new();
let spare = out.spare_capacity_mut();
mel_filter_bank_rows_scalar(spare, &all_freqs, &f_pts, n_mels);
}
#[cfg(target_arch = "aarch64")]
#[test]
fn mel_filter_bank_neon_collapsed_row_with_tail_is_zero() {
if !crate::simd::is_neon_available() {
return;
}
let n_mels = 3usize;
let n_freqs = 6usize; let all_freqs: Vec<f32> = (0..n_freqs).map(|k| 100.0 * k as f32).collect();
let f_pts = vec![0.0_f32, 500.0, 500.0, 1500.0, 2000.0];
let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
unsafe {
super::mel_filter_bank_rows_neon(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
out.set_len(n_mels * n_freqs);
}
for k in 0..n_freqs {
assert_eq!(
out[n_freqs + k],
0.0,
"NEON collapsed row m=1 cell k={k} should be 0.0 (incl. tail)"
);
}
}
#[cfg(target_arch = "aarch64")]
#[test]
fn mel_filter_bank_neon_matches_handcomputed_triangle_with_tail() {
if !crate::simd::is_neon_available() {
return;
}
let n_mels = 1usize;
let all_freqs = vec![0.0_f32, 100.0, 200.0, 300.0, 400.0, 500.0]; let n_freqs = all_freqs.len();
let f_pts = vec![0.0_f32, 200.0, 400.0]; let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
unsafe {
super::mel_filter_bank_rows_neon(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
out.set_len(n_mels * n_freqs);
}
let expected = [0.0_f32, 0.5, 1.0, 0.5, 0.0, 0.0];
for (i, (&got, &want)) in out.iter().zip(expected.iter()).enumerate() {
assert!(
(got - want).abs() <= 1e-6,
"NEON hand-computed triangle mismatch i={i}: got={got} want={want}"
);
}
}
#[cfg(target_arch = "aarch64")]
#[test]
#[should_panic(expected = "mel_filter_bank_rows_neon: out.len() (3) must equal n_mels * n_freqs")]
fn mel_filter_bank_neon_panics_on_size_mismatch() {
if !crate::simd::is_neon_available() {
panic!(
"mel_filter_bank_rows_neon: out.len() (3) must equal n_mels * n_freqs (skipped — NEON unavailable)"
);
}
let all_freqs = vec![100.0_f32, 200.0, 300.0, 400.0]; let f_pts = vec![0.0_f32, 200.0, 400.0, 600.0]; let mut out: Vec<f32> = Vec::with_capacity(3); let spare = out.spare_capacity_mut();
unsafe { super::mel_filter_bank_rows_neon(&mut spare[..3], &all_freqs, &f_pts, 2) };
}
#[cfg(target_arch = "aarch64")]
#[test]
#[should_panic(expected = "mel_filter_bank_rows_neon: f_pts.len() (5) must equal n_mels + 2 (4)")]
fn mel_filter_bank_neon_panics_on_f_pts_size_mismatch() {
if !crate::simd::is_neon_available() {
panic!(
"mel_filter_bank_rows_neon: f_pts.len() (5) must equal n_mels + 2 (4) (skipped — NEON unavailable)"
);
}
let all_freqs = vec![100.0_f32, 200.0, 300.0]; let f_pts = vec![0.0_f32, 200.0, 400.0, 600.0, 800.0]; let mut out: Vec<f32> = Vec::with_capacity(6);
let spare = out.spare_capacity_mut();
unsafe { super::mel_filter_bank_rows_neon(&mut spare[..6], &all_freqs, &f_pts, 2) };
}
#[cfg(target_arch = "aarch64")]
#[test]
#[should_panic(expected = "overflow usize")]
fn mel_filter_bank_neon_panics_on_dimension_overflow() {
if !crate::simd::is_neon_available() {
panic!("dimensions overflow usize (skipped — NEON unavailable)");
}
let n_mels = usize::MAX / 4 + 1;
let all_freqs = vec![0.0_f32; 4]; let f_pts = vec![0.0_f32; 4];
let mut out: Vec<f32> = Vec::new();
let spare = out.spare_capacity_mut();
unsafe { super::mel_filter_bank_rows_neon(spare, &all_freqs, &f_pts, n_mels) };
}
#[test]
fn mel_filter_bank_dispatch_collapsed_row_with_tail_is_zero() {
let n_mels = 3usize;
let n_freqs = 6usize; let all_freqs: Vec<f32> = (0..n_freqs).map(|k| 100.0 * k as f32).collect();
let f_pts = vec![0.0_f32, 500.0, 500.0, 1500.0, 2000.0]; let mut out: Vec<f32> = Vec::with_capacity(n_mels * n_freqs);
let spare = out.spare_capacity_mut();
mel_filter_bank_rows(&mut spare[..n_mels * n_freqs], &all_freqs, &f_pts, n_mels);
unsafe { out.set_len(n_mels * n_freqs) };
for k in 0..n_freqs {
assert_eq!(
out[n_freqs + k],
0.0,
"dispatch collapsed row m=1 cell k={k} should be 0.0"
);
}
}
}