use crate::types::{ClusterConfig, SpeakerId, SpeakerIdRemap};
use crate::utils::{cosine_similarity, l2_normalize};
#[derive(Debug, Clone)]
struct Centroid {
vector: Vec<f32>,
count: usize,
confidence_sum: f32,
}
pub struct SpeakerCluster {
centroids: Vec<Centroid>,
config: ClusterConfig,
}
impl SpeakerCluster {
pub fn new(config: ClusterConfig) -> Self {
Self {
centroids: Vec::new(),
config,
}
}
pub fn assign(&mut self, embedding: &[f32]) -> (SpeakerId, f32) {
let mut best_id: Option<usize> = None;
let mut best_sim = f32::NEG_INFINITY;
for (i, centroid) in self.centroids.iter().enumerate() {
let sim = cosine_similarity(embedding, ¢roid.vector);
if sim > best_sim {
best_sim = sim;
best_id = Some(i);
}
}
if self.centroids.len() >= self.config.max_speakers {
let id = best_id.unwrap_or(0);
self.update_centroid(id, embedding, best_sim);
return (SpeakerId(id as u32), best_sim);
}
if let Some(id) = best_id {
if best_sim > self.config.threshold {
self.update_centroid(id, embedding, best_sim);
return (SpeakerId(id as u32), best_sim);
}
}
let new_id = self.centroids.len();
self.centroids.push(Centroid {
vector: embedding.to_vec(),
count: 1,
confidence_sum: 1.0,
});
(SpeakerId(new_id as u32), 1.0)
}
pub fn num_speakers(&self) -> usize {
self.centroids.len()
}
pub fn centroids(&self) -> Vec<(SpeakerId, &[f32], f32)> {
self.centroids
.iter()
.enumerate()
.map(|(i, c)| {
let avg_conf = c.confidence_sum / c.count as f32;
(SpeakerId(i as u32), c.vector.as_slice(), avg_conf)
})
.collect()
}
pub fn merge(&mut self, from: SpeakerId, into: SpeakerId) -> Option<SpeakerIdRemap> {
let from_idx = from.0 as usize;
let into_idx = into.0 as usize;
let old_len = self.centroids.len();
if from_idx >= old_len || into_idx >= old_len || from_idx == into_idx {
return None;
}
let from_c = self.centroids.remove(from_idx);
let adjusted_into = if into_idx > from_idx {
into_idx - 1
} else {
into_idx
};
if adjusted_into >= self.centroids.len() {
return None;
}
let into_c = &mut self.centroids[adjusted_into];
let total_count = into_c.count + from_c.count;
for (i, v) in into_c.vector.iter_mut().enumerate() {
*v = (*v * into_c.count as f32 + from_c.vector[i] * from_c.count as f32)
/ total_count as f32;
}
l2_normalize(&mut into_c.vector);
into_c.count = total_count;
into_c.confidence_sum += from_c.confidence_sum;
let mut mapping = Vec::with_capacity(old_len - self.centroids.len());
for old_id in from_idx..old_len {
let new_id = if old_id == from_idx {
adjusted_into
} else {
old_id - 1
};
mapping.push((SpeakerId(old_id as u32), SpeakerId(new_id as u32)));
}
SpeakerIdRemap::from_mapping(mapping)
}
fn update_centroid(&mut self, id: usize, embedding: &[f32], sim: f32) {
let c = &mut self.centroids[id];
let n = c.count as f32;
for (vec, &emb) in c.vector.iter_mut().zip(embedding.iter()) {
*vec = (*vec * n + emb) / (n + 1.0);
}
l2_normalize(&mut c.vector);
c.count += 1;
c.confidence_sum += sim;
}
}
#[allow(clippy::unwrap_used)]
#[cfg(test)]
mod tests {
use super::*;
fn emb(val: f32, dim: usize) -> Vec<f32> {
let mut v = vec![val; dim];
l2_normalize(&mut v);
v
}
#[test]
fn test_new_speaker() {
let mut cluster = SpeakerCluster::new(ClusterConfig::default());
let e = emb(1.0, 256);
let (id, conf) = cluster.assign(&e);
assert_eq!(id.0, 0);
assert!((conf - 1.0).abs() < 1e-5);
assert_eq!(cluster.num_speakers(), 1);
}
#[test]
fn test_same_speaker() {
let mut cluster = SpeakerCluster::new(ClusterConfig::default());
let e1 = emb(1.0, 256);
let mut e2 = e1.clone();
e2[0] += 0.001;
l2_normalize(&mut e2);
let (id1, _) = cluster.assign(&e1);
let (id2, _) = cluster.assign(&e2);
assert_eq!(id1, id2);
assert_eq!(cluster.num_speakers(), 1);
}
#[test]
fn test_different_speakers() {
let mut cluster = SpeakerCluster::new(ClusterConfig::default());
let mut e1 = vec![0.0f32; 256];
e1[0] = 1.0;
let mut e2 = vec![0.0f32; 256];
e2[1] = 1.0;
let (id1, _) = cluster.assign(&e1);
let (id2, _) = cluster.assign(&e2);
assert_ne!(id1, id2);
assert_eq!(cluster.num_speakers(), 2);
}
#[test]
fn test_max_speakers_limit() {
let config = ClusterConfig {
max_speakers: 2,
..Default::default()
};
let mut cluster = SpeakerCluster::new(config);
let e1 = emb(1.0, 256);
let e2 = emb(-1.0, 256);
let e3 = emb(0.5, 256);
cluster.assign(&e1);
cluster.assign(&e2);
let (id3, _) = cluster.assign(&e3);
assert!(id3.0 < 2);
assert_eq!(cluster.num_speakers(), 2);
}
#[test]
fn test_merge_remaps_speaker_ids() {
let mut cluster = SpeakerCluster::new(ClusterConfig::default());
let mut e0 = vec![0.0f32; 256];
e0[0] = 1.0;
let mut e1 = vec![0.0f32; 256];
e1[1] = 1.0;
let mut e2 = vec![0.0f32; 256];
e2[2] = 1.0;
let (id0, _) = cluster.assign(&e0); let (id1, _) = cluster.assign(&e1); let (id2, _) = cluster.assign(&e2); assert_eq!(cluster.num_speakers(), 3);
let remap = cluster.merge(id1, id0).expect("merge should succeed");
assert_eq!(cluster.num_speakers(), 2);
assert_eq!(remap.remap(id1), id0);
assert_eq!(remap.remap(id2), SpeakerId(1));
assert_eq!(remap.remap(id0), id0);
}
#[test]
fn test_merge_into_higher_index() {
let mut cluster = SpeakerCluster::new(ClusterConfig::default());
let mut e0 = vec![0.0f32; 256];
e0[0] = 1.0;
let mut e1 = vec![0.0f32; 256];
e1[1] = 1.0;
let mut e2 = vec![0.0f32; 256];
e2[2] = 1.0;
let (id0, _) = cluster.assign(&e0);
let (id1, _) = cluster.assign(&e1);
let (id2, _) = cluster.assign(&e2);
let remap = cluster.merge(id0, id2).expect("merge should succeed");
assert_eq!(cluster.num_speakers(), 2);
assert_eq!(remap.remap(id0), SpeakerId(1));
assert_eq!(remap.remap(id1), SpeakerId(0));
assert_eq!(remap.remap(id2), SpeakerId(1));
}
#[test]
fn test_merge_invalid_returns_none() {
let mut cluster = SpeakerCluster::new(ClusterConfig::default());
let e0 = emb(1.0, 256);
cluster.assign(&e0);
assert!(cluster.merge(SpeakerId(0), SpeakerId(0)).is_none());
assert!(cluster.merge(SpeakerId(5), SpeakerId(0)).is_none());
assert!(cluster.merge(SpeakerId(0), SpeakerId(5)).is_none());
}
}