asr-features 0.1.1

ASR feature extraction and cache warming utilities
Documentation
use numpy::ndarray::Array2;
use numpy::{IntoPyArray, PyArray2, PyReadonlyArrayDyn};
use pyo3::exceptions::PyValueError;
use pyo3::prelude::*;
use pyo3::wrap_pyfunction;

use crate::feature_loader::RustParquetFeatureCacheReader;
use crate::{
    extract_audio_features_from_samples, extract_w2v_bert_features_from_samples,
    squeezeformer_frontend_config, w2v_bert_frontend_config, zipformer_frontend_config,
    FeatureMatrix,
};

fn py_error(error: anyhow::Error) -> PyErr {
    PyValueError::new_err(error.to_string())
}

fn mono_waveform(waveform: PyReadonlyArrayDyn<'_, f32>) -> PyResult<Vec<f32>> {
    let array = waveform.as_array();
    match array.ndim() {
        1 => Ok(array.iter().copied().collect()),
        2 => {
            let shape = array.shape();
            let channels = shape[0];
            let frames = shape[1];
            if channels == 0 || frames == 0 {
                return Ok(Vec::new());
            }
            let mut mono = Vec::with_capacity(frames);
            for frame in 0..frames {
                let mut sum = 0.0f32;
                for channel in 0..channels {
                    sum += *array
                        .get([channel, frame])
                        .ok_or_else(|| PyValueError::new_err("invalid waveform shape"))?;
                }
                mono.push(sum / channels as f32);
            }
            Ok(mono)
        }
        _ => Err(PyValueError::new_err(format!(
            "expected waveform with shape [time] or [channels, time], got {:?}",
            array.shape()
        ))),
    }
}

fn feature_matrix_to_pyarray<'py>(
    py: Python<'py>,
    features: FeatureMatrix,
) -> PyResult<Bound<'py, PyArray2<f32>>> {
    let array = Array2::from_shape_vec((features.rows, features.cols), features.values)
        .map_err(|error| PyValueError::new_err(error.to_string()))?;
    Ok(array.into_pyarray(py))
}

#[pyfunction]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
    waveform,
    sample_rate,
    *,
    target_sample_rate=16_000,
    n_fft=400,
    win_length=None,
    hop_length=160,
    n_mels=80,
    preemphasis=0.97,
    normalize_signal=true,
    normalize_feature=true,
    normalize_per_frame=false
))]
fn extract_squeezeformer<'py>(
    py: Python<'py>,
    waveform: PyReadonlyArrayDyn<'py, f32>,
    sample_rate: u32,
    target_sample_rate: u32,
    n_fft: usize,
    win_length: Option<usize>,
    hop_length: usize,
    n_mels: usize,
    preemphasis: f32,
    normalize_signal: bool,
    normalize_feature: bool,
    normalize_per_frame: bool,
) -> PyResult<Bound<'py, PyArray2<f32>>> {
    let waveform = mono_waveform(waveform)?;
    let mut config = squeezeformer_frontend_config();
    config.sample_rate = target_sample_rate;
    config.n_fft = n_fft;
    config.win_length = win_length.unwrap_or(n_fft);
    config.hop_length = hop_length;
    config.n_mels = n_mels;
    config.preemphasis = preemphasis;
    config.normalize_signal = normalize_signal;
    config.normalize_feature = normalize_feature;
    config.normalize_per_frame = normalize_per_frame;
    let features =
        extract_audio_features_from_samples(&waveform, sample_rate, &config).map_err(py_error)?;
    feature_matrix_to_pyarray(py, features)
}

#[pyfunction]
#[allow(clippy::too_many_arguments)]
#[pyo3(signature = (
    waveform,
    sample_rate,
    *,
    target_sample_rate=16_000,
    n_fft=400,
    win_length=None,
    hop_length=160,
    n_mels=80,
    preemphasis=0.0,
    normalize_signal=false,
    normalize_feature=false,
    normalize_per_frame=false
))]
fn extract_zipformer<'py>(
    py: Python<'py>,
    waveform: PyReadonlyArrayDyn<'py, f32>,
    sample_rate: u32,
    target_sample_rate: u32,
    n_fft: usize,
    win_length: Option<usize>,
    hop_length: usize,
    n_mels: usize,
    preemphasis: f32,
    normalize_signal: bool,
    normalize_feature: bool,
    normalize_per_frame: bool,
) -> PyResult<Bound<'py, PyArray2<f32>>> {
    let waveform = mono_waveform(waveform)?;
    let mut config = zipformer_frontend_config();
    config.sample_rate = target_sample_rate;
    config.n_fft = n_fft;
    config.win_length = win_length.unwrap_or(n_fft);
    config.hop_length = hop_length;
    config.n_mels = n_mels;
    config.preemphasis = preemphasis;
    config.normalize_signal = normalize_signal;
    config.normalize_feature = normalize_feature;
    config.normalize_per_frame = normalize_per_frame;
    let features =
        extract_audio_features_from_samples(&waveform, sample_rate, &config).map_err(py_error)?;
    feature_matrix_to_pyarray(py, features)
}

#[pyfunction]
#[pyo3(signature = (
    waveform,
    sample_rate,
    *,
    target_sample_rate=16_000,
    feature_size=80,
    stride=2,
    padding_value=1.0
))]
fn extract_w2v_bert<'py>(
    py: Python<'py>,
    waveform: PyReadonlyArrayDyn<'py, f32>,
    sample_rate: u32,
    target_sample_rate: u32,
    feature_size: usize,
    stride: usize,
    padding_value: f32,
) -> PyResult<Bound<'py, PyArray2<f32>>> {
    let waveform = mono_waveform(waveform)?;
    let config = w2v_bert_frontend_config(
        None,
        Some(target_sample_rate),
        Some(feature_size),
        Some(stride),
        Some(feature_size * stride.max(1)),
        Some(padding_value),
    );
    let features = extract_w2v_bert_features_from_samples(&waveform, sample_rate, &config)
        .map_err(py_error)?;
    feature_matrix_to_pyarray(py, features)
}

#[pymodule]
fn asr_features(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_function(wrap_pyfunction!(extract_squeezeformer, m)?)?;
    m.add_function(wrap_pyfunction!(extract_zipformer, m)?)?;
    m.add_function(wrap_pyfunction!(extract_w2v_bert, m)?)?;
    m.add_class::<RustParquetFeatureCacheReader>()?;
    Ok(())
}