speakrs 0.3.2

Fast Rust speaker diarization with pyannote-level accuracy and native CoreML/CUDA acceleration
mod support;

use std::fs;
use std::path::Path;

use speakrs::pipeline::{DiarizationPipeline, FRAME_DURATION_SECONDS, FRAME_STEP_SECONDS};
use speakrs::segment::Segment;

use support::{ExampleResult, load_models, load_wav_samples};

struct TranscriptRow {
    start: f64,
    end: f64,
    text: String,
}

fn main() -> ExampleResult<()> {
    support::init_tracing();
    let args: Vec<String> = std::env::args().collect();
    if args.len() != 4 {
        eprintln!(
            "Usage: cargo run --example assign_transcript_speakers -- <models-dir> <audio.wav> <transcript.tsv>"
        );
        std::process::exit(1);
    }

    let models_dir = Path::new(&args[1]);
    let audio_path = Path::new(&args[2]);
    let transcript_path = Path::new(&args[3]);

    let mut models = load_models(models_dir)?;
    let audio = load_wav_samples(audio_path)?;
    let mut pipeline = DiarizationPipeline::new(&mut models.0, &mut models.1, models_dir)?;
    let result = pipeline.run(&audio)?;

    let mut exclusive = result.discrete_diarization.clone();
    exclusive.make_exclusive();
    let segments = exclusive.to_segments(FRAME_STEP_SECONDS, FRAME_DURATION_SECONDS);
    let transcript = load_transcript(transcript_path)?;

    println!("start\tend\tspeaker\ttext");
    for row in transcript {
        let speaker = dominant_speaker(&segments, row.start, row.end)
            .unwrap_or("UNKNOWN")
            .to_owned();
        println!(
            "{:.3}\t{:.3}\t{}\t{}",
            row.start, row.end, speaker, row.text
        );
    }

    Ok(())
}

fn load_transcript(path: &Path) -> ExampleResult<Vec<TranscriptRow>> {
    let content = fs::read_to_string(path)?;
    let mut rows = Vec::new();

    for (line_idx, line) in content.lines().enumerate() {
        if line.trim().is_empty() {
            continue;
        }

        let mut fields = line.splitn(3, '\t');
        let start = fields
            .next()
            .ok_or_else(|| format!("line {} is missing start time", line_idx + 1))?
            .parse::<f64>()?;
        let end = fields
            .next()
            .ok_or_else(|| format!("line {} is missing end time", line_idx + 1))?
            .parse::<f64>()?;
        let text = fields
            .next()
            .ok_or_else(|| format!("line {} is missing transcript text", line_idx + 1))?
            .to_owned();

        rows.push(TranscriptRow { start, end, text });
    }

    Ok(rows)
}

fn dominant_speaker(segments: &[Segment], start: f64, end: f64) -> Option<&str> {
    let mut best_speaker = None;
    let mut best_overlap = 0.0f64;

    for segment in segments {
        let overlap = overlap_seconds(segment.start, segment.end, start, end);
        if overlap > best_overlap {
            best_overlap = overlap;
            best_speaker = Some(segment.speaker.as_str());
        }
    }

    best_speaker
}

fn overlap_seconds(lhs_start: f64, lhs_end: f64, rhs_start: f64, rhs_end: f64) -> f64 {
    (lhs_end.min(rhs_end) - lhs_start.max(rhs_start)).max(0.0)
}