#![cfg(feature = "diarization")]
use crate::enrollment::TranscriptRole;
use ndarray::{Array1, Array2, Array3};
use ort::session::Session;
use ort::value::Tensor;
use std::path::PathBuf;
use std::sync::{Arc, Mutex};
#[derive(Debug, Clone)]
pub struct DiarizationConfig {
pub merge_threshold: f32,
pub min_segment_ms: u32,
pub max_speakers: usize,
pub model_path: Option<PathBuf>,
}
impl Default for DiarizationConfig {
fn default() -> Self {
Self {
merge_threshold: 0.70,
min_segment_ms: 250,
max_speakers: 8,
model_path: None,
}
}
}
#[derive(Debug, thiserror::Error)]
pub enum DiarizationError {
#[error("model file not found at {0}")]
ModelMissing(PathBuf),
#[error("ONNX runtime error: {0}")]
Onnx(String),
#[error("audio segment too short: {0} ms (min {1} ms)")]
SegmentTooShort(u32, u32),
#[error("io: {0}")]
Io(#[from] std::io::Error),
}
#[derive(Debug, Clone)]
struct SpeakerCluster {
id: usize,
centroid: Array1<f32>,
count: usize,
}
impl SpeakerCluster {
fn merge(&mut self, embedding: &Array1<f32>) {
let n = self.count as f32;
for (c, e) in self.centroid.iter_mut().zip(embedding.iter()) {
*c = (n * *c + e) / (n + 1.0);
}
self.count += 1;
l2_normalize_in_place(&mut self.centroid);
}
}
pub struct SpeakerDiarizer {
session: Mutex<Session>,
config: DiarizationConfig,
clusters: Mutex<Vec<SpeakerCluster>>,
next_id: Mutex<usize>,
}
impl SpeakerDiarizer {
pub fn new(config: DiarizationConfig) -> std::result::Result<Self, DiarizationError> {
let path = match &config.model_path {
Some(p) => p.clone(),
None => default_model_path()?,
};
if !path.exists() {
download_model(&path)?;
}
let session = Session::builder()
.map_err(|e| DiarizationError::Onnx(e.to_string()))?
.commit_from_file(&path)
.map_err(|e| DiarizationError::Onnx(e.to_string()))?;
Ok(Self {
session: Mutex::new(session),
config,
clusters: Mutex::new(Vec::new()),
next_id: Mutex::new(0),
})
}
pub fn classify(
&self,
samples: &[i16],
sample_rate: u32,
) -> std::result::Result<TranscriptRole, DiarizationError> {
let duration_ms = (samples.len() as u64 * 1000 / sample_rate as u64) as u32;
if duration_ms < self.config.min_segment_ms {
return Ok(TranscriptRole::Unknown);
}
let embedding = self.embed(samples, sample_rate)?;
let label = self.assign(embedding);
Ok(label)
}
pub fn speaker_count(&self) -> usize {
self.clusters.lock().unwrap().len()
}
fn embed(
&self,
samples: &[i16],
sample_rate: u32,
) -> std::result::Result<Array1<f32>, DiarizationError> {
let pcm_16k = if sample_rate == 16_000 {
i16_to_f32(samples)
} else {
resample_to_16k(samples, sample_rate)
};
let mel = log_mel_spectrogram(&pcm_16k);
let input = mel_to_input_tensor(&mel);
let mut session = self
.session
.lock()
.map_err(|e| DiarizationError::Onnx(format!("session lock: {}", e)))?;
let outputs = session
.run(ort::inputs![input])
.map_err(|e| DiarizationError::Onnx(e.to_string()))?;
let (_shape, data) = outputs[0]
.try_extract_tensor::<f32>()
.map_err(|e| DiarizationError::Onnx(e.to_string()))?;
let mut embedding = Array1::from_iter(data.iter().copied());
l2_normalize_in_place(&mut embedding);
Ok(embedding)
}
fn assign(&self, embedding: Array1<f32>) -> TranscriptRole {
let mut clusters = self.clusters.lock().unwrap();
let mut next_id = self.next_id.lock().unwrap();
let best = clusters
.iter()
.enumerate()
.map(|(i, c)| (i, cosine_similarity(&embedding, &c.centroid)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
match best {
Some((idx, sim)) if sim >= self.config.merge_threshold => {
clusters[idx].merge(&embedding);
TranscriptRole::OtherSpeaker {
local_id: format!("speaker_{}", clusters[idx].id),
}
}
Some((idx, _sim)) if clusters.len() >= self.config.max_speakers => {
clusters[idx].merge(&embedding);
TranscriptRole::OtherSpeaker {
local_id: format!("speaker_{}", clusters[idx].id),
}
}
_ => {
let id = *next_id;
*next_id += 1;
let cluster = SpeakerCluster {
id,
centroid: embedding,
count: 1,
};
clusters.push(cluster);
TranscriptRole::OtherSpeaker {
local_id: format!("speaker_{}", id),
}
}
}
}
}
const SAMPLE_RATE_HZ: usize = 16_000;
const N_FFT: usize = 512;
const HOP_LENGTH: usize = 160; const N_MELS: usize = 80;
fn log_mel_spectrogram(pcm: &[f32]) -> Array2<f32> {
let frames = (pcm.len().saturating_sub(N_FFT)) / HOP_LENGTH + 1;
if frames == 0 {
return Array2::zeros((N_MELS, 1));
}
let mel_filters = mel_filterbank();
let window = hann_window(N_FFT);
let mut mel = Array2::<f32>::zeros((N_MELS, frames));
for f in 0..frames {
let start = f * HOP_LENGTH;
let end = (start + N_FFT).min(pcm.len());
let mut frame = vec![0f32; N_FFT];
let len = end - start;
frame[..len].copy_from_slice(&pcm[start..end]);
for (i, w) in window.iter().enumerate() {
frame[i] *= w;
}
let spec = power_spectrum(&frame);
for m in 0..N_MELS {
let energy: f32 = spec
.iter()
.zip(mel_filters[m].iter())
.map(|(s, w)| s * w)
.sum();
mel[(m, f)] = (energy + 1e-10).ln();
}
}
mel
}
fn power_spectrum(frame: &[f32]) -> Vec<f32> {
let n = frame.len();
let half = n / 2 + 1;
let mut out = Vec::with_capacity(half);
for k in 0..half {
let (mut re, mut im) = (0f32, 0f32);
for (i, &x) in frame.iter().enumerate() {
let theta = -2.0 * std::f32::consts::PI * (k * i) as f32 / n as f32;
re += x * theta.cos();
im += x * theta.sin();
}
out.push(re * re + im * im);
}
out
}
fn hann_window(n: usize) -> Vec<f32> {
(0..n)
.map(|i| 0.5 - 0.5 * (2.0 * std::f32::consts::PI * i as f32 / (n - 1) as f32).cos())
.collect()
}
fn mel_filterbank() -> Vec<Vec<f32>> {
let n_fft_bins = N_FFT / 2 + 1;
let mel_low = hz_to_mel(0.0);
let mel_high = hz_to_mel(SAMPLE_RATE_HZ as f32 / 2.0);
let mel_pts: Vec<f32> = (0..N_MELS + 2)
.map(|i| mel_low + (mel_high - mel_low) * i as f32 / (N_MELS + 1) as f32)
.collect();
let hz_pts: Vec<f32> = mel_pts.iter().map(|m| mel_to_hz(*m)).collect();
let bin_pts: Vec<f32> = hz_pts
.iter()
.map(|h| h * (N_FFT as f32) / (SAMPLE_RATE_HZ as f32))
.collect();
let mut filters = vec![vec![0f32; n_fft_bins]; N_MELS];
for m in 0..N_MELS {
let left = bin_pts[m];
let center = bin_pts[m + 1];
let right = bin_pts[m + 2];
for k in 0..n_fft_bins {
let kf = k as f32;
let v = if kf < left || kf > right {
0.0
} else if kf <= center {
(kf - left) / (center - left).max(1e-10)
} else {
(right - kf) / (right - center).max(1e-10)
};
filters[m][k] = v.max(0.0);
}
}
filters
}
fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10f32.powf(mel / 2595.0) - 1.0)
}
fn mel_to_input_tensor(mel: &Array2<f32>) -> Tensor<f32> {
let (n_mels, n_frames) = mel.dim();
let arr = Array3::from_shape_fn((1, n_frames, n_mels), |(_, t, m)| mel[(m, t)]);
Tensor::from_array(arr).expect("mel array shape always builds a valid tensor")
}
fn i16_to_f32(samples: &[i16]) -> Vec<f32> {
samples
.iter()
.map(|s| *s as f32 / i16::MAX as f32)
.collect()
}
fn resample_to_16k(samples: &[i16], from_rate: u32) -> Vec<f32> {
let f32_in = i16_to_f32(samples);
let ratio = SAMPLE_RATE_HZ as f64 / from_rate as f64;
let out_len = (f32_in.len() as f64 * ratio) as usize;
let mut out = Vec::with_capacity(out_len);
for i in 0..out_len {
let src = i as f64 / ratio;
let lo = src.floor() as usize;
let hi = (lo + 1).min(f32_in.len() - 1);
let frac = (src - lo as f64) as f32;
out.push(f32_in[lo] * (1.0 - frac) + f32_in[hi] * frac);
}
out
}
fn cosine_similarity(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na <= f32::EPSILON || nb <= f32::EPSILON {
return 0.0;
}
dot / (na * nb)
}
fn l2_normalize_in_place(v: &mut Array1<f32>) {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > f32::EPSILON {
for x in v.iter_mut() {
*x /= norm;
}
}
}
fn default_model_path() -> std::result::Result<PathBuf, DiarizationError> {
let home = std::env::var_os("HOME").map(PathBuf::from).ok_or_else(|| {
DiarizationError::Io(std::io::Error::new(
std::io::ErrorKind::NotFound,
"HOME not set",
))
})?;
Ok(home
.join(".car")
.join("models")
.join("wespeaker-resnet34.onnx"))
}
const MODEL_URL: &str =
"https://huggingface.co/onnx-community/wespeaker-voxceleb-resnet34-LM/resolve/main/onnx/model.onnx";
fn download_model(target: &PathBuf) -> std::result::Result<(), DiarizationError> {
if let Some(parent) = target.parent() {
std::fs::create_dir_all(parent)?;
}
tracing::info!(
"[diarization] downloading speaker embedder (~28 MB) to {}",
target.display()
);
let bytes = reqwest::blocking::get(MODEL_URL)
.and_then(|r| r.bytes())
.map_err(|e| DiarizationError::Onnx(format!("download failed: {}", e)))?;
std::fs::write(target, &bytes)?;
tracing::info!("[diarization] wrote {} bytes", bytes.len());
Ok(())
}
pub fn role_to_str(role: &TranscriptRole) -> String {
match role {
TranscriptRole::EnrolledUser => "enrolled_user".to_string(),
TranscriptRole::OtherSpeaker { local_id } => format!("other:{}", local_id),
TranscriptRole::Unknown => "unknown".to_string(),
}
}
pub type Diarizer = SpeakerDiarizer;
pub type SharedDiarizer = Arc<Diarizer>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn diarizer_model_url_points_at_onnx_community_mirror() {
assert!(
MODEL_URL.starts_with("https://huggingface.co/onnx-community/"),
"diarizer URL must point at the onnx-community mirror that \
actually publishes the ONNX export; got {MODEL_URL}"
);
assert!(
MODEL_URL.ends_with("/model.onnx"),
"diarizer URL must point at the unquantized model.onnx file \
(~28 MB) for parity with the docstring; got {MODEL_URL}"
);
}
#[test]
fn mel_to_input_tensor_has_trailing_mel_axis() {
let n_frames = 4;
let mel = Array2::from_shape_fn((N_MELS, n_frames), |(m, t)| (m * 100 + t) as f32);
let tensor = mel_to_input_tensor(&mel);
let (shape, _data) = tensor
.try_extract_tensor::<f32>()
.expect("tensor should extract back to ndarray");
let dims: Vec<i64> = shape.to_vec();
assert_eq!(
dims,
vec![1_i64, n_frames as i64, N_MELS as i64],
"expected (B=1, T={n_frames}, 80); got {dims:?}"
);
}
#[test]
fn cluster_merge_updates_running_mean() {
let mut cluster = SpeakerCluster {
id: 0,
centroid: Array1::from_vec(vec![1.0, 0.0, 0.0]),
count: 1,
};
let new = Array1::from_vec(vec![0.0, 1.0, 0.0]);
cluster.merge(&new);
let expected = 1.0 / 2f32.sqrt();
assert!((cluster.centroid[0] - expected).abs() < 1e-4);
assert!((cluster.centroid[1] - expected).abs() < 1e-4);
assert!(cluster.centroid[2].abs() < 1e-4);
assert_eq!(cluster.count, 2);
}
#[test]
fn cosine_similarity_matches_definition() {
let a = Array1::from_vec(vec![1.0, 0.0, 0.0]);
let b = Array1::from_vec(vec![1.0, 0.0, 0.0]);
let c = Array1::from_vec(vec![0.0, 1.0, 0.0]);
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 1e-6);
assert!(cosine_similarity(&a, &c).abs() < 1e-6);
}
#[test]
fn cosine_similarity_handles_zero_vectors() {
let a = Array1::from_vec(vec![0.0, 0.0, 0.0]);
let b = Array1::from_vec(vec![1.0, 0.0, 0.0]);
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn l2_normalize_unit_norm() {
let mut v = Array1::from_vec(vec![3.0, 4.0, 0.0]);
l2_normalize_in_place(&mut v);
let norm = v.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
}
#[test]
fn hz_mel_round_trip() {
let hz = 1000.0_f32;
let mel = hz_to_mel(hz);
let back = mel_to_hz(mel);
assert!((hz - back).abs() < 0.5);
}
#[test]
fn mel_filterbank_has_expected_shape() {
let bank = mel_filterbank();
assert_eq!(bank.len(), N_MELS);
let n_fft_bins = N_FFT / 2 + 1;
for filt in &bank {
assert_eq!(filt.len(), n_fft_bins);
assert!(filt.iter().all(|v| *v >= 0.0));
}
}
#[test]
fn role_to_str_matches_session_format() {
assert_eq!(role_to_str(&TranscriptRole::EnrolledUser), "enrolled_user");
assert_eq!(role_to_str(&TranscriptRole::Unknown), "unknown");
assert_eq!(
role_to_str(&TranscriptRole::OtherSpeaker {
local_id: "speaker_0".into()
}),
"other:speaker_0"
);
}
#[test]
fn resample_halves_sample_count_at_2x_rate() {
let input = vec![0i16; 32_000]; let out = resample_to_16k(&input, 32_000);
assert!(((out.len() as i32) - 16_000).abs() <= 1);
}
#[test]
fn resample_passthrough_at_16k() {
let input = vec![0i16; 16_000];
let out = resample_to_16k(&input, 16_000);
assert_eq!(out.len(), 16_000);
}
#[test]
fn config_defaults_are_sane() {
let c = DiarizationConfig::default();
assert!(c.merge_threshold > 0.0 && c.merge_threshold < 1.0);
assert!(c.min_segment_ms > 0);
assert!(c.max_speakers > 0);
}
#[test]
fn synthetic_clustering_assigns_three_speakers() {
let config = DiarizationConfig::default();
let mut clusters: Vec<SpeakerCluster> = Vec::new();
let mut next_id = 0;
let inputs = [
Array1::from_vec(vec![1.0, 0.0, 0.0]),
Array1::from_vec(vec![0.0, 1.0, 0.0]),
Array1::from_vec(vec![0.0, 0.0, 1.0]),
Array1::from_vec(vec![0.95, 0.05, 0.0]),
Array1::from_vec(vec![0.99, 0.01, 0.0]),
];
let mut labels = Vec::new();
for emb in &inputs {
let best = clusters
.iter()
.enumerate()
.map(|(i, c)| (i, cosine_similarity(emb, &c.centroid)))
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
match best {
Some((idx, sim)) if sim >= config.merge_threshold => {
clusters[idx].merge(emb);
labels.push(clusters[idx].id);
}
_ => {
let id = next_id;
next_id += 1;
clusters.push(SpeakerCluster {
id,
centroid: emb.clone(),
count: 1,
});
labels.push(id);
}
}
}
assert_eq!(labels.len(), 5);
assert_eq!(labels[0], 0);
assert_eq!(labels[1], 1);
assert_eq!(labels[2], 2);
assert_eq!(labels[3], 0);
assert_eq!(labels[4], 0);
assert_eq!(clusters.len(), 3);
}
}