use crate::types::{SpeakerId, SpeakerTurn, TimeRange};
#[cfg(feature = "segmentation")]
const TIME_RANGE_EPS_SECS: f64 = 1e-6;
pub trait Resegmenter: Send + Sync {
fn resegment(&self, inputs: ResegmentInputs<'_>) -> Result<Vec<SpeakerTurn>, ResegmentError>;
}
#[derive(Debug, Clone)]
pub struct ResegmentInputs<'a> {
pub primary_turns: &'a [SpeakerTurn],
pub speaker_centroids: &'a [SpeakerCentroid],
pub overlap_regions: &'a [OverlapRegionInput],
}
#[derive(Debug, Clone, PartialEq)]
pub struct SpeakerCentroid {
pub speaker: SpeakerId,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct OverlapRegionInput {
pub time: TimeRange,
pub primary_speaker: SpeakerId,
pub embedding: Vec<f32>,
}
#[derive(Debug, thiserror::Error)]
pub enum ResegmentError {
#[error("centroid dim mismatch at index {index}: expected {expected}, got {actual}")]
CentroidDimMismatch {
index: usize,
expected: usize,
actual: usize,
},
#[error("overlap embedding dim mismatch at index {index}: expected {expected}, got {actual}")]
OverlapDimMismatch {
index: usize,
expected: usize,
actual: usize,
},
#[error("primary speaker {primary} for overlap region {index} not present in centroids")]
MissingPrimaryCentroid { index: usize, primary: SpeakerId },
}
pub fn compute_centroids(embeddings: &[Vec<f32>], labels: &[usize]) -> Vec<SpeakerCentroid> {
if embeddings.len() != labels.len() || embeddings.is_empty() {
return Vec::new();
}
let mut buckets: std::collections::BTreeMap<usize, Vec<&Vec<f32>>> =
std::collections::BTreeMap::new();
for (emb, &lbl) in embeddings.iter().zip(labels.iter()) {
buckets.entry(lbl).or_default().push(emb);
}
let mut out = Vec::with_capacity(buckets.len());
for (lbl, members) in buckets {
let owned: Vec<Vec<f32>> = members.iter().map(|e| (*e).clone()).collect();
if let Some(mut mean) = crate::utils::mean_vector(&owned) {
crate::utils::l2_normalize(&mut mean);
let id = SpeakerId(lbl as u32);
out.push(SpeakerCentroid {
speaker: id,
embedding: mean,
});
}
}
out.sort_by_key(|c| c.speaker.0);
out
}
#[cfg(feature = "segmentation")]
pub fn extract_overlap_time_ranges(
segments: &[crate::segmentation::RawSegment],
) -> Vec<(TimeRange, u8, u8)> {
let mut pairs: Vec<(TimeRange, u8, u8)> = Vec::new();
for (i, a) in segments.iter().enumerate() {
if !a.is_overlap {
continue;
}
for b in segments.iter().skip(i + 1) {
if !b.is_overlap {
continue;
}
if a.local_speaker_idx == b.local_speaker_idx {
continue;
}
if (a.time.start - b.time.start).abs() > TIME_RANGE_EPS_SECS
|| (a.time.end - b.time.end).abs() > TIME_RANGE_EPS_SECS
{
continue;
}
let (lo, hi) = if a.local_speaker_idx < b.local_speaker_idx {
(a.local_speaker_idx, b.local_speaker_idx)
} else {
(b.local_speaker_idx, a.local_speaker_idx)
};
pairs.push((a.time, lo, hi));
}
}
pairs
}
#[derive(Debug, Clone, Copy)]
pub struct OverlapResegmenter {
threshold: f32,
min_overlap_secs: f32,
}
impl OverlapResegmenter {
pub fn new(threshold: f32, min_overlap_secs: f32) -> Self {
Self {
threshold,
min_overlap_secs: min_overlap_secs.max(0.0),
}
}
pub fn threshold(&self) -> f32 {
self.threshold
}
pub fn min_overlap_secs(&self) -> f32 {
self.min_overlap_secs
}
}
impl Default for OverlapResegmenter {
fn default() -> Self {
Self::new(0.0, 0.1)
}
}
impl Resegmenter for OverlapResegmenter {
fn resegment(&self, inputs: ResegmentInputs<'_>) -> Result<Vec<SpeakerTurn>, ResegmentError> {
let mut out: Vec<SpeakerTurn> = inputs.primary_turns.to_vec();
if inputs.speaker_centroids.len() < 2 || inputs.overlap_regions.is_empty() {
out.sort_by(|a, b| a.time.start.total_cmp(&b.time.start));
return Ok(out);
}
let expected_dim = inputs.speaker_centroids[0].embedding.len();
for (i, c) in inputs.speaker_centroids.iter().enumerate() {
if c.embedding.len() != expected_dim {
return Err(ResegmentError::CentroidDimMismatch {
index: i,
expected: expected_dim,
actual: c.embedding.len(),
});
}
}
for (i, region) in inputs.overlap_regions.iter().enumerate() {
if region.embedding.len() != expected_dim {
return Err(ResegmentError::OverlapDimMismatch {
index: i,
expected: expected_dim,
actual: region.embedding.len(),
});
}
if !inputs
.speaker_centroids
.iter()
.any(|c| c.speaker == region.primary_speaker)
{
return Err(ResegmentError::MissingPrimaryCentroid {
index: i,
primary: region.primary_speaker,
});
}
if region.time.duration() < f64::from(self.min_overlap_secs) {
continue;
}
let mut best: Option<(SpeakerId, f32)> = None;
for c in inputs.speaker_centroids.iter() {
if c.speaker == region.primary_speaker {
continue;
}
let s = crate::utils::cosine_similarity(®ion.embedding, &c.embedding);
let take = match best {
None => true,
Some((_, b)) => s > b,
};
if take {
best = Some((c.speaker, s));
}
}
if let Some((id, score)) = best
&& score > self.threshold
{
out.push(SpeakerTurn {
speaker: id,
time: region.time,
text: None,
});
}
}
out.sort_by(|a, b| a.time.start.total_cmp(&b.time.start));
Ok(out)
}
}
#[cfg(test)]
mod trait_tests {
use super::*;
struct ConstantResegmenter {
out: Vec<SpeakerTurn>,
}
impl Resegmenter for ConstantResegmenter {
fn resegment(
&self,
_inputs: ResegmentInputs<'_>,
) -> Result<Vec<SpeakerTurn>, ResegmentError> {
Ok(self.out.clone())
}
}
fn turn(start: f64, end: f64, spk: u32) -> SpeakerTurn {
SpeakerTurn {
speaker: SpeakerId(spk),
time: TimeRange { start, end },
text: None,
}
}
#[test]
fn resegmenter_trait_object_is_dyn_compatible() {
let r = ConstantResegmenter {
out: vec![turn(0.0, 1.0, 0)],
};
let _b: Box<dyn Resegmenter> = Box::new(r);
}
#[test]
fn resegmenter_returns_owned_turns() {
let r = ConstantResegmenter {
out: vec![turn(0.0, 1.0, 0), turn(1.0, 2.0, 1)],
};
let inputs = ResegmentInputs {
primary_turns: &[],
speaker_centroids: &[],
overlap_regions: &[],
};
let out = r.resegment(inputs).unwrap();
assert_eq!(out.len(), 2);
assert_eq!(out[0].speaker, SpeakerId(0));
}
#[test]
fn error_centroid_dim_mismatch_displays() {
let err = ResegmentError::CentroidDimMismatch {
index: 1,
expected: 192,
actual: 256,
};
let msg = format!("{err}");
assert!(msg.contains("192"));
assert!(msg.contains("256"));
assert!(msg.contains("index 1"));
}
#[test]
fn error_overlap_dim_mismatch_displays() {
let err = ResegmentError::OverlapDimMismatch {
index: 0,
expected: 192,
actual: 64,
};
let msg = format!("{err}");
assert!(msg.contains("192"));
assert!(msg.contains("64"));
}
#[test]
fn error_missing_primary_centroid_displays() {
let err = ResegmentError::MissingPrimaryCentroid {
index: 2,
primary: SpeakerId(7),
};
let msg = format!("{err}");
assert!(msg.contains('2'));
assert!(msg.contains('7'));
}
}
#[cfg(test)]
mod centroid_tests {
use super::*;
fn unit(dim: usize, axis: usize) -> Vec<f32> {
let mut v = vec![0.0_f32; dim];
v[axis] = 1.0;
v
}
#[test]
fn compute_centroids_l2_normalized() {
let embeddings = vec![unit(3, 0), unit(3, 0), unit(3, 1), unit(3, 1)];
let labels = vec![0, 0, 1, 1];
let centroids = compute_centroids(&embeddings, &labels);
assert_eq!(centroids.len(), 2);
for c in ¢roids {
let n: f32 = c.embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(n - 1.0).abs() < 1e-3,
"centroid not L2-normalized: norm={n}"
);
}
}
#[test]
fn compute_centroids_drops_empty_clusters() {
let embeddings = vec![unit(3, 0), unit(3, 1), unit(3, 1)];
let labels = vec![0, 2, 2];
let centroids = compute_centroids(&embeddings, &labels);
assert_eq!(centroids.len(), 2);
let speakers: Vec<u32> = centroids.iter().map(|c| c.speaker.0).collect();
assert_eq!(speakers, vec![0, 2]);
}
#[test]
fn compute_centroids_sorted_by_speaker_id() {
let embeddings = vec![unit(3, 0), unit(3, 1), unit(3, 2)];
let labels = vec![5, 1, 3];
let centroids = compute_centroids(&embeddings, &labels);
let speakers: Vec<u32> = centroids.iter().map(|c| c.speaker.0).collect();
assert_eq!(speakers, vec![1, 3, 5]);
}
#[test]
fn compute_centroids_empty_input_returns_empty() {
let centroids = compute_centroids(&[], &[]);
assert!(centroids.is_empty());
}
#[test]
fn compute_centroids_label_mismatch_returns_empty() {
let centroids = compute_centroids(&[unit(3, 0)], &[0, 1]);
assert!(centroids.is_empty());
}
}
#[cfg(all(test, feature = "segmentation"))]
mod overlap_extract_tests {
use super::*;
use crate::segmentation::RawSegment;
use crate::types::Confidence;
fn raw(start: f64, end: f64, spk: u8, overlap: bool) -> RawSegment {
RawSegment {
time: TimeRange { start, end },
local_speaker_idx: spk,
is_overlap: overlap,
confidence: Confidence::new(0.9).unwrap(),
}
}
#[test]
fn extract_returns_pairs_for_simultaneous_overlap_segments() {
let segs = vec![raw(0.0, 1.0, 0, true), raw(0.0, 1.0, 1, true)];
let pairs = extract_overlap_time_ranges(&segs);
assert_eq!(pairs.len(), 1);
assert!((pairs[0].0.start - 0.0).abs() < 1e-6);
assert!((pairs[0].0.end - 1.0).abs() < 1e-6);
assert_eq!(pairs[0].1, 0);
assert_eq!(pairs[0].2, 1);
}
#[test]
fn extract_ignores_non_overlap_segments() {
let segs = vec![raw(0.0, 1.0, 0, false), raw(0.0, 1.0, 1, false)];
let pairs = extract_overlap_time_ranges(&segs);
assert!(pairs.is_empty());
}
#[test]
fn extract_ignores_overlap_flag_without_pair() {
let segs = vec![raw(0.0, 1.0, 0, true)];
let pairs = extract_overlap_time_ranges(&segs);
assert!(pairs.is_empty());
}
#[test]
fn extract_handles_multiple_overlap_regions() {
let segs = vec![
raw(0.0, 1.0, 0, true),
raw(0.0, 1.0, 1, true),
raw(2.0, 3.0, 1, true),
raw(2.0, 3.0, 2, true),
];
let pairs = extract_overlap_time_ranges(&segs);
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0].1, 0);
assert_eq!(pairs[0].2, 1);
assert_eq!(pairs[1].1, 1);
assert_eq!(pairs[1].2, 2);
}
#[test]
fn extract_three_way_overlap_emits_all_three_pairs() {
let segs = vec![
raw(0.0, 1.0, 0, true),
raw(0.0, 1.0, 1, true),
raw(0.0, 1.0, 2, true),
];
let pairs = extract_overlap_time_ranges(&segs);
assert_eq!(pairs.len(), 3);
let local_pairs: std::collections::HashSet<(u8, u8)> =
pairs.iter().map(|p| (p.1, p.2)).collect();
assert!(local_pairs.contains(&(0, 1)));
assert!(local_pairs.contains(&(0, 2)));
assert!(local_pairs.contains(&(1, 2)));
}
}
#[cfg(test)]
mod resegmenter_tests {
use super::*;
use crate::types::{SpeakerId, SpeakerTurn, TimeRange};
fn unit(dim: usize, axis: usize) -> Vec<f32> {
let mut v = vec![0.0_f32; dim];
v[axis] = 1.0;
v
}
fn turn(start: f64, end: f64, spk: u32) -> SpeakerTurn {
SpeakerTurn {
speaker: SpeakerId(spk),
time: TimeRange { start, end },
text: None,
}
}
fn centroid(spk: u32, dim: usize, axis: usize) -> SpeakerCentroid {
SpeakerCentroid {
speaker: SpeakerId(spk),
embedding: unit(dim, axis),
}
}
fn region(start: f64, end: f64, primary: u32, dim: usize, axis: usize) -> OverlapRegionInput {
OverlapRegionInput {
time: TimeRange { start, end },
primary_speaker: SpeakerId(primary),
embedding: unit(dim, axis),
}
}
#[test]
fn no_overlap_passes_primary_through() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0), turn(2.0, 3.0, 1)];
let centroids = vec![centroid(0, 3, 0), centroid(1, 3, 1)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: &[],
};
let out = r.resegment(inputs).unwrap();
assert_eq!(out, primary);
}
#[test]
fn single_cluster_passes_through() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0)];
let centroids = vec![centroid(0, 3, 0)];
let regions = vec![region(0.5, 0.9, 0, 3, 0)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let out = r.resegment(inputs).unwrap();
assert_eq!(out, primary);
}
#[test]
fn picks_secondary_excluding_primary() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0)];
let centroids = vec![centroid(0, 3, 0), centroid(1, 3, 1), centroid(2, 3, 2)];
let regions = vec![region(0.0, 1.0, 0, 3, 1)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let out = r.resegment(inputs).unwrap();
assert_eq!(out.len(), 2);
let speakers: Vec<u32> = out.iter().map(|t| t.speaker.0).collect();
assert!(speakers.contains(&0));
assert!(speakers.contains(&1));
assert!(!speakers.contains(&2));
}
#[test]
fn threshold_blocks_low_cosine() {
let r = OverlapResegmenter::new(0.99, 0.0);
let primary = vec![turn(0.0, 1.0, 0)];
let centroids = vec![centroid(0, 3, 0), centroid(1, 3, 1)];
let regions = vec![region(0.0, 1.0, 0, 3, 0)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let out = r.resegment(inputs).unwrap();
assert_eq!(out, primary, "no secondary should be appended");
}
#[test]
fn min_duration_blocks_short_region() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0)];
let centroids = vec![centroid(0, 3, 0), centroid(1, 3, 1)];
let regions = vec![region(0.10, 0.15, 0, 3, 1)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let out = r.resegment(inputs).unwrap();
assert_eq!(out, primary);
}
#[test]
fn output_is_sorted_by_start() {
let r = OverlapResegmenter::default();
let primary = vec![turn(2.0, 3.0, 0), turn(0.0, 1.0, 0)];
let centroids = vec![centroid(0, 3, 0), centroid(1, 3, 1)];
let regions = vec![region(2.0, 3.0, 0, 3, 1)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let out = r.resegment(inputs).unwrap();
for w in out.windows(2) {
assert!(w[0].time.start <= w[1].time.start);
}
}
#[test]
fn missing_primary_centroid_errors() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0)];
let centroids = vec![centroid(1, 3, 1), centroid(2, 3, 2)];
let regions = vec![region(0.0, 1.0, 0, 3, 1)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let err = r.resegment(inputs).expect_err("missing primary must error");
assert!(matches!(
err,
ResegmentError::MissingPrimaryCentroid {
primary: SpeakerId(0),
..
}
));
}
#[test]
fn centroid_dim_mismatch_errors() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0)];
let centroids = vec![
centroid(0, 3, 0),
SpeakerCentroid {
speaker: SpeakerId(1),
embedding: vec![1.0, 0.0], },
];
let regions = vec![region(0.0, 1.0, 0, 3, 1)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let err = r.resegment(inputs).expect_err("dim mismatch must error");
assert!(matches!(err, ResegmentError::CentroidDimMismatch { .. }));
}
#[test]
fn overlap_dim_mismatch_errors() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0)];
let centroids = vec![centroid(0, 3, 0), centroid(1, 3, 1)];
let regions = vec![OverlapRegionInput {
time: TimeRange {
start: 0.0,
end: 1.0,
},
primary_speaker: SpeakerId(0),
embedding: vec![1.0, 0.0], }];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: ¢roids,
overlap_regions: ®ions,
};
let err = r.resegment(inputs).expect_err("dim mismatch must error");
assert!(matches!(err, ResegmentError::OverlapDimMismatch { .. }));
}
#[test]
fn empty_centroids_passes_through() {
let r = OverlapResegmenter::default();
let primary = vec![turn(0.0, 1.0, 0)];
let inputs = ResegmentInputs {
primary_turns: &primary,
speaker_centroids: &[],
overlap_regions: &[],
};
let out = r.resegment(inputs).unwrap();
assert_eq!(out, primary);
}
}