#![allow(clippy::unwrap_used)]
#![cfg(all(
feature = "onnx",
feature = "segmentation",
feature = "embedder",
feature = "clusterer",
feature = "resegmentation",
feature = "download",
))]
use polyvoice::clusterer::{AhcClusterer, Clusterer};
use polyvoice::embedder::{Embedder, ResNet34Adapter};
use polyvoice::models::ModelRegistry;
use polyvoice::segmentation::{PowersetSegmenter, Segmenter};
use polyvoice::utils::cosine_similarity;
use polyvoice::wav::read_wav;
use std::path::Path;
#[ignore = "requires cached ONNX bundle + wav file in data/voxconverse-test/"]
#[test]
fn debug_v2_segments_on_aepyx() {
let wav_path = Path::new("data/voxconverse-test/audio/aepyx.wav");
let _rttm_path = Path::new("data/voxconverse-test/rttm/aepyx.rttm");
if !wav_path.is_file() {
println!("WAV not found — skipping");
return;
}
let (samples, sr) = read_wav(wav_path).expect("read wav");
assert_eq!(sr, 16000);
let registry = ModelRegistry::default().expect("registry");
let models = registry
.ensure_for_profile(polyvoice::types::Profile::Balanced)
.expect("models");
let segmenter = PowersetSegmenter::new(&models.segmenter_path).expect("segmenter");
let embedder = ResNet34Adapter::new(&models.embedder_path, 1).expect("embedder");
println!("=== SEGMENTATION ===");
let segments = segmenter.segment(&samples).expect("segment");
println!("Total segments: {}", segments.len());
let durations: Vec<f64> = segments.iter().map(|s| s.time.duration()).collect();
let avg_dur = durations.iter().sum::<f64>() / durations.len() as f64;
let min_dur = durations.iter().copied().fold(f64::INFINITY, f64::min);
let max_dur = durations.iter().copied().fold(f64::NEG_INFINITY, f64::max);
println!(
"Duration: avg={:.2}s min={:.2}s max={:.2}s",
avg_dur, min_dur, max_dur
);
let mut short = 0;
let mut medium = 0;
let mut long = 0;
for d in &durations {
if *d < 0.5 {
short += 1;
} else if *d < 2.0 {
medium += 1;
} else {
long += 1;
}
}
println!(
"Distribution: <0.5s={} 0.5-2s={} >2s={}",
short, medium, long
);
let primary: Vec<_> = segments.iter().filter(|s| !s.is_overlap).cloned().collect();
println!("Primary (non-overlap) segments: {}", primary.len());
for (max_gap, min_dur) in [(0.5, 0.0), (2.0, 0.0), (2.0, 0.5)] {
let mut aggregated: Vec<polyvoice::segmentation::RawSegment> = Vec::new();
for seg in &primary {
if seg.time.duration() < min_dur && !aggregated.is_empty() {
continue;
}
if let Some(last) = aggregated.last_mut()
&& last.local_speaker_idx == seg.local_speaker_idx
&& seg.time.start - last.time.end < max_gap
{
last.time.end = seg.time.end.max(last.time.end);
continue;
}
aggregated.push(seg.clone());
}
let agg_durations: Vec<f64> = aggregated.iter().map(|s| s.time.duration()).collect();
let agg_avg = agg_durations.iter().sum::<f64>() / agg_durations.len() as f64;
let agg_short = agg_durations.iter().filter(|&&d| d < 0.5).count();
println!(
"Agg (gap<{:.1}s, min>{:.1}s): {} segs, avg={:.2}s, short={}",
max_gap,
min_dur,
aggregated.len(),
agg_avg,
agg_short
);
}
println!("\n=== LEGACY SILEROVAD SEGMENTATION ===");
let silero_path = Path::new("models/silero_vad.onnx");
let mut vad = if silero_path.is_file() {
polyvoice::SileroVad::new(silero_path, 512).expect("vad")
} else {
println!(" silero_vad.onnx not found — skipping comparison");
return;
};
let speech_regions = polyvoice::vad::segment_speech(
&mut vad,
&samples,
&polyvoice::types::DiarizationConfig::default(),
&polyvoice::vad::VadConfig::default(),
)
.expect("segment_speech");
println!("SileroVAD speech regions: {}", speech_regions.len());
let vad_durations: Vec<f64> = speech_regions
.iter()
.map(|(s, e)| (*e - *s) as f64 / sr as f64)
.collect();
let vad_avg = vad_durations.iter().sum::<f64>() / vad_durations.len() as f64;
println!(" avg region duration: {:.2}s", vad_avg);
println!("\n=== DER USING LOCAL_SPEAKER_IDX DIRECTLY ===");
{
let ref_turns = {
let raw = polyvoice::rttm::parse_rttm_file(_rttm_path).expect("parse rttm");
let grouped = polyvoice::rttm::group_by_file(&raw);
let segs: Vec<_> = grouped
.get("aepyx")
.map(|v| v.iter().map(|s| (*s).clone()).collect())
.unwrap_or_default();
let (turns, _map) = polyvoice::rttm::to_speaker_turns(&segs);
turns
};
let turns: Vec<polyvoice::types::SpeakerTurn> = primary
.iter()
.map(|s| polyvoice::types::SpeakerTurn {
speaker: polyvoice::types::SpeakerId(s.local_speaker_idx as u32),
time: s.time,
text: None,
})
.collect();
let der = polyvoice::der::compute_der(&ref_turns, &turns, 0.25);
println!(
" DER = {:.2}% (speakers={})",
der.der * 100.0,
primary
.iter()
.map(|s| s.local_speaker_idx)
.collect::<std::collections::HashSet<_>>()
.len()
);
}
println!("\n=== EMBEDDINGS ===");
let sr_f64 = sr as f64;
let mut embeddings: Vec<Vec<f32>> = Vec::with_capacity(primary.len());
for seg in &primary {
let start = (seg.time.start * sr_f64) as usize;
let end = ((seg.time.end * sr_f64) as usize).min(samples.len());
if end <= start {
continue;
}
let chunk = &samples[start..end];
let emb = embedder.embed(chunk).expect("embed");
embeddings.push(emb);
}
println!("Valid embeddings: {}", embeddings.len());
if embeddings.len() >= 2 {
let mut sims: Vec<f32> = Vec::new();
for i in 0..embeddings.len().min(50) {
for j in (i + 1)..embeddings.len().min(50) {
sims.push(cosine_similarity(&embeddings[i], &embeddings[j]));
}
}
sims.sort_by(|a, b| a.total_cmp(b));
println!("Pairwise cosine similarity (first 50 embeddings):");
println!(
" min={:.3} p5={:.3} p25={:.3} median={:.3} p75={:.3} p95={:.3} max={:.3}",
sims.first().unwrap(),
sims[sims.len() * 5 / 100],
sims[sims.len() * 25 / 100],
sims[sims.len() * 50 / 100],
sims[sims.len() * 75 / 100],
sims[sims.len() * 95 / 100],
sims.last().unwrap(),
);
}
println!("\n=== CLUSTERING (AHC auto) ===");
let clusterer = AhcClusterer::new(20);
let labels = clusterer.cluster(&embeddings).expect("cluster");
let num_clusters = labels
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.len();
println!(
"Inferred clusters: {} (max_clusters={})",
num_clusters,
clusterer.max_clusters()
);
println!("\n=== CLUSTERING (AHC fixed threshold 0.45) ===");
let clusterer_fixed = AhcClusterer::with_threshold(20, 0.45);
let labels_fixed = clusterer_fixed.cluster(&embeddings).expect("cluster");
let num_fixed = labels_fixed
.iter()
.copied()
.collect::<std::collections::HashSet<_>>()
.len();
println!("Inferred clusters: {} (threshold=0.45)", num_fixed);
println!("\n=== SEGMENT → CLUSTER MAP (first 20) ===");
for (i, seg) in primary.iter().take(20).enumerate() {
println!(
" {:2}: {:.2}s - {:.2}s cluster={}",
i, seg.time.start, seg.time.end, labels[i]
);
}
}