use anyhow::Context;
use ort::session::Session;
use ort::value::TensorRef;
use std::path::Path;
use std::sync::Mutex;
pub const EMBEDDING_DIM: usize = 256;
pub const SEGMENT_SAMPLES: usize = 24000;
pub struct SpeakerEncoder {
session: Mutex<Session>,
}
impl SpeakerEncoder {
pub fn load(model_dir: &Path) -> anyhow::Result<Self> {
let path = model_dir.join("wespeaker_resnet34.onnx");
if !path.exists() {
anyhow::bail!(
"wespeaker_resnet34.onnx not found in {}",
model_dir.display()
);
}
let session = Session::builder()
.context("Failed to create ONNX session builder")?
.commit_from_file(&path)
.context("Failed to load speaker encoder model")?;
Ok(Self {
session: Mutex::new(session),
})
}
pub fn extract_embedding(&self, samples: &[f32]) -> anyhow::Result<[f32; EMBEDDING_DIM]> {
let mut buf = [0.0f32; SEGMENT_SAMPLES];
let copy_len = samples.len().min(SEGMENT_SAMPLES);
buf[..copy_len].copy_from_slice(&samples[..copy_len]);
let input_tensor =
TensorRef::from_array_view(([1_usize, SEGMENT_SAMPLES], buf.as_slice()))?;
let mut session = self.session.lock().unwrap_or_else(|e| {
tracing::warn!("SpeakerEncoder session mutex was poisoned, recovering");
e.into_inner()
});
let outputs = session
.run(ort::inputs![input_tensor])
.context("SpeakerEncoder inference failed")?;
let (_shape, data) = outputs[0]
.try_extract_tensor::<f32>()
.context("Failed to extract speaker embedding tensor")?;
anyhow::ensure!(
data.len() >= EMBEDDING_DIM,
"Expected embedding dim >= {}, got {}",
EMBEDDING_DIM,
data.len()
);
let mut embedding = [0.0f32; EMBEDDING_DIM];
embedding.copy_from_slice(&data[..EMBEDDING_DIM]);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for v in &mut embedding {
*v /= norm;
}
}
Ok(embedding)
}
}
const COSINE_THRESHOLD: f32 = 0.5;
const MAX_SPEAKERS: usize = 64;
pub struct SpeakerCluster {
centroids: Vec<[f32; EMBEDDING_DIM]>,
counts: Vec<usize>,
threshold: f32,
}
impl SpeakerCluster {
pub fn new() -> Self {
Self::with_threshold(COSINE_THRESHOLD)
}
pub fn with_threshold(threshold: f32) -> Self {
Self {
centroids: Vec::new(),
counts: Vec::new(),
threshold,
}
}
pub fn assign(&mut self, embedding: &[f32; EMBEDDING_DIM]) -> u32 {
let mut best_id: Option<usize> = None;
let mut best_sim = f32::NEG_INFINITY;
for (i, centroid) in self.centroids.iter().enumerate() {
let sim = cosine_similarity(embedding, centroid);
if sim > best_sim {
best_sim = sim;
best_id = Some(i);
}
}
if self.centroids.len() >= MAX_SPEAKERS {
let id = best_id.unwrap_or(0);
let n = self.counts[id] as f32;
let centroid = &mut self.centroids[id];
for (c, &e) in centroid.iter_mut().zip(embedding.iter()) {
*c = (*c * n + e) / (n + 1.0);
}
let norm: f32 = self.centroids[id].iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in &mut self.centroids[id] {
*x /= norm;
}
}
self.counts[id] += 1;
return id as u32;
}
if let Some(id) = best_id
&& best_sim > self.threshold
{
let n = self.counts[id] as f32;
let centroid = &mut self.centroids[id];
for (c, &e) in centroid.iter_mut().zip(embedding.iter()) {
*c = (*c * n + e) / (n + 1.0);
}
self.counts[id] += 1;
let norm: f32 = self.centroids[id].iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 1e-8 {
for x in &mut self.centroids[id] {
*x /= norm;
}
}
return id as u32;
}
self.centroids.push(*embedding);
self.counts.push(1);
(self.centroids.len() - 1) as u32
}
pub fn num_speakers(&self) -> usize {
self.centroids.len()
}
}
impl Default for SpeakerCluster {
fn default() -> Self {
Self::new()
}
}
pub fn cosine_similarity(a: &[f32; EMBEDDING_DIM], b: &[f32; EMBEDDING_DIM]) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum();
let norm_a: f32 = a.iter().map(|&x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|&x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-8 || norm_b < 1e-8 {
return 0.0;
}
dot / (norm_a * norm_b)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_embedding(value: f32) -> [f32; EMBEDDING_DIM] {
[value; EMBEDDING_DIM]
}
fn make_unit_embedding(index: usize) -> [f32; EMBEDDING_DIM] {
let mut emb = [0.0f32; EMBEDDING_DIM];
emb[index] = 1.0;
emb
}
#[test]
fn test_cosine_similarity_identical() {
let a = make_embedding(1.0);
let sim = cosine_similarity(&a, &a);
assert!((sim - 1.0).abs() < 1e-5, "expected ~1.0, got {sim}");
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = make_unit_embedding(0);
let b = make_unit_embedding(1);
let sim = cosine_similarity(&a, &b);
assert!(sim.abs() < 1e-5, "expected ~0.0, got {sim}");
}
#[test]
fn test_cosine_similarity_opposite() {
let a = make_embedding(1.0);
let b = make_embedding(-1.0);
let sim = cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 1e-5, "expected ~-1.0, got {sim}");
}
#[test]
fn test_cluster_new_speaker() {
let mut cluster = SpeakerCluster::new();
let emb = make_embedding(1.0);
let id = cluster.assign(&emb);
assert_eq!(id, 0, "first speaker should be ID 0");
assert_eq!(cluster.num_speakers(), 1);
}
#[test]
fn test_cluster_same_speaker() {
let mut cluster = SpeakerCluster::new();
let emb_a = make_embedding(1.0);
let mut emb_b = make_embedding(1.0);
emb_b[0] = 1.001;
let id1 = cluster.assign(&emb_a);
let id2 = cluster.assign(&emb_b);
assert_eq!(
id1, id2,
"similar embeddings should map to the same speaker"
);
assert_eq!(cluster.num_speakers(), 1);
}
#[test]
fn test_cluster_different_speakers() {
let mut cluster = SpeakerCluster::new();
let emb_a = make_unit_embedding(0);
let emb_b = make_unit_embedding(1);
let id1 = cluster.assign(&emb_a);
let id2 = cluster.assign(&emb_b);
assert_ne!(
id1, id2,
"orthogonal embeddings should be different speakers"
);
assert_eq!(cluster.num_speakers(), 2);
}
#[test]
fn test_cluster_three_speakers() {
let mut cluster = SpeakerCluster::new();
let emb_a = make_unit_embedding(0);
let emb_b = make_unit_embedding(1);
let emb_c = make_unit_embedding(2);
let id_a1 = cluster.assign(&emb_a);
let id_b = cluster.assign(&emb_b);
let id_c = cluster.assign(&emb_c);
let id_a2 = cluster.assign(&emb_a);
assert_eq!(id_a1, 0);
assert_eq!(id_b, 1);
assert_eq!(id_c, 2);
assert_eq!(
id_a2, id_a1,
"returning to speaker A should yield the same ID"
);
assert_eq!(cluster.num_speakers(), 3);
}
#[test]
fn test_embedding_dim_constant() {
assert_eq!(EMBEDDING_DIM, 256);
}
#[test]
fn test_segment_samples_constant() {
assert_eq!(SEGMENT_SAMPLES, 24000);
}
#[test]
fn test_load_returns_error_for_missing_file() {
let result = SpeakerEncoder::load(Path::new("/nonexistent/path"));
assert!(result.is_err());
let err = result.err().unwrap();
let msg = format!("{err}");
assert!(msg.contains("wespeaker_resnet34.onnx"));
}
}