use std::collections::HashMap;
use std::sync::Arc;
use ndarray::{Array1, Array2};
use ordered_float::OrderedFloat;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum PoolingStrategy {
Mean,
Max,
Last,
Attention,
MeanMax,
}
impl Default for PoolingStrategy {
fn default() -> Self {
Self::Mean
}
}
#[derive(Clone, Debug)]
pub struct AcousticEmbeddingConfig {
pub embedding_dim: usize,
pub feature_dim: usize,
pub pooling: PoolingStrategy,
pub normalize: bool,
pub text_projection_dim: Option<usize>,
}
impl Default for AcousticEmbeddingConfig {
fn default() -> Self {
Self {
embedding_dim: 128,
feature_dim: 40,
pooling: PoolingStrategy::Mean,
normalize: true,
text_projection_dim: None,
}
}
}
pub trait AcousticEncoder: Send + Sync {
fn encode_frames(&self, frames: &[Vec<f32>]) -> Vec<Vec<f32>>;
fn hidden_dim(&self) -> usize;
fn feature_dim(&self) -> usize;
}
#[derive(Clone, Debug)]
pub struct LinearEncoder {
weights: Array2<f32>,
bias: Array1<f32>,
}
impl LinearEncoder {
pub fn new(feature_dim: usize, hidden_dim: usize) -> Self {
let scale = (2.0 / (feature_dim + hidden_dim) as f32).sqrt();
let weights = Array2::from_shape_fn((feature_dim, hidden_dim), |_| {
(rand::random::<f32>() - 0.5) * 2.0 * scale
});
let bias = Array1::zeros(hidden_dim);
Self { weights, bias }
}
pub fn from_weights(weights: Array2<f32>, bias: Array1<f32>) -> Self {
Self { weights, bias }
}
}
impl AcousticEncoder for LinearEncoder {
fn encode_frames(&self, frames: &[Vec<f32>]) -> Vec<Vec<f32>> {
frames
.iter()
.map(|frame| {
let input = Array1::from_vec(frame.clone());
let output = input.dot(&self.weights) + &self.bias;
output.to_vec()
})
.collect()
}
fn hidden_dim(&self) -> usize {
self.weights.ncols()
}
fn feature_dim(&self) -> usize {
self.weights.nrows()
}
}
pub struct AcousticWordEmbedding {
encoder: Arc<dyn AcousticEncoder>,
config: AcousticEmbeddingConfig,
text_projection: Option<Array2<f32>>,
word_cache: HashMap<String, Array1<f32>>,
word_index: Vec<(String, Array1<f32>)>,
}
impl AcousticWordEmbedding {
pub fn new(config: AcousticEmbeddingConfig) -> Self {
let encoder = Arc::new(LinearEncoder::new(config.feature_dim, config.embedding_dim));
Self::with_encoder(encoder, config)
}
pub fn with_encoder(
encoder: Arc<dyn AcousticEncoder>,
config: AcousticEmbeddingConfig,
) -> Self {
let text_projection = config.text_projection_dim.map(|text_dim| {
let hidden = encoder.hidden_dim();
let scale = (2.0 / (hidden + text_dim) as f32).sqrt();
Array2::from_shape_fn((hidden, text_dim), |_| {
(rand::random::<f32>() - 0.5) * 2.0 * scale
})
});
Self {
encoder,
config,
text_projection,
word_cache: HashMap::new(),
word_index: Vec::new(),
}
}
pub fn config(&self) -> &AcousticEmbeddingConfig {
&self.config
}
pub fn embedding_dim(&self) -> usize {
if self.text_projection.is_some() {
self.config
.text_projection_dim
.unwrap_or(self.encoder.hidden_dim())
} else {
self.encoder.hidden_dim()
}
}
pub fn encode(&self, frames: &[Vec<f32>]) -> Vec<f32> {
if frames.is_empty() {
return vec![0.0; self.embedding_dim()];
}
let encoded = self.encoder.encode_frames(frames);
let pooled = self.apply_pooling(&encoded);
let projected = if let Some(ref proj) = self.text_projection {
pooled.dot(proj)
} else {
pooled
};
if self.config.normalize {
let norm = projected.dot(&projected).sqrt();
if norm > 1e-8 {
(projected / norm).to_vec()
} else {
projected.to_vec()
}
} else {
projected.to_vec()
}
}
fn apply_pooling(&self, frames: &[Vec<f32>]) -> Array1<f32> {
if frames.is_empty() {
return Array1::zeros(self.encoder.hidden_dim());
}
let hidden_dim = frames[0].len();
let num_frames = frames.len();
match self.config.pooling {
PoolingStrategy::Mean => {
let mut sum = Array1::zeros(hidden_dim);
for frame in frames {
sum += &Array1::from_vec(frame.clone());
}
sum / num_frames as f32
}
PoolingStrategy::Max => {
let mut max = Array1::from_vec(frames[0].clone());
for frame in frames.iter().skip(1) {
for (i, &v) in frame.iter().enumerate() {
if v > max[i] {
max[i] = v;
}
}
}
max
}
PoolingStrategy::Last => Array1::from_vec(frames[num_frames - 1].clone()),
PoolingStrategy::Attention => {
let mut sum = Array1::zeros(hidden_dim);
for frame in frames {
sum += &Array1::from_vec(frame.clone());
}
sum / num_frames as f32
}
PoolingStrategy::MeanMax => {
let mut mean = Array1::zeros(hidden_dim);
let mut max = Array1::from_vec(frames[0].clone());
for frame in frames {
let arr = Array1::from_vec(frame.clone());
mean += &arr;
for (i, &v) in frame.iter().enumerate() {
if v > max[i] {
max[i] = v;
}
}
}
mean /= num_frames as f32;
let mut concat = Vec::with_capacity(hidden_dim * 2);
concat.extend(mean.iter().copied());
concat.extend(max.iter().copied());
Array1::from_vec(concat)
}
}
}
pub fn audio_similarity(&self, audio1: &[Vec<f32>], audio2: &[Vec<f32>]) -> f64 {
let emb1 = self.encode(audio1);
let emb2 = self.encode(audio2);
self.cosine_similarity(&emb1, &emb2)
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f64 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a < 1e-8 || norm_b < 1e-8 {
0.0
} else {
(dot / (norm_a * norm_b)) as f64
}
}
pub fn add_word(&mut self, word: &str, frames: &[Vec<f32>]) {
let embedding = Array1::from_vec(self.encode(frames));
self.word_cache.insert(word.to_string(), embedding.clone());
self.word_index.push((word.to_string(), embedding));
}
pub fn add_word_embedding(&mut self, word: &str, embedding: Vec<f32>) {
let arr = Array1::from_vec(embedding);
self.word_cache.insert(word.to_string(), arr.clone());
self.word_index.push((word.to_string(), arr));
}
pub fn get_word_embedding(&self, word: &str) -> Option<&Array1<f32>> {
self.word_cache.get(word)
}
pub fn query_by_example(&self, audio: &[Vec<f32>], k: usize) -> Vec<(String, f64)> {
let query_emb = self.encode(audio);
self.query_by_embedding(&query_emb, k)
}
pub fn query_by_embedding(&self, query_emb: &[f32], k: usize) -> Vec<(String, f64)> {
let mut scores: Vec<(String, f64)> = self
.word_index
.iter()
.map(|(word, emb)| {
let sim = self.cosine_similarity(query_emb, emb.as_slice().unwrap());
(word.clone(), sim)
})
.collect();
scores.sort_by(|a, b| OrderedFloat(b.1).cmp(&OrderedFloat(a.1)));
scores.into_iter().take(k).collect()
}
pub fn index_size(&self) -> usize {
self.word_index.len()
}
pub fn clear_index(&mut self) {
self.word_cache.clear();
self.word_index.clear();
}
pub fn all_pairwise_similarities(&self) -> Array2<f32> {
let n = self.word_index.len();
let mut sims = Array2::zeros((n, n));
for i in 0..n {
for j in i..n {
let sim = self.cosine_similarity(
self.word_index[i].1.as_slice().unwrap(),
self.word_index[j].1.as_slice().unwrap(),
) as f32;
sims[[i, j]] = sim;
sims[[j, i]] = sim;
}
}
sims
}
}
#[derive(Clone, Debug, Default)]
pub struct AcousticEmbeddingStats {
pub num_words: usize,
pub total_frames: usize,
pub avg_norm: f64,
pub avg_similarity: f64,
}
impl AcousticWordEmbedding {
pub fn compute_stats(&self) -> AcousticEmbeddingStats {
let num_words = self.word_index.len();
if num_words == 0 {
return AcousticEmbeddingStats::default();
}
let avg_norm: f64 = self
.word_index
.iter()
.map(|(_, emb)| emb.dot(emb).sqrt() as f64)
.sum::<f64>()
/ num_words as f64;
let avg_similarity = if num_words <= 1000 {
let sims = self.all_pairwise_similarities();
let total: f32 = sims.sum();
let count = (num_words * num_words) as f32;
(total / count) as f64
} else {
0.0
};
AcousticEmbeddingStats {
num_words,
total_frames: 0, avg_norm,
avg_similarity,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_encoder() {
let encoder = LinearEncoder::new(40, 128);
assert_eq!(encoder.feature_dim(), 40);
assert_eq!(encoder.hidden_dim(), 128);
let frames = vec![vec![0.0f32; 40]; 10];
let encoded = encoder.encode_frames(&frames);
assert_eq!(encoded.len(), 10);
assert_eq!(encoded[0].len(), 128);
}
#[test]
fn test_acoustic_word_embedding_encode() {
let config = AcousticEmbeddingConfig {
embedding_dim: 64,
feature_dim: 40,
pooling: PoolingStrategy::Mean,
normalize: true,
text_projection_dim: None,
};
let awe = AcousticWordEmbedding::new(config);
let frames = vec![vec![1.0f32; 40]; 20];
let embedding = awe.encode(&frames);
assert_eq!(embedding.len(), 64);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01);
}
#[test]
fn test_pooling_strategies() {
let config = AcousticEmbeddingConfig::default();
let awe = AcousticWordEmbedding::new(config);
let frames = vec![vec![1.0f32; 128], vec![2.0f32; 128], vec![3.0f32; 128]];
let mean = awe.apply_pooling(&frames);
assert!((mean[0] - 2.0).abs() < 0.01);
let config_max = AcousticEmbeddingConfig {
pooling: PoolingStrategy::Max,
..Default::default()
};
let awe_max = AcousticWordEmbedding::new(config_max);
let max = awe_max.apply_pooling(&frames);
assert!((max[0] - 3.0).abs() < 0.01);
let config_last = AcousticEmbeddingConfig {
pooling: PoolingStrategy::Last,
..Default::default()
};
let awe_last = AcousticWordEmbedding::new(config_last);
let last = awe_last.apply_pooling(&frames);
assert!((last[0] - 3.0).abs() < 0.01);
}
#[test]
fn test_audio_similarity() {
let config = AcousticEmbeddingConfig::default();
let awe = AcousticWordEmbedding::new(config);
let frames1 = vec![vec![1.0f32; 40]; 10];
let sim_self = awe.audio_similarity(&frames1, &frames1);
assert!(sim_self > 0.99);
let frames2 = vec![vec![-1.0f32; 40]; 10];
let sim_diff = awe.audio_similarity(&frames1, &frames2);
assert!(sim_diff < sim_self);
}
#[test]
fn test_query_by_example() {
let config = AcousticEmbeddingConfig::default();
let mut awe = AcousticWordEmbedding::new(config);
awe.add_word("hello", &vec![vec![1.0f32; 40]; 10]);
awe.add_word("world", &vec![vec![2.0f32; 40]; 10]);
awe.add_word("foo", &vec![vec![-1.0f32; 40]; 10]);
assert_eq!(awe.index_size(), 3);
let query = vec![vec![1.0f32; 40]; 10];
let results = awe.query_by_example(&query, 2);
assert_eq!(results.len(), 2);
assert_eq!(results[0].0, "hello"); }
#[test]
fn test_empty_audio() {
let config = AcousticEmbeddingConfig::default();
let awe = AcousticWordEmbedding::new(config);
let embedding = awe.encode(&[]);
assert_eq!(embedding.len(), awe.embedding_dim());
}
#[test]
fn test_word_embedding_cache() {
let config = AcousticEmbeddingConfig::default();
let mut awe = AcousticWordEmbedding::new(config);
awe.add_word("test", &vec![vec![1.0f32; 40]; 5]);
let emb = awe.get_word_embedding("test");
assert!(emb.is_some());
let emb_none = awe.get_word_embedding("missing");
assert!(emb_none.is_none());
}
#[test]
fn test_compute_stats() {
let config = AcousticEmbeddingConfig::default();
let mut awe = AcousticWordEmbedding::new(config);
let stats_empty = awe.compute_stats();
assert_eq!(stats_empty.num_words, 0);
awe.add_word("a", &vec![vec![1.0f32; 40]; 5]);
awe.add_word("b", &vec![vec![2.0f32; 40]; 5]);
let stats = awe.compute_stats();
assert_eq!(stats.num_words, 2);
assert!(stats.avg_norm > 0.0);
}
#[test]
fn test_text_projection() {
let config = AcousticEmbeddingConfig {
embedding_dim: 64,
feature_dim: 40,
text_projection_dim: Some(100), ..Default::default()
};
let awe = AcousticWordEmbedding::new(config);
let frames = vec![vec![1.0f32; 40]; 10];
let embedding = awe.encode(&frames);
assert_eq!(embedding.len(), 100);
}
}