#![allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use anyhow::Result;
use candle_core::{Device, Tensor};
use mistralrs_audio::AudioInput;
use rubato::Resampler;
use rustfft::{num_complex::Complex32, FftPlanner};
use super::config::AudioEncodingArgs;
const N_LEFT_PAD_TOKENS: usize = 32;
const N_RIGHT_PAD_TOKENS: usize = 17;
pub struct VoxtralAudioProcessor {
sampling_rate: u32,
frame_rate: f32,
num_mel_bins: usize,
hop_length: usize,
window_size: usize,
global_log_mel_max: f32,
}
impl VoxtralAudioProcessor {
pub fn new(cfg: &AudioEncodingArgs) -> Self {
Self {
sampling_rate: cfg.sampling_rate,
frame_rate: cfg.frame_rate as f32,
num_mel_bins: cfg.num_mel_bins,
hop_length: cfg.hop_length,
window_size: cfg.window_size,
global_log_mel_max: cfg.global_log_mel_max as f32,
}
}
pub fn new_from_processor(other: &Self) -> Self {
Self {
sampling_rate: other.sampling_rate,
frame_rate: other.frame_rate,
num_mel_bins: other.num_mel_bins,
hop_length: other.hop_length,
window_size: other.window_size,
global_log_mel_max: other.global_log_mel_max,
}
}
fn samples_per_token(&self) -> usize {
(self.sampling_rate as f32 / self.frame_rate) as usize
}
pub fn process_audio(&self, audio: &AudioInput, device: &Device) -> Result<Tensor> {
let mono = audio.to_mono();
let samples = if audio.sample_rate != self.sampling_rate {
self.resample(&mono, audio.sample_rate, self.sampling_rate)?
} else {
mono
};
let spt = self.samples_per_token();
let left_pad = N_LEFT_PAD_TOKENS * spt;
let right_pad = N_RIGHT_PAD_TOKENS * spt;
let mut padded = vec![0.0f32; left_pad + samples.len() + right_pad];
padded[left_pad..left_pad + samples.len()].copy_from_slice(&samples);
let mel = self.compute_mel_spectrogram(&padded)?;
let num_frames = mel.len();
if num_frames == 0 {
anyhow::bail!("Audio too short to produce mel frames");
}
let data: Vec<f32> = mel.into_iter().flatten().collect();
let tensor = Tensor::from_vec(data, (1, num_frames, self.num_mel_bins), device)?;
Ok(tensor)
}
fn resample(&self, samples: &[f32], from_rate: u32, to_rate: u32) -> Result<Vec<f32>> {
if from_rate == to_rate {
return Ok(samples.to_vec());
}
let sinc = rubato::SincInterpolationParameters {
sinc_len: 256,
f_cutoff: 0.95,
interpolation: rubato::SincInterpolationType::Linear,
oversampling_factor: 256,
window: rubato::WindowFunction::BlackmanHarris2,
};
let mut resampler = rubato::SincFixedIn::<f32>::new(
to_rate as f64 / from_rate as f64,
2.0,
sinc,
samples.len(),
1,
)?;
let result = resampler.process(&[samples.to_vec()], None)?;
Ok(result[0].clone())
}
fn compute_mel_spectrogram(&self, samples: &[f32]) -> Result<Vec<Vec<f32>>> {
let n_fft = self.window_size;
let hop = self.hop_length;
let n_freqs = n_fft / 2 + 1;
let pad = n_fft / 2;
if samples.is_empty() {
return Ok(Vec::new());
}
let padded_len = pad + samples.len() + pad;
let mut padded = vec![0.0f32; padded_len];
for (i, p) in padded.iter_mut().enumerate().take(pad) {
let src_idx = (pad - i).min(samples.len() - 1);
*p = samples[src_idx];
}
padded[pad..pad + samples.len()].copy_from_slice(samples);
for i in 0..pad {
let src_idx = samples.len().saturating_sub(2 + i);
padded[pad + samples.len() + i] = samples[src_idx];
}
let total_frames = (padded_len - n_fft) / hop + 1;
let num_frames = total_frames.saturating_sub(1);
let window: Vec<f32> = (0..n_fft)
.map(|n| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * n as f32 / n_fft as f32).cos()))
.collect();
let mel_filters = self.create_mel_filterbank(n_fft)?;
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(n_fft);
let mut mel_features = Vec::with_capacity(num_frames);
let log_mel_floor = self.global_log_mel_max - 8.0;
for frame_idx in 0..num_frames {
let start = frame_idx * hop;
let mut buf: Vec<Complex32> = padded[start..start + n_fft]
.iter()
.zip(window.iter())
.map(|(&s, &w)| Complex32::new(s * w, 0.0))
.collect();
fft.process(&mut buf);
let power: Vec<f32> = buf[..n_freqs].iter().map(|c| c.norm_sqr()).collect();
let mut mel_frame = vec![0.0f32; self.num_mel_bins];
for (mel_idx, filter) in mel_filters.iter().enumerate() {
let mut sum = 0.0f32;
for (freq_idx, &coeff) in filter.iter().enumerate() {
if freq_idx < power.len() {
sum += power[freq_idx] * coeff;
}
}
let log_val = sum.max(1e-10).log10();
let clamped = log_val.max(log_mel_floor);
mel_frame[mel_idx] = (clamped + 4.0) / 4.0;
}
mel_features.push(mel_frame);
}
Ok(mel_features)
}
fn hertz_to_mel(freq: f32) -> f32 {
const MIN_LOG_HERTZ: f32 = 1000.0;
const MIN_LOG_MEL: f32 = 15.0;
const LOGSTEP: f32 = 27.0 / 1.856_298; if freq >= MIN_LOG_HERTZ {
MIN_LOG_MEL + (freq / MIN_LOG_HERTZ).ln() * LOGSTEP
} else {
3.0 * freq / 200.0
}
}
fn mel_to_hertz(mel: f32) -> f32 {
const MIN_LOG_HERTZ: f32 = 1000.0;
const MIN_LOG_MEL: f32 = 15.0;
const LOGSTEP: f32 = 1.856_298 / 27.0; if mel >= MIN_LOG_MEL {
MIN_LOG_HERTZ * (LOGSTEP * (mel - MIN_LOG_MEL)).exp()
} else {
200.0 * mel / 3.0
}
}
fn create_mel_filterbank(&self, n_fft: usize) -> Result<Vec<Vec<f32>>> {
let n_freqs = n_fft / 2 + 1;
let sr = self.sampling_rate as f32;
let n_mels = self.num_mel_bins;
let fft_freqs: Vec<f32> = (0..n_freqs)
.map(|i| i as f32 * (sr / 2.0) / (n_freqs - 1) as f32)
.collect();
let mel_min = Self::hertz_to_mel(0.0);
let mel_max = Self::hertz_to_mel(sr / 2.0);
let filter_freqs: Vec<f32> = (0..n_mels + 2)
.map(|i| {
let mel = mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32;
Self::mel_to_hertz(mel)
})
.collect();
let filter_diff: Vec<f32> = filter_freqs.windows(2).map(|w| w[1] - w[0]).collect();
let mut filterbank = vec![vec![0.0f32; n_freqs]; n_mels];
for m in 0..n_mels {
for (j, &fft_f) in fft_freqs.iter().enumerate() {
let slope_left = fft_f - filter_freqs[m];
let slope_right = filter_freqs[m + 2] - fft_f;
let down = slope_left / filter_diff[m]; let up = slope_right / filter_diff[m + 1]; filterbank[m][j] = 0.0f32.max(down.min(up));
}
}
for m in 0..n_mels {
let enorm = 2.0 / (filter_freqs[m + 2] - filter_freqs[m]);
for val in filterbank[m].iter_mut().take(n_freqs) {
*val *= enorm;
}
}
Ok(filterbank)
}
}