use super::{VoiceError, VoiceResult};
#[derive(Debug, Clone)]
pub struct EmbeddingConfig {
pub embedding_dim: usize,
pub sample_rate: u32,
pub frame_length_ms: u32,
pub frame_shift_ms: u32,
pub n_mels: usize,
pub normalize: bool,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
embedding_dim: 192,
sample_rate: 16000,
frame_length_ms: 25,
frame_shift_ms: 10,
n_mels: 80,
normalize: true,
}
}
}
impl EmbeddingConfig {
#[must_use]
pub fn ecapa_tdnn() -> Self {
Self {
embedding_dim: 192,
n_mels: 80,
..Self::default()
}
}
#[must_use]
pub fn x_vector() -> Self {
Self {
embedding_dim: 512,
n_mels: 30,
..Self::default()
}
}
#[must_use]
pub fn resnet() -> Self {
Self {
embedding_dim: 256,
n_mels: 64,
..Self::default()
}
}
pub fn validate(&self) -> VoiceResult<()> {
if self.embedding_dim == 0 {
return Err(VoiceError::InvalidConfig(
"embedding_dim must be > 0".to_string(),
));
}
if self.sample_rate == 0 {
return Err(VoiceError::InvalidConfig(
"sample_rate must be > 0".to_string(),
));
}
if self.frame_length_ms == 0 {
return Err(VoiceError::InvalidConfig(
"frame_length_ms must be > 0".to_string(),
));
}
if self.n_mels == 0 {
return Err(VoiceError::InvalidConfig("n_mels must be > 0".to_string()));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SpeakerEmbedding {
vector: Vec<f32>,
normalized: bool,
}
impl SpeakerEmbedding {
#[must_use]
pub fn from_vec(vector: Vec<f32>) -> Self {
Self {
vector,
normalized: false,
}
}
#[must_use]
pub fn zeros(dim: usize) -> Self {
Self {
vector: vec![0.0; dim],
normalized: false,
}
}
#[must_use]
pub fn dim(&self) -> usize {
self.vector.len()
}
#[must_use]
pub fn as_slice(&self) -> &[f32] {
&self.vector
}
#[must_use]
pub fn as_mut_slice(&mut self) -> &mut [f32] {
&mut self.vector
}
#[must_use]
pub fn into_vec(self) -> Vec<f32> {
self.vector
}
#[must_use]
pub fn is_normalized(&self) -> bool {
self.normalized
}
pub fn normalize(&mut self) {
let norm = self.l2_norm();
if norm > f32::EPSILON {
for x in &mut self.vector {
*x /= norm;
}
}
self.normalized = true;
}
#[must_use]
pub fn l2_norm(&self) -> f32 {
self.vector.iter().map(|x| x * x).sum::<f32>().sqrt()
}
pub fn dot(&self, other: &Self) -> VoiceResult<f32> {
if self.dim() != other.dim() {
return Err(VoiceError::DimensionMismatch {
expected: self.dim(),
got: other.dim(),
});
}
Ok(self
.vector
.iter()
.zip(other.vector.iter())
.map(|(a, b)| a * b)
.sum())
}
pub fn euclidean_distance(&self, other: &Self) -> VoiceResult<f32> {
if self.dim() != other.dim() {
return Err(VoiceError::DimensionMismatch {
expected: self.dim(),
got: other.dim(),
});
}
Ok(crate::nn::functional::euclidean_distance(
&self.vector,
&other.vector,
))
}
}
pub trait EmbeddingExtractor {
fn extract(&self, audio: &[f32]) -> VoiceResult<SpeakerEmbedding>;
fn embedding_dim(&self) -> usize;
fn sample_rate(&self) -> u32;
}
#[derive(Debug)]
pub struct EcapaTdnn {
config: EmbeddingConfig,
}
impl EcapaTdnn {
#[must_use]
pub fn new(config: EmbeddingConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default_config() -> Self {
Self::new(EmbeddingConfig::ecapa_tdnn())
}
}
impl EmbeddingExtractor for EcapaTdnn {
fn extract(&self, audio: &[f32]) -> VoiceResult<SpeakerEmbedding> {
if audio.is_empty() {
return Err(VoiceError::InvalidAudio("empty audio".to_string()));
}
Err(VoiceError::NotImplemented(
"ECAPA-TDNN requires model weights (use from_apr to load)".to_string(),
))
}
fn embedding_dim(&self) -> usize {
self.config.embedding_dim
}
fn sample_rate(&self) -> u32 {
self.config.sample_rate
}
}
#[derive(Debug)]
pub struct XVector {
config: EmbeddingConfig,
}
impl XVector {
#[must_use]
pub fn new(config: EmbeddingConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default_config() -> Self {
Self::new(EmbeddingConfig::x_vector())
}
}
impl EmbeddingExtractor for XVector {
fn extract(&self, audio: &[f32]) -> VoiceResult<SpeakerEmbedding> {
if audio.is_empty() {
return Err(VoiceError::InvalidAudio("empty audio".to_string()));
}
Err(VoiceError::NotImplemented(
"X-Vector requires model weights (use from_apr to load)".to_string(),
))
}
fn embedding_dim(&self) -> usize {
self.config.embedding_dim
}
fn sample_rate(&self) -> u32 {
self.config.sample_rate
}
}
#[must_use]
pub fn cosine_similarity(a: &SpeakerEmbedding, b: &SpeakerEmbedding) -> f32 {
if a.dim() != b.dim() || a.dim() == 0 {
return 0.0;
}
crate::nn::functional::cosine_similarity_slice(a.as_slice(), b.as_slice())
}
#[must_use]
pub fn normalize_embedding(embedding: &SpeakerEmbedding) -> SpeakerEmbedding {
let mut normalized = embedding.clone();
normalized.normalize();
normalized
}
pub fn average_embeddings(embeddings: &[SpeakerEmbedding]) -> VoiceResult<SpeakerEmbedding> {
if embeddings.is_empty() {
return Err(VoiceError::InvalidConfig(
"cannot average empty list".to_string(),
));
}
let dim = embeddings[0].dim();
for emb in embeddings.iter().skip(1) {
if emb.dim() != dim {
return Err(VoiceError::DimensionMismatch {
expected: dim,
got: emb.dim(),
});
}
}
let mut avg = vec![0.0_f32; dim];
let count = embeddings.len() as f32;
for emb in embeddings {
for (i, &val) in emb.as_slice().iter().enumerate() {
avg[i] += val / count;
}
}
Ok(SpeakerEmbedding::from_vec(avg))
}
#[must_use]
pub fn similarity_matrix(embeddings: &[SpeakerEmbedding]) -> Vec<Vec<f32>> {
let n = embeddings.len();
let mut matrix = vec![vec![0.0_f32; n]; n];
for i in 0..n {
matrix[i][i] = 1.0; for j in (i + 1)..n {
let sim = cosine_similarity(&embeddings[i], &embeddings[j]);
matrix[i][j] = sim;
matrix[j][i] = sim; }
}
matrix
}
#[cfg(test)]
#[path = "embedding_tests.rs"]
mod tests;
#[cfg(test)]
#[path = "tests_embedding_contract.rs"]
mod tests_embedding_contract;