use anyhow::Result;
use burn::tensor::{Tensor, backend::Backend};
pub fn extract_speaker_embedding<B: Backend>(encoder_output: Tensor<B, 3>) -> Result<Vec<f32>> {
let [_, n_frames, d_model] = encoder_output.dims();
let flat: Vec<f32> = encoder_output
.into_data()
.to_vec::<f32>()
.map_err(|e| anyhow::anyhow!("extract_speaker_embedding: tensor read failed: {:?}", e))?;
let mut embedding = vec![0.0f32; d_model];
for frame in 0..n_frames {
for dim in 0..d_model {
embedding[dim] += flat[frame * d_model + dim];
}
}
let scale = n_frames as f32;
for v in &mut embedding {
*v /= scale;
}
let norm: f32 = embedding.iter().map(|v| v * v).sum::<f32>().sqrt();
if norm > 1e-8 {
for v in &mut embedding {
*v /= norm;
}
}
Ok(embedding)
}
#[cfg(test)]
mod tests {
use super::*;
use burn::tensor::TensorData;
use burn_flex::Flex;
use burn_flex::FlexDevice;
#[test]
fn test_embedding_is_unit_norm() {
let device = FlexDevice;
let enc: Tensor<Flex<f32>, 3> =
Tensor::from_data(TensorData::new(vec![1.0f32; 8 * 4], [1, 8, 4]), &device);
let emb = extract_speaker_embedding(enc).unwrap();
let norm: f32 = emb.iter().map(|v| v * v).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-5, "embedding norm={norm}");
}
#[test]
fn test_embedding_dimension_matches_d_model() {
let device = FlexDevice;
let d_model = 384usize;
let enc: Tensor<Flex<f32>, 3> = Tensor::zeros([1, 1500, d_model], &device);
let emb = extract_speaker_embedding(enc).unwrap();
assert_eq!(emb.len(), d_model);
}
#[test]
fn test_identical_encoder_outputs_produce_identical_embeddings() {
let device = FlexDevice;
let data = (0..384).map(|i| i as f32).collect::<Vec<_>>();
let flat: Vec<f32> = data.iter().cycle().take(1500 * 384).copied().collect();
let enc: Tensor<Flex<f32>, 3> =
Tensor::from_data(TensorData::new(flat, [1, 1500, 384]), &device);
let enc2: Tensor<Flex<f32>, 3> = enc.clone();
let e1 = extract_speaker_embedding(enc).unwrap();
let e2 = extract_speaker_embedding(enc2).unwrap();
let max_diff = e1
.iter()
.zip(e2.iter())
.map(|(a, b)| (a - b).abs())
.fold(0.0f32, f32::max);
assert!(max_diff < 1e-6);
}
}