1use crate::embedding::Embedding;
2use std::cmp::Ordering;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
6pub struct EmbeddingManager {
7 max_speakers: usize,
8 speakers: HashMap<usize, Embedding>,
9 next_speaker_id: usize,
10}
11
12impl EmbeddingManager {
13 pub fn new(max_speakers: usize) -> Self {
14 Self {
15 max_speakers,
16 speakers: HashMap::new(),
17 next_speaker_id: 1,
18 }
19 }
20
21 fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
22 let (dot, norm_a, norm_b) = a
23 .iter()
24 .zip(b.iter())
25 .fold((0.0, 0.0, 0.0), |(dot, norm_a, norm_b), (x, y)| {
26 (dot + x * y, norm_a + x * x, norm_b + y * y)
27 });
28
29 if norm_a == 0.0 || norm_b == 0.0 {
30 return 0.0;
31 }
32
33 dot / (norm_a.sqrt() * norm_b.sqrt())
34 }
35
36 pub fn upsert(&mut self, embedding: &Embedding, threshold: f32) -> Option<usize> {
39 let (best_speaker_id, _best_similarity) = self.speakers.iter().fold(
40 (None, threshold),
41 |(best_id, best_similarity), (&speaker_id, speaker_embedding)| {
42 let similarity =
43 Self::cosine_similarity(embedding.as_slice(), speaker_embedding.as_slice());
44 if similarity > best_similarity {
45 (Some(speaker_id), similarity)
46 } else {
47 (best_id, best_similarity)
48 }
49 },
50 );
51
52 match best_speaker_id {
53 Some(id) => Some(id),
54 None => self.add_speaker(embedding),
55 }
56 }
57
58 pub fn best_match(&self, embedding: &Embedding) -> Option<usize> {
59 self.speakers
60 .iter()
61 .map(|(&speaker_id, speaker_embedding)| {
62 (
63 speaker_id,
64 Self::cosine_similarity(embedding.as_slice(), speaker_embedding.as_slice()),
65 )
66 })
67 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
68 .map(|(speaker_id, _)| speaker_id)
69 }
70
71 fn add_speaker(&mut self, embedding: &Embedding) -> Option<usize> {
72 if self.is_full() {
73 return None;
74 }
75 let speaker_id = self.next_speaker_id;
76 self.speakers.insert(speaker_id, embedding.clone());
77 self.next_speaker_id += 1;
78 Some(speaker_id)
79 }
80
81 pub fn speaker_count(&self) -> usize {
82 self.speakers.len()
83 }
84
85 pub fn is_full(&self) -> bool {
86 self.speakers.len() >= self.max_speakers
87 }
88
89 pub fn speakers(&self) -> &HashMap<usize, Embedding> {
90 &self.speakers
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn zero_vectors_do_not_produce_nan() {
100 let a = Embedding::new(vec![0.0, 0.0]);
101 let b = Embedding::new(vec![0.0, 0.0]);
102 assert_eq!(
103 EmbeddingManager::cosine_similarity(a.as_slice(), b.as_slice()),
104 0.0
105 );
106 }
107
108 #[test]
109 fn upsert_adds_until_cap() {
110 let mut manager = EmbeddingManager::new(1);
111 let first = manager.upsert(&Embedding::new(vec![1.0, 0.0]), 0.5);
112 assert_eq!(first, Some(1));
113
114 let second = manager.upsert(&Embedding::new(vec![0.0, 1.0]), 0.5);
116 assert!(second.is_none());
117 }
118}