infinite/
infinite.rs

1// Models are loaded from src/nn/*/model.bpk.
2// wget https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/6_speakers.wav
3// cargo run --example infinite 6_speakers.wav
4
5use anyhow::Result;
6use pyannote_rs::{EmbeddingExtractor, EmbeddingManager, Segmenter};
7
8fn process_segment(
9    segment: pyannote_rs::Segment,
10    embedding_extractor: &EmbeddingExtractor,
11    embedding_manager: &mut EmbeddingManager,
12    search_threshold: f32,
13    sample_rate: u32,
14) -> Result<()> {
15    let embedding = embedding_extractor.extract(&segment.samples, sample_rate)?;
16
17    let speaker = embedding_manager
18        .upsert(&embedding, search_threshold)
19        .or_else(|| embedding_manager.best_match(&embedding))
20        .map(|r| r.to_string())
21        .unwrap_or("?".into());
22
23    println!(
24        "start = {:.2}, end = {:.2}, speaker = {}",
25        segment.start, segment.end, speaker
26    );
27
28    Ok(())
29}
30
31fn main() -> Result<()> {
32    let audio_path = std::env::args().nth(1).expect("Please specify audio file");
33    let search_threshold = 0.5;
34
35    let embedding_model_path = "src/nn/speaker_identification/model.bpk";
36    let segmentation_model_path = "src/nn/segmentation/model.bpk";
37
38    let (samples, sample_rate) = pyannote_rs::read_wav(&audio_path)?;
39    let embedding_extractor = EmbeddingExtractor::new(embedding_model_path)?;
40    let mut embedding_manager = EmbeddingManager::new(usize::MAX);
41    let segmenter = Segmenter::new(segmentation_model_path)?;
42
43    for segment in segmenter.iter_segments(&samples, sample_rate)? {
44        match segment {
45            Ok(segment) => {
46                if let Err(error) = process_segment(
47                    segment,
48                    &embedding_extractor,
49                    &mut embedding_manager,
50                    search_threshold,
51                    sample_rate,
52                ) {
53                    eprintln!("Error processing segment: {:?}", error);
54                }
55            }
56            Err(error) => eprintln!("Failed to process segment: {:?}", error),
57        }
58    }
59
60    Ok(())
61}