use anyhow::{Context, Result};
use clap::Parser;
use polyvoice::der::compute_der;
use polyvoice::models::ModelRegistry;
use polyvoice::pipeline::Pipeline;
use polyvoice::rttm::{group_by_file, parse_rttm_file, to_speaker_turns};
use polyvoice::types::{ClusterConfig, DiarizationConfig, Profile, SampleRate};
use polyvoice::vad::VadConfig;
use polyvoice::wav::read_wav;
use polyvoice::{FbankOnnxExtractor, SileroVad};
use serde::Serialize;
use std::path::PathBuf;
use std::time::Instant;
#[derive(Parser, Debug)]
#[command(name = "polyvoice-bench", about = "Run DER on a {audio,rttm} dataset")]
struct Args {
dataset: PathBuf,
#[arg(long, default_value = "balanced")]
profile: String,
#[arg(long)]
output: Option<PathBuf>,
#[arg(long, default_value = "0.25")]
collar: f64,
#[arg(long, default_value = "false")]
skip_overlap: bool,
#[arg(long)]
max_files: Option<usize>,
#[arg(long, default_value = "0.45")]
threshold: f32,
}
#[derive(Serialize)]
struct BenchReport {
schema: &'static str,
profile: String,
files: usize,
der_collar_0_25_skip_overlap: f64,
der_no_collar: f64,
miss: f64,
false_alarm: f64,
confusion: f64,
rt_factor_avg: f64,
polyvoice_version: &'static str,
}
fn parse_profile(name: &str) -> Result<Profile> {
match name {
"mobile" => Ok(Profile::Mobile),
"balanced" => Ok(Profile::Balanced),
other => anyhow::bail!("invalid profile: {other}"),
}
}
fn main() -> Result<()> {
let args = Args::parse();
let profile = parse_profile(&args.profile)?;
let registry = ModelRegistry::default().context("registry")?;
let models = registry
.ensure_for_profile(profile)
.context("ensure models")?;
let embedding_dim = profile.embedding_dim();
let extractor = FbankOnnxExtractor::new(&models.embedder_path, embedding_dim, 1)
.context("load embedder")?;
let mut vad = SileroVad::new(&models.segmenter_path, 512).context("load vad")?;
let config = DiarizationConfig {
cluster: ClusterConfig {
threshold: args.threshold,
..Default::default()
},
..DiarizationConfig::default()
};
let vad_config = VadConfig::default();
let pipeline = Pipeline::new(config, vad_config);
let audio_dir = args.dataset.join("audio");
let rttm_dir = args.dataset.join("rttm");
let mut wavs: Vec<PathBuf> = std::fs::read_dir(&audio_dir)
.with_context(|| format!("read_dir {}", audio_dir.display()))?
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|x| x == "wav"))
.map(|e| e.path())
.collect();
wavs.sort();
if let Some(n) = args.max_files {
wavs.truncate(n);
}
let mut totals = Aggregate::default();
let mut total_audio_secs = 0.0_f64;
let mut total_runtime_secs = 0.0_f64;
for wav in &wavs {
let stem = wav.file_stem().and_then(|s| s.to_str()).unwrap_or("");
let rttm = rttm_dir.join(format!("{stem}.rttm"));
if !rttm.is_file() {
eprintln!("[SKIP] {stem}: no rttm");
continue;
}
let (samples, sr_hz) = read_wav(wav)?;
let _sr = SampleRate::new(sr_hz)
.ok_or_else(|| anyhow::anyhow!("invalid sample rate: {sr_hz}"))?;
let audio_secs = samples.len() as f64 / sr_hz as f64;
let t0 = Instant::now();
let result = pipeline.run(&samples, &extractor, &mut vad)?;
let runtime_secs = t0.elapsed().as_secs_f64();
let ref_turns = {
let raw = parse_rttm_file(&rttm).context("parse rttm")?;
let grouped = group_by_file(&raw);
let segs: Vec<_> = grouped
.get(stem)
.map(|v| v.iter().map(|s| (*s).clone()).collect())
.unwrap_or_default();
let (turns, _map) = to_speaker_turns(&segs);
turns
};
let der = compute_der(&ref_turns, &result.turns, args.collar);
totals.der_total += der.der;
totals.miss += der.miss_rate;
totals.false_alarm += der.false_alarm_rate;
totals.confusion += der.confusion_rate;
totals.count += 1;
total_audio_secs += audio_secs;
total_runtime_secs += runtime_secs;
println!(
"{stem}\t DER={:.3}%\t miss={:.3}%\t fa={:.3}%\t conf={:.3}%\t rt={:.1}x\t spk={}\t turns={}",
der.der * 100.0,
der.miss_rate * 100.0,
der.false_alarm_rate * 100.0,
der.confusion_rate * 100.0,
audio_secs / runtime_secs.max(1e-6),
result.num_speakers,
result.turns.len(),
);
}
let n = totals.count.max(1) as f64;
let report = BenchReport {
schema: "polyvoice-bench-v0.5",
profile: args.profile.clone(),
files: totals.count,
der_collar_0_25_skip_overlap: (totals.der_total / n) * 100.0,
der_no_collar: 0.0, miss: (totals.miss / n) * 100.0,
false_alarm: (totals.false_alarm / n) * 100.0,
confusion: (totals.confusion / n) * 100.0,
rt_factor_avg: total_audio_secs / total_runtime_secs.max(1e-6),
polyvoice_version: env!("CARGO_PKG_VERSION"),
};
let json = serde_json::to_string_pretty(&report)?;
match args.output {
Some(p) => std::fs::write(&p, json)?,
None => println!("{json}"),
}
Ok(())
}
#[derive(Default)]
struct Aggregate {
der_total: f64,
miss: f64,
false_alarm: f64,
confusion: f64,
count: usize,
}