Skip to main content

polyvoice/types/
mod.rs

1//! Core types for speaker diarization.
2//!
3//! These types are shared across the offline pipeline, online diarizer, and
4//! evaluation code. Start with [`DiarizationResult`] and [`SpeakerId`].
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::fmt;
9
10/// Opaque identifier for a speaker cluster.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct SpeakerId(pub u32);
13
14/// A remapping table produced by [`SpeakerCluster::merge`](crate::cluster::SpeakerCluster::merge).
15///
16/// When two speaker centroids are merged, all indices after the removed one shift
17/// left by one. This struct captures the old → new mapping so that callers can
18/// update any stored [`SpeakerId`]s (e.g. in [`Segment`]s or [`SpeakerTurn`]s).
19#[derive(Debug, Clone, PartialEq)]
20pub struct SpeakerIdRemap {
21    /// Mapping from old SpeakerId to new SpeakerId.
22    mapping: Vec<(SpeakerId, SpeakerId)>,
23}
24
25impl SpeakerIdRemap {
26    /// Create a remap from a raw vector of (old, new) pairs.
27    ///
28    /// { true }
29    /// `fn from_mapping(mapping: Vec<(SpeakerId, SpeakerId)>) -> Option<Self>`
30    /// { ret.is_some() == (mapping.iter().map(|(old, _)| old).collect::<HashSet<_>>().len() == mapping.len()) }
31    pub fn from_mapping(mapping: Vec<(SpeakerId, SpeakerId)>) -> Option<Self> {
32        let mut seen = HashSet::with_capacity(mapping.len());
33        for (old, _) in &mapping {
34            if !seen.insert(old) {
35                return None;
36            }
37        }
38        Some(Self { mapping })
39    }
40
41    /// { true }
42    /// pub fn remap(&self, id: SpeakerId) -> SpeakerId
43    /// { ret == self.mapping.iter().find(|(old, _)| *old == id).map(|(_, new)| *new).unwrap_or(id) }
44    /// Apply the remap to a single [`SpeakerId`].
45    ///
46    /// Returns the new ID if the old ID was remapped, otherwise returns `id` unchanged.
47    pub fn remap(&self, id: SpeakerId) -> SpeakerId {
48        self.mapping
49            .iter()
50            .find(|(old, _)| *old == id)
51            .map(|(_, new)| *new)
52            .unwrap_or(id)
53    }
54
55    /// { true }
56    /// pub fn is_empty(&self) -> bool
57    /// { ret == (self.mapping.len() == 0) }
58    /// Returns true if no IDs were changed.
59    pub fn is_empty(&self) -> bool {
60        self.mapping.is_empty()
61    }
62
63    /// { true }
64    /// pub fn len(&self) -> usize
65    /// { ret == self.mapping.len() }
66    /// Returns the number of remapped IDs.
67    pub fn len(&self) -> usize {
68        self.mapping.len()
69    }
70}
71
72/// Remap speaker IDs in a slice of [`Segment`]s in-place.
73///
74/// { true }
75/// `fn remap_segments(segments: &mut [Segment], remap: &SpeakerIdRemap)`
76/// { segments.iter().all(|s| s.speaker.map_or(true, |spk| remap.remap(spk) == s.speaker.unwrap())) || !remap.is_empty() }
77pub fn remap_segments(segments: &mut [Segment], remap: &SpeakerIdRemap) {
78    for seg in segments.iter_mut() {
79        if let Some(spk) = seg.speaker {
80            seg.speaker = Some(remap.remap(spk));
81        }
82    }
83}
84
85/// Remap speaker IDs in a slice of [`SpeakerTurn`]s in-place.
86///
87/// { true }
88/// `fn remap_turns(turns: &mut [SpeakerTurn], remap: &SpeakerIdRemap)`
89/// { turns.iter().all(|t| remap.remap(t.speaker) == t.speaker) || !remap.is_empty() }
90pub fn remap_turns(turns: &mut [SpeakerTurn], remap: &SpeakerIdRemap) {
91    for turn in turns.iter_mut() {
92        turn.speaker = remap.remap(turn.speaker);
93    }
94}
95
96impl fmt::Display for SpeakerId {
97    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98        write!(f, "SPEAKER_{:02}", self.0)
99    }
100}
101
102/// Pre-configured model bundles trading off accuracy and footprint.
103///
104/// `Mobile` targets weak/embedded ARM CPUs (≤10 MB total models, ≤200 MB peak RAM).
105/// `Balanced` targets modern phone/laptop ARM CPUs (≤35 MB total models, ≤400 MB peak RAM).
106/// `Custom` defers all model selection to the caller and is used by `PipelineBuilder`
107/// when individual `Segmenter`/`Embedder`/`Clusterer` instances are supplied directly.
108///
109/// Added in v0.6 (M0).
110/// §5.1 for the full motivation.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
112pub enum Profile {
113    Mobile,
114    Balanced,
115    Custom,
116}
117
118impl Profile {
119    /// Embedding dimension produced by the embedder for this profile.
120    /// Returns 0 for `Custom` (caller must resolve dimension explicitly).
121    pub const fn embedding_dim(self) -> usize {
122        match self {
123            Profile::Mobile => 512,   // CAM++ output dim (voxceleb_CAM++.onnx)
124            Profile::Balanced => 256, // WeSpeaker ResNet34 output dim
125            Profile::Custom => 0,
126        }
127    }
128
129    /// Default cosine similarity threshold tuned to the embedding space of this profile.
130    pub const fn default_threshold(self) -> f32 {
131        match self {
132            Profile::Mobile => 0.55,
133            Profile::Balanced => 0.45,
134            Profile::Custom => 0.5,
135        }
136    }
137
138    /// Stable identifier used in the manifest TOML and CLI flags.
139    pub const fn manifest_id(self) -> &'static str {
140        match self {
141            Profile::Mobile => "mobile",
142            Profile::Balanced => "balanced",
143            Profile::Custom => "custom",
144        }
145    }
146}
147
148impl std::str::FromStr for Profile {
149    type Err = ProfileParseError;
150
151    fn from_str(s: &str) -> Result<Self, Self::Err> {
152        match s.to_ascii_lowercase().as_str() {
153            "mobile" => Ok(Profile::Mobile),
154            "balanced" => Ok(Profile::Balanced),
155            "custom" => Ok(Profile::Custom),
156            other => Err(ProfileParseError(other.to_owned())),
157        }
158    }
159}
160
161/// Returned by `Profile::from_str` when the input doesn't match a known variant.
162#[derive(Debug, Clone)]
163pub struct ProfileParseError(pub String);
164
165impl std::fmt::Display for ProfileParseError {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        write!(
168            f,
169            "unknown profile '{}': expected mobile|balanced|custom",
170            self.0
171        )
172    }
173}
174
175impl std::error::Error for ProfileParseError {}
176
177/// A validated sample rate (8000–192000 Hz).
178///
179/// Invariant: 8000 <= inner <= 192000.
180#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
181pub struct SampleRate(u32);
182
183impl SampleRate {
184    /// { true }
185    /// `pub fn new(rate: u32) -> Option<Self>`
186    /// { ret.is_some() == (8000..=192000).contains(&rate) }
187    /// Create a validated sample rate.
188    ///
189    /// Returns `None` if the rate is outside the supported range (8000–192000 Hz).
190    ///
191    /// ```rust
192    /// use polyvoice::SampleRate;
193    /// let sr = SampleRate::new(16000).expect("valid rate");
194    /// assert_eq!(sr.get(), 16000);
195    /// assert!(SampleRate::new(7000).is_none());
196    /// ```
197    pub fn new(rate: u32) -> Option<Self> {
198        (8000..=192000).contains(&rate).then_some(Self(rate))
199    }
200
201    /// { true }
202    /// pub fn get(&self) -> u32
203    /// { ret == self.0 && 8000 <= ret && ret <= 192000 }
204    /// Return the raw sample rate value in Hz.
205    ///
206    /// ```rust
207    /// use polyvoice::SampleRate;
208    /// let sr = SampleRate::new(44100).unwrap();
209    /// assert_eq!(sr.get(), 44100);
210    /// ```
211    pub fn get(&self) -> u32 {
212        self.0
213    }
214}
215
216impl Default for SampleRate {
217    fn default() -> Self {
218        Self(16000)
219    }
220}
221
222/// A validated confidence score in [0.0, 1.0].
223///
224/// Invariant: 0.0 <= inner <= 1.0.
225#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
226pub struct Confidence(f32);
227
228impl Confidence {
229    /// { true }
230    /// `pub fn new(v: f32) -> Option<Self>`
231    /// { ret.is_some() == (0.0..=1.0).contains(&v) }
232    /// Create a validated confidence score.
233    ///
234    /// Returns `None` if `v` is outside `[0.0, 1.0]`.
235    ///
236    /// ```rust
237    /// use polyvoice::Confidence;
238    /// assert!(Confidence::new(0.75).is_some());
239    /// assert!(Confidence::new(1.5).is_none());
240    /// ```
241    pub fn new(v: f32) -> Option<Self> {
242        (0.0..=1.0).contains(&v).then_some(Self(v))
243    }
244
245    /// { true }
246    /// pub fn get(&self) -> f32
247    /// { ret == self.0 && 0.0 <= ret && ret <= 1.0 }
248    /// Return the raw confidence value.
249    ///
250    /// ```rust
251    /// use polyvoice::Confidence;
252    /// let c = Confidence::new(0.9).unwrap();
253    /// assert_eq!(c.get(), 0.9);
254    /// ```
255    pub fn get(&self) -> f32 {
256        self.0
257    }
258}
259
260impl Default for Confidence {
261    fn default() -> Self {
262        Self(1.0)
263    }
264}
265
266/// A non-negative duration in seconds.
267///
268/// Invariant: inner >= 0.0.
269#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
270pub struct Seconds(f32);
271
272impl Seconds {
273    /// { true }
274    /// `pub fn new(v: f32) -> Option<Self>`
275    /// { ret.is_some() == (v >= 0.0) }
276    /// Create a validated non-negative duration in seconds.
277    ///
278    /// Returns `None` if `v` is negative.
279    ///
280    /// ```rust
281    /// use polyvoice::Seconds;
282    /// assert!(Seconds::new(3.5).is_some());
283    /// assert!(Seconds::new(-1.0).is_none());
284    /// ```
285    pub fn new(v: f32) -> Option<Self> {
286        (v >= 0.0).then_some(Self(v))
287    }
288
289    /// { true }
290    /// pub fn get(&self) -> f32
291    /// { ret == self.0 && ret >= 0.0 }
292    /// Return the raw duration value in seconds.
293    ///
294    /// ```rust
295    /// use polyvoice::Seconds;
296    /// let s = Seconds::new(2.0).unwrap();
297    /// assert_eq!(s.get(), 2.0);
298    /// ```
299    pub fn get(&self) -> f32 {
300        self.0
301    }
302}
303
304impl Default for Seconds {
305    fn default() -> Self {
306        Self(0.0)
307    }
308}
309
310/// A time interval in seconds.
311#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
312pub struct TimeRange {
313    /// Start time in seconds.
314    pub start: f64,
315    /// End time in seconds.
316    pub end: f64,
317}
318
319impl TimeRange {
320    /// { true }
321    /// pub fn duration(&self) -> f64
322    /// { ret >= 0.0 }
323    /// Return the duration of this time range in seconds.
324    ///
325    /// ```rust
326    /// use polyvoice::TimeRange;
327    /// let tr = TimeRange { start: 1.0, end: 3.5 };
328    /// assert_eq!(tr.duration(), 2.5);
329    /// ```
330    pub fn duration(&self) -> f64 {
331        (self.end - self.start).max(0.0)
332    }
333}
334
335/// A speech segment with a speaker label.
336#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
337pub struct Segment {
338    /// Time range of the segment.
339    pub time: TimeRange,
340    /// Assigned speaker (None if not yet clustered).
341    pub speaker: Option<SpeakerId>,
342    /// Confidence of the speaker assignment (cosine similarity or posterior).
343    pub confidence: Option<f32>,
344}
345
346/// A speaker turn: continuous stretch of speech by one speaker.
347#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
348pub struct SpeakerTurn {
349    pub speaker: SpeakerId,
350    pub time: TimeRange,
351    /// Transcript text, if available from an ASR downstream.
352    #[serde(skip_serializing_if = "Option::is_none")]
353    pub text: Option<String>,
354}
355
356/// Alignment of a single word to a speaker and time range.
357#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
358pub struct WordAlignment {
359    pub word: String,
360    pub time: TimeRange,
361    pub speaker: Option<SpeakerId>,
362    pub confidence: f32,
363}
364
365/// Result of offline diarization.
366#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
367pub struct DiarizationResult {
368    pub segments: Vec<Segment>,
369    pub turns: Vec<SpeakerTurn>,
370    pub num_speakers: usize,
371}
372
373/// Configuration for speaker clustering.
374#[derive(Debug, Clone, Copy)]
375pub struct ClusterConfig {
376    /// Cosine similarity threshold for assigning to an existing speaker.
377    pub threshold: f32,
378    /// Maximum number of speakers to track.
379    pub max_speakers: usize,
380}
381
382impl Default for ClusterConfig {
383    fn default() -> Self {
384        Self {
385            threshold: 0.45,
386            max_speakers: 64,
387        }
388    }
389}
390
391/// Configuration for sliding-window embedding extraction.
392#[derive(Debug, Clone, Copy)]
393pub struct WindowConfig {
394    /// Window size for embedding extraction, in seconds.
395    pub window_secs: f32,
396    /// Hop length between consecutive windows, in seconds.
397    pub hop_secs: f32,
398    /// Sample rate expected by the embedding model (usually 16000).
399    pub sample_rate: SampleRate,
400}
401
402impl Default for WindowConfig {
403    fn default() -> Self {
404        Self {
405            window_secs: 1.5,
406            hop_secs: 0.75,
407            sample_rate: SampleRate(16000),
408        }
409    }
410}
411
412impl WindowConfig {
413    /// { self.window_secs >= 0.0 }
414    /// `fn window_samples(&self) -> usize`
415    /// { ret == (self.window_secs * self.sample_rate.get() as f32) as usize }
416    pub fn window_samples(&self) -> usize {
417        (self.window_secs * self.sample_rate.get() as f32) as usize
418    }
419
420    /// { self.hop_secs >= 0.0 }
421    /// `fn hop_samples(&self) -> usize`
422    /// { ret == (self.hop_secs * self.sample_rate.get() as f32) as usize }
423    pub fn hop_samples(&self) -> usize {
424        (self.hop_secs * self.sample_rate.get() as f32) as usize
425    }
426}
427
428/// Configuration for post-clustering speech filtering.
429#[derive(Debug, Clone, Copy)]
430pub struct SpeechFilterConfig {
431    /// Minimum speech duration to consider for clustering, in seconds.
432    pub min_speech_secs: f32,
433    /// Maximum gap between same-speaker segments to merge, in seconds.
434    pub max_gap_secs: f32,
435}
436
437impl Default for SpeechFilterConfig {
438    fn default() -> Self {
439        Self {
440            min_speech_secs: 0.25,
441            max_gap_secs: 0.5,
442        }
443    }
444}
445
446/// Configuration shared between online and offline diarizers.
447#[derive(Debug, Clone, Copy)]
448pub struct DiarizationConfig {
449    pub cluster: ClusterConfig,
450    pub window: WindowConfig,
451    pub speech_filter: SpeechFilterConfig,
452    /// Maximum allowed audio duration in seconds (DoS guard).
453    pub max_duration_secs: f32,
454}
455
456impl Default for DiarizationConfig {
457    fn default() -> Self {
458        Self {
459            cluster: ClusterConfig::default(),
460            window: WindowConfig::default(),
461            speech_filter: SpeechFilterConfig::default(),
462            max_duration_secs: 3600.0,
463        }
464    }
465}
466
467impl DiarizationConfig {
468    /// { self.window.window_secs >= 0.0 }
469    /// `fn window_samples(&self) -> usize`
470    /// { ret == self.window.window_samples() }
471    pub fn window_samples(&self) -> usize {
472        self.window.window_samples()
473    }
474
475    /// { self.window.hop_secs >= 0.0 }
476    /// `fn hop_samples(&self) -> usize`
477    /// { ret == self.window.hop_samples() }
478    pub fn hop_samples(&self) -> usize {
479        self.window.hop_samples()
480    }
481}
482
483#[allow(clippy::unwrap_used)]
484#[cfg(test)]
485mod speaker_id_remap_tests {
486    use super::*;
487
488    #[test]
489    fn from_mapping_accepts_unique_old_ids() {
490        let mapping = vec![
491            (SpeakerId(0), SpeakerId(0)),
492            (SpeakerId(1), SpeakerId(0)),
493            (SpeakerId(2), SpeakerId(1)),
494        ];
495        let remap = SpeakerIdRemap::from_mapping(mapping).unwrap();
496        assert_eq!(remap.len(), 3);
497        assert_eq!(remap.remap(SpeakerId(0)), SpeakerId(0));
498        assert_eq!(remap.remap(SpeakerId(1)), SpeakerId(0));
499        assert_eq!(remap.remap(SpeakerId(2)), SpeakerId(1));
500        assert_eq!(remap.remap(SpeakerId(99)), SpeakerId(99));
501    }
502
503    #[test]
504    fn from_mapping_rejects_duplicate_old_ids() {
505        let mapping = vec![(SpeakerId(0), SpeakerId(1)), (SpeakerId(0), SpeakerId(2))];
506        assert!(SpeakerIdRemap::from_mapping(mapping).is_none());
507    }
508}
509
510#[allow(clippy::unwrap_used)]
511#[cfg(test)]
512mod profile_tests {
513    use super::*;
514
515    #[test]
516    fn mobile_profile_uses_cam_pp_dim() {
517        assert_eq!(Profile::Mobile.embedding_dim(), 512);
518    }
519
520    #[test]
521    fn balanced_profile_uses_resnet34_dim() {
522        assert_eq!(Profile::Balanced.embedding_dim(), 256);
523    }
524
525    #[test]
526    fn custom_profile_dim_is_unresolved() {
527        assert_eq!(Profile::Custom.embedding_dim(), 0);
528    }
529
530    #[test]
531    fn default_thresholds_match_spec() {
532        // §5.1 of v1.0 design spec
533        assert!((Profile::Mobile.default_threshold() - 0.55).abs() < 1e-6);
534        assert!((Profile::Balanced.default_threshold() - 0.45).abs() < 1e-6);
535        assert!((Profile::Custom.default_threshold() - 0.5).abs() < 1e-6);
536    }
537
538    #[test]
539    fn manifest_id_for_each_variant() {
540        assert_eq!(Profile::Mobile.manifest_id(), "mobile");
541        assert_eq!(Profile::Balanced.manifest_id(), "balanced");
542        assert_eq!(Profile::Custom.manifest_id(), "custom");
543    }
544
545    #[test]
546    fn from_str_parses_kebab_and_lowercase() {
547        assert_eq!("mobile".parse::<Profile>().unwrap(), Profile::Mobile);
548        assert_eq!("Mobile".parse::<Profile>().unwrap(), Profile::Mobile);
549        assert_eq!("balanced".parse::<Profile>().unwrap(), Profile::Balanced);
550        assert!("nope".parse::<Profile>().is_err());
551    }
552}
553
554#[cfg(kani)]
555mod kani_proofs;