use polyvoice::{
detect_overlaps, DiarizationConfig, DummyExtractor, OfflineDiarizer, OnlineDiarizer,
SpeakerCluster, TimeRange, WordAlignment,
};
fn sine_wave(freq: f32, duration_secs: f32, sample_rate: u32) -> Vec<f32> {
let num_samples = (duration_secs * sample_rate as f32) as usize;
(0..num_samples)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(2.0 * std::f32::consts::PI * freq * t).sin() * 0.5
})
.collect()
}
fn two_speaker_audio(sample_rate: u32) -> Vec<f32> {
let mut audio = Vec::new();
audio.extend_from_slice(&sine_wave(200.0, 3.0, sample_rate));
audio.extend_from_slice(&sine_wave(400.0, 3.0, sample_rate));
audio.extend_from_slice(&sine_wave(200.0, 4.0, sample_rate));
audio
}
#[test]
fn test_offline_two_speakers_dummy() {
let sample_rate = 16000;
let config = DiarizationConfig {
window_secs: 1.0,
hop_secs: 0.5,
..Default::default()
};
let diarizer = OfflineDiarizer::new(config);
let extractor = DummyExtractor::new(256);
let samples = two_speaker_audio(sample_rate);
let result = diarizer.run(&samples, &extractor).unwrap();
assert!(!result.segments.is_empty());
assert!(!result.turns.is_empty());
for turn in &result.turns {
assert!(turn.time.duration() > 0.0);
}
}
#[test]
fn test_online_streaming_basic() {
let sample_rate = 16000;
let config = DiarizationConfig {
window_secs: 1.0,
hop_secs: 0.5,
..Default::default()
};
let mut diarizer = OnlineDiarizer::new(config);
let extractor = DummyExtractor::new(256);
let samples = two_speaker_audio(sample_rate);
let chunk_size = sample_rate as usize;
let mut all_segments = Vec::new();
for chunk in samples.chunks(chunk_size) {
let segs = diarizer.feed(chunk, &extractor).unwrap();
all_segments.extend(segs);
}
let final_seg = diarizer.flush(&extractor).unwrap();
if let Some(s) = final_seg {
all_segments.push(s);
}
assert!(!all_segments.is_empty());
for window in all_segments.windows(2) {
assert!(window[0].time.start <= window[1].time.start);
}
}
#[test]
fn test_word_alignment() {
let sample_rate = 16000;
let config = DiarizationConfig {
window_secs: 1.0,
hop_secs: 0.5,
..Default::default()
};
let mut diarizer = OnlineDiarizer::new(config);
let extractor = DummyExtractor::new(256);
let samples = sine_wave(300.0, 3.0, sample_rate);
let _ = diarizer.feed(&samples, &extractor).unwrap();
let mut words = vec![
WordAlignment {
word: "hello".into(),
time: TimeRange { start: 0.5, end: 1.0 },
speaker: None,
confidence: 0.9,
},
WordAlignment {
word: "world".into(),
time: TimeRange { start: 1.5, end: 2.0 },
speaker: None,
confidence: 0.8,
},
];
diarizer.align_words(&mut words);
for w in &words {
assert!(w.speaker.is_some());
}
}
#[test]
fn test_overlap_detection() {
let segments = vec![
polyvoice::Segment {
time: TimeRange { start: 0.0, end: 3.0 },
speaker: Some(polyvoice::SpeakerId(0)),
confidence: None,
},
polyvoice::Segment {
time: TimeRange { start: 1.0, end: 4.0 },
speaker: Some(polyvoice::SpeakerId(1)),
confidence: None,
},
];
let overlaps = detect_overlaps(&segments);
assert_eq!(overlaps.len(), 1);
assert!((overlaps[0].time.start - 1.0).abs() < 1e-5);
assert!((overlaps[0].time.end - 3.0).abs() < 1e-5);
}
#[test]
fn test_cluster_two_distinct_speakers() {
let mut cluster = SpeakerCluster::new(DiarizationConfig::default());
let mut emb_a = vec![0.0f32; 256];
emb_a[0] = 1.0;
let mut emb_b = vec![0.0f32; 256];
emb_b[1] = 1.0;
let (id_a1, _) = cluster.assign(&emb_a);
let (id_b1, _) = cluster.assign(&emb_b);
let (id_a2, _) = cluster.assign(&emb_a);
let (id_b2, _) = cluster.assign(&emb_b);
assert_ne!(id_a1, id_b1, "different speakers should have different IDs");
assert_eq!(id_a1, id_a2, "same speaker should have same ID");
assert_eq!(id_b1, id_b2, "same speaker should have same ID");
assert_eq!(cluster.num_speakers(), 2);
}
#[test]
fn test_offline_produces_multiple_speakers() {
let sample_rate = 16000;
let config = DiarizationConfig {
window_secs: 0.5,
hop_secs: 0.25,
threshold: 0.3, ..Default::default()
};
let diarizer = OfflineDiarizer::new(config);
let extractor = DummyExtractor::new(256);
let mut audio = sine_wave(200.0, 2.0, sample_rate);
audio.extend_from_slice(&sine_wave(800.0, 2.0, sample_rate));
let result = diarizer.run(&audio, &extractor).unwrap();
assert!(!result.segments.is_empty());
}
#[test]
fn test_offline_merges_small_gaps() {
let sample_rate = 16000;
let config = DiarizationConfig {
window_secs: 0.5,
hop_secs: 0.25,
..Default::default()
};
let diarizer = OfflineDiarizer::new(config);
let extractor = DummyExtractor::new(256);
let audio = sine_wave(300.0, 4.0, sample_rate);
let result = diarizer.run(&audio, &extractor).unwrap();
for window in result.turns.windows(2) {
let a = &window[0];
let b = &window[1];
if a.speaker == b.speaker {
assert!(
b.time.start - a.time.end > 0.5,
"gaps <= 0.5s should be merged for same speaker"
);
}
}
}