1#![allow(deprecated)] pub trait Embedder: Send + Sync {
16 fn dim(&self) -> usize;
18
19 fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError>;
25
26 fn embed_batch(&self, audios: &[&[f32]]) -> Result<Vec<Vec<f32>>, EmbedderError> {
29 audios.iter().map(|a| self.embed(a)).collect()
30 }
31}
32
33#[derive(Debug, thiserror::Error)]
35pub enum EmbedderError {
36 #[error("audio too short for this embedder: {actual_secs:.3}s < {min_secs:.3}s")]
37 AudioTooShort { actual_secs: f32, min_secs: f32 },
38
39 #[error("ONNX inference failed: {detail}")]
40 InferenceFailed { detail: String },
41
42 #[error("expected embedding dim {expected}, got {actual}")]
43 DimMismatch { expected: usize, actual: usize },
44
45 #[error("model file io error on {path}: {detail}")]
46 ModelIo {
47 path: std::path::PathBuf,
48 detail: String,
49 },
50
51 #[error("legacy adapter error: {0}")]
52 Legacy(String),
53}
54
55pub fn apply_overlap_mask(
67 audio: &[f32],
68 overlap_regions: &[(f32, f32)],
69 sample_rate: u32,
70) -> Vec<f32> {
71 let mut out = audio.to_vec();
72 if out.is_empty() {
73 return out;
74 }
75 let sr = sample_rate as f32;
76 for &(start_s, end_s) in overlap_regions {
77 if !end_s.is_finite() || !start_s.is_finite() || end_s <= start_s {
78 continue;
79 }
80 let start = (start_s * sr).max(0.0).floor() as usize;
81 let end = (end_s * sr).max(0.0).ceil() as usize;
82 let end = end.min(out.len());
83 if start >= end || start >= out.len() {
84 continue;
85 }
86 for v in &mut out[start..end] {
87 *v = 0.0;
88 }
89 }
90 out
91}
92
93use crossbeam_queue::ArrayQueue;
94use std::sync::Arc;
95
96pub struct EmbedderPool<E: Embedder> {
102 queue: Arc<ArrayQueue<E>>,
103 dim: usize,
104 capacity: usize,
105}
106
107impl<E: Embedder> EmbedderPool<E> {
108 pub fn new(embedders: Vec<E>) -> Result<Self, EmbedderError> {
115 let dim = embedders.first().map(|e| e.dim()).unwrap_or(0);
116 for e in embedders.iter().skip(1) {
117 let actual = e.dim();
118 if actual != dim {
119 return Err(EmbedderError::DimMismatch {
120 expected: dim,
121 actual,
122 });
123 }
124 }
125 let capacity = embedders.len().max(1);
126 let queue = Arc::new(ArrayQueue::new(capacity));
127 for e in embedders {
128 let _ = queue.push(e);
130 }
131 Ok(Self {
132 queue,
133 dim,
134 capacity,
135 })
136 }
137
138 pub fn dim(&self) -> usize {
142 self.dim
143 }
144 pub fn capacity(&self) -> usize {
148 self.capacity
149 }
150
151 pub fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
157 if self.dim == 0 {
158 return Err(EmbedderError::Legacy("empty pool".to_owned()));
160 }
161 let embedder = loop {
164 if let Some(e) = self.queue.pop() {
165 break e;
166 }
167 std::hint::spin_loop();
168 };
169 let result = embedder.embed(audio);
170 let _ = self.queue.push(embedder);
172 result
173 }
174}
175
176#[cfg(feature = "onnx")]
180fn parallel_embed_batch<E: Embedder>(
181 embedder: &E,
182 audios: &[&[f32]],
183) -> Result<Vec<Vec<f32>>, EmbedderError> {
184 let n = audios.len();
185 if n == 0 {
186 return Ok(Vec::new());
187 }
188 let num_threads = std::thread::available_parallelism()
189 .map(|n| n.get())
190 .unwrap_or(4)
191 .min(n);
192
193 let chunk_size = n.div_ceil(num_threads);
194 let chunks: Vec<&[&[f32]]> = audios.chunks(chunk_size).collect();
195
196 std::thread::scope(|s| {
197 let handles: Vec<_> = chunks
198 .into_iter()
199 .map(|chunk| {
200 s.spawn(move || {
201 chunk
202 .iter()
203 .map(|audio| embedder.embed(audio))
204 .collect::<Vec<_>>()
205 })
206 })
207 .collect();
208
209 let mut all_results = Vec::with_capacity(n);
210 for h in handles {
211 let chunk_results = h
212 .join()
213 .map_err(|_| EmbedderError::Legacy("embed_batch thread panicked".to_string()))?;
214 all_results.extend(chunk_results);
215 }
216 all_results.into_iter().collect::<Result<Vec<_>, _>>()
217 })
218}
219
220#[cfg(all(feature = "onnx", feature = "embedder"))]
221mod onnx_adapters {
222 use super::*;
223 use crate::ecapa::FbankOnnxExtractor;
224 use crate::embedding::EmbeddingExtractor;
225 use std::path::Path;
226
227 pub struct ResNet34Adapter {
233 inner: FbankOnnxExtractor,
234 dim: usize,
235 }
236
237 impl ResNet34Adapter {
238 pub fn new(path: impl AsRef<Path>, pool_size: usize) -> Result<Self, EmbedderError> {
243 let inner = FbankOnnxExtractor::new(path.as_ref(), 256, pool_size).map_err(|e| {
244 EmbedderError::ModelIo {
245 path: path.as_ref().to_path_buf(),
246 detail: format!("{e}"),
247 }
248 })?;
249 Ok(Self { inner, dim: 256 })
250 }
251 }
252
253 impl Embedder for ResNet34Adapter {
254 fn dim(&self) -> usize {
255 self.dim
256 }
257
258 fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
259 let config = crate::types::DiarizationConfig::default();
260 self.inner
261 .extract(audio, &config)
262 .map_err(|e| EmbedderError::Legacy(format!("{e}")))
263 }
264
265 fn embed_batch(&self, audios: &[&[f32]]) -> Result<Vec<Vec<f32>>, EmbedderError> {
266 parallel_embed_batch(self, audios)
267 }
268 }
269
270 pub struct CamPlusPlusExtractor {
276 inner: FbankOnnxExtractor,
277 dim: usize,
278 }
279
280 impl CamPlusPlusExtractor {
281 pub fn new(
289 path: impl AsRef<Path>,
290 dim: usize,
291 pool_size: usize,
292 ) -> Result<Self, EmbedderError> {
293 let inner = FbankOnnxExtractor::new(path.as_ref(), dim, pool_size).map_err(|e| {
294 EmbedderError::ModelIo {
295 path: path.as_ref().to_path_buf(),
296 detail: format!("{e}"),
297 }
298 })?;
299 Ok(Self { inner, dim })
300 }
301 }
302
303 impl Embedder for CamPlusPlusExtractor {
304 fn dim(&self) -> usize {
305 self.dim
306 }
307
308 fn embed(&self, audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
309 let config = crate::types::DiarizationConfig::default();
310 self.inner
311 .extract(audio, &config)
312 .map_err(|e| EmbedderError::Legacy(format!("{e}")))
313 }
314
315 fn embed_batch(&self, audios: &[&[f32]]) -> Result<Vec<Vec<f32>>, EmbedderError> {
316 parallel_embed_batch(self, audios)
317 }
318 }
319}
320
321#[cfg(all(feature = "onnx", feature = "embedder"))]
322pub use onnx_adapters::{CamPlusPlusExtractor, ResNet34Adapter};
323
324#[allow(clippy::unwrap_used)]
325#[cfg(test)]
326mod overlap_mask_tests {
327 use super::*;
328
329 #[test]
330 fn no_overlap_regions_pass_through() {
331 let audio = vec![1.0_f32; 16_000];
332 let masked = apply_overlap_mask(&audio, &[], 16_000);
333 assert_eq!(masked, audio);
334 }
335
336 #[test]
337 fn single_overlap_region_is_zeroed() {
338 let audio = vec![1.0_f32; 16_000];
339 let masked = apply_overlap_mask(&audio, &[(0.5, 0.7)], 16_000);
340 for (i, &v) in masked.iter().enumerate() {
341 if (8000..11200).contains(&i) {
342 assert_eq!(v, 0.0, "sample {i} should be zeroed");
343 } else {
344 assert_eq!(v, 1.0, "sample {i} should pass through");
345 }
346 }
347 }
348
349 #[test]
350 fn empty_input_returns_empty() {
351 let masked = apply_overlap_mask(&[], &[(0.0, 1.0)], 16_000);
352 assert!(masked.is_empty());
353 }
354
355 #[test]
356 fn out_of_bounds_overlap_is_clamped() {
357 let audio = vec![1.0_f32; 100];
358 let masked = apply_overlap_mask(&audio, &[(0.5, 1.0)], 16_000);
359 assert_eq!(masked, audio, "out-of-bounds overlap is a no-op");
360 }
361
362 #[test]
363 fn negative_overlap_start_is_clamped_to_zero() {
364 let audio = vec![1.0_f32; 16_000];
365 let masked = apply_overlap_mask(&audio, &[(-1.0, 0.5)], 16_000);
366 for &v in masked.iter().take(8000) {
367 assert_eq!(v, 0.0);
368 }
369 for &v in masked.iter().skip(8000) {
370 assert_eq!(v, 1.0);
371 }
372 }
373
374 #[test]
375 fn multiple_overlap_regions_all_zeroed() {
376 let audio = vec![1.0_f32; 16_000];
377 let masked = apply_overlap_mask(&audio, &[(0.1, 0.2), (0.5, 0.6), (0.9, 1.0)], 16_000);
378 let zero_ranges = [(1600..3200), (8000..9600), (14_400..16_000)];
379 for (i, &v) in masked.iter().enumerate() {
380 let in_zero = zero_ranges.iter().any(|r| r.contains(&i));
381 if in_zero {
382 assert_eq!(v, 0.0, "sample {i} should be zeroed");
383 } else {
384 assert_eq!(v, 1.0, "sample {i} should pass through");
385 }
386 }
387 }
388
389 #[test]
390 fn invalid_overlap_with_end_before_start_is_no_op() {
391 let audio = vec![1.0_f32; 16_000];
392 let masked = apply_overlap_mask(&audio, &[(0.7, 0.5)], 16_000);
393 assert_eq!(masked, audio, "end<start is silently skipped");
394 }
395}
396
397#[allow(clippy::unwrap_used)]
398#[cfg(test)]
399mod trait_tests {
400 use super::*;
401
402 struct ConstantEmbedder {
404 values: Vec<f32>,
405 }
406
407 impl Embedder for ConstantEmbedder {
408 fn dim(&self) -> usize {
409 self.values.len()
410 }
411 fn embed(&self, _audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
412 Ok(self.values.clone())
413 }
414 }
415
416 #[test]
417 fn embedder_trait_object_is_dyn_compatible() {
418 let e = ConstantEmbedder {
419 values: vec![0.1, 0.2, 0.3],
420 };
421 let _b: Box<dyn Embedder> = Box::new(e);
422 }
423
424 #[test]
425 fn embedder_default_batch_is_serial() {
426 let e = ConstantEmbedder {
427 values: vec![0.5; 4],
428 };
429 let inputs: Vec<&[f32]> = vec![&[][..], &[][..], &[][..]];
430 let out = e.embed_batch(&inputs).unwrap();
431 assert_eq!(out.len(), 3);
432 assert!(out.iter().all(|v| v.len() == 4 && v[0] == 0.5));
433 }
434
435 #[test]
436 fn embedder_dim_matches_output() {
437 let e = ConstantEmbedder {
438 values: vec![1.0; 192],
439 };
440 assert_eq!(e.dim(), 192);
441 assert_eq!(e.embed(&[]).unwrap().len(), 192);
442 }
443
444 #[test]
445 fn embedder_error_audio_too_short_displays() {
446 let err = EmbedderError::AudioTooShort {
447 actual_secs: 0.05,
448 min_secs: 0.25,
449 };
450 let msg = format!("{err}");
451 assert!(msg.contains("0.05"));
452 assert!(msg.contains("0.25"));
453 }
454}
455
456#[allow(clippy::unwrap_used)]
457#[cfg(test)]
458mod pool_tests {
459 use super::*;
460 use std::sync::Arc;
461 use std::sync::atomic::{AtomicUsize, Ordering};
462
463 struct CountingEmbedder {
465 counter: Arc<AtomicUsize>,
466 dim: usize,
467 }
468
469 impl Embedder for CountingEmbedder {
470 fn dim(&self) -> usize {
471 self.dim
472 }
473 fn embed(&self, _audio: &[f32]) -> Result<Vec<f32>, EmbedderError> {
474 self.counter.fetch_add(1, Ordering::SeqCst);
475 Ok(vec![0.0; self.dim])
476 }
477 }
478
479 fn make_pool(n: usize) -> (EmbedderPool<CountingEmbedder>, Arc<AtomicUsize>) {
480 let counter = Arc::new(AtomicUsize::new(0));
481 let mut embedders = Vec::with_capacity(n);
482 for _ in 0..n {
483 embedders.push(CountingEmbedder {
484 counter: counter.clone(),
485 dim: 192,
486 });
487 }
488 let pool = EmbedderPool::new(embedders).unwrap();
489 (pool, counter)
490 }
491
492 #[test]
493 fn pool_with_single_embedder_round_trip() {
494 let (pool, counter) = make_pool(1);
495 let result = pool.embed(&[0.0_f32; 100]).unwrap();
496 assert_eq!(result.len(), 192);
497 assert_eq!(counter.load(Ordering::SeqCst), 1);
498 }
499
500 #[test]
501 fn pool_dim_is_consistent() {
502 let (pool, _) = make_pool(4);
503 assert_eq!(pool.dim(), 192);
504 }
505
506 #[test]
507 fn pool_serial_embed_increments_counter_per_call() {
508 let (pool, counter) = make_pool(2);
509 for _ in 0..5 {
510 pool.embed(&[0.0_f32; 100]).unwrap();
511 }
512 assert_eq!(counter.load(Ordering::SeqCst), 5);
513 }
514
515 #[test]
516 fn pool_with_zero_embedders_errors() {
517 let pool: EmbedderPool<CountingEmbedder> = EmbedderPool::new(Vec::new()).unwrap();
518 let err = pool
519 .embed(&[0.0_f32; 100])
520 .expect_err("empty pool must fail");
521 assert!(matches!(err, EmbedderError::Legacy(_)));
522 }
523
524 #[test]
525 fn pool_rejects_mismatched_embedder_dims() {
526 let counter = Arc::new(AtomicUsize::new(0));
527 let embedders = vec![
528 CountingEmbedder {
529 counter: counter.clone(),
530 dim: 192,
531 },
532 CountingEmbedder {
533 counter: counter.clone(),
534 dim: 256,
535 },
536 ];
537 let err = match EmbedderPool::new(embedders) {
538 Err(e) => e,
539 Ok(_) => panic!("mismatched dims must fail"),
540 };
541 assert!(
542 matches!(
543 err,
544 EmbedderError::DimMismatch {
545 expected: 192,
546 actual: 256
547 }
548 ),
549 "expected DimMismatch(192, 256), got {err}"
550 );
551 }
552}