Skip to main content

polyvoice/embedder/
mod.rs

1#![allow(deprecated)] // legacy embedding API (F09); see polyvoice::embedder
2//! v1.0 `Embedder` trait + concrete extractors (CAM++, ResNet34) + pool +
3//! overlap-mask helper.
4//!
5//! Added in v0.6 (M2).
6
7/// Speaker embedding extractor — turns a slice of 16 kHz mono audio into a
8/// fixed-dimension embedding vector. Implementations are expected to L2-normalize
9/// their output so cosine similarity is a meaningful metric downstream.
10///
11/// In v1.0 (M2) the polyvoice crate introduces `Embedder` as the canonical
12/// trait. The legacy `EmbeddingExtractor` trait and its implementations
13/// (`FbankOnnxExtractor`, `OnnxEmbeddingExtractor`, `DummyExtractor`) remain
14/// available unchanged — M6 will deprecate them.
15pub trait Embedder: Send + Sync {
16    /// Output dimension of this embedder. Constant per instance.
17    fn dim(&self) -> usize;
18
19    /// Compute an embedding for one audio segment.
20    ///
21    /// **Requires:** `audio` is 16 kHz mono PCM.
22    /// **Guarantees on Ok:** `result.len() == self.dim()` and the vector is L2-normalized
23    /// (`|sum(x²)¹ᐟ² − 1.0| < 1e-3`).
24    fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError>;
25
26    /// Compute embeddings for a batch of audio segments. Default implementation
27    /// is sequential; impls may override with a true batched ONNX call.
28    fn embed_batch(&self, audios: &[&[f32]]) -> Result<Vec<Vec<f32>>, EmbedderError> {
29        audios.iter().map(|a| self.embed(a)).collect()
30    }
31}
32
33/// Errors from `Embedder` implementations.
34#[derive(Debug, thiserror::Error)]
35pub enum EmbedderError {
36    #[error("audio too short for this embedder: {actual_secs:.3}s < {min_secs:.3}s")]
37    AudioTooShort { actual_secs: f32, min_secs: f32 },
38
39    #[error("ONNX inference failed: {detail}")]
40    InferenceFailed { detail: String },
41
42    #[error("expected embedding dim {expected}, got {actual}")]
43    DimMismatch { expected: usize, actual: usize },
44
45    #[error("model file io error on {path}: {detail}")]
46    ModelIo {
47        path: std::path::PathBuf,
48        detail: String,
49    },
50
51    #[error("legacy adapter error: {0}")]
52    Legacy(String),
53}
54
55/// { true }
56/// `pub fn apply_overlap_mask( audio: &[f32], overlap_regions: &[(f32, f32)], sample_rate: u32, ) -> Vec<f32>`
57/// { ret.len() == audio.len() }
58/// Zero-fill audio samples in regions where the segmenter flagged a 2-speaker
59/// overlap. The returned `Vec<f32>` is a copy of `audio` with zeros in the
60/// `(start_secs, end_secs)` ranges listed in `overlap_regions`.
61///
62/// Out-of-bounds and inverted (end < start) regions are silently clamped or
63/// skipped — never panics.
64///
65/// **Pure Rust, no allocations beyond the output Vec, wasm32-clean.**
66pub fn apply_overlap_mask(
67    audio: &[f32],
68    overlap_regions: &[(f32, f32)],
69    sample_rate: u32,
70) -> Vec<f32> {
71    let mut out = audio.to_vec();
72    if out.is_empty() {
73        return out;
74    }
75    let sr = sample_rate as f32;
76    for &(start_s, end_s) in overlap_regions {
77        if !end_s.is_finite() || !start_s.is_finite() || end_s <= start_s {
78            continue;
79        }
80        let start = (start_s * sr).max(0.0).floor() as usize;
81        let end = (end_s * sr).max(0.0).ceil() as usize;
82        let end = end.min(out.len());
83        if start >= end || start >= out.len() {
84            continue;
85        }
86        for v in &mut out[start..end] {
87            *v = 0.0;
88        }
89    }
90    out
91}
92
93use crossbeam_queue::ArrayQueue;
94use std::sync::Arc;
95
96/// Lock-free pool of `Embedder` instances for concurrent extraction.
97///
98/// Generic over `E: Embedder` so the same pool implementation works for
99/// `CamPlusPlusExtractor`, `ResNet34Adapter`, or any user-provided embedder.
100/// All embedders in a pool must share the same output dimension.
101pub struct EmbedderPool<E: Embedder> {
102    queue: Arc<ArrayQueue<E>>,
103    dim: usize,
104    capacity: usize,
105}
106
107impl<E: Embedder> EmbedderPool<E> {
108    /// { true }
109    /// `pub fn new(embedders: Vec<E>) -> Result<Self, EmbedderError>`
110    /// { ret.is_ok() => ret.as_ref().unwrap().dim() == embedders.first().map_or(0, |e| e.dim()) }
111    /// Build a pool from a list of embedders. All must share the same `dim()`.
112    /// An empty list constructs a pool that fails on every call (returns
113    /// `EmbedderError::Legacy("empty pool")`).
114    pub fn new(embedders: Vec<E>) -> Result<Self, EmbedderError> {
115        let dim = embedders.first().map(|e| e.dim()).unwrap_or(0);
116        for e in embedders.iter().skip(1) {
117            let actual = e.dim();
118            if actual != dim {
119                return Err(EmbedderError::DimMismatch {
120                    expected: dim,
121                    actual,
122                });
123            }
124        }
125        let capacity = embedders.len().max(1);
126        let queue = Arc::new(ArrayQueue::new(capacity));
127        for e in embedders {
128            // ArrayQueue::push only fails if full; capacity == count, so push always succeeds.
129            let _ = queue.push(e);
130        }
131        Ok(Self {
132            queue,
133            dim,
134            capacity,
135        })
136    }
137
138    /// { true }
139    /// pub fn dim(&self) -> usize
140    /// { ret == self.dim }
141    pub fn dim(&self) -> usize {
142        self.dim
143    }
144    /// { true }
145    /// pub fn capacity(&self) -> usize
146    /// { ret == self.capacity }
147    pub fn capacity(&self) -> usize {
148        self.capacity
149    }
150
151    /// { true }
152    /// `pub fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError>`
153    /// { ret.as_ref().map_or(true, |v| v.len() == self.dim) }
154    /// Extract a single embedding using the next-available pooled embedder.
155    /// Blocks (busy-spins) until one is free.
156    pub fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
157        if self.dim == 0 {
158            // Empty-construction case.
159            return Err(EmbedderError::Legacy("empty pool".to_owned()));
160        }
161        // Acquire (busy-wait fallback for simplicity; real Pipeline use is
162        // through `rayon::par_iter` which already throttles concurrency).
163        let embedder = loop {
164            if let Some(e) = self.queue.pop() {
165                break e;
166            }
167            std::hint::spin_loop();
168        };
169        let result = embedder.embed(audio);
170        // Always return the embedder.
171        let _ = self.queue.push(embedder);
172        result
173    }
174}
175
176/// Parallel batch embedding using `std::thread::scope`.
177/// Spawns up to `available_parallelism` threads, each processing a chunk
178/// of the input via `embedder.embed()`.
179#[cfg(feature = "onnx")]
180fn parallel_embed_batch<E: Embedder>(
181    embedder: &E,
182    audios: &[&[f32]],
183) -> Result<Vec<Vec<f32>>, EmbedderError> {
184    let n = audios.len();
185    if n == 0 {
186        return Ok(Vec::new());
187    }
188    let num_threads = std::thread::available_parallelism()
189        .map(|n| n.get())
190        .unwrap_or(4)
191        .min(n);
192
193    let chunk_size = n.div_ceil(num_threads);
194    let chunks: Vec<&[&[f32]]> = audios.chunks(chunk_size).collect();
195
196    std::thread::scope(|s| {
197        let handles: Vec<_> = chunks
198            .into_iter()
199            .map(|chunk| {
200                s.spawn(move || {
201                    chunk
202                        .iter()
203                        .map(|audio| embedder.embed(audio))
204                        .collect::<Vec<_>>()
205                })
206            })
207            .collect();
208
209        let mut all_results = Vec::with_capacity(n);
210        for h in handles {
211            let chunk_results = h
212                .join()
213                .map_err(|_| EmbedderError::Legacy("embed_batch thread panicked".to_string()))?;
214            all_results.extend(chunk_results);
215        }
216        all_results.into_iter().collect::<Result<Vec<_>, _>>()
217    })
218}
219
220#[cfg(all(feature = "onnx", feature = "embedder"))]
221mod onnx_adapters {
222    use super::*;
223    use crate::ecapa::FbankOnnxExtractor;
224    use crate::embedding::EmbeddingExtractor;
225    use std::path::Path;
226
227    /// New-trait adapter for the existing `FbankOnnxExtractor` (WeSpeaker ResNet34, 256-d).
228    ///
229    /// The legacy `FbankOnnxExtractor` already implements the v0.5.x
230    /// `EmbeddingExtractor`; this adapter exposes the same model through the
231    /// v1.0 `Embedder` trait. M6 will fold this into a unified type.
232    pub struct ResNet34Adapter {
233        inner: FbankOnnxExtractor,
234        dim: usize,
235    }
236
237    impl ResNet34Adapter {
238        /// { true }
239        /// `pub fn new(path: impl AsRef<Path>, pool_size: usize) -> Result<Self, EmbedderError>`
240        /// { ret.as_ref().map_or(true, |e| e.dim() == 256) }
241        /// Load the WeSpeaker ResNet34 ONNX model.
242        pub fn new(path: impl AsRef<Path>, pool_size: usize) -> Result<Self, EmbedderError> {
243            let inner = FbankOnnxExtractor::new(path.as_ref(), 256, pool_size).map_err(|e| {
244                EmbedderError::ModelIo {
245                    path: path.as_ref().to_path_buf(),
246                    detail: format!("{e}"),
247                }
248            })?;
249            Ok(Self { inner, dim: 256 })
250        }
251    }
252
253    impl Embedder for ResNet34Adapter {
254        fn dim(&self) -> usize {
255            self.dim
256        }
257
258        fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
259            let config = crate::types::DiarizationConfig::default();
260            self.inner
261                .extract(audio, &config)
262                .map_err(|e| EmbedderError::Legacy(format!("{e}")))
263        }
264
265        fn embed_batch(&self, audios: &[&[f32]]) -> Result<Vec<Vec<f32>>, EmbedderError> {
266            parallel_embed_batch(self, audios)
267        }
268    }
269
270    /// CAM++ embedder (Channel-Attentive Multi-scale Pooling). Dim is supplied
271    /// explicitly because WeSpeaker ships several CAM++ variants:
272    /// `voxceleb_CAM++.onnx` is 512-d; smaller variants exist at 192-d.
273    /// Targets the Mobile profile of v1.0; M5 may swap to INT8 + smaller dim.
274    /// Uses the same 80-bin log-mel fbank pipeline as ResNet34.
275    pub struct CamPlusPlusExtractor {
276        inner: FbankOnnxExtractor,
277        dim: usize,
278    }
279
280    impl CamPlusPlusExtractor {
281        /// { true }
282        /// `pub fn new( path: impl AsRef<Path>, dim: usize, pool_size: usize, ) -> Result<Self, EmbedderError>`
283        /// { ret.as_ref().map_or(true, |e| e.dim() == dim) }
284        /// Load a CAM++ ONNX model. `dim` must match the model's output
285        /// dimension (e.g. 192 or 512 depending on the variant). Pool size
286        /// controls the number of concurrent ONNX sessions held internally
287        /// (canonical: `num_cpus().min(4)`).
288        pub fn new(
289            path: impl AsRef<Path>,
290            dim: usize,
291            pool_size: usize,
292        ) -> Result<Self, EmbedderError> {
293            let inner = FbankOnnxExtractor::new(path.as_ref(), dim, pool_size).map_err(|e| {
294                EmbedderError::ModelIo {
295                    path: path.as_ref().to_path_buf(),
296                    detail: format!("{e}"),
297                }
298            })?;
299            Ok(Self { inner, dim })
300        }
301    }
302
303    impl Embedder for CamPlusPlusExtractor {
304        fn dim(&self) -> usize {
305            self.dim
306        }
307
308        fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
309            let config = crate::types::DiarizationConfig::default();
310            self.inner
311                .extract(audio, &config)
312                .map_err(|e| EmbedderError::Legacy(format!("{e}")))
313        }
314
315        fn embed_batch(&self, audios: &[&[f32]]) -> Result<Vec<Vec<f32>>, EmbedderError> {
316            parallel_embed_batch(self, audios)
317        }
318    }
319}
320
321#[cfg(all(feature = "onnx", feature = "embedder"))]
322pub use onnx_adapters::{CamPlusPlusExtractor, ResNet34Adapter};
323
324#[allow(clippy::unwrap_used)]
325#[cfg(test)]
326mod overlap_mask_tests {
327    use super::*;
328
329    #[test]
330    fn no_overlap_regions_pass_through() {
331        let audio = vec![1.0_f32; 16_000];
332        let masked = apply_overlap_mask(&audio, &[], 16_000);
333        assert_eq!(masked, audio);
334    }
335
336    #[test]
337    fn single_overlap_region_is_zeroed() {
338        let audio = vec![1.0_f32; 16_000];
339        let masked = apply_overlap_mask(&audio, &[(0.5, 0.7)], 16_000);
340        for (i, &v) in masked.iter().enumerate() {
341            if (8000..11200).contains(&i) {
342                assert_eq!(v, 0.0, "sample {i} should be zeroed");
343            } else {
344                assert_eq!(v, 1.0, "sample {i} should pass through");
345            }
346        }
347    }
348
349    #[test]
350    fn empty_input_returns_empty() {
351        let masked = apply_overlap_mask(&[], &[(0.0, 1.0)], 16_000);
352        assert!(masked.is_empty());
353    }
354
355    #[test]
356    fn out_of_bounds_overlap_is_clamped() {
357        let audio = vec![1.0_f32; 100];
358        let masked = apply_overlap_mask(&audio, &[(0.5, 1.0)], 16_000);
359        assert_eq!(masked, audio, "out-of-bounds overlap is a no-op");
360    }
361
362    #[test]
363    fn negative_overlap_start_is_clamped_to_zero() {
364        let audio = vec![1.0_f32; 16_000];
365        let masked = apply_overlap_mask(&audio, &[(-1.0, 0.5)], 16_000);
366        for &v in masked.iter().take(8000) {
367            assert_eq!(v, 0.0);
368        }
369        for &v in masked.iter().skip(8000) {
370            assert_eq!(v, 1.0);
371        }
372    }
373
374    #[test]
375    fn multiple_overlap_regions_all_zeroed() {
376        let audio = vec![1.0_f32; 16_000];
377        let masked = apply_overlap_mask(&audio, &[(0.1, 0.2), (0.5, 0.6), (0.9, 1.0)], 16_000);
378        let zero_ranges = [(1600..3200), (8000..9600), (14_400..16_000)];
379        for (i, &v) in masked.iter().enumerate() {
380            let in_zero = zero_ranges.iter().any(|r| r.contains(&i));
381            if in_zero {
382                assert_eq!(v, 0.0, "sample {i} should be zeroed");
383            } else {
384                assert_eq!(v, 1.0, "sample {i} should pass through");
385            }
386        }
387    }
388
389    #[test]
390    fn invalid_overlap_with_end_before_start_is_no_op() {
391        let audio = vec![1.0_f32; 16_000];
392        let masked = apply_overlap_mask(&audio, &[(0.7, 0.5)], 16_000);
393        assert_eq!(masked, audio, "end<start is silently skipped");
394    }
395}
396
397#[allow(clippy::unwrap_used)]
398#[cfg(test)]
399mod trait_tests {
400    use super::*;
401
402    /// In-memory dummy used by trait tests.
403    struct ConstantEmbedder {
404        values: Vec<f32>,
405    }
406
407    impl Embedder for ConstantEmbedder {
408        fn dim(&self) -> usize {
409            self.values.len()
410        }
411        fn embed(&self, _audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
412            Ok(self.values.clone())
413        }
414    }
415
416    #[test]
417    fn embedder_trait_object_is_dyn_compatible() {
418        let e = ConstantEmbedder {
419            values: vec![0.1, 0.2, 0.3],
420        };
421        let _b: Box<dyn Embedder> = Box::new(e);
422    }
423
424    #[test]
425    fn embedder_default_batch_is_serial() {
426        let e = ConstantEmbedder {
427            values: vec![0.5; 4],
428        };
429        let inputs: Vec<&[f32]> = vec![&[][..], &[][..], &[][..]];
430        let out = e.embed_batch(&inputs).unwrap();
431        assert_eq!(out.len(), 3);
432        assert!(out.iter().all(|v| v.len() == 4 && v[0] == 0.5));
433    }
434
435    #[test]
436    fn embedder_dim_matches_output() {
437        let e = ConstantEmbedder {
438            values: vec![1.0; 192],
439        };
440        assert_eq!(e.dim(), 192);
441        assert_eq!(e.embed(&[]).unwrap().len(), 192);
442    }
443
444    #[test]
445    fn embedder_error_audio_too_short_displays() {
446        let err = EmbedderError::AudioTooShort {
447            actual_secs: 0.05,
448            min_secs: 0.25,
449        };
450        let msg = format!("{err}");
451        assert!(msg.contains("0.05"));
452        assert!(msg.contains("0.25"));
453    }
454}
455
456#[allow(clippy::unwrap_used)]
457#[cfg(test)]
458mod pool_tests {
459    use super::*;
460    use std::sync::Arc;
461    use std::sync::atomic::{AtomicUsize, Ordering};
462
463    /// Counts how many times `embed` was called.
464    struct CountingEmbedder {
465        counter: Arc<AtomicUsize>,
466        dim: usize,
467    }
468
469    impl Embedder for CountingEmbedder {
470        fn dim(&self) -> usize {
471            self.dim
472        }
473        fn embed(&self, _audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
474            self.counter.fetch_add(1, Ordering::SeqCst);
475            Ok(vec![0.0; self.dim])
476        }
477    }
478
479    fn make_pool(n: usize) -> (EmbedderPool<CountingEmbedder>, Arc<AtomicUsize>) {
480        let counter = Arc::new(AtomicUsize::new(0));
481        let mut embedders = Vec::with_capacity(n);
482        for _ in 0..n {
483            embedders.push(CountingEmbedder {
484                counter: counter.clone(),
485                dim: 192,
486            });
487        }
488        let pool = EmbedderPool::new(embedders).unwrap();
489        (pool, counter)
490    }
491
492    #[test]
493    fn pool_with_single_embedder_round_trip() {
494        let (pool, counter) = make_pool(1);
495        let result = pool.embed(&[0.0_f32; 100]).unwrap();
496        assert_eq!(result.len(), 192);
497        assert_eq!(counter.load(Ordering::SeqCst), 1);
498    }
499
500    #[test]
501    fn pool_dim_is_consistent() {
502        let (pool, _) = make_pool(4);
503        assert_eq!(pool.dim(), 192);
504    }
505
506    #[test]
507    fn pool_serial_embed_increments_counter_per_call() {
508        let (pool, counter) = make_pool(2);
509        for _ in 0..5 {
510            pool.embed(&[0.0_f32; 100]).unwrap();
511        }
512        assert_eq!(counter.load(Ordering::SeqCst), 5);
513    }
514
515    #[test]
516    fn pool_with_zero_embedders_errors() {
517        let pool: EmbedderPool<CountingEmbedder> = EmbedderPool::new(Vec::new()).unwrap();
518        let err = pool
519            .embed(&[0.0_f32; 100])
520            .expect_err("empty pool must fail");
521        assert!(matches!(err, EmbedderError::Legacy(_)));
522    }
523
524    #[test]
525    fn pool_rejects_mismatched_embedder_dims() {
526        let counter = Arc::new(AtomicUsize::new(0));
527        let embedders = vec![
528            CountingEmbedder {
529                counter: counter.clone(),
530                dim: 192,
531            },
532            CountingEmbedder {
533                counter: counter.clone(),
534                dim: 256,
535            },
536        ];
537        let err = match EmbedderPool::new(embedders) {
538            Err(e) => e,
539            Ok(_) => panic!("mismatched dims must fail"),
540        };
541        assert!(
542            matches!(
543                err,
544                EmbedderError::DimMismatch {
545                    expected: 192,
546                    actual: 256
547                }
548            ),
549            "expected DimMismatch(192, 256), got {err}"
550        );
551    }
552}