use anyhow::{Context, Result};
use clap::Parser;
use polyvoice::der::compute_der;
use polyvoice::models::ModelRegistry;
use polyvoice::pipeline::Pipeline;
use polyvoice::rttm::{group_by_file, parse_rttm_file, to_speaker_turns};
use polyvoice::types::{ClusterConfig, DiarizationConfig, Profile, SampleRate};
use polyvoice::vad::VadConfig;
use polyvoice::wav::read_wav;
use polyvoice::{FbankOnnxExtractor, SileroVad};
use serde::Serialize;
use std::collections::HashSet;
use std::path::PathBuf;
use std::time::Instant;
#[derive(Parser, Debug)]
#[command(name = "polyvoice-bench", about = "Run DER on a {audio,rttm} dataset")]
struct Args {
dataset: PathBuf,
#[arg(long, default_value = "balanced")]
profile: String,
#[arg(long)]
output: Option<PathBuf>,
#[arg(long, default_value = "0.25")]
collar: f64,
#[arg(long, default_value = "false")]
skip_overlap: bool,
#[arg(long)]
max_files: Option<usize>,
#[arg(long, default_value = "0.45")]
threshold: f32,
}
#[derive(Serialize)]
struct ModelHash {
model_id: String,
sha256: String,
}
#[derive(Serialize)]
struct PerFileResult {
filename: String,
der_collar: f64,
der_no_collar: f64,
miss_rate: f64,
false_alarm_rate: f64,
confusion_rate: f64,
rt_factor: f64,
ref_speakers: usize,
hyp_speakers: usize,
num_turns: usize,
audio_duration_secs: f64,
runtime_secs: f64,
}
#[derive(Serialize)]
struct SpeakerCountDiagnostics {
exact: usize,
plus_minus_1: usize,
off_by_2_or_more: usize,
}
#[derive(Serialize)]
struct BenchReport {
schema: &'static str,
crate_version: &'static str,
git_sha: String,
host_arch: String,
host_os: String,
command_line: String,
dataset_name: String,
profile: String,
files_processed: usize,
files_skipped: usize,
der_collar: f64,
der_no_collar: f64,
miss: f64,
false_alarm: f64,
confusion: f64,
rt_factor_avg: f64,
speaker_count: SpeakerCountDiagnostics,
model_hashes: Vec<ModelHash>,
per_file: Vec<PerFileResult>,
}
fn parse_profile(name: &str) -> Result<Profile> {
match name {
"mobile" => Ok(Profile::Mobile),
"balanced" => Ok(Profile::Balanced),
other => anyhow::bail!("invalid profile: {other}"),
}
}
fn git_sha() -> String {
std::process::Command::new("git")
.args(["rev-parse", "HEAD"])
.output()
.ok()
.and_then(|o| {
if o.status.success() {
String::from_utf8(o.stdout).ok()
} else {
None
}
})
.map(|s| s.trim().to_owned())
.unwrap_or_else(|| "unknown".to_owned())
}
fn model_hashes(registry: &ModelRegistry, profile: Profile) -> Vec<ModelHash> {
let mut out = Vec::new();
let manifest = registry.manifest();
let prof = match manifest.profile(profile.manifest_id()) {
Some(p) => p,
None => return out,
};
for model_id in [&prof.segmenter, &prof.embedder] {
if let Some(entry) = manifest.model(model_id) {
out.push(ModelHash {
model_id: model_id.clone(),
sha256: entry.sha256.clone(),
});
}
}
out
}
fn main() -> Result<()> {
let args = Args::parse();
let profile = parse_profile(&args.profile)?;
let registry = ModelRegistry::default().context("registry")?;
let models = registry
.ensure_for_profile(profile)
.context("ensure models")?;
let embedding_dim = profile.embedding_dim();
let extractor = FbankOnnxExtractor::new(&models.embedder_path, embedding_dim, 1)
.context("load embedder")?;
let mut vad = SileroVad::new(&models.segmenter_path, 512).context("load vad")?;
let config = DiarizationConfig {
cluster: ClusterConfig {
threshold: args.threshold,
..Default::default()
},
..DiarizationConfig::default()
};
let vad_config = VadConfig::default();
let pipeline = Pipeline::new(config, vad_config);
let audio_dir = args.dataset.join("audio");
let rttm_dir = args.dataset.join("rttm");
let mut wavs: Vec<PathBuf> = std::fs::read_dir(&audio_dir)
.with_context(|| format!("read_dir {}", audio_dir.display()))?
.filter_map(|e| e.ok())
.filter(|e| e.path().extension().is_some_and(|x| x == "wav"))
.map(|e| e.path())
.collect();
wavs.sort();
if let Some(n) = args.max_files {
wavs.truncate(n);
}
let dataset_name = args
.dataset
.file_name()
.and_then(|s| s.to_str())
.unwrap_or("unknown")
.to_owned();
let mut totals = Aggregate::default();
let mut total_audio_secs = 0.0_f64;
let mut total_runtime_secs = 0.0_f64;
let mut speaker_count_exact = 0_usize;
let mut speaker_count_pm1 = 0_usize;
let mut speaker_count_off = 0_usize;
let mut files_skipped = 0_usize;
let mut per_file = Vec::with_capacity(wavs.len());
for wav in &wavs {
let stem = wav.file_stem().and_then(|s| s.to_str()).unwrap_or("");
let rttm = rttm_dir.join(format!("{stem}.rttm"));
if !rttm.is_file() {
eprintln!("[SKIP] {stem}: no rttm");
files_skipped += 1;
continue;
}
let (samples, sr_hz) = read_wav(wav)?;
let _sr = SampleRate::new(sr_hz)
.ok_or_else(|| anyhow::anyhow!("invalid sample rate: {sr_hz}"))?;
let audio_secs = samples.len() as f64 / sr_hz as f64;
let t0 = Instant::now();
let result = pipeline.run(&samples, &extractor, &mut vad)?;
let runtime_secs = t0.elapsed().as_secs_f64();
let ref_turns = {
let raw = parse_rttm_file(&rttm).context("parse rttm")?;
let grouped = group_by_file(&raw);
let segs: Vec<_> = grouped
.get(stem)
.or_else(|| stem.split('.').next().and_then(|s| grouped.get(s)))
.map(|v| v.iter().map(|s| (*s).clone()).collect())
.unwrap_or_default();
let (turns, _map) = to_speaker_turns(&segs);
turns
};
let der = compute_der(&ref_turns, &result.turns, args.collar);
let der_no_collar = compute_der(&ref_turns, &result.turns, 0.0);
let ref_speakers: HashSet<_> = ref_turns.iter().map(|t| t.speaker.0).collect();
let hyp_speakers: HashSet<_> = result.turns.iter().map(|t| t.speaker.0).collect();
let ref_count = ref_speakers.len();
let hyp_count = hyp_speakers.len();
let diff = ref_count.abs_diff(hyp_count);
match diff {
0 => speaker_count_exact += 1,
1 => speaker_count_pm1 += 1,
_ => speaker_count_off += 1,
}
totals.der_total += der.der;
totals.der_no_collar_total += der_no_collar.der;
totals.miss += der.miss_rate;
totals.false_alarm += der.false_alarm_rate;
totals.confusion += der.confusion_rate;
totals.count += 1;
total_audio_secs += audio_secs;
total_runtime_secs += runtime_secs;
let rt_factor = audio_secs / runtime_secs.max(1e-6);
println!(
"{stem}\t DER={:.3}%\t miss={:.3}%\t fa={:.3}%\t conf={:.3}%\t rt={:.1}x\t spk={}\t turns={}",
der.der * 100.0,
der.miss_rate * 100.0,
der.false_alarm_rate * 100.0,
der.confusion_rate * 100.0,
rt_factor,
result.num_speakers,
result.turns.len(),
);
per_file.push(PerFileResult {
filename: stem.to_owned(),
der_collar: der.der * 100.0,
der_no_collar: der_no_collar.der * 100.0,
miss_rate: der.miss_rate * 100.0,
false_alarm_rate: der.false_alarm_rate * 100.0,
confusion_rate: der.confusion_rate * 100.0,
rt_factor,
ref_speakers: ref_count,
hyp_speakers: hyp_count,
num_turns: result.turns.len(),
audio_duration_secs: audio_secs,
runtime_secs,
});
}
let n = totals.count.max(1) as f64;
let report = BenchReport {
schema: "polyvoice-bench-v0.6",
crate_version: env!("CARGO_PKG_VERSION"),
git_sha: git_sha(),
host_arch: std::env::consts::ARCH.to_owned(),
host_os: std::env::consts::OS.to_owned(),
command_line: std::env::args().collect::<Vec<_>>().join(" "),
dataset_name,
profile: args.profile.clone(),
files_processed: totals.count,
files_skipped,
der_collar: (totals.der_total / n) * 100.0,
der_no_collar: (totals.der_no_collar_total / n) * 100.0,
miss: (totals.miss / n) * 100.0,
false_alarm: (totals.false_alarm / n) * 100.0,
confusion: (totals.confusion / n) * 100.0,
rt_factor_avg: total_audio_secs / total_runtime_secs.max(1e-6),
speaker_count: SpeakerCountDiagnostics {
exact: speaker_count_exact,
plus_minus_1: speaker_count_pm1,
off_by_2_or_more: speaker_count_off,
},
model_hashes: model_hashes(®istry, profile),
per_file,
};
let json = serde_json::to_string_pretty(&report)?;
match args.output {
Some(p) => std::fs::write(&p, json)?,
None => println!("{json}"),
}
Ok(())
}
#[derive(Default)]
struct Aggregate {
der_total: f64,
der_no_collar_total: f64,
miss: f64,
false_alarm: f64,
confusion: f64,
count: usize,
}
#[cfg(test)]
mod prop_tests {
use super::*;
use proptest::prelude::*;
proptest! {
#[test]
fn bench_args_parses_with_valid_args(
profile in "(mobile|balanced)",
collar in 0.0f64..1.0f64,
threshold in 0.0f32..1.0f32,
max_files in 0usize..100usize,
) {
let args = vec![
"polyvoice-bench".to_string(),
"/tmp/dataset".to_string(),
"--profile".to_string(), profile,
"--collar".to_string(), collar.to_string(),
"--threshold".to_string(), threshold.to_string(),
"--max-files".to_string(), max_files.to_string(),
];
let result = Args::try_parse_from(&args);
prop_assert!(result.is_ok());
}
#[test]
fn parse_profile_accepts_only_valid(s in "[a-zA-Z0-9_-]{1,20}") {
let result = parse_profile(&s);
if s == "mobile" || s == "balanced" {
prop_assert!(result.is_ok());
} else {
prop_assert!(result.is_err());
}
}
}
}