use crate::types::{SpeakerTurn, TimeRange};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy)]
pub struct DerResult {
pub der: f64,
pub miss_rate: f64,
pub false_alarm_rate: f64,
pub confusion_rate: f64,
pub total_speech: f64,
pub total_ref_frames: u64,
pub missed_frames: u64,
pub false_alarm_frames: u64,
pub confusion_frames: u64,
}
impl std::fmt::Display for DerResult {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"DER={:.1}% (miss={:.1}%, fa={:.1}%, conf={:.1}%, speech={:.1}s)",
self.der * 100.0,
self.miss_rate * 100.0,
self.false_alarm_rate * 100.0,
self.confusion_rate * 100.0,
self.total_speech,
)
}
}
pub fn compute_der(
reference: &[SpeakerTurn],
hypothesis: &[SpeakerTurn],
collar: f64,
) -> DerResult {
der_core(reference, hypothesis, collar, Region::All)
}
pub fn compute_der_single_speaker_regions(
reference: &[SpeakerTurn],
hypothesis: &[SpeakerTurn],
collar: f64,
) -> DerResult {
der_core(reference, hypothesis, collar, Region::SingleSpeaker)
}
#[derive(Clone, Copy, PartialEq, Eq)]
enum Region {
All,
SingleSpeaker,
Overlap,
}
fn der_core(
reference: &[SpeakerTurn],
hypothesis: &[SpeakerTurn],
collar: f64,
region: Region,
) -> DerResult {
if reference.is_empty() {
return DerResult {
der: 0.0,
miss_rate: 0.0,
false_alarm_rate: 0.0,
confusion_rate: 0.0,
total_speech: 0.0,
total_ref_frames: 0,
missed_frames: 0,
false_alarm_frames: 0,
confusion_frames: 0,
};
}
if !collar.is_finite() || collar < 0.0 {
return DerResult {
der: 0.0,
miss_rate: 0.0,
false_alarm_rate: 0.0,
confusion_rate: 0.0,
total_speech: 0.0,
total_ref_frames: 0,
missed_frames: 0,
false_alarm_frames: 0,
confusion_frames: 0,
};
}
let resolution = 0.01; const MAX_FRAMES: usize = 24 * 3600 * 100;
let max_time = reference
.iter()
.chain(hypothesis.iter())
.map(|t| t.time.end)
.fold(0.0f64, f64::max);
if !max_time.is_finite() || max_time < 0.0 {
return DerResult {
der: 0.0,
miss_rate: 0.0,
false_alarm_rate: 0.0,
confusion_rate: 0.0,
total_speech: 0.0,
total_ref_frames: 0,
missed_frames: 0,
false_alarm_frames: 0,
confusion_frames: 0,
};
}
let n_frames = ((max_time / resolution).ceil() as usize + 1).min(MAX_FRAMES);
let mut ignore_mask = build_collar_mask(reference, collar, resolution, n_frames);
let ref_frames = build_speaker_frames(reference, resolution, n_frames);
let hyp_frames = build_speaker_frames(hypothesis, resolution, n_frames);
match region {
Region::All => {}
Region::SingleSpeaker => {
for (i, frame) in ref_frames.iter().enumerate() {
if frame.len() >= 2 {
ignore_mask[i] = true;
}
}
}
Region::Overlap => {
for (i, frame) in ref_frames.iter().enumerate() {
if frame.len() < 2 {
ignore_mask[i] = true;
}
}
}
}
let mapping = optimal_speaker_mapping(&ref_frames, &hyp_frames, &ignore_mask);
let mut total_ref = 0u64;
let mut missed = 0u64;
let mut false_alarm = 0u64;
let mut confusion = 0u64;
for i in 0..n_frames {
if ignore_mask[i] {
continue;
}
let ref_spk = &ref_frames[i];
let hyp_spk = &hyp_frames[i];
let n_ref = ref_spk.len() as u64;
let n_hyp = hyp_spk.len() as u64;
total_ref += n_ref;
let mut n_correct = 0u64;
for h in hyp_spk {
if let Some(&mapped_ref) = mapping.get(h) {
if ref_spk.contains(&mapped_ref) {
n_correct += 1;
}
}
}
n_correct = n_correct.min(n_ref);
missed += n_ref.saturating_sub(n_hyp);
false_alarm += n_hyp.saturating_sub(n_ref);
confusion += n_ref.min(n_hyp) - n_correct;
}
let total_ref_f = total_ref as f64;
if total_ref == 0 {
return DerResult {
der: 0.0,
miss_rate: 0.0,
false_alarm_rate: 0.0,
confusion_rate: 0.0,
total_speech: 0.0,
total_ref_frames: 0,
missed_frames: 0,
false_alarm_frames: 0,
confusion_frames: 0,
};
}
let total_speech_secs = total_ref as f64 * resolution;
DerResult {
der: (missed + false_alarm + confusion) as f64 / total_ref_f,
miss_rate: missed as f64 / total_ref_f,
false_alarm_rate: false_alarm as f64 / total_ref_f,
confusion_rate: confusion as f64 / total_ref_f,
total_speech: total_speech_secs,
total_ref_frames: total_ref,
missed_frames: missed,
false_alarm_frames: false_alarm,
confusion_frames: confusion,
}
}
fn build_collar_mask(
reference: &[SpeakerTurn],
collar: f64,
resolution: f64,
n_frames: usize,
) -> Vec<bool> {
let mut mask = vec![false; n_frames];
if collar <= 0.0 {
return mask;
}
for turn in reference {
for boundary in [turn.time.start, turn.time.end] {
let start_frame = ((boundary - collar).max(0.0) / resolution) as usize;
let end_frame = ((boundary + collar) / resolution).ceil() as usize;
for item in mask
.iter_mut()
.take(end_frame.min(n_frames))
.skip(start_frame)
{
*item = true;
}
}
}
mask
}
fn build_speaker_frames(turns: &[SpeakerTurn], resolution: f64, n_frames: usize) -> Vec<Vec<u32>> {
let mut frames: Vec<Vec<u32>> = vec![Vec::new(); n_frames];
for turn in turns {
let start_frame = (turn.time.start / resolution) as usize;
let end_frame = (turn.time.end / resolution).ceil() as usize;
for frame in frames
.iter_mut()
.take(end_frame.min(n_frames))
.skip(start_frame)
{
if !frame.contains(&turn.speaker.0) {
frame.push(turn.speaker.0);
}
}
}
frames
}
fn optimal_speaker_mapping(
ref_frames: &[Vec<u32>],
hyp_frames: &[Vec<u32>],
collar_mask: &[bool],
) -> HashMap<u32, u32> {
let mut cooccurrence: HashMap<(u32, u32), u64> = HashMap::new();
for i in 0..ref_frames.len().min(hyp_frames.len()) {
if collar_mask[i] {
continue;
}
for &r in &ref_frames[i] {
for &h in &hyp_frames[i] {
*cooccurrence.entry((h, r)).or_insert(0) += 1;
}
}
}
if cooccurrence.is_empty() {
return HashMap::new();
}
let mut hyp_ids: Vec<u32> = cooccurrence.keys().map(|&(h, _)| h).collect();
hyp_ids.sort_unstable();
hyp_ids.dedup();
let mut ref_ids: Vec<u32> = cooccurrence.keys().map(|&(_, r)| r).collect();
ref_ids.sort_unstable();
ref_ids.dedup();
let n = hyp_ids.len().max(ref_ids.len());
let mut cost = vec![vec![0.0_f32; n]; n];
for (&(h, r), &count) in &cooccurrence {
if let (Ok(i), Ok(j)) = (hyp_ids.binary_search(&h), ref_ids.binary_search(&r)) {
cost[i][j] = -(count as f32);
}
}
let assignment = match crate::hungarian::solve(&cost) {
Some(a) => a,
None => return HashMap::new(),
};
let mut mapping: HashMap<u32, u32> = HashMap::new();
for (row, &col) in assignment.iter().enumerate() {
if let (Some(&h), Some(&r)) = (hyp_ids.get(row), ref_ids.get(col)) {
if cooccurrence.get(&(h, r)).copied().unwrap_or(0) > 0 {
mapping.insert(h, r);
}
}
}
mapping
}
pub fn compute_der_from_rttm(
reference: &[(f64, f64, &str)],
hypothesis: &[SpeakerTurn],
collar: f64,
) -> DerResult {
let mut speaker_map: HashMap<&str, u32> = HashMap::new();
let mut next_id = 1000u32;
let ref_turns: Vec<SpeakerTurn> = reference
.iter()
.map(|&(start, end, speaker)| {
let id = *speaker_map.entry(speaker).or_insert_with(|| {
let id = next_id;
next_id += 1;
id
});
SpeakerTurn {
speaker: crate::types::SpeakerId(id),
time: TimeRange { start, end },
text: None,
}
})
.collect();
compute_der(&ref_turns, hypothesis, collar)
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct SpeakerRecall {
pub speaker: u32,
pub ref_frames: u64,
pub recalled_frames: u64,
pub recall: f64,
}
#[derive(Debug, Clone)]
pub struct DerDecomposition {
pub total: DerResult,
pub single_speaker: DerResult,
pub overlap: DerResult,
pub per_speaker_recall: Vec<SpeakerRecall>,
}
pub fn compute_der_decomposition(
reference: &[SpeakerTurn],
hypothesis: &[SpeakerTurn],
collar: f64,
) -> DerDecomposition {
DerDecomposition {
total: der_core(reference, hypothesis, collar, Region::All),
single_speaker: der_core(reference, hypothesis, collar, Region::SingleSpeaker),
overlap: der_core(reference, hypothesis, collar, Region::Overlap),
per_speaker_recall: compute_per_speaker_recall(reference, hypothesis, collar),
}
}
fn compute_per_speaker_recall(
reference: &[SpeakerTurn],
hypothesis: &[SpeakerTurn],
collar: f64,
) -> Vec<SpeakerRecall> {
if reference.is_empty() || !collar.is_finite() || collar < 0.0 {
return Vec::new();
}
let resolution = 0.01;
const MAX_FRAMES: usize = 24 * 3600 * 100;
let max_time = reference
.iter()
.chain(hypothesis.iter())
.map(|t| t.time.end)
.fold(0.0f64, f64::max);
if !max_time.is_finite() || max_time < 0.0 {
return Vec::new();
}
let n_frames = ((max_time / resolution).ceil() as usize + 1).min(MAX_FRAMES);
let collar_mask = build_collar_mask(reference, collar, resolution, n_frames);
let ref_frames = build_speaker_frames(reference, resolution, n_frames);
let hyp_frames = build_speaker_frames(hypothesis, resolution, n_frames);
let mapping = optimal_speaker_mapping(&ref_frames, &hyp_frames, &collar_mask);
let mut ref_to_hyp: HashMap<u32, u32> = HashMap::new();
for (&h, &r) in &mapping {
ref_to_hyp.insert(r, h);
}
let mut ref_count: HashMap<u32, u64> = HashMap::new();
let mut recalled: HashMap<u32, u64> = HashMap::new();
for i in 0..n_frames {
if collar_mask[i] {
continue;
}
for &r in &ref_frames[i] {
*ref_count.entry(r).or_insert(0) += 1;
if let Some(&h) = ref_to_hyp.get(&r) {
if hyp_frames[i].contains(&h) {
*recalled.entry(r).or_insert(0) += 1;
}
}
}
}
let mut out: Vec<SpeakerRecall> = ref_count
.into_iter()
.map(|(speaker, ref_frames)| {
let recalled_frames = recalled.get(&speaker).copied().unwrap_or(0);
SpeakerRecall {
speaker,
ref_frames,
recalled_frames,
recall: recalled_frames as f64 / ref_frames as f64,
}
})
.collect();
out.sort_by_key(|s| s.speaker);
out
}
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod tests {
use super::*;
use crate::types::SpeakerId;
fn turn(speaker: u32, start: f64, end: f64) -> SpeakerTurn {
SpeakerTurn {
speaker: SpeakerId(speaker),
time: TimeRange { start, end },
text: None,
}
}
#[test]
fn perfect_match() {
let reference = vec![turn(0, 0.0, 3.0), turn(1, 3.5, 6.0), turn(0, 6.5, 10.0)];
let hypothesis = vec![turn(0, 0.0, 3.0), turn(1, 3.5, 6.0), turn(0, 6.5, 10.0)];
let result = compute_der(&reference, &hypothesis, 0.0);
assert!(
result.der < 0.01,
"perfect match DER should be ~0, got {}",
result.der
);
}
#[test]
fn swapped_ids_still_maps() {
let reference = vec![turn(0, 0.0, 3.0), turn(1, 3.5, 6.0)];
let hypothesis = vec![turn(5, 0.0, 3.0), turn(9, 3.5, 6.0)];
let result = compute_der(&reference, &hypothesis, 0.0);
assert!(
result.der < 0.01,
"swapped IDs should map correctly, got DER={}",
result.der
);
}
#[test]
fn full_miss() {
let reference = vec![turn(0, 0.0, 5.0)];
let hypothesis = vec![];
let result = compute_der(&reference, &hypothesis, 0.0);
assert!((result.miss_rate - 1.0).abs() < 0.01);
assert!((result.der - 1.0).abs() < 0.01);
}
#[test]
fn full_false_alarm() {
let reference = vec![turn(0, 0.0, 5.0)];
let hypothesis = vec![turn(0, 0.0, 5.0), turn(1, 0.0, 5.0)];
let result = compute_der(&reference, &hypothesis, 0.0);
assert!(result.false_alarm_rate > 0.5);
}
#[test]
fn speaker_confusion() {
let reference = vec![turn(0, 0.0, 3.0), turn(1, 3.0, 6.0)];
let hypothesis = vec![turn(0, 0.0, 6.0)];
let result = compute_der(&reference, &hypothesis, 0.0);
assert!(
result.confusion_rate > 0.3,
"should have confusion, got {}",
result
);
}
#[test]
fn collar_reduces_error() {
let reference = vec![turn(0, 0.0, 5.0), turn(1, 5.0, 10.0)];
let hypothesis = vec![turn(0, 0.0, 5.2), turn(1, 5.2, 10.0)];
let no_collar = compute_der(&reference, &hypothesis, 0.0);
let with_collar = compute_der(&reference, &hypothesis, 0.25);
assert!(with_collar.der < no_collar.der, "collar should reduce DER");
}
#[test]
fn empty_reference() {
let result = compute_der(&[], &[turn(0, 0.0, 5.0)], 0.0);
assert_eq!(result.der, 0.0);
}
#[test]
fn non_finite_collar_returns_zero() {
let reference = vec![turn(0, 0.0, 5.0)];
let hypothesis = vec![turn(0, 0.0, 5.0)];
let result = compute_der(&reference, &hypothesis, f64::NAN);
assert_eq!(result.der, 0.0);
let result = compute_der(&reference, &hypothesis, f64::NEG_INFINITY);
assert_eq!(result.der, 0.0);
}
#[test]
fn huge_max_time_is_capped() {
let reference = vec![turn(0, 0.0, 1e12)];
let hypothesis = vec![turn(0, 0.0, 1e12)];
let result = compute_der(&reference, &hypothesis, 0.0);
assert_eq!(result.der, 0.0);
}
#[test]
fn der_result_frame_counts_are_consistent() {
let reference = vec![turn(0, 0.0, 3.0), turn(1, 3.0, 6.0)];
let hypothesis = vec![turn(0, 0.0, 3.0)];
let r = compute_der(&reference, &hypothesis, 0.0);
assert!(
r.total_ref_frames > 0,
"expected non-empty reference frames"
);
let expected = (r.missed_frames + r.false_alarm_frames + r.confusion_frames) as f64
/ r.total_ref_frames as f64;
assert!(
(r.der - expected).abs() < 1e-9,
"der {} != error-frames/ref-frames {expected}",
r.der
);
assert!(
(r.total_ref_frames as f64 * 0.01 - r.total_speech).abs() < 1e-9,
"frame count * 0.01 ({}) != total_speech ({})",
r.total_ref_frames as f64 * 0.01,
r.total_speech
);
}
#[test]
fn single_speaker_der_excludes_overlap_frames() {
let reference = vec![turn(0, 0.0, 4.0), turn(1, 2.0, 6.0)];
let hypothesis: Vec<SpeakerTurn> = vec![];
let full = compute_der(&reference, &hypothesis, 0.0);
let single = compute_der_single_speaker_regions(&reference, &hypothesis, 0.0);
assert!(
single.total_ref_frames < full.total_ref_frames,
"overlap frames must be excluded: single={} full={}",
single.total_ref_frames,
full.total_ref_frames
);
assert!(
(380..=420).contains(&single.total_ref_frames),
"expected ~400 single-speaker frames, got {}",
single.total_ref_frames
);
assert!(
(single.miss_rate - 1.0).abs() < 1e-9,
"miss={}",
single.miss_rate
);
}
#[test]
fn single_speaker_der_ignores_overlap_mismatch() {
let reference = vec![turn(0, 0.0, 6.0), turn(1, 4.0, 6.0)];
let hypothesis = vec![turn(0, 0.0, 6.0)];
let single = compute_der_single_speaker_regions(&reference, &hypothesis, 0.0);
assert!(
single.der < 0.01,
"single-speaker DER must ignore the overlap-region mismatch, got {single}"
);
}
#[test]
fn decomposition_splits_overlap_and_recall() {
let reference = vec![turn(0, 0.0, 6.0), turn(1, 3.0, 6.0)];
let hypothesis = vec![turn(0, 0.0, 6.0)];
let d = compute_der_decomposition(&reference, &hypothesis, 0.0);
assert!((d.total.der - 1.0 / 3.0).abs() < 0.02, "total {}", d.total);
assert!(d.single_speaker.der < 0.02, "single {}", d.single_speaker);
assert!((d.overlap.der - 0.5).abs() < 0.02, "overlap {}", d.overlap);
let r0 = d
.per_speaker_recall
.iter()
.find(|s| s.speaker == 0)
.expect("spk0 recall");
let r1 = d
.per_speaker_recall
.iter()
.find(|s| s.speaker == 1)
.expect("spk1 recall");
assert!((r0.recall - 1.0).abs() < 0.02, "spk0 recall {}", r0.recall);
assert!(r1.recall < 0.02, "spk1 recall {}", r1.recall);
}
#[test]
fn optimal_mapping_beats_greedy_on_counterexample() {
let mut ref_frames: Vec<Vec<u32>> = Vec::new();
let mut hyp_frames: Vec<Vec<u32>> = Vec::new();
for _ in 0..10 {
ref_frames.push(vec![0]);
hyp_frames.push(vec![0]);
}
for _ in 0..9 {
ref_frames.push(vec![1]);
hyp_frames.push(vec![0]);
}
for _ in 0..8 {
ref_frames.push(vec![0]);
hyp_frames.push(vec![1]);
}
let collar_mask = vec![false; ref_frames.len()];
let mapping = optimal_speaker_mapping(&ref_frames, &hyp_frames, &collar_mask);
assert_eq!(
mapping.get(&0),
Some(&1),
"hyp 0 must map to ref 1 (optimal), not ref 0 (greedy)"
);
assert_eq!(mapping.get(&1), Some(&0), "hyp 1 must map to ref 0");
}
}