Skip to main content

asr_features/
frontend.rs

1use anyhow::Result;
2use sha2::{Digest, Sha256};
3
4use crate::dsp::{compute_audio_featurizer_features, compute_w2v_bert_features};
5use crate::util::{hex_prefix, py_bool, py_float};
6
7#[derive(Debug, Clone)]
8pub struct AudioFrontendConfig {
9    pub sample_rate: u32,
10    pub n_fft: usize,
11    pub win_length: usize,
12    pub hop_length: usize,
13    pub n_mels: usize,
14    pub preemphasis: f32,
15    pub normalize_signal: bool,
16    pub normalize_feature: bool,
17    pub normalize_per_frame: bool,
18}
19
20#[derive(Debug, Clone)]
21pub struct W2vBertFrontendConfig {
22    pub model_source: String,
23    pub sample_rate: u32,
24    pub feature_size: usize,
25    pub stride: usize,
26    pub feature_dim: usize,
27    pub padding_value: f32,
28}
29
30#[derive(Debug, Clone)]
31pub(crate) enum FrontendConfig {
32    Audio(AudioFrontendConfig),
33    W2vBert(W2vBertFrontendConfig),
34}
35
36#[derive(Debug)]
37pub struct FeatureMatrix {
38    pub rows: usize,
39    pub cols: usize,
40    pub values: Vec<f32>,
41}
42
43pub fn squeezeformer_frontend_config() -> AudioFrontendConfig {
44    AudioFrontendConfig {
45        sample_rate: 16_000,
46        n_fft: 400,
47        win_length: 400,
48        hop_length: 160,
49        n_mels: 80,
50        preemphasis: 0.97,
51        normalize_signal: true,
52        normalize_feature: true,
53        normalize_per_frame: false,
54    }
55}
56
57pub fn zipformer_frontend_config() -> AudioFrontendConfig {
58    AudioFrontendConfig {
59        sample_rate: 16_000,
60        n_fft: 400,
61        win_length: 400,
62        hop_length: 160,
63        n_mels: 80,
64        preemphasis: 0.0,
65        normalize_signal: false,
66        normalize_feature: false,
67        normalize_per_frame: false,
68    }
69}
70
71pub fn w2v_bert_frontend_config(
72    model_source: Option<String>,
73    sample_rate: Option<u32>,
74    feature_size: Option<usize>,
75    stride: Option<usize>,
76    feature_dim: Option<usize>,
77    padding_value: Option<f32>,
78) -> W2vBertFrontendConfig {
79    let feature_size = feature_size.unwrap_or(80);
80    let stride = stride.unwrap_or(2).max(1);
81    W2vBertFrontendConfig {
82        model_source: model_source.unwrap_or_else(|| "facebook/w2v-bert-2.0".to_string()),
83        sample_rate: sample_rate.unwrap_or(16_000),
84        feature_size,
85        stride,
86        feature_dim: feature_dim.unwrap_or(feature_size * stride),
87        padding_value: padding_value.unwrap_or(1.0),
88    }
89}
90
91pub fn extract_audio_features_from_samples(
92    waveform: &[f32],
93    sample_rate: u32,
94    config: &AudioFrontendConfig,
95) -> Result<FeatureMatrix> {
96    let mut waveform = waveform.to_vec();
97    compute_audio_featurizer_features(&mut waveform, sample_rate, config)
98}
99
100pub fn extract_w2v_bert_features_from_samples(
101    waveform: &[f32],
102    sample_rate: u32,
103    config: &W2vBertFrontendConfig,
104) -> Result<FeatureMatrix> {
105    let mut waveform = waveform.to_vec();
106    compute_w2v_bert_features(&mut waveform, sample_rate, config)
107}
108
109pub(crate) fn compute_features(
110    mut waveform: Vec<f32>,
111    sample_rate: u32,
112    frontend: &FrontendConfig,
113) -> Result<FeatureMatrix> {
114    match frontend {
115        FrontendConfig::Audio(config) => {
116            compute_audio_featurizer_features(&mut waveform, sample_rate, config)
117        }
118        FrontendConfig::W2vBert(config) => {
119            compute_w2v_bert_features(&mut waveform, sample_rate, config)
120        }
121    }
122}
123
124impl FrontendConfig {
125    pub(crate) fn feature_dim(&self) -> usize {
126        match self {
127            Self::Audio(config) => config.n_mels,
128            Self::W2vBert(config) => config.feature_dim,
129        }
130    }
131
132    pub(crate) fn sample_rate(&self) -> u32 {
133        match self {
134            Self::Audio(config) => config.sample_rate,
135            Self::W2vBert(config) => config.sample_rate,
136        }
137    }
138
139    pub(crate) fn config_repr(&self) -> String {
140        match self {
141            Self::Audio(config) => format!(
142                "{{'featurizer': {{'sample_rate': {}, 'n_fft': {}, 'win_length': {}, 'n_mels': {}, 'backend': 'torchaudio', 'preemphasis': {}, 'normalize_signal': {}, 'normalize_feature': {}, 'normalize_per_frame': {}, 'hop_length': {}}}}}",
143                config.sample_rate,
144                config.n_fft,
145                config.win_length,
146                config.n_mels,
147                py_float(config.preemphasis),
148                py_bool(config.normalize_signal),
149                py_bool(config.normalize_feature),
150                py_bool(config.normalize_per_frame),
151                config.hop_length,
152            ),
153            Self::W2vBert(config) => format!(
154                "{{'featurizer': {{'type': 'w2v_bert', 'model_source': '{}', 'sample_rate': {}, 'feature_size': {}, 'stride': {}, 'feature_dim': {}, 'padding_value': {}}}}}",
155                config.model_source.replace('\'', "\\'"),
156                config.sample_rate,
157                config.feature_size,
158                config.stride,
159                config.feature_dim,
160                py_float(config.padding_value),
161            ),
162        }
163    }
164
165    pub(crate) fn frontend_hash(&self) -> String {
166        let digest = Sha256::digest(self.config_repr().as_bytes());
167        hex_prefix(&digest, 12)
168    }
169}