use rustfft::num_complex::Complex;
use rustfft::num_traits::Zero;
use rustfft::{Fft, FftPlanner};
use std::sync::Arc;
const N_FFT: usize = 400;
const HOP_LENGTH: usize = 160;
const N_MELS: usize = 80;
const NUM_FREQ_BINS: usize = N_FFT / 2 + 1; pub(crate) const N_SAMPLES: usize = 16_000 * 8; const PAD_SIZE: usize = N_FFT / 2; const PADDED_LENGTH: usize = N_SAMPLES + N_FFT; const NUM_FRAMES: usize = 1 + (PADDED_LENGTH - N_FFT) / HOP_LENGTH; const OUTPUT_FRAMES: usize = NUM_FRAMES - 1;
const MEL_FILTERS_BYTES: &[u8] = include_bytes!("mel_filters_80x201_f32.bin");
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Precision {
F32,
F64,
}
struct SparseMelEntry<F> {
mel: u8,
weight: F,
}
struct SparseMelFilter<F> {
offsets: [u16; NUM_FREQ_BINS + 1],
entries: Vec<SparseMelEntry<F>>,
}
fn build_sparse_mel_f32() -> SparseMelFilter<f32> {
let dense: Vec<f32> = MEL_FILTERS_BYTES
.chunks_exact(4)
.map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]]))
.collect();
assert_eq!(dense.len(), NUM_FREQ_BINS * N_MELS);
let mut offsets = [0u16; NUM_FREQ_BINS + 1];
let mut entries = Vec::with_capacity(NUM_FREQ_BINS * 3);
for freq in 0..NUM_FREQ_BINS {
offsets[freq] = entries.len() as u16;
let base = freq * N_MELS;
for mel in 0..N_MELS {
let w = dense[base + mel];
if w != 0.0 {
entries.push(SparseMelEntry {
mel: mel as u8,
weight: w,
});
}
}
}
offsets[NUM_FREQ_BINS] = entries.len() as u16;
SparseMelFilter { offsets, entries }
}
fn build_sparse_mel_f64() -> SparseMelFilter<f64> {
let f32v = build_sparse_mel_f32();
SparseMelFilter {
offsets: f32v.offsets,
entries: f32v
.entries
.iter()
.map(|e| SparseMelEntry {
mel: e.mel,
weight: e.weight as f64,
})
.collect(),
}
}
#[inline(always)]
fn fast_log10_f32(x: f32) -> f32 {
let bits = x.to_bits();
let exponent = ((bits >> 23) & 0xFF) as f32 - 127.0;
let mantissa = f32::from_bits((bits & 0x007F_FFFF) | 0x3F80_0000); let log2 = exponent
+ (-1.4927_8
+ mantissa * (2.1126_4 + mantissa * (-0.7291_04 + mantissa * 0.1096_9)));
log2 * 0.301_029_995_7 }
pub(crate) struct WhisperFeatureExtractor {
inner: InnerState,
}
impl WhisperFeatureExtractor {
pub fn new(precision: Precision) -> Self {
Self {
inner: match precision {
Precision::F32 => InnerState::F32(F32State::new()),
Precision::F64 => InnerState::F64(F64State::new()),
},
}
}
pub fn extract(&mut self, audio: &[f32]) -> &[f32] {
assert_eq!(audio.len(), N_SAMPLES);
match &mut self.inner {
InnerState::F32(s) => s.extract(audio),
InnerState::F64(s) => s.extract(audio),
}
}
pub fn precision(&self) -> Precision {
match &self.inner {
InnerState::F32(_) => Precision::F32,
InnerState::F64(_) => Precision::F64,
}
}
}
enum InnerState {
F32(F32State),
F64(F64State),
}
struct F32State {
hann_window: [f32; N_FFT],
sparse_mel: SparseMelFilter<f32>,
fft: Arc<dyn Fft<f32>>,
padded: Vec<f32>,
fft_buffer: Vec<Complex<f32>>,
mel_spec: Vec<f32>,
output: Vec<f32>,
}
impl F32State {
fn new() -> Self {
let mut hann_window = [0.0f32; N_FFT];
for i in 0..N_FFT {
hann_window[i] = (0.5
* (1.0 - (2.0 * std::f64::consts::PI * i as f64 / N_FFT as f64).cos()))
as f32;
}
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(N_FFT);
Self {
hann_window,
sparse_mel: build_sparse_mel_f32(),
fft,
padded: vec![0.0; PADDED_LENGTH],
fft_buffer: vec![Complex::zero(); N_FFT],
mel_spec: vec![0.0; N_MELS * NUM_FRAMES],
output: vec![0.0; N_MELS * OUTPUT_FRAMES],
}
}
fn extract(&mut self, audio: &[f32]) -> &[f32] {
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
for &x in audio {
let v = x as f64;
sum += v;
sum_sq += v * v;
}
let n = audio.len() as f64;
let mean = (sum / n) as f32;
let variance = ((sum_sq / n) - (sum / n) * (sum / n)) as f32;
let inv_std = 1.0f32 / (variance.sqrt() + 1e-7);
unsafe {
let p = self.padded.as_mut_ptr();
let a = audio.as_ptr();
for i in 0..PAD_SIZE {
*p.add(i) = (*a.add(PAD_SIZE - i) - mean) * inv_std;
}
for i in 0..N_SAMPLES {
*p.add(PAD_SIZE + i) = (*a.add(i) - mean) * inv_std;
}
for i in 0..PAD_SIZE {
*p.add(PAD_SIZE + N_SAMPLES + i) = (*a.add(N_SAMPLES - 2 - i) - mean) * inv_std;
}
}
self.mel_spec.fill(0.0);
let offsets = &self.sparse_mel.offsets;
let entries = &self.sparse_mel.entries;
for frame_idx in 0..NUM_FRAMES {
let start = frame_idx * HOP_LENGTH;
unsafe {
let pp = self.padded.as_ptr().add(start);
let hp = self.hann_window.as_ptr();
let fp = self.fft_buffer.as_mut_ptr();
for i in 0..N_FFT {
*fp.add(i) = Complex::new(*pp.add(i) * *hp.add(i), 0.0);
}
}
self.fft.process(&mut self.fft_buffer);
unsafe {
let fb = self.fft_buffer.as_ptr();
let ms = self.mel_spec.as_mut_ptr();
let ep = entries.as_ptr();
for freq in 0..NUM_FREQ_BINS {
let c = &*fb.add(freq);
let power = c.re * c.re + c.im * c.im;
let s = *offsets.get_unchecked(freq) as usize;
let e = *offsets.get_unchecked(freq + 1) as usize;
for ei in s..e {
let entry = &*ep.add(ei);
let mel = entry.mel as usize;
*ms.add(mel * NUM_FRAMES + frame_idx) += entry.weight * power;
}
}
}
}
let mel_floor: f32 = 1e-10;
let mut global_max = f32::NEG_INFINITY;
for mel in 0..N_MELS {
let src_base = mel * NUM_FRAMES;
let dst_base = mel * OUTPUT_FRAMES;
for frame in 0..OUTPUT_FRAMES {
unsafe {
let mut v = *self.mel_spec.get_unchecked(src_base + frame);
if v < mel_floor {
v = mel_floor;
}
let log_v = fast_log10_f32(v);
*self.output.get_unchecked_mut(dst_base + frame) = log_v;
if log_v > global_max {
global_max = log_v;
}
}
}
}
let floor = global_max - 8.0;
let scale_inv = 0.25; for v in self.output.iter_mut() {
let clamped = if *v < floor { floor } else { *v };
*v = (clamped + 4.0) * scale_inv;
}
&self.output
}
}
struct F64State {
hann_window: [f64; N_FFT],
sparse_mel: SparseMelFilter<f64>,
fft: Arc<dyn Fft<f64>>,
padded: Vec<f64>,
fft_buffer: Vec<Complex<f64>>,
mel_spec: Vec<f64>,
output: Vec<f32>,
}
impl F64State {
fn new() -> Self {
let mut hann_window = [0.0f64; N_FFT];
for i in 0..N_FFT {
hann_window[i] =
0.5 * (1.0 - (2.0 * std::f64::consts::PI * i as f64 / N_FFT as f64).cos());
}
let mut planner = FftPlanner::<f64>::new();
let fft = planner.plan_fft_forward(N_FFT);
Self {
hann_window,
sparse_mel: build_sparse_mel_f64(),
fft,
padded: vec![0.0; PADDED_LENGTH],
fft_buffer: vec![Complex::zero(); N_FFT],
mel_spec: vec![0.0; N_MELS * NUM_FRAMES],
output: vec![0.0; N_MELS * OUTPUT_FRAMES],
}
}
fn extract(&mut self, audio: &[f32]) -> &[f32] {
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
for &x in audio {
let v = x as f64;
sum += v;
sum_sq += v * v;
}
let n = audio.len() as f64;
let mean = sum / n;
let variance = (sum_sq / n) - mean * mean;
let inv_std = 1.0 / (variance.sqrt() + 1e-7);
for i in 0..PAD_SIZE {
self.padded[i] = (audio[PAD_SIZE - i] as f64 - mean) * inv_std;
}
for i in 0..N_SAMPLES {
self.padded[PAD_SIZE + i] = (audio[i] as f64 - mean) * inv_std;
}
for i in 0..PAD_SIZE {
self.padded[PAD_SIZE + N_SAMPLES + i] =
(audio[N_SAMPLES - 2 - i] as f64 - mean) * inv_std;
}
self.mel_spec.fill(0.0);
let offsets = &self.sparse_mel.offsets;
let entries = &self.sparse_mel.entries;
for frame_idx in 0..NUM_FRAMES {
let start = frame_idx * HOP_LENGTH;
for i in 0..N_FFT {
self.fft_buffer[i] = Complex::new(
self.padded[start + i] * self.hann_window[i],
0.0,
);
}
self.fft.process(&mut self.fft_buffer);
for freq in 0..NUM_FREQ_BINS {
let re = self.fft_buffer[freq].re;
let im = self.fft_buffer[freq].im;
let power = re * re + im * im;
let s = offsets[freq] as usize;
let e = offsets[freq + 1] as usize;
for ei in s..e {
let entry = &entries[ei];
self.mel_spec[entry.mel as usize * NUM_FRAMES + frame_idx] +=
entry.weight * power;
}
}
}
let mel_floor: f64 = 1e-10;
let mut global_max = f32::NEG_INFINITY;
for mel in 0..N_MELS {
let src = mel * NUM_FRAMES;
let dst = mel * OUTPUT_FRAMES;
for frame in 0..OUTPUT_FRAMES {
let mut v = self.mel_spec[src + frame];
if v < mel_floor {
v = mel_floor;
}
let log_v = v.log10() as f32;
self.output[dst + frame] = log_v;
if log_v > global_max {
global_max = log_v;
}
}
}
let floor = global_max - 8.0;
for v in self.output.iter_mut() {
*v = (v.max(floor) + 4.0) * 0.25;
}
&self.output
}
}