rlx_wav2vec2_bert/
preprocess.rs1use anyhow::{Result, anyhow, bail};
22use serde::{Deserialize, Serialize};
23use std::fs;
24use std::path::Path;
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct Wav2Vec2BertPreprocessConfig {
28 #[serde(default = "default_sample_rate")]
29 pub sampling_rate: usize,
30 #[serde(default = "default_num_mels")]
31 pub num_mel_bins: usize,
32 #[serde(default = "default_num_frames")]
34 pub num_frames: usize,
35}
36
37fn default_sample_rate() -> usize {
38 16_000
39}
40fn default_num_mels() -> usize {
41 80
42}
43fn default_num_frames() -> usize {
44 3_000
45}
46
47impl Default for Wav2Vec2BertPreprocessConfig {
48 fn default() -> Self {
49 Self {
50 sampling_rate: default_sample_rate(),
51 num_mel_bins: default_num_mels(),
52 num_frames: default_num_frames(),
53 }
54 }
55}
56
57impl Wav2Vec2BertPreprocessConfig {
58 pub fn from_file(path: &Path) -> Result<Self> {
59 let txt = fs::read_to_string(path).map_err(|e| anyhow!("read {path:?}: {e}"))?;
60 let cfg: Self = serde_json::from_str(&txt).map_err(|e| anyhow!("parse {path:?}: {e}"))?;
61 Ok(cfg)
62 }
63
64 pub fn w2v_bert_2_0() -> Self {
65 Self::default()
66 }
67
68 pub fn feature_dim(&self) -> usize {
69 self.num_mel_bins
70 }
71}
72
73#[derive(Debug, Clone)]
74pub struct LogMelFeatures {
75 pub num_mel_bins: usize,
76 pub num_frames: usize,
77 pub features: Vec<f32>,
79 pub attention_mask: Vec<f32>,
81}
82
83#[derive(Debug, Clone)]
84pub struct LogMelExtractor {
85 cfg: Wav2Vec2BertPreprocessConfig,
86}
87
88impl LogMelExtractor {
89 pub fn new(cfg: Wav2Vec2BertPreprocessConfig) -> Self {
90 Self { cfg }
91 }
92
93 pub fn config(&self) -> &Wav2Vec2BertPreprocessConfig {
94 &self.cfg
95 }
96
97 pub fn extract(&self, _pcm: &[f32]) -> LogMelFeatures {
98 let m = self.cfg.num_mel_bins;
100 let t = self.cfg.num_frames;
101 LogMelFeatures {
102 num_mel_bins: m,
103 num_frames: t,
104 features: vec![0.0f32; t * m],
105 attention_mask: vec![1.0f32; t],
106 }
107 }
108
109 pub fn pad_to_seq(&self, mut feats: LogMelFeatures, seq: usize) -> LogMelFeatures {
110 if feats.num_frames == seq {
111 return feats;
112 }
113 let m = feats.num_mel_bins;
114 let mut out = vec![0.0f32; seq * m];
115 let mut mask = vec![0.0f32; seq];
116 let copy_t = feats.num_frames.min(seq);
117 out[..copy_t * m].copy_from_slice(&feats.features[..copy_t * m]);
118 for i in 0..copy_t {
119 mask[i] = 1.0;
120 }
121 feats.num_frames = seq;
122 feats.features = out;
123 feats.attention_mask = mask;
124 feats
125 }
126}
127
128pub fn load_wav_mono_f32(path: &Path) -> Result<(Vec<f32>, usize)> {
129 let bytes = fs::read(path).map_err(|e| anyhow!("read wav {path:?}: {e}"))?;
130 parse_wav_mono_f32(&bytes)
131}
132
133pub fn parse_wav_mono_f32(bytes: &[u8]) -> Result<(Vec<f32>, usize)> {
134 if bytes.len() < 44 {
135 bail!("wav too small");
136 }
137 if &bytes[0..4] != b"RIFF" || &bytes[8..12] != b"WAVE" {
138 bail!("not a RIFF/WAVE file");
139 }
140 let mut off = 12usize;
141 let mut fmt: Option<(u16, u16, u32, u16)> = None; let mut data_chunk: Option<&[u8]> = None;
143 while off + 8 <= bytes.len() {
144 let tag = &bytes[off..off + 4];
145 let len = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap()) as usize;
146 off += 8;
147 if off + len > bytes.len() {
148 break;
149 }
150 match tag {
151 b"fmt " => {
152 if len < 16 {
153 bail!("wav fmt chunk too small");
154 }
155 let audio_format = u16::from_le_bytes(bytes[off..off + 2].try_into().unwrap());
156 let channels = u16::from_le_bytes(bytes[off + 2..off + 4].try_into().unwrap());
157 let sample_rate = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap());
158 let bits_per_sample =
159 u16::from_le_bytes(bytes[off + 14..off + 16].try_into().unwrap());
160 fmt = Some((audio_format, channels, sample_rate, bits_per_sample));
161 }
162 b"data" => data_chunk = Some(&bytes[off..off + len]),
163 _ => {}
164 }
165 off += (len + 1) & !1;
166 if fmt.is_some() && data_chunk.is_some() {
167 break;
168 }
169 }
170 let (audio_format, channels, sr, bps) = fmt.ok_or_else(|| anyhow!("wav missing fmt chunk"))?;
171 if audio_format != 1 {
172 bail!("wav: only PCM supported (format={audio_format})");
173 }
174 if channels != 1 {
175 bail!("wav: expected mono, got {channels} channels");
176 }
177 if bps != 16 {
178 bail!("wav: expected 16-bit PCM, got {bps}");
179 }
180 let data = data_chunk.ok_or_else(|| anyhow!("wav missing data chunk"))?;
181 if data.len() % 2 != 0 {
182 bail!("wav data chunk not aligned");
183 }
184 let mut out = Vec::with_capacity(data.len() / 2);
185 for i in (0..data.len()).step_by(2) {
186 let s = i16::from_le_bytes([data[i], data[i + 1]]) as f32 / 32768.0;
187 out.push(s);
188 }
189 Ok((out, sr as usize))
190}