Skip to main content

rlx_whisper/
audio.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Audio I/O, mel spectrograms, and VAD segmentation for Whisper.
17
18use crate::config::WhisperConfig;
19use crate::mel::pcm_to_log_mel;
20use crate::vad::{VadConfig, segments_by_vad};
21use anyhow::{Result, anyhow, bail};
22use std::fs;
23use std::path::Path;
24
25pub const SAMPLE_RATE: usize = 16_000;
26pub const N_SAMPLES: usize = 30 * SAMPLE_RATE;
27pub const N_FRAMES: usize = 3_000;
28
29#[derive(Debug, Clone)]
30pub struct MelSpectrogram {
31    pub n_mels: usize,
32    pub n_frames: usize,
33    /// Row-major `[1, n_mels, n_frames]` as f32.
34    pub data: Vec<f32>,
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub struct SpeechSegment {
39    pub start: usize,
40    pub end: usize,
41}
42
43#[derive(Debug, Clone, Default)]
44pub struct EnergyVad {
45    pub threshold: f32,
46    pub min_len_samples: usize,
47}
48
49impl EnergyVad {
50    pub fn to_vad_config(&self) -> VadConfig {
51        VadConfig {
52            kind: crate::vad::VadKind::Energy,
53            threshold: self.threshold,
54            min_speech_samples: self.min_len_samples.max(SAMPLE_RATE / 10),
55            ..VadConfig::default()
56        }
57    }
58}
59
60pub fn pcm_segments_by_vad(vad: &EnergyVad, pcm: &[f32]) -> Vec<SpeechSegment> {
61    segments_by_vad(&vad.to_vad_config(), pcm)
62}
63
64pub fn pcm_segments_by_vad_config(cfg: &VadConfig, pcm: &[f32]) -> Vec<SpeechSegment> {
65    segments_by_vad(cfg, pcm)
66}
67
68/// Pad with zeros at the end or truncate to [`N_SAMPLES`] (OpenAI `pad_or_trim`).
69pub fn pad_or_trim_pcm(pcm: &[f32]) -> Vec<f32> {
70    if pcm.len() >= N_SAMPLES {
71        pcm[..N_SAMPLES].to_vec()
72    } else {
73        let mut out = vec![0.0f32; N_SAMPLES];
74        out[..pcm.len()].copy_from_slice(pcm);
75        out
76    }
77}
78
79pub fn pcm_to_mel(cfg: &WhisperConfig, pcm: &[f32]) -> MelSpectrogram {
80    let pcm = pad_or_trim_pcm(pcm);
81    pcm_to_log_mel(&pcm, cfg.num_mel_bins, N_FRAMES)
82}
83
84pub fn load_wav_mono_f32(path: &Path) -> Result<Vec<f32>> {
85    let bytes = fs::read(path).map_err(|e| anyhow!("read wav {path:?}: {e}"))?;
86    parse_wav_mono_f32(&bytes)
87}
88
89pub fn parse_wav_mono_f32(bytes: &[u8]) -> Result<Vec<f32>> {
90    // Very small RIFF/WAVE PCM parser (16-bit mono).
91    if bytes.len() < 44 {
92        bail!("wav too small");
93    }
94    if &bytes[0..4] != b"RIFF" || &bytes[8..12] != b"WAVE" {
95        bail!("not a RIFF/WAVE file");
96    }
97    let mut off = 12usize;
98    let mut fmt: Option<(u16, u16, u32, u16)> = None; // (audio_format, channels, sample_rate, bits_per_sample)
99    let mut data_chunk: Option<&[u8]> = None;
100    while off + 8 <= bytes.len() {
101        let tag = &bytes[off..off + 4];
102        let len = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap()) as usize;
103        off += 8;
104        if off + len > bytes.len() {
105            break;
106        }
107        match tag {
108            b"fmt " => {
109                if len < 16 {
110                    bail!("wav fmt chunk too small");
111                }
112                let audio_format = u16::from_le_bytes(bytes[off..off + 2].try_into().unwrap());
113                let channels = u16::from_le_bytes(bytes[off + 2..off + 4].try_into().unwrap());
114                let sample_rate = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap());
115                let bits_per_sample =
116                    u16::from_le_bytes(bytes[off + 14..off + 16].try_into().unwrap());
117                fmt = Some((audio_format, channels, sample_rate, bits_per_sample));
118            }
119            b"data" => {
120                data_chunk = Some(&bytes[off..off + len]);
121            }
122            _ => {}
123        }
124        off += (len + 1) & !1; // word-align
125        if fmt.is_some() && data_chunk.is_some() {
126            break;
127        }
128    }
129    let (audio_format, channels, sr, bps) = fmt.ok_or_else(|| anyhow!("wav missing fmt chunk"))?;
130    if audio_format != 1 {
131        bail!("wav: only PCM supported (format={audio_format})");
132    }
133    if channels != 1 {
134        bail!("wav: expected mono, got {channels} channels");
135    }
136    if sr as usize != SAMPLE_RATE {
137        bail!("wav: expected {SAMPLE_RATE} Hz, got {sr}");
138    }
139    if bps != 16 {
140        bail!("wav: expected 16-bit PCM, got {bps}");
141    }
142    let data = data_chunk.ok_or_else(|| anyhow!("wav missing data chunk"))?;
143    if data.len() % 2 != 0 {
144        bail!("wav data chunk not aligned");
145    }
146    let mut out = Vec::with_capacity(data.len() / 2);
147    for i in (0..data.len()).step_by(2) {
148        let s = i16::from_le_bytes([data[i], data[i + 1]]) as f32 / 32768.0;
149        out.push(s);
150    }
151    Ok(out)
152}