use anyhow::{Result, anyhow, bail};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Wav2Vec2BertPreprocessConfig {
#[serde(default = "default_sample_rate")]
pub sampling_rate: usize,
#[serde(default = "default_num_mels")]
pub num_mel_bins: usize,
#[serde(default = "default_num_frames")]
pub num_frames: usize,
}
fn default_sample_rate() -> usize {
16_000
}
fn default_num_mels() -> usize {
80
}
fn default_num_frames() -> usize {
3_000
}
impl Default for Wav2Vec2BertPreprocessConfig {
fn default() -> Self {
Self {
sampling_rate: default_sample_rate(),
num_mel_bins: default_num_mels(),
num_frames: default_num_frames(),
}
}
}
impl Wav2Vec2BertPreprocessConfig {
pub fn from_file(path: &Path) -> Result<Self> {
let txt = fs::read_to_string(path).map_err(|e| anyhow!("read {path:?}: {e}"))?;
let cfg: Self = serde_json::from_str(&txt).map_err(|e| anyhow!("parse {path:?}: {e}"))?;
Ok(cfg)
}
pub fn w2v_bert_2_0() -> Self {
Self::default()
}
pub fn feature_dim(&self) -> usize {
self.num_mel_bins
}
}
#[derive(Debug, Clone)]
pub struct LogMelFeatures {
pub num_mel_bins: usize,
pub num_frames: usize,
pub features: Vec<f32>,
pub attention_mask: Vec<f32>,
}
#[derive(Debug, Clone)]
pub struct LogMelExtractor {
cfg: Wav2Vec2BertPreprocessConfig,
}
impl LogMelExtractor {
pub fn new(cfg: Wav2Vec2BertPreprocessConfig) -> Self {
Self { cfg }
}
pub fn config(&self) -> &Wav2Vec2BertPreprocessConfig {
&self.cfg
}
pub fn extract(&self, _pcm: &[f32]) -> LogMelFeatures {
let m = self.cfg.num_mel_bins;
let t = self.cfg.num_frames;
LogMelFeatures {
num_mel_bins: m,
num_frames: t,
features: vec![0.0f32; t * m],
attention_mask: vec![1.0f32; t],
}
}
pub fn pad_to_seq(&self, mut feats: LogMelFeatures, seq: usize) -> LogMelFeatures {
if feats.num_frames == seq {
return feats;
}
let m = feats.num_mel_bins;
let mut out = vec![0.0f32; seq * m];
let mut mask = vec![0.0f32; seq];
let copy_t = feats.num_frames.min(seq);
out[..copy_t * m].copy_from_slice(&feats.features[..copy_t * m]);
for i in 0..copy_t {
mask[i] = 1.0;
}
feats.num_frames = seq;
feats.features = out;
feats.attention_mask = mask;
feats
}
}
pub fn load_wav_mono_f32(path: &Path) -> Result<(Vec<f32>, usize)> {
let bytes = fs::read(path).map_err(|e| anyhow!("read wav {path:?}: {e}"))?;
parse_wav_mono_f32(&bytes)
}
pub fn parse_wav_mono_f32(bytes: &[u8]) -> Result<(Vec<f32>, usize)> {
if bytes.len() < 44 {
bail!("wav too small");
}
if &bytes[0..4] != b"RIFF" || &bytes[8..12] != b"WAVE" {
bail!("not a RIFF/WAVE file");
}
let mut off = 12usize;
let mut fmt: Option<(u16, u16, u32, u16)> = None; let mut data_chunk: Option<&[u8]> = None;
while off + 8 <= bytes.len() {
let tag = &bytes[off..off + 4];
let len = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap()) as usize;
off += 8;
if off + len > bytes.len() {
break;
}
match tag {
b"fmt " => {
if len < 16 {
bail!("wav fmt chunk too small");
}
let audio_format = u16::from_le_bytes(bytes[off..off + 2].try_into().unwrap());
let channels = u16::from_le_bytes(bytes[off + 2..off + 4].try_into().unwrap());
let sample_rate = u32::from_le_bytes(bytes[off + 4..off + 8].try_into().unwrap());
let bits_per_sample =
u16::from_le_bytes(bytes[off + 14..off + 16].try_into().unwrap());
fmt = Some((audio_format, channels, sample_rate, bits_per_sample));
}
b"data" => data_chunk = Some(&bytes[off..off + len]),
_ => {}
}
off += (len + 1) & !1;
if fmt.is_some() && data_chunk.is_some() {
break;
}
}
let (audio_format, channels, sr, bps) = fmt.ok_or_else(|| anyhow!("wav missing fmt chunk"))?;
if audio_format != 1 {
bail!("wav: only PCM supported (format={audio_format})");
}
if channels != 1 {
bail!("wav: expected mono, got {channels} channels");
}
if bps != 16 {
bail!("wav: expected 16-bit PCM, got {bps}");
}
let data = data_chunk.ok_or_else(|| anyhow!("wav missing data chunk"))?;
if data.len() % 2 != 0 {
bail!("wav data chunk not aligned");
}
let mut out = Vec::with_capacity(data.len() / 2);
for i in (0..data.len()).step_by(2) {
let s = i16::from_le_bytes([data[i], data[i + 1]]) as f32 / 32768.0;
out.push(s);
}
Ok((out, sr as usize))
}