speakrs 0.4.0

Fast Rust speaker diarization with pyannote-level accuracy and native CoreML/CUDA acceleration
use std::error::Error;
use std::ffi::OsStr;
use std::fs;
use std::path::Path;

use speakrs::inference::{EmbeddingModel, SegmentationModel};
use speakrs::pipeline::DiarizationPipeline;

pub type ExampleResult<T> = Result<T, Box<dyn Error + Send + Sync>>;

#[allow(dead_code)]
pub fn init_tracing() {
    let _ = tracing_subscriber::fmt()
        .with_env_filter(
            tracing_subscriber::EnvFilter::try_from_default_env()
                .unwrap_or_else(|_| "speakrs=info".parse().unwrap()),
        )
        .with_writer(std::io::stderr)
        .try_init();
}

#[allow(dead_code)]
pub fn file_id_from_path(path: &Path) -> String {
    path.file_stem()
        .and_then(OsStr::to_str)
        .unwrap_or("file1")
        .to_owned()
}

pub fn load_models(path: &Path) -> ExampleResult<(SegmentationModel, EmbeddingModel)> {
    let segmentation = SegmentationModel::new(
        path.join("segmentation-3.0.onnx").to_str().ok_or_else(|| {
            format!(
                "segmentation model path is not valid UTF-8: {}",
                path.join("segmentation-3.0.onnx").display()
            )
        })?,
        DiarizationPipeline::default_segmentation_step(),
    )?;
    let embedding = EmbeddingModel::new(
        path.join("wespeaker-voxceleb-resnet34.onnx")
            .to_str()
            .ok_or_else(|| {
                format!(
                    "embedding model path is not valid UTF-8: {}",
                    path.join("wespeaker-voxceleb-resnet34.onnx").display()
                )
            })?,
    )?;
    Ok((segmentation, embedding))
}

pub fn load_wav_samples(path: &Path) -> ExampleResult<Vec<f32>> {
    let data = fs::read(path)?;

    let channels = u16::from_le_bytes(data[22..24].try_into()?);
    let sample_rate = u32::from_le_bytes(data[24..28].try_into()?);
    let bits_per_sample = u16::from_le_bytes(data[34..36].try_into()?);

    if channels != 1 {
        return Err(format!("expected mono WAV, got {channels} channels").into());
    }
    if sample_rate != 16_000 {
        return Err(format!("expected 16kHz WAV, got {sample_rate}Hz").into());
    }
    if bits_per_sample != 16 {
        return Err(format!("expected 16-bit PCM WAV, got {bits_per_sample}-bit").into());
    }

    let mut pos = 12usize;
    while pos + 8 < data.len() {
        let chunk_id = &data[pos..pos + 4];
        let chunk_size = u32::from_le_bytes(data[pos + 4..pos + 8].try_into()?) as usize;

        if chunk_id == b"data" {
            let samples = data[pos + 8..pos + 8 + chunk_size]
                .chunks_exact(2)
                .map(|bytes| i16::from_le_bytes([bytes[0], bytes[1]]) as f32 / 32768.0)
                .collect();
            return Ok(samples);
        }

        pos += 8 + chunk_size;
    }

    Err("no data chunk found in WAV".into())
}