polyvoice/types/mod.rs
1//! Core types for speaker diarization.
2//!
3//! These types are shared across the offline pipeline, online diarizer, and
4//! evaluation code. Start with [`DiarizationResult`] and [`SpeakerId`].
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashSet;
8use std::fmt;
9
10/// Opaque identifier for a speaker cluster.
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub struct SpeakerId(pub u32);
13
14/// A remapping table produced by [`SpeakerCluster::merge`](crate::cluster::SpeakerCluster::merge).
15///
16/// When two speaker centroids are merged, all indices after the removed one shift
17/// left by one. This struct captures the old → new mapping so that callers can
18/// update any stored [`SpeakerId`]s (e.g. in [`Segment`]s or [`SpeakerTurn`]s).
19#[derive(Debug, Clone, PartialEq)]
20pub struct SpeakerIdRemap {
21 /// Mapping from old SpeakerId to new SpeakerId.
22 mapping: Vec<(SpeakerId, SpeakerId)>,
23}
24
25impl SpeakerIdRemap {
26 /// Create a remap from a raw vector of (old, new) pairs.
27 ///
28 /// { true }
29 /// `fn from_mapping(mapping: Vec<(SpeakerId, SpeakerId)>) -> Option<Self>`
30 /// { ret.is_some() == (mapping.iter().map(|(old, _)| old).collect::<HashSet<_>>().len() == mapping.len()) }
31 pub fn from_mapping(mapping: Vec<(SpeakerId, SpeakerId)>) -> Option<Self> {
32 let mut seen = HashSet::with_capacity(mapping.len());
33 for (old, _) in &mapping {
34 if !seen.insert(old) {
35 return None;
36 }
37 }
38 Some(Self { mapping })
39 }
40
41 /// { true }
42 /// pub fn remap(&self, id: SpeakerId) -> SpeakerId
43 /// { ret == self.mapping.iter().find(|(old, _)| *old == id).map(|(_, new)| *new).unwrap_or(id) }
44 /// Apply the remap to a single [`SpeakerId`].
45 ///
46 /// Returns the new ID if the old ID was remapped, otherwise returns `id` unchanged.
47 pub fn remap(&self, id: SpeakerId) -> SpeakerId {
48 self.mapping
49 .iter()
50 .find(|(old, _)| *old == id)
51 .map(|(_, new)| *new)
52 .unwrap_or(id)
53 }
54
55 /// { true }
56 /// pub fn is_empty(&self) -> bool
57 /// { ret == (self.mapping.len() == 0) }
58 /// Returns true if no IDs were changed.
59 pub fn is_empty(&self) -> bool {
60 self.mapping.is_empty()
61 }
62
63 /// { true }
64 /// pub fn len(&self) -> usize
65 /// { ret == self.mapping.len() }
66 /// Returns the number of remapped IDs.
67 pub fn len(&self) -> usize {
68 self.mapping.len()
69 }
70}
71
72/// Remap speaker IDs in a slice of [`Segment`]s in-place.
73///
74/// { true }
75/// `fn remap_segments(segments: &mut [Segment], remap: &SpeakerIdRemap)`
76/// { segments.iter().all(|s| s.speaker.map_or(true, |spk| remap.remap(spk) == s.speaker.unwrap())) || !remap.is_empty() }
77pub fn remap_segments(segments: &mut [Segment], remap: &SpeakerIdRemap) {
78 for seg in segments.iter_mut() {
79 if let Some(spk) = seg.speaker {
80 seg.speaker = Some(remap.remap(spk));
81 }
82 }
83}
84
85/// Remap speaker IDs in a slice of [`SpeakerTurn`]s in-place.
86///
87/// { true }
88/// `fn remap_turns(turns: &mut [SpeakerTurn], remap: &SpeakerIdRemap)`
89/// { turns.iter().all(|t| remap.remap(t.speaker) == t.speaker) || !remap.is_empty() }
90pub fn remap_turns(turns: &mut [SpeakerTurn], remap: &SpeakerIdRemap) {
91 for turn in turns.iter_mut() {
92 turn.speaker = remap.remap(turn.speaker);
93 }
94}
95
96impl fmt::Display for SpeakerId {
97 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
98 write!(f, "SPEAKER_{:02}", self.0)
99 }
100}
101
102/// Pre-configured model bundles trading off accuracy and footprint.
103///
104/// `Mobile` targets weak/embedded ARM CPUs (≤10 MB total models, ≤200 MB peak RAM).
105/// `Balanced` targets modern phone/laptop ARM CPUs (≤35 MB total models, ≤400 MB peak RAM).
106/// `Custom` defers all model selection to the caller and is used by `PipelineBuilder`
107/// when individual `Segmenter`/`Embedder`/`Clusterer` instances are supplied directly.
108///
109/// Added in v0.6 (M0).
110/// §5.1 for the full motivation.
111#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
112pub enum Profile {
113 Mobile,
114 Balanced,
115 Custom,
116}
117
118impl Profile {
119 /// Embedding dimension produced by the embedder for this profile.
120 /// Returns 0 for `Custom` (caller must resolve dimension explicitly).
121 pub const fn embedding_dim(self) -> usize {
122 match self {
123 Profile::Mobile => 512, // CAM++ output dim (voxceleb_CAM++.onnx)
124 Profile::Balanced => 256, // WeSpeaker ResNet34 output dim
125 Profile::Custom => 0,
126 }
127 }
128
129 /// Default cosine similarity threshold tuned to the embedding space of this profile.
130 pub const fn default_threshold(self) -> f32 {
131 match self {
132 Profile::Mobile => 0.55,
133 Profile::Balanced => 0.45,
134 Profile::Custom => 0.5,
135 }
136 }
137
138 /// Stable identifier used in the manifest TOML and CLI flags.
139 pub const fn manifest_id(self) -> &'static str {
140 match self {
141 Profile::Mobile => "mobile",
142 Profile::Balanced => "balanced",
143 Profile::Custom => "custom",
144 }
145 }
146}
147
148impl std::str::FromStr for Profile {
149 type Err = ProfileParseError;
150
151 fn from_str(s: &str) -> Result<Self, Self::Err> {
152 match s.to_ascii_lowercase().as_str() {
153 "mobile" => Ok(Profile::Mobile),
154 "balanced" => Ok(Profile::Balanced),
155 "custom" => Ok(Profile::Custom),
156 other => Err(ProfileParseError(other.to_owned())),
157 }
158 }
159}
160
161/// Returned by `Profile::from_str` when the input doesn't match a known variant.
162#[derive(Debug, Clone)]
163pub struct ProfileParseError(pub String);
164
165impl std::fmt::Display for ProfileParseError {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 write!(
168 f,
169 "unknown profile '{}': expected mobile|balanced|custom",
170 self.0
171 )
172 }
173}
174
175impl std::error::Error for ProfileParseError {}
176
177/// A validated sample rate (8000–192000 Hz).
178///
179/// Invariant: 8000 <= inner <= 192000.
180#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
181pub struct SampleRate(u32);
182
183impl SampleRate {
184 /// { true }
185 /// `pub fn new(rate: u32) -> Option<Self>`
186 /// { ret.is_some() == (8000..=192000).contains(&rate) }
187 /// Create a validated sample rate.
188 ///
189 /// Returns `None` if the rate is outside the supported range (8000–192000 Hz).
190 ///
191 /// ```rust
192 /// use polyvoice::SampleRate;
193 /// let sr = SampleRate::new(16000).expect("valid rate");
194 /// assert_eq!(sr.get(), 16000);
195 /// assert!(SampleRate::new(7000).is_none());
196 /// ```
197 pub fn new(rate: u32) -> Option<Self> {
198 (8000..=192000).contains(&rate).then_some(Self(rate))
199 }
200
201 /// { true }
202 /// pub fn get(&self) -> u32
203 /// { ret == self.0 && 8000 <= ret && ret <= 192000 }
204 /// Return the raw sample rate value in Hz.
205 ///
206 /// ```rust
207 /// use polyvoice::SampleRate;
208 /// let sr = SampleRate::new(44100).unwrap();
209 /// assert_eq!(sr.get(), 44100);
210 /// ```
211 pub fn get(&self) -> u32 {
212 self.0
213 }
214}
215
216impl Default for SampleRate {
217 fn default() -> Self {
218 Self(16000)
219 }
220}
221
222/// A validated confidence score in [0.0, 1.0].
223///
224/// Invariant: 0.0 <= inner <= 1.0.
225#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
226pub struct Confidence(f32);
227
228impl Confidence {
229 /// { true }
230 /// `pub fn new(v: f32) -> Option<Self>`
231 /// { ret.is_some() == (0.0..=1.0).contains(&v) }
232 /// Create a validated confidence score.
233 ///
234 /// Returns `None` if `v` is outside `[0.0, 1.0]`.
235 ///
236 /// ```rust
237 /// use polyvoice::Confidence;
238 /// assert!(Confidence::new(0.75).is_some());
239 /// assert!(Confidence::new(1.5).is_none());
240 /// ```
241 pub fn new(v: f32) -> Option<Self> {
242 (0.0..=1.0).contains(&v).then_some(Self(v))
243 }
244
245 /// { true }
246 /// pub fn get(&self) -> f32
247 /// { ret == self.0 && 0.0 <= ret && ret <= 1.0 }
248 /// Return the raw confidence value.
249 ///
250 /// ```rust
251 /// use polyvoice::Confidence;
252 /// let c = Confidence::new(0.9).unwrap();
253 /// assert_eq!(c.get(), 0.9);
254 /// ```
255 pub fn get(&self) -> f32 {
256 self.0
257 }
258}
259
260impl Default for Confidence {
261 fn default() -> Self {
262 Self(1.0)
263 }
264}
265
266/// A non-negative duration in seconds.
267///
268/// Invariant: inner >= 0.0.
269#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
270pub struct Seconds(f32);
271
272impl Seconds {
273 /// { true }
274 /// `pub fn new(v: f32) -> Option<Self>`
275 /// { ret.is_some() == (v >= 0.0) }
276 /// Create a validated non-negative duration in seconds.
277 ///
278 /// Returns `None` if `v` is negative.
279 ///
280 /// ```rust
281 /// use polyvoice::Seconds;
282 /// assert!(Seconds::new(3.5).is_some());
283 /// assert!(Seconds::new(-1.0).is_none());
284 /// ```
285 pub fn new(v: f32) -> Option<Self> {
286 (v >= 0.0).then_some(Self(v))
287 }
288
289 /// { true }
290 /// pub fn get(&self) -> f32
291 /// { ret == self.0 && ret >= 0.0 }
292 /// Return the raw duration value in seconds.
293 ///
294 /// ```rust
295 /// use polyvoice::Seconds;
296 /// let s = Seconds::new(2.0).unwrap();
297 /// assert_eq!(s.get(), 2.0);
298 /// ```
299 pub fn get(&self) -> f32 {
300 self.0
301 }
302}
303
304impl Default for Seconds {
305 fn default() -> Self {
306 Self(0.0)
307 }
308}
309
310/// A time interval in seconds.
311#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
312pub struct TimeRange {
313 /// Start time in seconds.
314 pub start: f64,
315 /// End time in seconds.
316 pub end: f64,
317}
318
319impl TimeRange {
320 /// { true }
321 /// pub fn duration(&self) -> f64
322 /// { ret >= 0.0 }
323 /// Return the duration of this time range in seconds.
324 ///
325 /// ```rust
326 /// use polyvoice::TimeRange;
327 /// let tr = TimeRange { start: 1.0, end: 3.5 };
328 /// assert_eq!(tr.duration(), 2.5);
329 /// ```
330 pub fn duration(&self) -> f64 {
331 (self.end - self.start).max(0.0)
332 }
333}
334
335/// A speech segment with a speaker label.
336#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
337pub struct Segment {
338 /// Time range of the segment.
339 pub time: TimeRange,
340 /// Assigned speaker (None if not yet clustered).
341 pub speaker: Option<SpeakerId>,
342 /// Confidence of the speaker assignment (cosine similarity or posterior).
343 pub confidence: Option<f32>,
344}
345
346/// A speaker turn: continuous stretch of speech by one speaker.
347#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
348pub struct SpeakerTurn {
349 pub speaker: SpeakerId,
350 pub time: TimeRange,
351 /// Transcript text, if available from an ASR downstream.
352 #[serde(skip_serializing_if = "Option::is_none")]
353 pub text: Option<String>,
354}
355
356/// Alignment of a single word to a speaker and time range.
357#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
358pub struct WordAlignment {
359 pub word: String,
360 pub time: TimeRange,
361 pub speaker: Option<SpeakerId>,
362 pub confidence: f32,
363}
364
365/// Result of offline diarization.
366#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
367pub struct DiarizationResult {
368 pub segments: Vec<Segment>,
369 pub turns: Vec<SpeakerTurn>,
370 pub num_speakers: usize,
371}
372
373/// Configuration for speaker clustering.
374#[derive(Debug, Clone, Copy)]
375pub struct ClusterConfig {
376 /// Cosine similarity threshold for assigning to an existing speaker.
377 pub threshold: f32,
378 /// Maximum number of speakers to track.
379 pub max_speakers: usize,
380}
381
382impl Default for ClusterConfig {
383 fn default() -> Self {
384 Self {
385 threshold: 0.45,
386 max_speakers: 64,
387 }
388 }
389}
390
391/// Configuration for sliding-window embedding extraction.
392#[derive(Debug, Clone, Copy)]
393pub struct WindowConfig {
394 /// Window size for embedding extraction, in seconds.
395 pub window_secs: f32,
396 /// Hop length between consecutive windows, in seconds.
397 pub hop_secs: f32,
398 /// Sample rate expected by the embedding model (usually 16000).
399 pub sample_rate: SampleRate,
400}
401
402impl Default for WindowConfig {
403 fn default() -> Self {
404 Self {
405 window_secs: 1.5,
406 hop_secs: 0.75,
407 sample_rate: SampleRate(16000),
408 }
409 }
410}
411
412impl WindowConfig {
413 /// { self.window_secs >= 0.0 }
414 /// `fn window_samples(&self) -> usize`
415 /// { ret == (self.window_secs * self.sample_rate.get() as f32) as usize }
416 pub fn window_samples(&self) -> usize {
417 (self.window_secs * self.sample_rate.get() as f32) as usize
418 }
419
420 /// { self.hop_secs >= 0.0 }
421 /// `fn hop_samples(&self) -> usize`
422 /// { ret == (self.hop_secs * self.sample_rate.get() as f32) as usize }
423 pub fn hop_samples(&self) -> usize {
424 (self.hop_secs * self.sample_rate.get() as f32) as usize
425 }
426}
427
428/// Configuration for post-clustering speech filtering.
429#[derive(Debug, Clone, Copy)]
430pub struct SpeechFilterConfig {
431 /// Minimum speech duration to consider for clustering, in seconds.
432 pub min_speech_secs: f32,
433 /// Maximum gap between same-speaker segments to merge, in seconds.
434 pub max_gap_secs: f32,
435}
436
437impl Default for SpeechFilterConfig {
438 fn default() -> Self {
439 Self {
440 min_speech_secs: 0.25,
441 max_gap_secs: 0.5,
442 }
443 }
444}
445
446/// Configuration shared between online and offline diarizers.
447#[derive(Debug, Clone, Copy)]
448pub struct DiarizationConfig {
449 pub cluster: ClusterConfig,
450 pub window: WindowConfig,
451 pub speech_filter: SpeechFilterConfig,
452 /// Maximum allowed audio duration in seconds (DoS guard).
453 pub max_duration_secs: f32,
454}
455
456impl Default for DiarizationConfig {
457 fn default() -> Self {
458 Self {
459 cluster: ClusterConfig::default(),
460 window: WindowConfig::default(),
461 speech_filter: SpeechFilterConfig::default(),
462 max_duration_secs: 3600.0,
463 }
464 }
465}
466
467impl DiarizationConfig {
468 /// { self.window.window_secs >= 0.0 }
469 /// `fn window_samples(&self) -> usize`
470 /// { ret == self.window.window_samples() }
471 pub fn window_samples(&self) -> usize {
472 self.window.window_samples()
473 }
474
475 /// { self.window.hop_secs >= 0.0 }
476 /// `fn hop_samples(&self) -> usize`
477 /// { ret == self.window.hop_samples() }
478 pub fn hop_samples(&self) -> usize {
479 self.window.hop_samples()
480 }
481}
482
483#[allow(clippy::unwrap_used)]
484#[cfg(test)]
485mod speaker_id_remap_tests {
486 use super::*;
487
488 #[test]
489 fn from_mapping_accepts_unique_old_ids() {
490 let mapping = vec![
491 (SpeakerId(0), SpeakerId(0)),
492 (SpeakerId(1), SpeakerId(0)),
493 (SpeakerId(2), SpeakerId(1)),
494 ];
495 let remap = SpeakerIdRemap::from_mapping(mapping).unwrap();
496 assert_eq!(remap.len(), 3);
497 assert_eq!(remap.remap(SpeakerId(0)), SpeakerId(0));
498 assert_eq!(remap.remap(SpeakerId(1)), SpeakerId(0));
499 assert_eq!(remap.remap(SpeakerId(2)), SpeakerId(1));
500 assert_eq!(remap.remap(SpeakerId(99)), SpeakerId(99));
501 }
502
503 #[test]
504 fn from_mapping_rejects_duplicate_old_ids() {
505 let mapping = vec![(SpeakerId(0), SpeakerId(1)), (SpeakerId(0), SpeakerId(2))];
506 assert!(SpeakerIdRemap::from_mapping(mapping).is_none());
507 }
508}
509
510#[allow(clippy::unwrap_used)]
511#[cfg(test)]
512mod profile_tests {
513 use super::*;
514
515 #[test]
516 fn mobile_profile_uses_cam_pp_dim() {
517 assert_eq!(Profile::Mobile.embedding_dim(), 512);
518 }
519
520 #[test]
521 fn balanced_profile_uses_resnet34_dim() {
522 assert_eq!(Profile::Balanced.embedding_dim(), 256);
523 }
524
525 #[test]
526 fn custom_profile_dim_is_unresolved() {
527 assert_eq!(Profile::Custom.embedding_dim(), 0);
528 }
529
530 #[test]
531 fn default_thresholds_match_spec() {
532 // §5.1 of v1.0 design spec
533 assert!((Profile::Mobile.default_threshold() - 0.55).abs() < 1e-6);
534 assert!((Profile::Balanced.default_threshold() - 0.45).abs() < 1e-6);
535 assert!((Profile::Custom.default_threshold() - 0.5).abs() < 1e-6);
536 }
537
538 #[test]
539 fn manifest_id_for_each_variant() {
540 assert_eq!(Profile::Mobile.manifest_id(), "mobile");
541 assert_eq!(Profile::Balanced.manifest_id(), "balanced");
542 assert_eq!(Profile::Custom.manifest_id(), "custom");
543 }
544
545 #[test]
546 fn from_str_parses_kebab_and_lowercase() {
547 assert_eq!("mobile".parse::<Profile>().unwrap(), Profile::Mobile);
548 assert_eq!("Mobile".parse::<Profile>().unwrap(), Profile::Mobile);
549 assert_eq!("balanced".parse::<Profile>().unwrap(), Profile::Balanced);
550 assert!("nope".parse::<Profile>().is_err());
551 }
552}
553
554#[cfg(kani)]
555mod kani_proofs;