1use anyhow::Result;
2use sha2::{Digest, Sha256};
3
4use crate::dsp::{compute_audio_featurizer_features, compute_w2v_bert_features};
5use crate::util::{hex_prefix, py_bool, py_float};
6
7#[derive(Debug, Clone)]
8pub struct AudioFrontendConfig {
9 pub sample_rate: u32,
10 pub n_fft: usize,
11 pub win_length: usize,
12 pub hop_length: usize,
13 pub n_mels: usize,
14 pub preemphasis: f32,
15 pub normalize_signal: bool,
16 pub normalize_feature: bool,
17 pub normalize_per_frame: bool,
18}
19
20#[derive(Debug, Clone)]
21pub struct W2vBertFrontendConfig {
22 pub model_source: String,
23 pub sample_rate: u32,
24 pub feature_size: usize,
25 pub stride: usize,
26 pub feature_dim: usize,
27 pub padding_value: f32,
28}
29
30#[derive(Debug, Clone)]
31pub(crate) enum FrontendConfig {
32 Audio(AudioFrontendConfig),
33 W2vBert(W2vBertFrontendConfig),
34}
35
36#[derive(Debug)]
37pub struct FeatureMatrix {
38 pub rows: usize,
39 pub cols: usize,
40 pub values: Vec<f32>,
41}
42
43pub fn squeezeformer_frontend_config() -> AudioFrontendConfig {
44 AudioFrontendConfig {
45 sample_rate: 16_000,
46 n_fft: 400,
47 win_length: 400,
48 hop_length: 160,
49 n_mels: 80,
50 preemphasis: 0.97,
51 normalize_signal: true,
52 normalize_feature: true,
53 normalize_per_frame: false,
54 }
55}
56
57pub fn zipformer_frontend_config() -> AudioFrontendConfig {
58 AudioFrontendConfig {
59 sample_rate: 16_000,
60 n_fft: 400,
61 win_length: 400,
62 hop_length: 160,
63 n_mels: 80,
64 preemphasis: 0.0,
65 normalize_signal: false,
66 normalize_feature: false,
67 normalize_per_frame: false,
68 }
69}
70
71pub fn w2v_bert_frontend_config(
72 model_source: Option<String>,
73 sample_rate: Option<u32>,
74 feature_size: Option<usize>,
75 stride: Option<usize>,
76 feature_dim: Option<usize>,
77 padding_value: Option<f32>,
78) -> W2vBertFrontendConfig {
79 let feature_size = feature_size.unwrap_or(80);
80 let stride = stride.unwrap_or(2).max(1);
81 W2vBertFrontendConfig {
82 model_source: model_source.unwrap_or_else(|| "facebook/w2v-bert-2.0".to_string()),
83 sample_rate: sample_rate.unwrap_or(16_000),
84 feature_size,
85 stride,
86 feature_dim: feature_dim.unwrap_or(feature_size * stride),
87 padding_value: padding_value.unwrap_or(1.0),
88 }
89}
90
91pub fn extract_audio_features_from_samples(
92 waveform: &[f32],
93 sample_rate: u32,
94 config: &AudioFrontendConfig,
95) -> Result<FeatureMatrix> {
96 let mut waveform = waveform.to_vec();
97 compute_audio_featurizer_features(&mut waveform, sample_rate, config)
98}
99
100pub fn extract_w2v_bert_features_from_samples(
101 waveform: &[f32],
102 sample_rate: u32,
103 config: &W2vBertFrontendConfig,
104) -> Result<FeatureMatrix> {
105 let mut waveform = waveform.to_vec();
106 compute_w2v_bert_features(&mut waveform, sample_rate, config)
107}
108
109pub(crate) fn compute_features(
110 mut waveform: Vec<f32>,
111 sample_rate: u32,
112 frontend: &FrontendConfig,
113) -> Result<FeatureMatrix> {
114 match frontend {
115 FrontendConfig::Audio(config) => {
116 compute_audio_featurizer_features(&mut waveform, sample_rate, config)
117 }
118 FrontendConfig::W2vBert(config) => {
119 compute_w2v_bert_features(&mut waveform, sample_rate, config)
120 }
121 }
122}
123
124impl FrontendConfig {
125 pub(crate) fn feature_dim(&self) -> usize {
126 match self {
127 Self::Audio(config) => config.n_mels,
128 Self::W2vBert(config) => config.feature_dim,
129 }
130 }
131
132 pub(crate) fn sample_rate(&self) -> u32 {
133 match self {
134 Self::Audio(config) => config.sample_rate,
135 Self::W2vBert(config) => config.sample_rate,
136 }
137 }
138
139 pub(crate) fn config_repr(&self) -> String {
140 match self {
141 Self::Audio(config) => format!(
142 "{{'featurizer': {{'sample_rate': {}, 'n_fft': {}, 'win_length': {}, 'n_mels': {}, 'backend': 'torchaudio', 'preemphasis': {}, 'normalize_signal': {}, 'normalize_feature': {}, 'normalize_per_frame': {}, 'hop_length': {}}}}}",
143 config.sample_rate,
144 config.n_fft,
145 config.win_length,
146 config.n_mels,
147 py_float(config.preemphasis),
148 py_bool(config.normalize_signal),
149 py_bool(config.normalize_feature),
150 py_bool(config.normalize_per_frame),
151 config.hop_length,
152 ),
153 Self::W2vBert(config) => format!(
154 "{{'featurizer': {{'type': 'w2v_bert', 'model_source': '{}', 'sample_rate': {}, 'feature_size': {}, 'stride': {}, 'feature_dim': {}, 'padding_value': {}}}}}",
155 config.model_source.replace('\'', "\\'"),
156 config.sample_rate,
157 config.feature_size,
158 config.stride,
159 config.feature_dim,
160 py_float(config.padding_value),
161 ),
162 }
163 }
164
165 pub(crate) fn frontend_hash(&self) -> String {
166 let digest = Sha256::digest(self.config_repr().as_bytes());
167 hex_prefix(&digest, 12)
168 }
169}