use smol_str::format_smolstr;
use std::f32::consts::PI;
use crate::{
Array, Error, Result,
error::{
AllocFailurePayload, ArithmeticOverflowPayload, CapExceededPayload, EmptyInputPayload,
InvariantViolationPayload, OutOfRangePayload, RankMismatchPayload,
},
ops::{
self,
fft::{self, FftNorm},
},
};
const KALDI_MEL_SCALE: f32 = 1127.0;
const KALDI_MEL_HZ_BREAK: f32 = 700.0;
const KALDI_FBANK_LOG_FLOOR: f32 = 1e-8;
const MAX_FBANK_WORK: usize = 64 * 1024 * 1024;
const MAX_DELTA_WIN_LENGTH: usize = 1024;
const MAX_DELTA_WORK: usize = 512 * 1024 * 1024;
#[inline]
#[must_use]
pub fn mel_scale_kaldi(hz: f32) -> f32 {
KALDI_MEL_SCALE * (1.0 + hz / KALDI_MEL_HZ_BREAK).ln()
}
#[inline]
#[must_use]
pub fn inverse_mel_scale_kaldi(mel: f32) -> f32 {
KALDI_MEL_HZ_BREAK * ((mel / KALDI_MEL_SCALE).exp() - 1.0)
}
#[inline]
fn next_power_of_2(x: usize) -> usize {
if x == 0 {
1
} else {
x.next_power_of_two()
}
}
pub fn get_mel_banks_kaldi(
num_bins: usize,
n_fft_padded: usize,
sample_freq: f32,
low_freq: f32,
high_freq: f32,
) -> Result<(Array, Array)> {
if num_bins <= 3 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: num_bins",
"must be > 3",
format!("{num_bins}"),
)));
}
if n_fft_padded == 0 || !n_fft_padded.is_multiple_of(2) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: n_fft_padded",
"must be a positive even number",
format!("{n_fft_padded}"),
)));
}
if !(sample_freq.is_finite() && sample_freq > 0.0) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: sample_freq",
"must be a finite value > 0.0",
format!("{sample_freq}"),
)));
}
let nyquist = 0.5 * sample_freq;
let high_freq = if high_freq <= 0.0 {
high_freq + nyquist
} else {
high_freq
};
if !(low_freq >= 0.0 && low_freq < nyquist) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: low_freq",
"must satisfy 0 <= low_freq < nyquist",
format_smolstr!("low_freq={low_freq}, nyquist={nyquist}"),
)));
}
if !(high_freq > 0.0 && high_freq <= nyquist) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: high_freq",
"must satisfy 0 < high_freq <= nyquist",
format_smolstr!("high_freq={high_freq}, nyquist={nyquist}"),
)));
}
if low_freq >= high_freq {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: low_freq",
"must be < high_freq",
format_smolstr!("low_freq={low_freq}, high_freq={high_freq}"),
)));
}
let num_fft_bins = n_fft_padded / 2; let bank_len = num_bins.checked_mul(num_fft_bins).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"get_mel_banks_kaldi: num_bins * num_fft_bins",
"usize",
[
("num_bins", num_bins as u64),
("num_fft_bins", num_fft_bins as u64),
],
))
})?;
if bank_len > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"get_mel_banks_kaldi: bank_len (= num_bins * num_fft_bins) exceeds work cap",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
bank_len as u64,
)));
}
let num_bins_i32 = i32::try_from(num_bins).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: num_bins",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_bins}"),
))
})?;
let num_fft_bins_i32 = i32::try_from(num_fft_bins).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"get_mel_banks_kaldi: num_fft_bins",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_fft_bins}"),
))
})?;
let fft_bin_width = sample_freq / n_fft_padded as f32;
let mel_low = mel_scale_kaldi(low_freq);
let mel_high = mel_scale_kaldi(high_freq);
let mel_delta = (mel_high - mel_low) / (num_bins as f32 + 1.0);
let mut bank: Vec<f32> = Vec::new();
bank.try_reserve_exact(bank_len).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"get_mel_banks_kaldi: bank reservation",
"f32 elements",
bank_len as u64,
e,
))
})?;
let mut centers: Vec<f32> = Vec::new();
centers.try_reserve_exact(num_bins).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"get_mel_banks_kaldi: centers reservation",
"f32 center freqs",
num_bins as u64,
e,
))
})?;
for m in 0..num_bins {
let center_mel = mel_low + ((m + 1) as f32) * mel_delta;
centers.push(inverse_mel_scale_kaldi(center_mel));
}
let spare = bank.spare_capacity_mut();
crate::simd::audio::kaldi_mel::get_mel_banks_kaldi_rows(
&mut spare[..bank_len],
num_bins,
num_fft_bins,
fft_bin_width,
mel_low,
mel_delta,
)?;
unsafe { bank.set_len(bank_len) };
let bins = Array::from_slice::<f32>(&bank, &[num_bins_i32, num_fft_bins_i32])?;
let center_freqs = Array::from_slice::<f32>(¢ers, &[num_bins_i32])?;
Ok((bins, center_freqs))
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Default, derive_more::Display, derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
pub enum KaldiWindow {
#[default]
Hamming,
Hanning,
Povey,
Rectangular,
}
impl KaldiWindow {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Hamming => "hamming",
Self::Hanning => "hanning",
Self::Povey => "povey",
Self::Rectangular => "rectangular",
}
}
}
fn build_kaldi_window(win_type: KaldiWindow, win_size: usize) -> Result<Array> {
if win_size < 2 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"build_kaldi_window: win_size",
"must be >= 2",
format!("{win_size}"),
)));
}
let win_i32 = i32::try_from(win_size).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"build_kaldi_window: win_size",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{win_size}"),
))
})?;
let buf: Vec<f32> = match win_type {
KaldiWindow::Hamming => crate::simd::audio::window::kaldi_window(
crate::simd::audio::window::KaldiWindowKind::Hamming,
win_size,
)?,
KaldiWindow::Hanning => crate::simd::audio::window::kaldi_window(
crate::simd::audio::window::KaldiWindowKind::Hanning,
win_size,
)?,
KaldiWindow::Rectangular => crate::simd::audio::window::kaldi_window(
crate::simd::audio::window::KaldiWindowKind::Rectangular,
win_size,
)?,
KaldiWindow::Povey => {
let mut buf: Vec<f32> = Vec::new();
buf.try_reserve_exact(win_size).map_err(|e| {
Error::AllocFailure(AllocFailurePayload::new(
"build_kaldi_window: Povey reservation",
"f32 elements",
win_size as u64,
e,
))
})?;
let denom = (win_size - 1) as f32;
for n in 0..win_size {
let theta = 2.0 * PI * (n as f32) / denom;
buf.push((0.5 - 0.5 * theta.cos()).powf(0.85));
}
buf
}
};
Array::from_slice::<f32>(&buf, &[win_i32])
}
fn strided_frames_snip_edges(
waveform: &Array,
win_size: usize,
win_inc: usize,
num_frames: usize,
) -> Result<Array> {
let last_index = (num_frames - 1)
.checked_mul(win_inc)
.and_then(|v| v.checked_add(win_size))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"strided_frames_snip_edges: reachable element range \
((num_frames - 1) * win_inc + win_size)",
"usize",
[
("num_frames", num_frames as u64),
("win_inc", win_inc as u64),
("win_size", win_size as u64),
],
))
})?;
let waveform_len = waveform.shape()[0];
if last_index > waveform_len {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_snip_edges: derived frame reach \
(internal invariant violated)",
"must be <= waveform.len()",
format_smolstr!(
"last_index={last_index}, waveform_len={waveform_len}, num_frames={num_frames}, \
win_inc={win_inc}, win_size={win_size}"
),
)));
}
let num_frames_i32 = i32::try_from(num_frames).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_snip_edges: num_frames",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_frames}"),
))
})?;
let win_size_i32 = i32::try_from(win_size).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_snip_edges: win_size",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{win_size}"),
))
})?;
let win_inc_i64 = i64::try_from(win_inc).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_snip_edges: win_inc",
"must fit in i64 (i64::MAX = 9223372036854775807)",
format_smolstr!("{win_inc}"),
))
})?;
let shape: &[i32] = &[num_frames_i32, win_size_i32];
unsafe { ops::shape::as_strided(waveform, &shape, &[win_inc_i64, 1], 0) }
}
fn reverse_1d(a: &Array) -> Result<Array> {
let shape = a.shape();
if shape.len() != 1 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"reverse_1d: expected 1-D input",
shape.len() as u32,
shape,
)));
}
let len = shape[0];
if len == 0 {
return Err(Error::EmptyInput(EmptyInputPayload::new(
"reverse_1d: array",
)));
}
let len_i32 = i32::try_from(len).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"reverse_1d: len",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{len}"),
))
})?;
let sentinel_i64 = -(i64::from(len_i32) + 1);
let stop = i32::try_from(sentinel_i64).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"reverse_1d: reverse sentinel -(len + 1)",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{sentinel_i64}"),
))
})?;
ops::indexing::slice(a, &[len_i32 - 1], &[stop], &[-1])
}
fn strided_frames_no_snip_edges(
waveform: &Array,
win_size: usize,
win_inc: usize,
num_frames: usize,
) -> Result<Array> {
let shape = waveform.shape();
if shape.len() != 1 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"strided_frames_no_snip_edges: expected 1-D waveform",
shape.len() as u32,
shape,
)));
}
let n = waveform.shape()[0];
let n_i32 = i32::try_from(n).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: waveform len",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n}"),
))
})?;
if num_frames == 0 {
return Array::zeros::<f32>(&[0_i32, 0_i32]);
}
let pad_i64 = (win_size as i64) / 2 - (win_inc as i64) / 2;
fn cap_reflected_len(seg_lens: &[usize], n: usize, pad: i64) -> Result<()> {
let mut reflected_len: usize = 0;
for &seg in seg_lens {
reflected_len = reflected_len.checked_add(seg).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"strided_frames_no_snip_edges: reflect-padded length \
(sum of concatenated segment lengths)",
"usize",
[("n", n as u64), ("pad", pad as u64), ("seg", seg as u64)],
))
})?;
}
if reflected_len > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"strided_frames_no_snip_edges: reflect-padded buffer length exceeds work cap \
(snip_edges=false reflect bookends would more than double the waveform's memory)",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
reflected_len as u64,
)));
}
Ok(())
}
let padded = if pad_i64 > 0 {
let pad = pad_i64 as usize;
if n < pad + 1 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: waveform len for reflect-pad \
(win_size/win_inc imply more reflection than the signal supports)",
"must be >= pad + 1",
format_smolstr!("n={n}, pad={pad}"),
)));
}
let pad_i32 = i32::try_from(pad).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: pad",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{pad}"),
))
})?;
let left_lo = 1_i32;
let left_hi = pad_i32 + 1;
let left_len = (left_hi - left_lo) as usize; let (right_lo, right_hi) = if pad > 1 {
(n_i32 - pad_i32, n_i32)
} else {
(1_i32, n_i32)
};
let right_len = (right_hi - right_lo) as usize; cap_reflected_len(&[left_len, n, right_len], n, pad_i64)?;
let left_seg = ops::indexing::slice(waveform, &[left_lo], &[left_hi], &[1_i32])?;
let pad_left = reverse_1d(&left_seg)?;
let right_seg = ops::indexing::slice(waveform, &[right_lo], &[right_hi], &[1_i32])?;
let pad_right = reverse_1d(&right_seg)?;
ops::shape::concatenate(&[&pad_left, waveform, &pad_right], 0)?
} else {
let abs_pad = (-pad_i64) as usize;
if abs_pad > n {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: |pad| for snip_edges=false buffer \
(win_inc too large relative to win_size)",
"must be <= waveform len",
format_smolstr!("abs_pad={abs_pad}, n={n}"),
)));
}
let abs_pad_i32 = i32::try_from(abs_pad).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: |pad|",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{abs_pad}"),
))
})?;
let head_len = n - abs_pad;
cap_reflected_len(&[head_len, n], n, pad_i64)?;
let head = ops::indexing::slice(waveform, &[abs_pad_i32], &[n_i32], &[1_i32])?;
let rev = reverse_1d(waveform)?;
ops::shape::concatenate(&[&head, &rev], 0)?
};
let padded_len = padded.shape()[0];
let last_index = (num_frames - 1)
.checked_mul(win_inc)
.and_then(|v| v.checked_add(win_size))
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"strided_frames_no_snip_edges: reachable element range \
((num_frames - 1) * win_inc + win_size)",
"usize",
[
("num_frames", num_frames as u64),
("win_inc", win_inc as u64),
("win_size", win_size as u64),
],
))
})?;
if last_index > padded_len {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: strided read end \
(win_size too large relative to signal length for centered snip_edges=false framing; \
reference would read out of bounds)",
"must be <= reflect-padded length",
format_smolstr!(
"last_index={last_index}, padded_len={padded_len}, num_frames={num_frames}, \
win_inc={win_inc}, win_size={win_size}, waveform_len={n}"
),
)));
}
let num_frames_i32 = i32::try_from(num_frames).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: num_frames",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_frames}"),
))
})?;
let win_size_i32 = i32::try_from(win_size).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: win_size",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{win_size}"),
))
})?;
let win_inc_i64 = i64::try_from(win_inc).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"strided_frames_no_snip_edges: win_inc",
"must fit in i64 (i64::MAX = 9223372036854775807)",
format_smolstr!("{win_inc}"),
))
})?;
let view_shape: &[i32] = &[num_frames_i32, win_size_i32];
unsafe { ops::shape::as_strided(&padded, &view_shape, &[win_inc_i64, 1], 0) }
}
#[allow(clippy::too_many_arguments)]
pub fn compute_fbank_kaldi(
waveform: &Array,
sample_rate: u32,
win_len: usize,
win_inc: usize,
num_mels: usize,
win_type: KaldiWindow,
preemphasis: f32,
dither: f32,
snip_edges: bool,
low_freq: f32,
high_freq: f32,
dither_key: Option<&Array>,
) -> Result<Array> {
let shape = waveform.shape();
if shape.len() != 1 {
let rank = shape.len() as u32;
return Err(Error::RankMismatch(RankMismatchPayload::new(
"compute_fbank_kaldi: expected 1-D waveform",
rank,
shape,
)));
}
if sample_rate == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"compute_fbank_kaldi: sample_rate",
"must be > 0",
)));
}
if win_len < 2 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: win_len",
"must be >= 2",
format!("{win_len}"),
)));
}
if win_inc == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"compute_fbank_kaldi: win_inc",
"must be > 0",
)));
}
if win_len > crate::audio::io::MAX_DECODED_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_fbank_kaldi: win_len exceeds cap",
"MAX_DECODED_SAMPLES",
crate::audio::io::MAX_DECODED_SAMPLES as u64,
win_len as u64,
)));
}
if !dither.is_finite() || dither < 0.0 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: dither",
"must be finite and >= 0.0",
format!("{dither}"),
)));
}
if !(preemphasis.is_finite() && (0.0..=1.0).contains(&preemphasis)) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: preemphasis",
"must be a finite float in [0.0, 1.0]",
format!("{preemphasis}"),
)));
}
if dither != 0.0 && dither_key.is_none() {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"compute_fbank_kaldi: dither_key when dither != 0.0 \
(use crate::ops::random::key(seed) or pass dither=0.0 to disable; \
the Python reference's implicit-default-key behavior is deliberately not mirrored — \
explicit keys make dithered features reproducible)",
"must be Some(_) when dither != 0.0",
)));
}
let samples_len = shape[0];
if samples_len > crate::audio::io::MAX_DECODED_SAMPLES {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_fbank_kaldi: samples_len exceeds cap \
(rejecting BEFORE `contiguous` would materialize the logical extent — \
a broadcasted-view input could otherwise drive a multi-GB allocation at eval time)",
"MAX_DECODED_SAMPLES",
crate::audio::io::MAX_DECODED_SAMPLES as u64,
samples_len as u64,
)));
}
let num_mels_i32 = i32::try_from(num_mels).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: num_mels",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_mels}"),
))
})?;
let num_frames = if snip_edges {
if samples_len < win_len {
return Array::zeros::<f32>(&[0_i32, num_mels_i32]);
}
1 + (samples_len - win_len) / win_inc
} else {
let m = (samples_len + win_inc / 2) / win_inc;
if m == 0 {
return Array::zeros::<f32>(&[0_i32, num_mels_i32]);
}
m
};
let n_fft_padded = next_power_of_2(win_len);
let frame_work = num_frames.checked_mul(n_fft_padded).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"compute_fbank_kaldi: frame work num_frames * n_fft_padded",
"usize",
[
("num_frames", num_frames as u64),
("n_fft_padded", n_fft_padded as u64),
],
))
})?;
if frame_work > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_fbank_kaldi: frame work (= num_frames * n_fft_padded) exceeds work cap",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
frame_work as u64,
)));
}
let out_elems = num_frames
.checked_mul(n_fft_padded / 2 + 1)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"compute_fbank_kaldi: rfft output element count num_frames * (n_fft_padded/2 + 1)",
"usize",
[
("num_frames", num_frames as u64),
("n_fft_padded", n_fft_padded as u64),
],
))
})?;
if out_elems > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_fbank_kaldi: rfft output element count exceeds work cap",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
out_elems as u64,
)));
}
let output_elems = num_frames.checked_mul(num_mels).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"compute_fbank_kaldi: output element count num_frames * num_mels",
"usize",
[
("num_frames", num_frames as u64),
("num_mels", num_mels as u64),
],
))
})?;
if output_elems > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_fbank_kaldi: output element count (= num_frames * num_mels) exceeds work cap",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
output_elems as u64,
)));
}
let mel_padded_elems = num_mels.checked_mul(n_fft_padded / 2 + 1).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"compute_fbank_kaldi: padded mel-bank element count num_mels * (n_fft_padded/2 + 1)",
"usize",
[
("num_mels", num_mels as u64),
("n_fft_padded", n_fft_padded as u64),
],
))
})?;
if mel_padded_elems > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_fbank_kaldi: padded mel-bank element count \
(= num_mels * (n_fft_padded/2 + 1)) exceeds work cap",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
mel_padded_elems as u64,
)));
}
let n_fft_padded_i32 = i32::try_from(n_fft_padded).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: n_fft_padded",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n_fft_padded}"),
))
})?;
let win_len_i32 = i32::try_from(win_len).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: win_len",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{win_len}"),
))
})?;
let num_frames_i32 = i32::try_from(num_frames).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: num_frames",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_frames}"),
))
})?;
let waveform_contig = ops::shape::contiguous(waveform, false)?;
let strided = if snip_edges {
strided_frames_snip_edges(&waveform_contig, win_len, win_inc, num_frames)?
} else {
strided_frames_no_snip_edges(&waveform_contig, win_len, win_inc, num_frames)?
};
let dithered = if dither > 0.0 {
let key = dither_key.expect("dither != 0.0 was checked to require a key above");
let shape: &[i32] = &[num_frames_i32, win_len_i32];
let noise = ops::random::normal(&shape, crate::Dtype::F32, 0.0, dither, key)?;
ops::arithmetic::add(&strided, &noise)?
} else {
strided
};
let row_means = ops::reduction::mean_axes(&dithered, &[1], true)?;
let centered = ops::arithmetic::subtract(&dithered, &row_means)?;
let preemphasized = if preemphasis > 0.0 {
let first_col = ops::indexing::slice(
¢ered,
&[0_i32, 0_i32],
&[num_frames_i32, 1_i32],
&[1_i32, 1_i32],
)?;
let rest = ops::indexing::slice(
¢ered,
&[0_i32, 1_i32],
&[num_frames_i32, win_len_i32],
&[1_i32, 1_i32],
)?;
let prev = ops::indexing::slice(
¢ered,
&[0_i32, 0_i32],
&[num_frames_i32, win_len_i32 - 1],
&[1_i32, 1_i32],
)?;
let p_arr = Array::full::<f32>(&[0_i32; 0], preemphasis)?;
let scaled_prev = ops::arithmetic::multiply(&prev, &p_arr)?;
let other_cols = ops::arithmetic::subtract(&rest, &scaled_prev)?;
let one_minus_p = Array::full::<f32>(&[0_i32; 0], 1.0 - preemphasis)?;
let first_col_scaled = ops::arithmetic::multiply(&first_col, &one_minus_p)?;
ops::shape::concatenate(&[&first_col_scaled, &other_cols], 1)?
} else {
centered
};
let window = build_kaldi_window(win_type, win_len)?;
let windowed = ops::arithmetic::multiply(&preemphasized, &window)?;
let padded = if n_fft_padded != win_len {
let pad_extent = i32::try_from(n_fft_padded - win_len).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_fbank_kaldi: pad extent (n_fft_padded - win_len)",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{}", n_fft_padded - win_len),
))
})?;
let pad_value = Array::zeros::<f32>(&[0_i32; 0])?;
ops::shape::pad(
&windowed,
&[1_i32],
&[0_i32],
&[pad_extent],
&pad_value,
c"constant",
)?
} else {
windowed
};
let spectrum = fft::rfft(&padded, n_fft_padded_i32, 1, FftNorm::Backward)?;
let power = spectrum.abs()?.square()?;
let (mel_bank, _centers) = get_mel_banks_kaldi(
num_mels,
n_fft_padded,
sample_rate as f32,
low_freq,
high_freq,
)?;
let pad_value = Array::zeros::<f32>(&[0_i32; 0])?;
let mel_padded = ops::shape::pad(
&mel_bank,
&[1_i32],
&[0_i32],
&[1_i32],
&pad_value,
c"constant",
)?;
let mel_t = mel_padded.transpose()?;
let mel_features = ops::linalg_basic::matmul(&power, &mel_t)?;
let floor = Array::full::<f32>(&[0_i32; 0], KALDI_FBANK_LOG_FLOOR)?;
let floored = ops::arithmetic::maximum(&mel_features, &floor)?;
floored.log()
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Default, derive_more::Display, derive_more::IsVariant,
)]
#[display("{}", self.as_str())]
pub enum DeltaPadMode {
#[default]
Edge,
Constant,
}
impl DeltaPadMode {
pub const fn as_str(&self) -> &'static str {
match self {
Self::Edge => "edge",
Self::Constant => "constant",
}
}
}
pub fn compute_deltas_kaldi(
specgram: &Array,
win_length: usize,
mode: DeltaPadMode,
) -> Result<Array> {
if win_length < 3 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"compute_deltas_kaldi: win_length",
"must be >= 3",
format!("{win_length}"),
)));
}
if win_length.is_multiple_of(2) {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"compute_deltas_kaldi: win_length",
"must be odd (an even win_length would silently truncate to the next-lower odd window)",
format!("{win_length}"),
)));
}
if win_length > MAX_DELTA_WIN_LENGTH {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_deltas_kaldi: win_length exceeds supported maximum \
(delta windows are tiny — the default is 5)",
"MAX_DELTA_WIN_LENGTH",
MAX_DELTA_WIN_LENGTH as u64,
win_length as u64,
)));
}
let orig_shape = specgram.shape();
if orig_shape.is_empty() {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"compute_deltas_kaldi: specgram must have rank >= 1 (a time axis)",
0,
Vec::new(),
)));
}
let time = orig_shape[orig_shape.len() - 1];
let total = specgram.size();
if total == 0 {
return Array::zeros::<f32>(&orig_shape.as_slice());
}
if total > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_deltas_kaldi: element count exceeds work cap",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
total as u64,
)));
}
let num_features = total / time;
let n = (win_length - 1) / 2;
let denom = (n as f64 * (n + 1) as f64 * (2 * n + 1) as f64) / 3.0;
let denom_f32 = denom as f32;
let padded_time = time.checked_add(2 * n).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"compute_deltas_kaldi: padded time (time + 2n)",
"usize",
[("time", time as u64), ("n", n as u64)],
))
})?;
let padded_work = num_features.checked_mul(padded_time).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"compute_deltas_kaldi: padded work num_features * (time + 2n)",
"usize",
[
("num_features", num_features as u64),
("padded_time", padded_time as u64),
],
))
})?;
if padded_work > MAX_FBANK_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_deltas_kaldi: padded element count (= num_features * (time + 2n)) exceeds work cap",
"MAX_FBANK_WORK",
MAX_FBANK_WORK as u64,
padded_work as u64,
)));
}
let delta_work = total.checked_mul(win_length - 1).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"compute_deltas_kaldi: accumulation work total * (win_length - 1)",
"usize",
[("total", total as u64), ("win_length", win_length as u64)],
))
})?;
if delta_work > MAX_DELTA_WORK {
return Err(Error::CapExceeded(CapExceededPayload::new(
"compute_deltas_kaldi: accumulation work (= total * (win_length - 1)) exceeds work cap \
(the delta loop runs win_length - 1 full-width passes over the spectrogram)",
"MAX_DELTA_WORK",
MAX_DELTA_WORK as u64,
delta_work as u64,
)));
}
let _padded_time_i32 = i32::try_from(padded_time).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_deltas_kaldi: padded_time",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{padded_time}"),
))
})?;
let num_features_i32 = i32::try_from(num_features).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_deltas_kaldi: num_features",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{num_features}"),
))
})?;
let time_i32 = i32::try_from(time).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_deltas_kaldi: time",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{time}"),
))
})?;
let flat = ops::shape::reshape(specgram, &(num_features, time))?;
let n_i32 = i32::try_from(n).map_err(|_| {
Error::OutOfRange(OutOfRangePayload::new(
"compute_deltas_kaldi: pad extent n",
"must fit in i32 (i32::MAX = 2147483647)",
format_smolstr!("{n}"),
))
})?;
let padded = match mode {
DeltaPadMode::Constant => {
let pad_value = Array::zeros::<f32>(&[0_i32; 0])?;
ops::shape::pad(&flat, &[1_i32], &[n_i32], &[n_i32], &pad_value, c"constant")?
}
DeltaPadMode::Edge => {
let first_col = ops::indexing::slice(
&flat,
&[0_i32, 0_i32],
&[num_features_i32, 1_i32],
&[1_i32, 1_i32],
)?;
let last_col = ops::indexing::slice(
&flat,
&[0_i32, time_i32 - 1],
&[num_features_i32, time_i32],
&[1_i32, 1_i32],
)?;
let pad_left = ops::shape::broadcast_to(&first_col, &(num_features, n))?;
let pad_right = ops::shape::broadcast_to(&last_col, &(num_features, n))?;
ops::shape::concatenate(&[&pad_left, &flat, &pad_right], 1)?
}
};
let mut acc = Array::zeros::<f32>(&[num_features_i32, time_i32])?;
for k in -(n as isize)..=(n as isize) {
if k == 0 {
continue;
}
let start = (n as isize + k) as i32; let stop = start + time_i32; let shifted = ops::indexing::slice(
&padded,
&[0_i32, start],
&[num_features_i32, stop],
&[1_i32, 1_i32],
)?;
let weight = Array::full::<f32>(&[0_i32; 0], k as f32)?;
let weighted = ops::arithmetic::multiply(&shifted, &weight)?;
acc = ops::arithmetic::add(&acc, &weighted)?;
}
let denom_arr = Array::full::<f32>(&[0_i32; 0], denom_f32)?;
let deltas = ops::arithmetic::divide(&acc, &denom_arr)?;
ops::shape::reshape(&deltas, &orig_shape.as_slice())
}
#[cfg(test)]
mod tests;