use crate::alignment::CaptionBlock;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
#[derive(Debug, Default)]
pub struct SpeakerLabelPool {
inner: Mutex<HashMap<String, Arc<str>>>,
}
impl SpeakerLabelPool {
pub fn new() -> Self {
Self::default()
}
pub fn intern(&self, label: &str) -> Arc<str> {
let mut map = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if let Some(existing) = map.get(label) {
return Arc::clone(existing);
}
let interned: Arc<str> = Arc::from(label);
map.insert(label.to_string(), Arc::clone(&interned));
interned
}
pub fn len(&self) -> usize {
self.inner.lock().unwrap_or_else(|e| e.into_inner()).len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SpeakerGender {
Male,
Female,
Other,
Unknown,
}
#[derive(Debug, Clone, PartialEq)]
pub struct Speaker {
pub id: u8,
pub name: Option<String>,
pub gender: Option<SpeakerGender>,
pub language: Option<String>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct SpeakerTurn {
pub speaker_id: u8,
pub start_ms: u64,
pub end_ms: u64,
}
impl SpeakerTurn {
pub fn duration_ms(&self) -> u64 {
self.end_ms.saturating_sub(self.start_ms)
}
pub fn overlaps_with(&self, other: &SpeakerTurn) -> bool {
self.start_ms < other.end_ms && other.start_ms < self.end_ms
}
}
#[derive(Debug, Clone)]
pub struct DiarizationResult {
pub speakers: HashMap<u8, Speaker>,
pub turns: Vec<SpeakerTurn>,
}
impl DiarizationResult {
pub fn new() -> Self {
Self {
speakers: HashMap::new(),
turns: Vec::new(),
}
}
pub fn total_speech_ms(&self) -> u64 {
self.turns.iter().map(|t| t.duration_ms()).sum()
}
}
pub fn merge_consecutive_turns(result: &DiarizationResult) -> Vec<SpeakerTurn> {
merge_consecutive_turns_with_gap(result, 500)
}
pub fn merge_consecutive_turns_with_gap(
result: &DiarizationResult,
max_gap_ms: u64,
) -> Vec<SpeakerTurn> {
let mut sorted = result.turns.clone();
sorted.sort_by_key(|t| t.start_ms);
if sorted.is_empty() {
return Vec::new();
}
let mut merged: Vec<SpeakerTurn> = Vec::new();
for turn in sorted {
if let Some(last) = merged.last_mut() {
let gap = turn.start_ms.saturating_sub(last.end_ms);
if last.speaker_id == turn.speaker_id && gap < max_gap_ms {
last.end_ms = last.end_ms.max(turn.end_ms);
continue;
}
}
merged.push(turn);
}
merged
}
#[derive(Debug, Clone, PartialEq)]
pub struct SpeakerStats {
pub total_time_ms: u64,
pub turn_count: u32,
pub avg_turn_ms: u64,
}
pub fn speaker_stats(result: &DiarizationResult) -> HashMap<u8, SpeakerStats> {
let mut totals: HashMap<u8, (u64, u32)> = HashMap::new();
for turn in &result.turns {
let entry = totals.entry(turn.speaker_id).or_insert((0, 0));
entry.0 += turn.duration_ms();
entry.1 += 1;
}
totals
.into_iter()
.map(|(id, (total_ms, count))| {
let avg = if count > 0 {
total_ms / u64::from(count)
} else {
0
};
(
id,
SpeakerStats {
total_time_ms: total_ms,
turn_count: count,
avg_turn_ms: avg,
},
)
})
.collect()
}
pub fn dominant_speaker(result: &DiarizationResult) -> Option<u8> {
let stats = speaker_stats(result);
stats
.into_iter()
.max_by_key(|(_, s)| s.total_time_ms)
.map(|(id, _)| id)
}
pub fn assign_speakers_to_blocks(blocks: &mut Vec<CaptionBlock>, diarization: &DiarizationResult) {
for block in blocks.iter_mut() {
let best = diarization
.turns
.iter()
.filter_map(|turn| {
let overlap_start = block.start_ms.max(turn.start_ms);
let overlap_end = block.end_ms.min(turn.end_ms);
if overlap_end > overlap_start {
Some((turn.speaker_id, overlap_end - overlap_start))
} else {
None
}
})
.max_by_key(|(_, overlap)| *overlap);
if let Some((speaker_id, _)) = best {
block.speaker_id = Some(speaker_id);
}
}
}
pub fn format_speaker_label(speaker: &Speaker) -> String {
match &speaker.name {
Some(name) => name.clone(),
None => format!("Speaker {}", speaker.id),
}
}
pub struct CrosstalkDetector {
pub min_overlap_fraction: f32,
}
impl CrosstalkDetector {
pub fn new() -> Self {
Self {
min_overlap_fraction: 0.0,
}
}
pub fn with_overlap_tolerance(min_overlap_fraction: f32) -> Self {
Self {
min_overlap_fraction: min_overlap_fraction.clamp(0.0, 1.0),
}
}
pub fn find_overlapping_turns(result: &DiarizationResult) -> Vec<(SpeakerTurn, SpeakerTurn)> {
Self::new().detect(result)
}
pub fn detect(&self, result: &DiarizationResult) -> Vec<(SpeakerTurn, SpeakerTurn)> {
let turns = &result.turns;
let mut overlapping: Vec<(SpeakerTurn, SpeakerTurn)> = Vec::new();
for i in 0..turns.len() {
for j in (i + 1)..turns.len() {
if !turns[i].overlaps_with(&turns[j]) {
continue;
}
let overlap_start = turns[i].start_ms.max(turns[j].start_ms);
let overlap_end = turns[i].end_ms.min(turns[j].end_ms);
if overlap_end <= overlap_start {
continue;
}
let overlap_ms = overlap_end - overlap_start;
if self.min_overlap_fraction > 0.0 {
let shorter_ms = turns[i].duration_ms().min(turns[j].duration_ms());
if shorter_ms == 0 {
continue;
}
let fraction = overlap_ms as f32 / shorter_ms as f32;
if fraction < self.min_overlap_fraction {
continue;
}
}
let (a, b) = if turns[i].start_ms <= turns[j].start_ms {
(turns[i].clone(), turns[j].clone())
} else {
(turns[j].clone(), turns[i].clone())
};
overlapping.push((a, b));
}
}
overlapping
}
}
impl Default for CrosstalkDetector {
fn default() -> Self {
Self::new()
}
}
pub fn voice_activity_ratio(result: &DiarizationResult, total_duration_ms: u64) -> f32 {
if total_duration_ms == 0 {
return 0.0;
}
let mut intervals: Vec<(u64, u64)> = result
.turns
.iter()
.map(|t| (t.start_ms, t.end_ms))
.collect();
intervals.sort_by_key(|&(s, _)| s);
let mut union_ms: u64 = 0;
let mut cursor: u64 = 0;
for (start, end) in intervals {
let effective_start = start.max(cursor);
if end > effective_start {
union_ms += end - effective_start;
cursor = end;
}
}
(union_ms as f32 / total_duration_ms as f32).min(1.0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::alignment::CaptionPosition;
fn make_speaker(id: u8, name: Option<&str>) -> Speaker {
Speaker {
id,
name: name.map(|s| s.to_string()),
gender: None,
language: None,
}
}
fn make_turn(speaker_id: u8, start_ms: u64, end_ms: u64) -> SpeakerTurn {
SpeakerTurn {
speaker_id,
start_ms,
end_ms,
}
}
fn make_block(id: u32, start_ms: u64, end_ms: u64) -> CaptionBlock {
CaptionBlock {
id,
start_ms,
end_ms,
lines: vec!["text".to_string()],
speaker_id: None,
position: CaptionPosition::Bottom,
}
}
fn simple_result() -> DiarizationResult {
let mut r = DiarizationResult::new();
r.speakers.insert(1, make_speaker(1, Some("Alice")));
r.speakers.insert(2, make_speaker(2, None));
r.turns = vec![
make_turn(1, 0, 3000),
make_turn(2, 3000, 6000),
make_turn(1, 6500, 9000),
];
r
}
#[test]
fn speaker_turn_duration() {
let t = make_turn(1, 1000, 4000);
assert_eq!(t.duration_ms(), 3000);
}
#[test]
fn speaker_turn_overlap_true() {
let a = make_turn(1, 0, 2000);
let b = make_turn(2, 1000, 3000);
assert!(a.overlaps_with(&b));
}
#[test]
fn speaker_turn_overlap_false_adjacent() {
let a = make_turn(1, 0, 1000);
let b = make_turn(2, 1000, 2000);
assert!(!a.overlaps_with(&b));
}
#[test]
fn speaker_turn_overlap_false_separate() {
let a = make_turn(1, 0, 1000);
let b = make_turn(2, 2000, 3000);
assert!(!a.overlaps_with(&b));
}
#[test]
fn merge_consecutive_empty() {
let r = DiarizationResult::new();
assert!(merge_consecutive_turns(&r).is_empty());
}
#[test]
fn merge_consecutive_same_speaker_small_gap() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 1000), make_turn(1, 1200, 2000)];
let result = merge_consecutive_turns(&r);
assert_eq!(result.len(), 1);
assert_eq!(result[0].start_ms, 0);
assert_eq!(result[0].end_ms, 2000);
}
#[test]
fn merge_consecutive_different_speakers_not_merged() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 1000), make_turn(2, 1200, 2000)];
let result = merge_consecutive_turns(&r);
assert_eq!(result.len(), 2);
}
#[test]
fn merge_consecutive_large_gap_not_merged() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 1000), make_turn(1, 2000, 3000)];
let result = merge_consecutive_turns(&r);
assert_eq!(result.len(), 2);
}
#[test]
fn merge_consecutive_sorts_before_merge() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 1200, 2000), make_turn(1, 0, 1000)];
let result = merge_consecutive_turns(&r);
assert_eq!(result.len(), 1);
assert_eq!(result[0].start_ms, 0);
}
#[test]
fn speaker_stats_basic() {
let r = simple_result();
let stats = speaker_stats(&r);
let s1 = stats.get(&1).expect("get should succeed");
assert_eq!(s1.turn_count, 2);
assert_eq!(s1.total_time_ms, 3000 + 2500); let s2 = stats.get(&2).expect("get should succeed");
assert_eq!(s2.turn_count, 1);
assert_eq!(s2.total_time_ms, 3000);
}
#[test]
fn speaker_stats_avg_turn() {
let r = simple_result();
let stats = speaker_stats(&r);
let s1 = stats.get(&1).expect("get should succeed");
assert_eq!(s1.avg_turn_ms, 2750);
}
#[test]
fn dominant_speaker_basic() {
let r = simple_result();
assert_eq!(dominant_speaker(&r), Some(1));
}
#[test]
fn dominant_speaker_empty() {
let r = DiarizationResult::new();
assert_eq!(dominant_speaker(&r), None);
}
#[test]
fn assign_speakers_assigns_overlapping_speaker() {
let r = simple_result();
let mut blocks = vec![make_block(1, 500, 2000), make_block(2, 3500, 5000)];
assign_speakers_to_blocks(&mut blocks, &r);
assert_eq!(blocks[0].speaker_id, Some(1));
assert_eq!(blocks[1].speaker_id, Some(2));
}
#[test]
fn assign_speakers_no_overlap_unchanged() {
let r = simple_result();
let mut blocks = vec![make_block(1, 100_000, 101_000)];
assign_speakers_to_blocks(&mut blocks, &r);
assert_eq!(blocks[0].speaker_id, None);
}
#[test]
fn format_speaker_label_with_name() {
let s = make_speaker(1, Some("Dr. Smith"));
assert_eq!(format_speaker_label(&s), "Dr. Smith");
}
#[test]
fn format_speaker_label_without_name() {
let s = make_speaker(5, None);
assert_eq!(format_speaker_label(&s), "Speaker 5");
}
#[test]
fn find_overlapping_turns_none() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 1000), make_turn(2, 1000, 2000)];
let overlaps = CrosstalkDetector::find_overlapping_turns(&r);
assert!(overlaps.is_empty());
}
#[test]
fn find_overlapping_turns_detects_overlap() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 2000), make_turn(2, 1000, 3000)];
let overlaps = CrosstalkDetector::find_overlapping_turns(&r);
assert_eq!(overlaps.len(), 1);
assert_eq!(overlaps[0].0.speaker_id, 1);
assert_eq!(overlaps[0].1.speaker_id, 2);
}
#[test]
fn find_overlapping_turns_multiple_overlaps() {
let mut r = DiarizationResult::new();
r.turns = vec![
make_turn(1, 0, 3000),
make_turn(2, 1000, 4000),
make_turn(3, 2000, 5000),
];
let overlaps = CrosstalkDetector::find_overlapping_turns(&r);
assert_eq!(overlaps.len(), 3);
}
#[test]
fn voice_activity_ratio_zero_duration() {
let r = simple_result();
assert_eq!(voice_activity_ratio(&r, 0), 0.0);
}
#[test]
fn voice_activity_ratio_full_coverage() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 10000)];
let ratio = voice_activity_ratio(&r, 10000);
assert!((ratio - 1.0).abs() < 1e-5);
}
#[test]
fn voice_activity_ratio_half_coverage() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 5000)];
let ratio = voice_activity_ratio(&r, 10000);
assert!((ratio - 0.5).abs() < 1e-5);
}
#[test]
fn voice_activity_ratio_overlapping_turns_not_double_counted() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 5000), make_turn(2, 0, 5000)];
let ratio = voice_activity_ratio(&r, 10000);
assert!((ratio - 0.5).abs() < 1e-5);
}
#[test]
fn diarization_result_total_speech_ms() {
let r = simple_result();
assert_eq!(r.total_speech_ms(), 8500);
}
#[test]
fn speaker_gender_variants_accessible() {
let g = SpeakerGender::Female;
assert_eq!(g, SpeakerGender::Female);
let g2 = SpeakerGender::Unknown;
assert_ne!(g, g2);
}
#[test]
fn merge_with_gap_zero_does_not_merge() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 1000), make_turn(1, 1200, 2000)];
let result = merge_consecutive_turns_with_gap(&r, 0);
assert_eq!(result.len(), 2);
}
#[test]
fn merge_with_large_gap_merges_far_turns() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 1000), make_turn(1, 3000, 5000)];
let result = merge_consecutive_turns_with_gap(&r, 3000);
assert_eq!(result.len(), 1);
assert_eq!(result[0].start_ms, 0);
assert_eq!(result[0].end_ms, 5000);
}
#[test]
fn crosstalk_detector_with_tolerance_filters_small_overlap() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 2000), make_turn(2, 1900, 3000)];
let detector = CrosstalkDetector::with_overlap_tolerance(0.10);
let overlaps = detector.detect(&r);
assert!(
overlaps.is_empty(),
"small overlap should be filtered by tolerance"
);
}
#[test]
fn crosstalk_detector_with_tolerance_keeps_large_overlap() {
let mut r = DiarizationResult::new();
r.turns = vec![make_turn(1, 0, 2000), make_turn(2, 500, 3000)];
let detector = CrosstalkDetector::with_overlap_tolerance(0.10);
let overlaps = detector.detect(&r);
assert_eq!(overlaps.len(), 1);
}
#[test]
fn assign_speakers_with_five_simultaneous_speakers() {
let mut r = DiarizationResult::new();
r.turns = vec![
make_turn(1, 0, 100),
make_turn(2, 0, 200),
make_turn(3, 0, 1000), make_turn(4, 0, 150),
make_turn(5, 0, 50),
];
let mut blocks = vec![make_block(1, 0, 1500)];
assign_speakers_to_blocks(&mut blocks, &r);
assert_eq!(blocks[0].speaker_id, Some(3));
}
}