use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use crate::error::VoiceError;
const EMBED_SAMPLE_RATE: u32 = 16_000;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum TranscriptRole {
Unknown,
EnrolledUser,
OtherSpeaker { local_id: String },
}
impl Default for TranscriptRole {
fn default() -> Self {
TranscriptRole::Unknown
}
}
pub trait SpeakerEmbedder: Send + Sync {
fn embed(&self, pcm: &[i16]) -> Vec<f32>;
fn model_id(&self) -> &str;
}
pub struct FilterbankEmbedder {
model_id: String,
mel_filters: usize,
}
impl FilterbankEmbedder {
pub fn new() -> Self {
Self {
model_id: "fbank-stats-v1".to_string(),
mel_filters: 40,
}
}
}
impl Default for FilterbankEmbedder {
fn default() -> Self {
Self::new()
}
}
impl SpeakerEmbedder for FilterbankEmbedder {
fn embed(&self, pcm: &[i16]) -> Vec<f32> {
if pcm.is_empty() {
return Vec::new();
}
let samples: Vec<f32> = pcm.iter().map(|&s| s as f32 / 32768.0).collect();
let frames = mel_filterbank_frames(&samples, EMBED_SAMPLE_RATE, self.mel_filters);
if frames.is_empty() {
return Vec::new();
}
let mut means = vec![0.0f32; self.mel_filters];
for frame in &frames {
for (i, v) in frame.iter().enumerate() {
means[i] += v;
}
}
let n = frames.len() as f32;
for m in &mut means {
*m /= n;
}
let mut vars = vec![0.0f32; self.mel_filters];
for frame in &frames {
for (i, v) in frame.iter().enumerate() {
let d = v - means[i];
vars[i] += d * d;
}
}
for v in &mut vars {
*v = (*v / n).sqrt();
}
let mut values: Vec<f32> = Vec::with_capacity(self.mel_filters * 2);
values.extend_from_slice(&means);
values.extend_from_slice(&vars);
normalize_in_place(&mut values);
values
}
fn model_id(&self) -> &str {
&self.model_id
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeakerEmbedding {
#[serde(default)]
pub values: Vec<f32>,
#[serde(default)]
pub model: String,
}
impl SpeakerEmbedding {
pub fn new(values: Vec<f32>, model: impl Into<String>) -> Self {
Self {
values,
model: model.into(),
}
}
pub fn normalize(&mut self) {
normalize_in_place(&mut self.values);
}
pub fn cosine_similarity(
a: &SpeakerEmbedding,
b: &SpeakerEmbedding,
) -> Result<f32, VoiceError> {
if a.model != b.model {
return Err(VoiceError::Config(format!(
"speaker embedding model mismatch: {} vs {}",
a.model, b.model
)));
}
if a.values.len() != b.values.len() {
return Err(VoiceError::Config(format!(
"speaker embedding dim mismatch: {} vs {}",
a.values.len(),
b.values.len()
)));
}
if a.values.is_empty() {
return Ok(0.0);
}
let dot: f32 = a
.values
.iter()
.zip(b.values.iter())
.map(|(x, y)| x * y)
.sum();
let na: f32 = a.values.iter().map(|v| v * v).sum::<f32>().sqrt();
let nb: f32 = b.values.iter().map(|v| v * v).sum::<f32>().sqrt();
let denom = na * nb;
if denom <= f32::EPSILON {
return Ok(0.0);
}
Ok(dot / denom)
}
}
impl Default for SpeakerEmbedding {
fn default() -> Self {
Self {
values: Vec::new(),
model: String::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Enrollment {
pub label: String,
#[serde(default)]
pub embedding: SpeakerEmbedding,
}
impl Enrollment {
pub fn load_from(path: &Path) -> Result<Self, VoiceError> {
let bytes = std::fs::read(path)?;
let text = std::str::from_utf8(&bytes)
.map_err(|e| VoiceError::Config(format!("enrollment not UTF-8: {e}")))?;
toml::from_str(text).map_err(|e| VoiceError::Config(format!("enrollment parse: {e}")))
}
pub fn save_to(&self, path: &Path) -> Result<(), VoiceError> {
if let Some(parent) = path.parent() {
std::fs::create_dir_all(parent)?;
}
let text = toml::to_string(self)
.map_err(|e| VoiceError::Config(format!("enrollment serialize: {e}")))?;
std::fs::write(path, text)?;
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PendingEnrollment {
pub label: String,
pub save_path: PathBuf,
pub threshold: f32,
}
impl PendingEnrollment {
pub fn new(label: impl Into<String>, save_path: impl Into<PathBuf>, threshold: f32) -> Self {
Self {
label: label.into(),
save_path: save_path.into(),
threshold,
}
}
}
#[derive(Debug, Clone)]
pub enum EnrollmentOutcome {
Captured { label: String, save_path: PathBuf },
Failed { reason: String },
}
const DEFAULT_MATCH_THRESHOLD: f32 = 0.72;
const OTHER_SPEAKER_MARGIN: f32 = 0.15;
#[derive(Debug, Clone)]
pub struct DiarizationTurn {
pub start_ms: u64,
pub end_ms: u64,
pub speaker_id: String,
pub embedding: SpeakerEmbedding,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum TurnRole {
EnrolledUser,
OtherSpeaker(String),
Unknown,
}
impl TurnRole {
pub fn into_transcript_role(self) -> TranscriptRole {
match self {
TurnRole::EnrolledUser => TranscriptRole::EnrolledUser,
TurnRole::OtherSpeaker(id) => TranscriptRole::OtherSpeaker { local_id: id },
TurnRole::Unknown => TranscriptRole::Unknown,
}
}
}
#[derive(Debug, Clone)]
pub struct TaggedTurn {
pub start_ms: u64,
pub end_ms: u64,
pub role: TurnRole,
pub embedding: SpeakerEmbedding,
pub audio: Vec<f32>,
}
pub trait SpeakerVerifier: Send + Sync {
fn matches(
&self,
samples: &[f32],
sample_rate: u32,
enrolled: &Enrollment,
threshold: f32,
) -> Result<bool, VoiceError>;
fn score(
&self,
samples: &[f32],
sample_rate: u32,
enrolled: &Enrollment,
) -> Result<f32, VoiceError>;
}
pub trait Diarizer: Send + Sync {
fn diarize(
&self,
samples: &[f32],
sample_rate: u32,
) -> Result<Vec<DiarizationTurn>, VoiceError>;
}
pub struct CosineVerifier {
embedder: Arc<dyn SpeakerEmbedder>,
}
impl CosineVerifier {
pub fn new(embedder: Arc<dyn SpeakerEmbedder>) -> Self {
Self { embedder }
}
}
impl SpeakerVerifier for CosineVerifier {
fn matches(
&self,
samples: &[f32],
sample_rate: u32,
enrolled: &Enrollment,
threshold: f32,
) -> Result<bool, VoiceError> {
Ok(self.score(samples, sample_rate, enrolled)? >= threshold)
}
fn score(
&self,
samples: &[f32],
sample_rate: u32,
enrolled: &Enrollment,
) -> Result<f32, VoiceError> {
if sample_rate != EMBED_SAMPLE_RATE {
return Err(VoiceError::Config(format!(
"CosineVerifier expects {EMBED_SAMPLE_RATE} Hz; got {sample_rate}"
)));
}
let pcm = f32_to_i16(samples);
let values = self.embedder.embed(&pcm);
if values.is_empty() {
return Err(VoiceError::Config("embedder produced empty vector".into()));
}
let candidate = SpeakerEmbedding::new(values, self.embedder.model_id());
SpeakerEmbedding::cosine_similarity(&candidate, &enrolled.embedding)
}
}
pub struct SilenceSplitDiarizer {
embedder: Arc<dyn SpeakerEmbedder>,
merge_threshold: f32,
silence_threshold_dbfs: f32,
min_chunk_ms: u64,
}
impl SilenceSplitDiarizer {
pub fn new(embedder: Arc<dyn SpeakerEmbedder>) -> Self {
Self {
embedder,
merge_threshold: 0.75,
silence_threshold_dbfs: -45.0,
min_chunk_ms: 300,
}
}
pub fn with_merge_threshold(mut self, t: f32) -> Self {
self.merge_threshold = t;
self
}
pub fn with_silence_threshold_dbfs(mut self, t: f32) -> Self {
self.silence_threshold_dbfs = t;
self
}
pub fn with_min_chunk_ms(mut self, ms: u64) -> Self {
self.min_chunk_ms = ms;
self
}
}
impl Diarizer for SilenceSplitDiarizer {
fn diarize(
&self,
samples: &[f32],
sample_rate: u32,
) -> Result<Vec<DiarizationTurn>, VoiceError> {
if sample_rate != EMBED_SAMPLE_RATE {
return Err(VoiceError::Config(format!(
"SilenceSplitDiarizer expects {EMBED_SAMPLE_RATE} Hz; got {sample_rate}"
)));
}
let chunks = split_on_silence(
samples,
sample_rate,
self.silence_threshold_dbfs,
self.min_chunk_ms,
);
if chunks.is_empty() {
return Ok(Vec::new());
}
let mut embedded: Vec<(u64, u64, SpeakerEmbedding)> = Vec::new();
for (start_ms, end_ms, piece) in &chunks {
let pcm = f32_to_i16(piece);
let values = self.embedder.embed(&pcm);
if values.is_empty() {
continue;
}
let emb = SpeakerEmbedding::new(values, self.embedder.model_id());
embedded.push((*start_ms, *end_ms, emb));
}
let mut clusters: Vec<SpeakerEmbedding> = Vec::new();
let mut assignments: Vec<usize> = Vec::with_capacity(embedded.len());
for (_, _, emb) in &embedded {
let mut best: Option<(usize, f32)> = None;
for (i, c) in clusters.iter().enumerate() {
if let Ok(s) = SpeakerEmbedding::cosine_similarity(c, emb) {
if best.map(|(_, bs)| s > bs).unwrap_or(true) {
best = Some((i, s));
}
}
}
match best {
Some((i, s)) if s >= self.merge_threshold => assignments.push(i),
_ => {
clusters.push(emb.clone());
assignments.push(clusters.len() - 1);
}
}
}
Ok(embedded
.into_iter()
.zip(assignments.iter())
.map(
|((start_ms, end_ms, embedding), cluster_id)| DiarizationTurn {
start_ms,
end_ms,
speaker_id: format!("speaker_{cluster_id}"),
embedding,
},
)
.collect())
}
}
pub struct SpeakerPipeline {
embedder: Arc<dyn SpeakerEmbedder>,
verifier: Arc<dyn SpeakerVerifier>,
diarizer: Arc<dyn Diarizer>,
enrollment: Mutex<Option<Enrollment>>,
match_threshold: f32,
pending: Mutex<Option<PendingEnrollment>>,
}
impl SpeakerPipeline {
pub fn baseline() -> Self {
let embedder: Arc<dyn SpeakerEmbedder> = Arc::new(FilterbankEmbedder::new());
let verifier: Arc<dyn SpeakerVerifier> =
Arc::new(CosineVerifier::new(Arc::clone(&embedder)));
let diarizer: Arc<dyn Diarizer> =
Arc::new(SilenceSplitDiarizer::new(Arc::clone(&embedder)));
Self {
embedder,
verifier,
diarizer,
enrollment: Mutex::new(None),
match_threshold: DEFAULT_MATCH_THRESHOLD,
pending: Mutex::new(None),
}
}
pub fn with_embedder(mut self, embedder: Box<dyn SpeakerEmbedder>) -> Self {
self.embedder = Arc::from(embedder);
self
}
pub fn with_verifier(mut self, verifier: Arc<dyn SpeakerVerifier>) -> Self {
self.verifier = verifier;
self
}
pub fn with_diarizer(mut self, diarizer: Arc<dyn Diarizer>) -> Self {
self.diarizer = diarizer;
self
}
pub fn with_enrollment(self, enrollment: Enrollment) -> Self {
if let Ok(mut slot) = self.enrollment.lock() {
*slot = Some(enrollment);
}
self
}
pub fn with_match_threshold(mut self, threshold: f32) -> Self {
self.match_threshold = threshold;
self
}
pub fn enrollment_snapshot(&self) -> Option<Enrollment> {
self.enrollment.lock().ok().and_then(|s| s.clone())
}
pub fn classify(&self, pcm: &[i16]) -> TranscriptRole {
let Some(enrollment) = self.enrollment_snapshot() else {
return TranscriptRole::Unknown;
};
if enrollment.embedding.values.is_empty() {
return TranscriptRole::Unknown;
}
let values = self.embedder.embed(pcm);
if values.is_empty() {
return TranscriptRole::Unknown;
}
let candidate = SpeakerEmbedding::new(values, self.embedder.model_id());
let score = match SpeakerEmbedding::cosine_similarity(&candidate, &enrollment.embedding) {
Ok(s) => s,
Err(e) => {
tracing::warn!("[voice] cosine compare rejected: {e}");
return TranscriptRole::Unknown;
}
};
if score >= self.match_threshold {
TranscriptRole::EnrolledUser
} else if score < self.match_threshold - OTHER_SPEAKER_MARGIN {
TranscriptRole::OtherSpeaker {
local_id: "overheard".to_string(),
}
} else {
TranscriptRole::Unknown
}
}
pub fn classify_turns(
&self,
samples: &[f32],
sample_rate: u32,
) -> Result<Vec<TaggedTurn>, VoiceError> {
let Some(enrollment) = self.enrollment_snapshot() else {
return Ok(Vec::new());
};
let turns = self.diarizer.diarize(samples, sample_rate)?;
if turns.is_empty() {
return Ok(Vec::new());
}
let mut out: Vec<TaggedTurn> = Vec::with_capacity(turns.len());
for turn in turns {
let role =
match SpeakerEmbedding::cosine_similarity(&turn.embedding, &enrollment.embedding) {
Ok(score) if score >= self.match_threshold => TurnRole::EnrolledUser,
Ok(_) => TurnRole::OtherSpeaker(turn.speaker_id.clone()),
Err(e) => {
tracing::warn!("[voice] cosine compare rejected in classify_turns: {e}");
TurnRole::Unknown
}
};
let start = ((turn.start_ms * sample_rate as u64) / 1000) as usize;
let end = ((turn.end_ms * sample_rate as u64) / 1000) as usize;
let end = end.min(samples.len());
let audio = if end > start {
samples[start..end].to_vec()
} else {
Vec::new()
};
out.push(TaggedTurn {
start_ms: turn.start_ms,
end_ms: turn.end_ms,
role,
embedding: turn.embedding,
audio,
});
}
Ok(out)
}
pub fn filter_to_enrolled_user(
&self,
samples: &[f32],
sample_rate: u32,
) -> Result<Option<Vec<f32>>, VoiceError> {
let Some(enrollment) = self.enrollment_snapshot() else {
return Ok(Some(samples.to_vec()));
};
if let Ok(score) = self.verifier.score(samples, sample_rate, &enrollment) {
if score >= self.match_threshold {
return Ok(Some(samples.to_vec()));
}
}
let turns = self.diarizer.diarize(samples, sample_rate)?;
if turns.is_empty() {
return Ok(None);
}
let mut kept: Vec<Vec<f32>> = Vec::new();
for turn in turns {
let score = SpeakerEmbedding::cosine_similarity(&turn.embedding, &enrollment.embedding)
.unwrap_or(0.0);
if score >= self.match_threshold {
let start = ((turn.start_ms * sample_rate as u64) / 1000) as usize;
let end = ((turn.end_ms * sample_rate as u64) / 1000) as usize;
let end = end.min(samples.len());
if end > start {
kept.push(samples[start..end].to_vec());
}
}
}
if kept.is_empty() {
return Ok(None);
}
let pad_samples = (sample_rate as usize) / 10;
let pad = vec![0.0f32; pad_samples];
let mut out: Vec<f32> = Vec::new();
for (i, chunk) in kept.into_iter().enumerate() {
if i > 0 {
out.extend_from_slice(&pad);
}
out.extend(chunk);
}
Ok(Some(out))
}
pub fn arm_enrollment(&self, request: PendingEnrollment) {
if let Ok(mut guard) = self.pending.lock() {
*guard = Some(request);
}
}
pub fn capture_enrollment(&self, pcm: &[i16]) -> Option<EnrollmentOutcome> {
let pending = self.pending.lock().ok().and_then(|mut g| g.take())?;
let values = self.embedder.embed(pcm);
if values.is_empty() {
return Some(EnrollmentOutcome::Failed {
reason: format!(
"embedder '{}' returned empty vector — audio too short or silent",
self.embedder.model_id()
),
});
}
let enrollment = Enrollment {
label: pending.label.clone(),
embedding: SpeakerEmbedding {
values,
model: self.embedder.model_id().to_string(),
},
};
if let Err(e) = enrollment.save_to(&pending.save_path) {
return Some(EnrollmentOutcome::Failed {
reason: format!("save to {}: {e}", pending.save_path.display()),
});
}
if let Ok(mut slot) = self.enrollment.lock() {
*slot = Some(enrollment);
}
Some(EnrollmentOutcome::Captured {
label: pending.label,
save_path: pending.save_path,
})
}
pub fn has_pending_enrollment(&self) -> bool {
self.pending.lock().map(|g| g.is_some()).unwrap_or(false)
}
pub fn is_enrolled(&self) -> bool {
self.enrollment.lock().map(|s| s.is_some()).unwrap_or(false)
}
}
impl Default for SpeakerPipeline {
fn default() -> Self {
Self::baseline()
}
}
fn normalize_in_place(values: &mut [f32]) {
let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > f32::EPSILON {
for v in values {
*v /= norm;
}
}
}
fn f32_to_i16(samples: &[f32]) -> Vec<i16> {
samples
.iter()
.map(|&s| (s.clamp(-1.0, 1.0) * 32767.0).round() as i16)
.collect()
}
fn split_on_silence(
samples: &[f32],
sample_rate: u32,
threshold_dbfs: f32,
min_chunk_ms: u64,
) -> Vec<(u64, u64, Vec<f32>)> {
if samples.is_empty() {
return Vec::new();
}
let frame_size = ((sample_rate as f32) * 0.020).round() as usize;
if frame_size == 0 {
return Vec::new();
}
let threshold_amp = 10f32.powf(threshold_dbfs / 20.0);
let mut out: Vec<(u64, u64, Vec<f32>)> = Vec::new();
let mut chunk_start: Option<usize> = None;
let mut i = 0usize;
while i + frame_size <= samples.len() {
let frame = &samples[i..i + frame_size];
let rms = (frame.iter().map(|s| s * s).sum::<f32>() / frame.len() as f32).sqrt();
let voiced = rms >= threshold_amp;
match (voiced, chunk_start) {
(true, None) => chunk_start = Some(i),
(false, Some(start)) => {
let ms = ((i - start) as u64 * 1000) / sample_rate as u64;
if ms >= min_chunk_ms {
out.push((
(start as u64 * 1000) / sample_rate as u64,
(i as u64 * 1000) / sample_rate as u64,
samples[start..i].to_vec(),
));
}
chunk_start = None;
}
_ => {}
}
i += frame_size;
}
if let Some(start) = chunk_start {
let ms = ((samples.len() - start) as u64 * 1000) / sample_rate as u64;
if ms >= min_chunk_ms {
out.push((
(start as u64 * 1000) / sample_rate as u64,
(samples.len() as u64 * 1000) / sample_rate as u64,
samples[start..].to_vec(),
));
}
}
out
}
fn mel_filterbank_frames(samples: &[f32], sample_rate: u32, n_mels: usize) -> Vec<Vec<f32>> {
let frame_size = ((sample_rate as f32) * 0.025).round() as usize;
let hop_size = ((sample_rate as f32) * 0.010).round() as usize;
if frame_size == 0 || hop_size == 0 || samples.len() < frame_size {
return vec![];
}
let window: Vec<f32> = (0..frame_size)
.map(|i| {
0.5 - 0.5 * ((2.0 * std::f32::consts::PI * i as f32) / (frame_size - 1) as f32).cos()
})
.collect();
let mel_min = hz_to_mel(80.0);
let mel_max = hz_to_mel((sample_rate as f32 / 2.0).min(8000.0));
let mel_edges: Vec<f32> = (0..n_mels + 2)
.map(|i| mel_min + (mel_max - mel_min) * (i as f32) / (n_mels + 1) as f32)
.collect();
let hz_edges: Vec<f32> = mel_edges.iter().copied().map(mel_to_hz).collect();
let n_fft = frame_size.next_power_of_two();
let bin_edges: Vec<f32> = hz_edges
.iter()
.map(|hz| hz * (n_fft as f32) / (sample_rate as f32))
.collect();
let mut out: Vec<Vec<f32>> = Vec::new();
let mut frame_samples = vec![0.0f32; n_fft];
let mut start = 0usize;
while start + frame_size <= samples.len() {
for i in 0..n_fft {
frame_samples[i] = if i < frame_size {
samples[start + i] * window[i]
} else {
0.0
};
}
let mags = dft_magnitudes(&frame_samples);
let mut mels = vec![0.0f32; n_mels];
for m in 0..n_mels {
let lo = bin_edges[m];
let ctr = bin_edges[m + 1];
let hi = bin_edges[m + 2];
let mut acc = 0.0f32;
for (k, mag) in mags.iter().enumerate() {
let k = k as f32;
let weight = if k <= lo || k >= hi {
0.0
} else if k <= ctr {
(k - lo) / (ctr - lo).max(1e-6)
} else {
(hi - k) / (hi - ctr).max(1e-6)
};
acc += weight * mag;
}
mels[m] = (acc + 1e-6).ln();
}
out.push(mels);
start += hop_size;
}
out
}
fn hz_to_mel(hz: f32) -> f32 {
2595.0 * (1.0 + hz / 700.0).log10()
}
fn mel_to_hz(mel: f32) -> f32 {
700.0 * (10f32.powf(mel / 2595.0) - 1.0)
}
fn dft_magnitudes(frame: &[f32]) -> Vec<f32> {
let n = frame.len();
let half = n / 2 + 1;
let mut out = vec![0.0f32; half];
for k in 0..half {
let mut re = 0.0f32;
let mut im = 0.0f32;
let factor = -2.0 * std::f32::consts::PI * (k as f32) / (n as f32);
for (t, x) in frame.iter().enumerate() {
let angle = factor * (t as f32);
re += x * angle.cos();
im += x * angle.sin();
}
out[k] = (re * re + im * im).sqrt();
}
out
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::tempdir;
fn tone_i16(freq: f32, duration_s: f32, sample_rate: u32) -> Vec<i16> {
let total = (duration_s * sample_rate as f32) as usize;
(0..total)
.map(|i| {
let t = i as f32 / sample_rate as f32;
let s = (2.0 * std::f32::consts::PI * freq * t).sin() * 0.4;
(s * 32767.0) as i16
})
.collect()
}
#[test]
fn default_role_is_unknown() {
assert_eq!(TranscriptRole::default(), TranscriptRole::Unknown);
}
#[test]
fn embedder_returns_unit_vector_for_tone() {
let audio = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
let embedder = FilterbankEmbedder::new();
let values = embedder.embed(&audio);
assert_eq!(values.len(), 80);
let norm: f32 = values.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-4, "embedding not unit-norm: {norm}");
}
#[test]
fn embedder_returns_empty_for_empty_input() {
let embedder = FilterbankEmbedder::new();
assert!(embedder.embed(&[]).is_empty());
}
#[test]
fn same_tone_self_matches() {
let audio = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
let embedder = FilterbankEmbedder::new();
let a = SpeakerEmbedding::new(embedder.embed(&audio), embedder.model_id());
let b = SpeakerEmbedding::new(embedder.embed(&audio), embedder.model_id());
let score = SpeakerEmbedding::cosine_similarity(&a, &b).unwrap();
assert!(
score > 0.99,
"identical audio should self-match, got {score}"
);
}
#[test]
fn different_tones_score_lower_than_same_tone() {
let same_a = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
let same_b = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
let different = tone_i16(880.0, 1.0, EMBED_SAMPLE_RATE);
let embedder = FilterbankEmbedder::new();
let ea = SpeakerEmbedding::new(embedder.embed(&same_a), embedder.model_id());
let eb = SpeakerEmbedding::new(embedder.embed(&same_b), embedder.model_id());
let ed = SpeakerEmbedding::new(embedder.embed(&different), embedder.model_id());
let same_score = SpeakerEmbedding::cosine_similarity(&ea, &eb).unwrap();
let diff_score = SpeakerEmbedding::cosine_similarity(&ea, &ed).unwrap();
assert!(
same_score > diff_score + 0.05,
"same: {same_score}, different: {diff_score}"
);
}
#[test]
fn cosine_similarity_refuses_model_mismatch() {
let a = SpeakerEmbedding::new(vec![1.0, 0.0], "model-a");
let b = SpeakerEmbedding::new(vec![1.0, 0.0], "model-b");
assert!(SpeakerEmbedding::cosine_similarity(&a, &b).is_err());
}
#[test]
fn baseline_without_enrollment_is_unknown() {
let p = SpeakerPipeline::baseline();
let audio = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
assert_eq!(p.classify(&audio), TranscriptRole::Unknown);
}
#[test]
fn classify_returns_enrolled_user_for_matching_audio() {
let embedder = FilterbankEmbedder::new();
let audio = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
let enrolled = Enrollment {
label: "matt".into(),
embedding: SpeakerEmbedding::new(embedder.embed(&audio), embedder.model_id()),
};
let pipeline = SpeakerPipeline::baseline().with_enrollment(enrolled);
assert_eq!(pipeline.classify(&audio), TranscriptRole::EnrolledUser);
}
#[test]
fn classify_returns_other_for_very_different_audio() {
let embedder = FilterbankEmbedder::new();
let enrolled_audio = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
let other_audio = tone_i16(2_000.0, 1.0, EMBED_SAMPLE_RATE);
let enrolled = Enrollment {
label: "matt".into(),
embedding: SpeakerEmbedding::new(embedder.embed(&enrolled_audio), embedder.model_id()),
};
let pipeline = SpeakerPipeline::baseline().with_enrollment(enrolled);
match pipeline.classify(&other_audio) {
TranscriptRole::OtherSpeaker { .. } | TranscriptRole::Unknown => {}
TranscriptRole::EnrolledUser => {
panic!("2 kHz tone should not match a 220 Hz enrollment")
}
}
}
#[test]
fn enrollment_roundtrip() {
let dir = tempdir().unwrap();
let path = dir.path().join("voiceprint.toml");
let e = Enrollment {
label: "matt".into(),
embedding: SpeakerEmbedding {
values: vec![0.1, 0.2, 0.3],
model: "fbank-stats-v1".into(),
},
};
e.save_to(&path).unwrap();
let loaded = Enrollment::load_from(&path).unwrap();
assert_eq!(loaded.label, "matt");
assert_eq!(loaded.embedding.values.len(), 3);
assert_eq!(loaded.embedding.model, "fbank-stats-v1");
}
#[test]
fn arm_then_capture_produces_voiceprint() {
let dir = tempdir().unwrap();
let path = dir.path().join("voiceprint.toml");
let pipeline = SpeakerPipeline::baseline();
assert!(!pipeline.has_pending_enrollment());
pipeline.arm_enrollment(PendingEnrollment::new("matt", &path, 0.72));
assert!(pipeline.has_pending_enrollment());
let audio = tone_i16(220.0, 1.0, EMBED_SAMPLE_RATE);
let outcome = pipeline.capture_enrollment(&audio).unwrap();
match outcome {
EnrollmentOutcome::Captured { label, save_path } => {
assert_eq!(label, "matt");
assert_eq!(save_path, path);
}
EnrollmentOutcome::Failed { reason } => {
panic!("expected Captured, got Failed: {reason}")
}
}
assert!(!pipeline.has_pending_enrollment());
assert!(pipeline.is_enrolled());
assert_eq!(pipeline.classify(&audio), TranscriptRole::EnrolledUser);
}
#[test]
fn capture_with_empty_audio_fails_cleanly() {
let dir = tempdir().unwrap();
let path = dir.path().join("voiceprint.toml");
let pipeline = SpeakerPipeline::baseline();
pipeline.arm_enrollment(PendingEnrollment::new("matt", &path, 0.72));
let outcome = pipeline.capture_enrollment(&[]).unwrap();
assert!(matches!(outcome, EnrollmentOutcome::Failed { .. }));
assert!(!pipeline.has_pending_enrollment());
}
fn tone_f32(freq: f32, duration_s: f32, sample_rate: u32) -> Vec<f32> {
let total = (duration_s * sample_rate as f32) as usize;
(0..total)
.map(|i| {
let t = i as f32 / sample_rate as f32;
(2.0 * std::f32::consts::PI * freq * t).sin() * 0.4
})
.collect()
}
fn silence_f32(duration_s: f32, sample_rate: u32) -> Vec<f32> {
vec![0.0f32; (duration_s * sample_rate as f32) as usize]
}
#[test]
fn turn_role_maps_to_transcript_role() {
assert_eq!(
TurnRole::EnrolledUser.into_transcript_role(),
TranscriptRole::EnrolledUser
);
assert_eq!(
TurnRole::OtherSpeaker("speaker_3".into()).into_transcript_role(),
TranscriptRole::OtherSpeaker {
local_id: "speaker_3".into()
}
);
assert_eq!(
TurnRole::Unknown.into_transcript_role(),
TranscriptRole::Unknown
);
}
#[test]
fn diarizer_splits_two_tones_separated_by_silence() {
let sr = EMBED_SAMPLE_RATE;
let mut audio = tone_f32(220.0, 1.0, sr);
audio.extend(silence_f32(0.5, sr));
audio.extend(tone_f32(1_500.0, 1.0, sr));
let embedder: Arc<dyn SpeakerEmbedder> = Arc::new(FilterbankEmbedder::new());
let diarizer = SilenceSplitDiarizer::new(embedder);
let turns = diarizer.diarize(&audio, sr).unwrap();
assert_eq!(turns.len(), 2, "expected 2 turns, got {}", turns.len());
assert_ne!(
turns[0].speaker_id, turns[1].speaker_id,
"different tones should land in different clusters"
);
}
#[test]
fn classify_turns_empty_without_enrollment() {
let pipeline = SpeakerPipeline::baseline();
let audio = tone_f32(220.0, 1.0, EMBED_SAMPLE_RATE);
let turns = pipeline.classify_turns(&audio, EMBED_SAMPLE_RATE).unwrap();
assert!(turns.is_empty());
}
#[test]
fn classify_turns_tags_enrolled_and_other() {
let sr = EMBED_SAMPLE_RATE;
let embedder = FilterbankEmbedder::new();
let enrolled_audio = tone_i16(220.0, 1.0, sr);
let enrolled = Enrollment {
label: "matt".into(),
embedding: SpeakerEmbedding::new(embedder.embed(&enrolled_audio), embedder.model_id()),
};
let pipeline = SpeakerPipeline::baseline().with_enrollment(enrolled);
let mut audio = tone_f32(220.0, 1.0, sr);
audio.extend(silence_f32(0.5, sr));
audio.extend(tone_f32(1_500.0, 1.0, sr));
let turns = pipeline.classify_turns(&audio, sr).unwrap();
assert_eq!(turns.len(), 2, "expected 2 turns, got {}", turns.len());
assert_eq!(turns[0].role, TurnRole::EnrolledUser);
assert!(matches!(turns[1].role, TurnRole::OtherSpeaker(_)));
assert!(turns[0].audio.len() > sr as usize / 2);
assert!(turns[1].audio.len() > sr as usize / 2);
}
#[test]
fn filter_to_enrolled_user_passthrough_without_enrollment() {
let pipeline = SpeakerPipeline::baseline();
let audio = tone_f32(220.0, 1.0, EMBED_SAMPLE_RATE);
let filtered = pipeline
.filter_to_enrolled_user(&audio, EMBED_SAMPLE_RATE)
.unwrap()
.unwrap();
assert_eq!(filtered.len(), audio.len());
}
#[test]
fn filter_to_enrolled_user_drops_non_matching() {
let sr = EMBED_SAMPLE_RATE;
let embedder = FilterbankEmbedder::new();
let enrolled_audio = tone_i16(220.0, 1.0, sr);
let enrolled = Enrollment {
label: "matt".into(),
embedding: SpeakerEmbedding::new(embedder.embed(&enrolled_audio), embedder.model_id()),
};
let pipeline = SpeakerPipeline::baseline().with_enrollment(enrolled);
let other_only = tone_f32(1_800.0, 1.0, sr);
let filtered = pipeline.filter_to_enrolled_user(&other_only, sr).unwrap();
assert!(
filtered.is_none(),
"non-matching-only segment should return None"
);
}
}