Skip to main content

rlx_wav2vec2_bert/
preprocess.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Minimal audio preprocessing for Wav2Vec2-BERT.
17//!
18//! The full-quality frontend (STFT + filterbank) can land later; for now we keep
19//! a lightweight implementation so the crate compiles and the CLI can load WAVs.
20
21use anyhow::{Result, anyhow, bail};
22use serde::{Deserialize, Serialize};
23use std::fs;
24use std::path::Path;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Wav2Vec2BertPreprocessConfig {
28    #[serde(default = "default_sample_rate")]
29    pub sampling_rate: usize,
30    #[serde(default = "default_num_mels")]
31    pub num_mel_bins: usize,
32    /// Target frame count for fixed-shape encoders.
33    #[serde(default = "default_num_frames")]
34    pub num_frames: usize,
35}
36
37fn default_sample_rate() -> usize {
38    16_000
39}
40fn default_num_mels() -> usize {
41    80
42}
43fn default_num_frames() -> usize {
44    3_000
45}
46
47impl Default for Wav2Vec2BertPreprocessConfig {
48    fn default() -> Self {
49        Self {
50            sampling_rate: default_sample_rate(),
51            num_mel_bins: default_num_mels(),
52            num_frames: default_num_frames(),
53        }
54    }
55}
56
57impl Wav2Vec2BertPreprocessConfig {
58    pub fn from_file(path: &Path) -> Result<Self> {
59        let txt = fs::read_to_string(path).map_err(|e| anyhow!("read {path:?}: {e}"))?;
60        let cfg: Self = serde_json::from_str(&txt).map_err(|e| anyhow!("parse {path:?}: {e}"))?;
61        Ok(cfg)
62    }
63
64    pub fn w2v_bert_2_0() -> Self {
65        Self::default()
66    }
67
68    pub fn feature_dim(&self) -> usize {
69        self.num_mel_bins
70    }
71}
72
73#[derive(Debug, Clone)]
74pub struct LogMelFeatures {
75    pub num_mel_bins: usize,
76    pub num_frames: usize,
77    /// Row-major `[1, num_frames, num_mel_bins]` (batch=1).
78    pub features: Vec<f32>,
79    /// Row-major `[1, num_frames]` (1 = valid, 0 = padded).
80    pub attention_mask: Vec<f32>,
81}
82
83#[derive(Debug, Clone)]
84pub struct LogMelExtractor {
85    cfg: Wav2Vec2BertPreprocessConfig,
86}
87
88impl LogMelExtractor {
89    pub fn new(cfg: Wav2Vec2BertPreprocessConfig) -> Self {
90        Self { cfg }
91    }
92
93    pub fn config(&self) -> &Wav2Vec2BertPreprocessConfig {
94        &self.cfg
95    }
96
97    pub fn extract(&self, _pcm: &[f32]) -> LogMelFeatures {
98        // Placeholder: returns zeros with the right shape.
99        let m = self.cfg.num_mel_bins;
100        let t = self.cfg.num_frames;
101        LogMelFeatures {
102            num_mel_bins: m,
103            num_frames: t,
104            features: vec![0.0f32; t * m],
105            attention_mask: vec![1.0f32; t],
106        }
107    }
108
109    pub fn pad_to_seq(&self, mut feats: LogMelFeatures, seq: usize) -> LogMelFeatures {
110        if feats.num_frames == seq {
111            return feats;
112        }
113        let m = feats.num_mel_bins;
114        let mut out = vec![0.0f32; seq * m];
115        let mut mask = vec![0.0f32; seq];
116        let copy_t = feats.num_frames.min(seq);
117        out[..copy_t * m].copy_from_slice(&feats.features[..copy_t * m]);
118        for i in 0..copy_t {
119            mask[i] = 1.0;
120        }
121        feats.num_frames = seq;
122        feats.features = out;
123        feats.attention_mask = mask;
124        feats
125    }
126}
127
128pub fn load_wav_mono_f32(path: &Path) -> Result<(Vec<f32>, usize)> {
129    let bytes = fs::read(path).map_err(|e| anyhow!("read wav {path:?}: {e}"))?;
130    parse_wav_mono_f32(&bytes)
131}
132
133pub fn parse_wav_mono_f32(bytes: &[u8]) -> Result<(Vec<f32>, usize)> {
134    if bytes.len() < 44 {
135        bail!("wav too small");
136    }
137    if &bytes[0..4] != b"RIFF" || &bytes[8..12] != b"WAVE" {
138        bail!("not a RIFF/WAVE file");
139    }
140    let mut off = 12usize;
141    let mut fmt: Option<(u16, u16, u32, u16)> = None; // (audio_format, channels, sample_rate, bits_per_sample)
142    let mut data_chunk: Option<&[u8]> = None;
143    while off + 8 <= bytes.len() {
144        let tag = &bytes[off..off + 4];
145        let len = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap()) as usize;
146        off += 8;
147        if off + len > bytes.len() {
148            break;
149        }
150        match tag {
151            b"fmt " => {
152                if len < 16 {
153                    bail!("wav fmt chunk too small");
154                }
155                let audio_format = u16::from_le_bytes(bytes[off..off + 2].try_into().unwrap());
156                let channels = u16::from_le_bytes(bytes[off + 2..off + 4].try_into().unwrap());
157                let sample_rate = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap());
158                let bits_per_sample =
159                    u16::from_le_bytes(bytes[off + 14..off + 16].try_into().unwrap());
160                fmt = Some((audio_format, channels, sample_rate, bits_per_sample));
161            }
162            b"data" => data_chunk = Some(&bytes[off..off + len]),
163            _ => {}
164        }
165        off += (len + 1) & !1;
166        if fmt.is_some() && data_chunk.is_some() {
167            break;
168        }
169    }
170    let (audio_format, channels, sr, bps) = fmt.ok_or_else(|| anyhow!("wav missing fmt chunk"))?;
171    if audio_format != 1 {
172        bail!("wav: only PCM supported (format={audio_format})");
173    }
174    if channels != 1 {
175        bail!("wav: expected mono, got {channels} channels");
176    }
177    if bps != 16 {
178        bail!("wav: expected 16-bit PCM, got {bps}");
179    }
180    let data = data_chunk.ok_or_else(|| anyhow!("wav missing data chunk"))?;
181    if data.len() % 2 != 0 {
182        bail!("wav data chunk not aligned");
183    }
184    let mut out = Vec::with_capacity(data.len() / 2);
185    for i in (0..data.len()).step_by(2) {
186        let s = i16::from_le_bytes([data[i], data[i + 1]]) as f32 / 32768.0;
187        out.push(s);
188    }
189    Ok((out, sr as usize))
190}