pyannote_rs/
identify.rs

1use crate::embedding::Embedding;
2use std::cmp::Ordering;
3use std::collections::HashMap;
4
5#[derive(Debug, Clone)]
6pub struct EmbeddingManager {
7    max_speakers: usize,
8    speakers: HashMap<usize, Embedding>,
9    next_speaker_id: usize,
10}
11
12impl EmbeddingManager {
13    pub fn new(max_speakers: usize) -> Self {
14        Self {
15            max_speakers,
16            speakers: HashMap::new(),
17            next_speaker_id: 1,
18        }
19    }
20
21    fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
22        let (dot, norm_a, norm_b) = a
23            .iter()
24            .zip(b.iter())
25            .fold((0.0, 0.0, 0.0), |(dot, norm_a, norm_b), (x, y)| {
26                (dot + x * y, norm_a + x * x, norm_b + y * y)
27            });
28
29        if norm_a == 0.0 || norm_b == 0.0 {
30            return 0.0;
31        }
32
33        dot / (norm_a.sqrt() * norm_b.sqrt())
34    }
35
36    /// Try to match a speaker; if none is found above `threshold`, register a new speaker
37    /// as long as capacity allows.
38    pub fn upsert(&mut self, embedding: &Embedding, threshold: f32) -> Option<usize> {
39        let (best_speaker_id, _best_similarity) = self.speakers.iter().fold(
40            (None, threshold),
41            |(best_id, best_similarity), (&speaker_id, speaker_embedding)| {
42                let similarity =
43                    Self::cosine_similarity(embedding.as_slice(), speaker_embedding.as_slice());
44                if similarity > best_similarity {
45                    (Some(speaker_id), similarity)
46                } else {
47                    (best_id, best_similarity)
48                }
49            },
50        );
51
52        match best_speaker_id {
53            Some(id) => Some(id),
54            None => self.add_speaker(embedding),
55        }
56    }
57
58    pub fn best_match(&self, embedding: &Embedding) -> Option<usize> {
59        self.speakers
60            .iter()
61            .map(|(&speaker_id, speaker_embedding)| {
62                (
63                    speaker_id,
64                    Self::cosine_similarity(embedding.as_slice(), speaker_embedding.as_slice()),
65                )
66            })
67            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal))
68            .map(|(speaker_id, _)| speaker_id)
69    }
70
71    fn add_speaker(&mut self, embedding: &Embedding) -> Option<usize> {
72        if self.is_full() {
73            return None;
74        }
75        let speaker_id = self.next_speaker_id;
76        self.speakers.insert(speaker_id, embedding.clone());
77        self.next_speaker_id += 1;
78        Some(speaker_id)
79    }
80
81    pub fn speaker_count(&self) -> usize {
82        self.speakers.len()
83    }
84
85    pub fn is_full(&self) -> bool {
86        self.speakers.len() >= self.max_speakers
87    }
88
89    pub fn speakers(&self) -> &HashMap<usize, Embedding> {
90        &self.speakers
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    #[test]
99    fn zero_vectors_do_not_produce_nan() {
100        let a = Embedding::new(vec![0.0, 0.0]);
101        let b = Embedding::new(vec![0.0, 0.0]);
102        assert_eq!(
103            EmbeddingManager::cosine_similarity(a.as_slice(), b.as_slice()),
104            0.0
105        );
106    }
107
108    #[test]
109    fn upsert_adds_until_cap() {
110        let mut manager = EmbeddingManager::new(1);
111        let first = manager.upsert(&Embedding::new(vec![1.0, 0.0]), 0.5);
112        assert_eq!(first, Some(1));
113
114        // Second unique embedding should be rejected because max_speakers is 1.
115        let second = manager.upsert(&Embedding::new(vec![0.0, 1.0]), 0.5);
116        assert!(second.is_none());
117    }
118}