asr-features 0.1.0

ASR feature extraction and cache warming utilities
Documentation
use anyhow::Result;
use sha2::{Digest, Sha256};

use crate::dsp::{compute_audio_featurizer_features, compute_w2v_bert_features};
use crate::util::{hex_prefix, py_bool, py_float};

#[derive(Debug, Clone)]
pub struct AudioFrontendConfig {
    pub sample_rate: u32,
    pub n_fft: usize,
    pub win_length: usize,
    pub hop_length: usize,
    pub n_mels: usize,
    pub preemphasis: f32,
    pub normalize_signal: bool,
    pub normalize_feature: bool,
    pub normalize_per_frame: bool,
}

#[derive(Debug, Clone)]
pub struct W2vBertFrontendConfig {
    pub model_source: String,
    pub sample_rate: u32,
    pub feature_size: usize,
    pub stride: usize,
    pub feature_dim: usize,
    pub padding_value: f32,
}

#[derive(Debug, Clone)]
pub(crate) enum FrontendConfig {
    Audio(AudioFrontendConfig),
    W2vBert(W2vBertFrontendConfig),
}

#[derive(Debug)]
pub struct FeatureMatrix {
    pub rows: usize,
    pub cols: usize,
    pub values: Vec<f32>,
}

pub fn squeezeformer_frontend_config() -> AudioFrontendConfig {
    AudioFrontendConfig {
        sample_rate: 16_000,
        n_fft: 400,
        win_length: 400,
        hop_length: 160,
        n_mels: 80,
        preemphasis: 0.97,
        normalize_signal: true,
        normalize_feature: true,
        normalize_per_frame: false,
    }
}

pub fn zipformer_frontend_config() -> AudioFrontendConfig {
    AudioFrontendConfig {
        sample_rate: 16_000,
        n_fft: 400,
        win_length: 400,
        hop_length: 160,
        n_mels: 80,
        preemphasis: 0.0,
        normalize_signal: false,
        normalize_feature: false,
        normalize_per_frame: false,
    }
}

pub fn w2v_bert_frontend_config(
    model_source: Option<String>,
    sample_rate: Option<u32>,
    feature_size: Option<usize>,
    stride: Option<usize>,
    feature_dim: Option<usize>,
    padding_value: Option<f32>,
) -> W2vBertFrontendConfig {
    let feature_size = feature_size.unwrap_or(80);
    let stride = stride.unwrap_or(2).max(1);
    W2vBertFrontendConfig {
        model_source: model_source.unwrap_or_else(|| "facebook/w2v-bert-2.0".to_string()),
        sample_rate: sample_rate.unwrap_or(16_000),
        feature_size,
        stride,
        feature_dim: feature_dim.unwrap_or(feature_size * stride),
        padding_value: padding_value.unwrap_or(1.0),
    }
}

pub fn extract_audio_features_from_samples(
    waveform: &[f32],
    sample_rate: u32,
    config: &AudioFrontendConfig,
) -> Result<FeatureMatrix> {
    let mut waveform = waveform.to_vec();
    compute_audio_featurizer_features(&mut waveform, sample_rate, config)
}

pub fn extract_w2v_bert_features_from_samples(
    waveform: &[f32],
    sample_rate: u32,
    config: &W2vBertFrontendConfig,
) -> Result<FeatureMatrix> {
    let mut waveform = waveform.to_vec();
    compute_w2v_bert_features(&mut waveform, sample_rate, config)
}

pub(crate) fn compute_features(
    mut waveform: Vec<f32>,
    sample_rate: u32,
    frontend: &FrontendConfig,
) -> Result<FeatureMatrix> {
    match frontend {
        FrontendConfig::Audio(config) => {
            compute_audio_featurizer_features(&mut waveform, sample_rate, config)
        }
        FrontendConfig::W2vBert(config) => {
            compute_w2v_bert_features(&mut waveform, sample_rate, config)
        }
    }
}

impl FrontendConfig {
    pub(crate) fn feature_dim(&self) -> usize {
        match self {
            Self::Audio(config) => config.n_mels,
            Self::W2vBert(config) => config.feature_dim,
        }
    }

    pub(crate) fn sample_rate(&self) -> u32 {
        match self {
            Self::Audio(config) => config.sample_rate,
            Self::W2vBert(config) => config.sample_rate,
        }
    }

    pub(crate) fn config_repr(&self) -> String {
        match self {
            Self::Audio(config) => format!(
                "{{'featurizer': {{'sample_rate': {}, 'n_fft': {}, 'win_length': {}, 'n_mels': {}, 'backend': 'torchaudio', 'preemphasis': {}, 'normalize_signal': {}, 'normalize_feature': {}, 'normalize_per_frame': {}, 'hop_length': {}}}}}",
                config.sample_rate,
                config.n_fft,
                config.win_length,
                config.n_mels,
                py_float(config.preemphasis),
                py_bool(config.normalize_signal),
                py_bool(config.normalize_feature),
                py_bool(config.normalize_per_frame),
                config.hop_length,
            ),
            Self::W2vBert(config) => format!(
                "{{'featurizer': {{'type': 'w2v_bert', 'model_source': '{}', 'sample_rate': {}, 'feature_size': {}, 'stride': {}, 'feature_dim': {}, 'padding_value': {}}}}}",
                config.model_source.replace('\'', "\\'"),
                config.sample_rate,
                config.feature_size,
                config.stride,
                config.feature_dim,
                py_float(config.padding_value),
            ),
        }
    }

    pub(crate) fn frontend_hash(&self) -> String {
        let digest = Sha256::digest(self.config_repr().as_bytes());
        hex_prefix(&digest, 12)
    }
}