use anyhow::{Result, anyhow};
use audioadapter_buffers::direct::SequentialSliceOfVecs;
use burn::tensor::{Tensor, backend::Backend};
use rubato::{Async, FixedAsync, Resampler, SincInterpolationParameters, WindowFunction};
use rustfft::{FftPlanner, num_complex::Complex};
use std::f32::consts::PI;
#[cfg(feature = "file-io")]
use std::path::Path;
#[cfg(feature = "file-io")]
use symphonia::core::codecs::audio::CODEC_ID_NULL_AUDIO;
#[cfg(feature = "file-io")]
use symphonia::core::{
codecs::audio::AudioDecoderOptions,
formats::{FormatOptions, probe::Hint},
io::MediaSourceStream,
meta::MetadataOptions,
};
pub const WHISPER_SAMPLE_RATE: u32 = 16000;
pub const WHISPER_N_FFT: usize = 400;
pub const WHISPER_HOP_LENGTH: usize = 160;
pub const WHISPER_N_MELS: usize = 80;
pub const WHISPER_CHUNK_LENGTH: usize = 30;
#[derive(Debug, Clone)]
pub struct AudioData {
pub samples: Vec<f32>,
pub sample_rate: u32,
pub channels: u16,
}
impl AudioData {
pub fn new(samples: Vec<f32>, sample_rate: u32, channels: u16) -> Self {
Self {
samples,
sample_rate,
channels,
}
}
pub fn duration(&self) -> f32 {
self.samples.len() as f32 / (self.sample_rate as f32 * self.channels as f32)
}
pub fn to_mono(&self) -> AudioData {
if self.channels == 1 {
return self.clone();
}
let mono_samples: Vec<f32> = self
.samples
.chunks(self.channels as usize)
.map(|chunk| chunk.iter().sum::<f32>() / self.channels as f32)
.collect();
AudioData {
samples: mono_samples,
sample_rate: self.sample_rate,
channels: 1,
}
}
pub fn resample(&self, target_sample_rate: u32) -> Result<AudioData> {
if self.sample_rate == target_sample_rate {
return Ok(self.clone());
}
let f_ratio = target_sample_rate as f64 / self.sample_rate as f64;
let params = SincInterpolationParameters {
sinc_len: 256,
f_cutoff: 0.95,
interpolation: rubato::SincInterpolationType::Cubic,
oversampling_factor: 256,
window: WindowFunction::BlackmanHarris2,
};
let chunk_size = 1024;
let mut resampler = Async::<f32>::new_sinc(
f_ratio,
2.0,
¶ms,
chunk_size,
self.channels as usize,
FixedAsync::Input,
)
.map_err(|e| anyhow!("Failed to create resampler: {}", e))?;
let frames_per_channel = self.samples.len() / self.channels as usize;
let mut input_channels: Vec<Vec<f32>> =
vec![Vec::with_capacity(frames_per_channel); self.channels as usize];
for (i, &sample) in self.samples.iter().enumerate() {
let channel = i % self.channels as usize;
input_channels[channel].push(sample);
}
let input_adapter =
SequentialSliceOfVecs::new(&input_channels, self.channels as usize, frames_per_channel)
.map_err(|e| anyhow!("Failed to create input adapter: {}", e))?;
let estimated_output_frames = (frames_per_channel as f64 * f_ratio) as usize;
let mut output_channels: Vec<Vec<f32>> =
vec![vec![0.0f32; estimated_output_frames]; self.channels as usize];
let mut output_adapter = SequentialSliceOfVecs::new_mut(
&mut output_channels,
self.channels as usize,
estimated_output_frames,
)
.map_err(|e| anyhow!("Failed to create output adapter: {}", e))?;
let mut indexing = rubato::Indexing {
input_offset: 0,
output_offset: 0,
active_channels_mask: None,
partial_len: None,
};
let mut input_frames_left = frames_per_channel;
let mut input_frames_next = resampler.input_frames_next();
while input_frames_left >= input_frames_next {
let (frames_read, frames_written) = resampler
.process_into_buffer(&input_adapter, &mut output_adapter, Some(&indexing))
.map_err(|e| anyhow!("Resampling failed: {}", e))?;
indexing.input_offset += frames_read;
indexing.output_offset += frames_written;
input_frames_left -= frames_read;
input_frames_next = resampler.input_frames_next();
}
let actual_output_frames = indexing.output_offset;
let mut output_samples = Vec::with_capacity(actual_output_frames * self.channels as usize);
for frame in 0..actual_output_frames {
for ch in &output_channels {
output_samples.push(ch[frame]);
}
}
Ok(AudioData {
samples: output_samples,
sample_rate: target_sample_rate,
channels: self.channels,
})
}
pub fn to_16khz_mono(&self) -> Result<AudioData> {
let mono = self.to_mono();
mono.resample(16000)
}
}
#[cfg(feature = "file-io")]
pub fn load_audio_file<P: AsRef<Path>>(path: P) -> Result<AudioData> {
let path = path.as_ref();
let file = std::fs::File::open(path)
.map_err(|e| anyhow!("Failed to open audio file '{}': {}", path.display(), e))?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = path.extension().and_then(|e| e.to_str()) {
hint.with_extension(ext);
}
let mut format = symphonia::default::get_probe()
.probe(
&hint,
mss,
FormatOptions::default(),
MetadataOptions::default(),
)
.map_err(|e| anyhow!("Unsupported audio format '{}': {}", path.display(), e))?;
let track = format
.tracks()
.iter()
.find(|t| {
t.codec_params
.as_ref()
.and_then(|cp| cp.audio())
.map(|ap| ap.codec != CODEC_ID_NULL_AUDIO)
.unwrap_or(false)
})
.ok_or_else(|| anyhow!("No audio tracks found in '{}'", path.display()))?;
let track_id = track.id;
let codec_params = track
.codec_params
.as_ref()
.and_then(|cp| cp.audio())
.ok_or_else(|| anyhow!("Missing codec parameters in '{}'", path.display()))?;
let sample_rate = codec_params
.sample_rate
.ok_or_else(|| anyhow!("Unknown sample rate in '{}'", path.display()))?;
let channels = codec_params
.channels
.as_ref()
.ok_or_else(|| anyhow!("Unknown channel count in '{}'", path.display()))?
.count() as u16;
let mut decoder = symphonia::default::get_codecs()
.make_audio_decoder(codec_params, &AudioDecoderOptions::default())
.map_err(|e| anyhow!("Failed to create decoder for '{}': {}", path.display(), e))?;
let mut samples: Vec<f32> = Vec::new();
loop {
let packet = match format.next_packet() {
Ok(Some(p)) => p,
Ok(None) => {
break;
}
Err(symphonia::core::errors::Error::ResetRequired) => {
continue;
}
Err(e) => {
return Err(anyhow!("Error reading '{}': {}", path.display(), e));
}
};
if packet.track_id != track_id {
continue;
}
let decoded = match decoder.decode(&packet) {
Ok(d) => d,
Err(symphonia::core::errors::Error::IoError(_)) => continue,
Err(e) => return Err(anyhow!("Decode error in '{}': {}", path.display(), e)),
};
decoded.copy_to_vec_interleaved::<f32>(&mut samples);
}
Ok(AudioData {
samples,
sample_rate,
channels,
})
}
pub fn prepare_centered_samples_raw(samples: &[f32], n_fft: usize) -> Vec<f32> {
let pad_len = n_fft / 2;
let n = samples.len();
let mut centered = Vec::with_capacity(n + 2 * pad_len);
for i in (1..=pad_len).rev() {
centered.push(samples[i]);
}
centered.extend_from_slice(samples);
for i in 0..pad_len {
centered.push(samples[n - 2 - i]);
}
centered
}
fn prepare_centered_samples(audio: &AudioData, n_fft: usize) -> Result<Vec<f32>> {
let audio = if audio.channels != 1 || audio.sample_rate != WHISPER_SAMPLE_RATE {
audio.to_16khz_mono()?
} else {
audio.clone()
};
let target_samples = 30 * WHISPER_SAMPLE_RATE as usize;
let mut padded = audio.samples;
if padded.len() > target_samples {
padded.truncate(target_samples);
} else {
padded.resize(target_samples, 0.0);
}
Ok(prepare_centered_samples_raw(&padded, n_fft))
}
fn mel_compress<B: Backend>(
ps: Tensor<B, 2>,
n_mels: usize,
n_fft: usize,
device: &B::Device,
) -> Tensor<B, 3> {
let [n_freqs, n_frames] = ps.dims();
let mel_filters = create_mel_filter_bank(n_fft, n_mels, WHISPER_SAMPLE_RATE as f32);
let mf_flat: Vec<f32> = mel_filters.into_iter().flatten().collect();
let mf_tensor: Tensor<B, 2> =
Tensor::<B, 1>::from_floats(mf_flat.as_slice(), device).reshape([n_mels, n_freqs]);
let mel: Tensor<B, 2> = mf_tensor.matmul(ps);
let log10_e = std::f32::consts::LOG10_E;
let log_mel: Tensor<B, 2> = mel.clamp_min(1e-10_f32).log().mul_scalar(log10_e);
let max_val: Tensor<B, 2> = log_mel.clone().max().reshape([1, 1]);
let log_mel = (log_mel - max_val.clone())
.clamp_min(-8.0_f32)
.add(max_val)
.add_scalar(4.0_f32)
.div_scalar(4.0_f32);
log_mel.reshape([1, n_mels, n_frames])
}
pub fn compute_mel_spectrogram<B: Backend>(
audio: &AudioData,
n_fft: usize,
hop_length: usize,
n_mels: usize,
device: &B::Device,
) -> Result<Tensor<B, 3>> {
let centered = prepare_centered_samples(audio, n_fft)?;
let magnitudes = compute_stft_magnitudes(¢ered, n_fft, hop_length);
let n_freqs = n_fft / 2 + 1;
let n_frames = magnitudes[0].len().saturating_sub(1);
let ps_flat: Vec<f32> = (0..n_freqs)
.flat_map(|f| magnitudes[f][..n_frames].iter().copied())
.collect();
let ps_tensor: Tensor<B, 2> =
Tensor::<B, 1>::from_floats(ps_flat.as_slice(), device).reshape([n_freqs, n_frames]);
Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
}
pub fn compute_mel_from_samples<B: Backend>(
samples: &[f32],
n_fft: usize,
hop_length: usize,
n_mels: usize,
device: &B::Device,
) -> Result<Tensor<B, 3>> {
let expected = 30 * WHISPER_SAMPLE_RATE as usize;
anyhow::ensure!(
samples.len() == expected,
"compute_mel_from_samples: expected {} samples, got {}",
expected,
samples.len()
);
let centered = prepare_centered_samples_raw(samples, n_fft);
let magnitudes = compute_stft_magnitudes(¢ered, n_fft, hop_length);
let n_freqs = n_fft / 2 + 1;
let n_frames = magnitudes[0].len().saturating_sub(1);
let ps_flat: Vec<f32> = (0..n_freqs)
.flat_map(|f| magnitudes[f][..n_frames].iter().copied())
.collect();
let ps_tensor: Tensor<B, 2> =
Tensor::<B, 1>::from_floats(ps_flat.as_slice(), device).reshape([n_freqs, n_frames]);
Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
}
#[cfg(feature = "cubecl-stft")]
pub fn compute_mel_spectrogram_wgpu(
audio: &AudioData,
n_fft: usize,
hop_length: usize,
n_mels: usize,
device: &burn_wgpu::WgpuDevice,
) -> Result<Tensor<WgpuBackend, 3>> {
use cubecl::prelude::Runtime;
let centered = prepare_centered_samples(audio, n_fft)?;
let n_freqs = n_fft / 2 + 1;
let n_frames_total = (centered.len() - n_fft) / hop_length + 1;
let n_frames = n_frames_total.saturating_sub(1);
let client = burn_wgpu::WgpuRuntime::client(device);
let gpu_out = crate::stft_gpu::compute_stft_power_gpu(
&client,
¢ered,
n_fft,
hop_length,
n_frames_total,
);
let ps_tensor: Tensor<WgpuBackend, 2> =
Tensor::<WgpuBackend, 1>::from_floats(&gpu_out[..n_frames * n_freqs], device)
.reshape([n_frames, n_freqs])
.transpose();
Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
}
#[cfg(feature = "cubecl-stft")]
pub fn compute_mel_from_samples_wgpu(
samples: &[f32],
n_fft: usize,
hop_length: usize,
n_mels: usize,
device: &burn_wgpu::WgpuDevice,
) -> Result<Tensor<WgpuBackend, 3>> {
use cubecl::prelude::Runtime;
let expected = 30 * WHISPER_SAMPLE_RATE as usize;
anyhow::ensure!(
samples.len() == expected,
"compute_mel_from_samples_wgpu: expected {} samples, got {}",
expected,
samples.len()
);
let centered = prepare_centered_samples_raw(samples, n_fft);
let n_freqs = n_fft / 2 + 1;
let n_frames_total = (centered.len() - n_fft) / hop_length + 1;
let n_frames = n_frames_total.saturating_sub(1);
let client = burn_wgpu::WgpuRuntime::client(device);
let gpu_out = crate::stft_gpu::compute_stft_power_gpu(
&client,
¢ered,
n_fft,
hop_length,
n_frames_total,
);
let ps_tensor: Tensor<WgpuBackend, 2> =
Tensor::<WgpuBackend, 1>::from_floats(&gpu_out[..n_frames * n_freqs], device)
.reshape([n_frames, n_freqs])
.transpose();
Ok(mel_compress(ps_tensor, n_mels, n_fft, device))
}
#[cfg(feature = "cubecl-stft")]
pub type WgpuBackend = burn_wgpu::CubeBackend<burn_wgpu::WgpuRuntime, f32, i32, u32>;
pub fn batch_mel_spectrograms<B: Backend>(
chunks: &[AudioData],
n_fft: usize,
hop_length: usize,
n_mels: usize,
device: &B::Device,
) -> Result<Tensor<B, 3>> {
anyhow::ensure!(!chunks.is_empty(), "batch_mel_spectrograms: no chunks");
let mels: Vec<Tensor<B, 3>> = chunks
.iter()
.map(|c| compute_mel_spectrogram(c, n_fft, hop_length, n_mels, device))
.collect::<Result<_>>()?;
Ok(Tensor::cat(mels, 0))
}
#[cfg(feature = "cubecl-stft")]
pub fn batch_mel_spectrograms_wgpu(
chunks: &[AudioData],
n_fft: usize,
hop_length: usize,
n_mels: usize,
device: &burn_wgpu::WgpuDevice,
) -> Result<Tensor<WgpuBackend, 3>> {
anyhow::ensure!(!chunks.is_empty(), "batch_mel_spectrograms_wgpu: no chunks");
let mels: Vec<Tensor<WgpuBackend, 3>> = chunks
.iter()
.map(|c| compute_mel_spectrogram_wgpu(c, n_fft, hop_length, n_mels, device))
.collect::<Result<_>>()?;
Ok(Tensor::cat(mels, 0))
}
#[allow(clippy::needless_range_loop)]
fn compute_stft_magnitudes(samples: &[f32], n_fft: usize, hop_length: usize) -> Vec<Vec<f32>> {
let n_freqs = n_fft / 2 + 1;
let n_frames = if samples.len() >= n_fft {
(samples.len() - n_fft) / hop_length + 1
} else {
0
};
if n_frames == 0 {
return vec![vec![0.0]; n_freqs];
}
let window: Vec<f32> = (0..n_fft)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n_fft as f32).cos()))
.collect();
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(n_fft);
let mut magnitudes = vec![vec![0.0f32; n_frames]; n_freqs];
let mut buffer: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); n_fft];
for frame_idx in 0..n_frames {
let start = frame_idx * hop_length;
for i in 0..n_fft {
let sample = if start + i < samples.len() {
samples[start + i]
} else {
0.0
};
buffer[i] = Complex::new(sample * window[i], 0.0);
}
fft.process(&mut buffer);
for freq in 0..n_freqs {
magnitudes[freq][frame_idx] = buffer[freq].norm_sqr();
}
}
magnitudes
}
fn create_mel_filter_bank(n_fft: usize, n_mels: usize, sample_rate: f32) -> Vec<Vec<f32>> {
let n_freqs = n_fft / 2 + 1;
let fmax = sample_rate / 2.0;
let f_sp: f32 = 200.0 / 3.0;
let min_log_hz: f32 = 1000.0;
let min_log_mel: f32 = min_log_hz / f_sp;
let logstep: f32 = 6.4f32.ln() / 27.0;
let hz_to_mel = |f: f32| -> f32 {
if f >= min_log_hz {
min_log_mel + (f / min_log_hz).ln() / logstep
} else {
f / f_sp
}
};
let mel_to_hz = |m: f32| -> f32 {
if m >= min_log_mel {
min_log_hz * ((m - min_log_mel) * logstep).exp()
} else {
f_sp * m
}
};
let mel_min = hz_to_mel(0.0);
let mel_max = hz_to_mel(fmax);
let hz_pts: Vec<f32> = (0..=n_mels + 1)
.map(|i| mel_to_hz(mel_min + (mel_max - mel_min) * i as f32 / (n_mels + 1) as f32))
.collect();
let fftfreqs: Vec<f32> = (0..n_freqs)
.map(|k| k as f32 * sample_rate / n_fft as f32)
.collect();
let mut filters = vec![vec![0.0f32; n_freqs]; n_mels];
for (i, filt) in filters.iter_mut().enumerate() {
let lower = hz_pts[i];
let center = hz_pts[i + 1];
let upper = hz_pts[i + 2];
let enorm = 2.0 / (upper - lower).max(1e-8);
for (k, &freq) in fftfreqs.iter().enumerate() {
let rising = if center > lower {
((freq - lower) / (center - lower)).max(0.0)
} else {
0.0
};
let falling = if upper > center {
((upper - freq) / (upper - center)).max(0.0)
} else {
0.0
};
filt[k] = rising.min(falling) * enorm;
}
}
filters
}
pub fn pad_or_trim_audio(audio: &AudioData, length_samples: usize) -> AudioData {
let mut samples = audio.samples.clone();
if samples.len() > length_samples {
samples.truncate(length_samples);
} else if samples.len() < length_samples {
samples.resize(length_samples, 0.0);
}
AudioData {
samples,
sample_rate: audio.sample_rate,
channels: audio.channels,
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_flex::FlexDevice;
#[test]
fn test_audio_data_creation() {
let audio = AudioData::new(vec![0.0, 0.5, -0.5, 1.0], 44100, 2);
assert_eq!(audio.duration(), 2.0 / 44100.0);
assert_eq!(audio.channels, 2);
}
#[test]
fn test_mono_conversion() {
let stereo = AudioData::new(vec![1.0, 2.0, 3.0, 4.0], 44100, 2);
let mono = stereo.to_mono();
assert_eq!(mono.samples, vec![1.5, 3.5]);
assert_eq!(mono.channels, 1);
}
#[test]
fn test_hann_window() {
let n = 400;
let window: Vec<f32> = (0..n)
.map(|i| 0.5 * (1.0 - (2.0 * PI * i as f32 / n as f32).cos()))
.collect();
assert!(window[0] < 0.01);
assert!(window[n - 1] < 0.01);
let center = n / 2;
assert!(window[center] > 0.99);
}
#[test]
fn test_mel_filter_bank() {
let filters = create_mel_filter_bank(400, 80, 16000.0);
assert_eq!(filters.len(), 80);
assert_eq!(filters[0].len(), 201);
for filter in &filters {
for &val in filter {
assert!(val >= 0.0, "Filter values should be non-negative");
}
}
for filter in &filters {
let max_val = filter.iter().cloned().fold(0.0f32, f32::max);
assert!(
max_val > 0.0,
"Filter should have at least one non-zero bin"
);
assert!(
max_val < 1.0,
"Filter peak should be less than 1.0 (Slaney norm)"
);
}
}
#[test]
fn test_stft_magnitudes() {
let sample_rate = 16000.0;
let freq = 440.0; let duration = 0.1; let n_samples = (sample_rate * duration) as usize;
let samples: Vec<f32> = (0..n_samples)
.map(|i| (2.0 * PI * freq * i as f32 / sample_rate).sin())
.collect();
let magnitudes = compute_stft_magnitudes(&samples, 400, 160);
assert_eq!(magnitudes.len(), 201);
assert!(!magnitudes[0].is_empty(), "Should have at least one frame");
for freq_bin in &magnitudes {
for &mag in freq_bin {
assert!(mag >= 0.0, "Magnitudes should be non-negative");
}
}
}
#[test]
fn test_mel_spectrogram() {
let audio = AudioData::new(vec![0.0; 16000], 16000, 1); let result = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
&audio,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
WHISPER_N_MELS,
&FlexDevice,
);
assert!(result.is_ok());
let tensor = result.unwrap();
let dims = tensor.dims();
assert_eq!(dims[0], 1, "Batch size should be 1");
assert_eq!(
dims[1], WHISPER_N_MELS,
"Should have {} mel bins",
WHISPER_N_MELS
);
assert_eq!(dims[2], 3000, "Should always return 3000 mel frames");
}
#[test]
fn test_mel_spectrogram_with_sine() {
let sample_rate = 16000;
let freq = 440.0;
let duration = 1.0;
let n_samples = (sample_rate as f32 * duration) as usize;
let samples: Vec<f32> = (0..n_samples)
.map(|i| 0.5 * (2.0 * PI * freq * i as f32 / sample_rate as f32).sin())
.collect();
let audio = AudioData::new(samples, sample_rate, 1);
let result = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
&audio,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
WHISPER_N_MELS,
&FlexDevice,
);
assert!(result.is_ok());
let tensor = result.unwrap();
let data = tensor.to_data();
let values: Vec<f32> = data.to_vec().unwrap();
for &val in &values {
assert!(val.is_finite(), "All values should be finite");
assert!(
(-2.0..=2.0).contains(&val),
"Values should be in reasonable range, got {}",
val
);
}
}
#[test]
fn test_pad_or_trim() {
let audio = AudioData::new(vec![1.0, 2.0, 3.0], 16000, 1);
let trimmed = pad_or_trim_audio(&audio, 2);
assert_eq!(trimmed.samples, vec![1.0, 2.0]);
let padded = pad_or_trim_audio(&audio, 5);
assert_eq!(padded.samples, vec![1.0, 2.0, 3.0, 0.0, 0.0]);
}
#[test]
fn test_compute_mel_from_samples_matches_audio_data() {
let samples: Vec<f32> = (0..480_000)
.map(|i| 0.5 * (2.0 * PI * 440.0 * i as f32 / 16000.0).sin())
.collect();
let audio = AudioData::new(samples.clone(), WHISPER_SAMPLE_RATE, 1);
let mel_via_audio = compute_mel_spectrogram::<burn_flex::Flex<f32>>(
&audio,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
WHISPER_N_MELS,
&FlexDevice,
)
.unwrap();
let mel_via_raw = compute_mel_from_samples::<burn_flex::Flex<f32>>(
&samples,
WHISPER_N_FFT,
WHISPER_HOP_LENGTH,
WHISPER_N_MELS,
&FlexDevice,
)
.unwrap();
let a: Vec<f32> = mel_via_audio.to_data().to_vec().unwrap();
let b: Vec<f32> = mel_via_raw.to_data().to_vec().unwrap();
assert_eq!(a.len(), b.len(), "shape mismatch");
for (i, (x, y)) in a.iter().zip(b.iter()).enumerate() {
assert_eq!(x, y, "mismatch at index {i}");
}
}
}