Skip to main content

polyvoice/segmentation/
powerset.rs

1//! `PowersetSegmenter` — ONNX-backed `Segmenter` wrapping
2//! `sherpa-onnx-pyannote-segmentation-3-0`.
3//!
4//! Slides a 10-second window across the audio with a 1.0s hop (90% overlap),
5//! runs ONNX inference per window, and feeds outputs into `Aggregator`.
6
7use crate::segmentation::aggregator::{AggregationConfig, Aggregator, WindowOutput};
8use crate::segmentation::{MIN_AUDIO_SAMPLES, RawSegment, SegmentationError, Segmenter};
9use ort::session::Session;
10use std::path::{Path, PathBuf};
11use std::sync::Mutex;
12
13/// Tunable parameters for `PowersetSegmenter`.
14#[derive(Debug, Clone)]
15pub struct PowersetConfig {
16    /// Window duration in seconds.
17    pub window_secs: f32,
18    /// Hop size between windows in seconds.
19    pub hop_secs: f32,
20    /// Sample rate the model expects (16000 for sherpa-onnx-pyannote-segmentation-3-0).
21    pub sample_rate: u32,
22    /// Forwarded to the inner `Aggregator`.
23    pub aggregation: AggregationConfig,
24}
25
26impl Default for PowersetConfig {
27    fn default() -> Self {
28        Self {
29            window_secs: 10.0,
30            hop_secs: 1.0,
31            sample_rate: 16000,
32            aggregation: AggregationConfig::default(),
33        }
34    }
35}
36
37/// ONNX-backed powerset speaker segmenter.
38pub struct PowersetSegmenter {
39    session: Mutex<Session>,
40    input_name: String,
41    config: PowersetConfig,
42    model_path: PathBuf,
43}
44
45impl PowersetSegmenter {
46    /// { true }
47    /// `pub fn new(model_path: impl AsRef<Path>) -> Result<Self, SegmentationError>`
48    /// { true }
49    /// Load the ONNX model from `model_path`.
50    pub fn new(model_path: impl AsRef<Path>) -> Result<Self, SegmentationError> {
51        Self::with_config(model_path, PowersetConfig::default())
52    }
53
54    /// { true }
55    /// `pub fn with_config( model_path: impl AsRef<Path>, config: PowersetConfig, ) -> Result<Self, SegmentationError>`
56    /// { true }
57    /// Load with explicit configuration.
58    pub fn with_config(
59        model_path: impl AsRef<Path>,
60        config: PowersetConfig,
61    ) -> Result<Self, SegmentationError> {
62        let path = model_path.as_ref().to_path_buf();
63        crate::onnx::validate_onnx_header(&path).map_err(|e| SegmentationError::ModelIo {
64            path: path.clone(),
65            detail: e.to_string(),
66        })?;
67        let mut builder = Session::builder().map_err(|e| SegmentationError::ModelIo {
68            path: path.clone(),
69            detail: format!("session builder failed: {e}"),
70        })?;
71        #[cfg(all(feature = "coreml", target_os = "macos", target_arch = "aarch64"))]
72        {
73            let coreml = ort::execution_providers::CoreMLExecutionProvider::default();
74            builder = builder
75                .with_execution_providers([coreml.build()])
76                .map_err(|e| SegmentationError::ModelIo {
77                    path: path.clone(),
78                    detail: format!("coreml ep: {e}"),
79                })?;
80        }
81        let session = builder
82            .commit_from_file(&path)
83            .map_err(|e| SegmentationError::ModelIo {
84                path: path.clone(),
85                detail: format!("commit_from_file failed: {e}"),
86            })?;
87        let input_name = session
88            .inputs()
89            .first()
90            .map(|i| i.name().to_owned())
91            .unwrap_or_else(|| "waveform".to_owned());
92        Ok(Self {
93            session: Mutex::new(session),
94            input_name,
95            config,
96            model_path: path,
97        })
98    }
99
100    /// { true }
101    /// pub fn config(&self) -> &PowersetConfig
102    /// { ret == &self.config }
103    pub fn config(&self) -> &PowersetConfig {
104        &self.config
105    }
106
107    /// { true }
108    /// pub fn model_path(&self) -> &Path
109    /// { ret == self.model_path.as_path() }
110    pub fn model_path(&self) -> &Path {
111        &self.model_path
112    }
113
114    fn window_samples(&self) -> usize {
115        (self.config.window_secs * self.config.sample_rate as f32) as usize
116    }
117
118    fn hop_samples(&self) -> usize {
119        (self.config.hop_secs * self.config.sample_rate as f32) as usize
120    }
121
122    /// Run inference on a single 10-second window.
123    /// Returns (logits_flat_row_major, num_frames).
124    fn infer_window(
125        &self,
126        window: &[f32],
127        window_idx: usize,
128    ) -> Result<(Vec<f32>, usize), SegmentationError> {
129        let win_samples = self.window_samples();
130        // Zero-pad short audio to the full window length.
131        let mut buf = vec![0.0_f32; win_samples];
132        let n = window.len().min(win_samples);
133        buf[..n].copy_from_slice(&window[..n]);
134
135        // Build input tensor with shape [1, 1, win_samples] matching the model's
136        // "waveform" input. Uses the same TensorRef::from_array_view pattern as
137        // silero_vad.rs (flat slice with explicit shape tuple).
138        let input_tensor = ort::value::TensorRef::from_array_view((
139            [1_usize, 1_usize, win_samples],
140            buf.as_slice(),
141        ))
142        .map_err(|e| SegmentationError::InferenceFailed {
143            window_idx,
144            detail: format!("input tensor: {e}"),
145        })?;
146
147        let mut guard = self.session.lock().unwrap_or_else(|e| e.into_inner());
148        let outputs = guard
149            .run(ort::inputs![self.input_name.as_str() => input_tensor])
150            .map_err(|e| SegmentationError::InferenceFailed {
151                window_idx,
152                detail: format!("session.run: {e}"),
153            })?;
154
155        // Extract first output by index (robust to any output name).
156        // try_extract_tensor returns (shape_slice, data_slice) matching ecapa.rs pattern.
157        let (shape, data) = outputs[0].try_extract_tensor::<f32>().map_err(|e| {
158            SegmentationError::InferenceFailed {
159                window_idx,
160                detail: format!("try_extract_tensor: {e}"),
161            }
162        })?;
163
164        // Expected shape: [1, num_frames, 7].
165        let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
166        if shape_vec.len() != 3 || shape_vec[0] != 1 || shape_vec[2] != 7 {
167            return Err(SegmentationError::InvalidOutputShape {
168                actual_shape: shape_vec,
169            });
170        }
171        let num_frames = shape_vec[1];
172        Ok((data.to_vec(), num_frames))
173    }
174}
175
176impl Segmenter for PowersetSegmenter {
177    fn segment(&self, audio: &[f32]) -> Result<Vec<RawSegment>, SegmentationError> {
178        if audio.len() < MIN_AUDIO_SAMPLES {
179            return Err(SegmentationError::AudioTooShort {
180                actual_secs: audio.len() as f32 / self.config.sample_rate as f32,
181                min_secs: MIN_AUDIO_SAMPLES as f32 / self.config.sample_rate as f32,
182            });
183        }
184
185        let win_samples = self.window_samples();
186        let hop_samples = self.hop_samples();
187        let mut windows: Vec<WindowOutput> = Vec::new();
188        for (window_idx, (start_sample, _end_sample)) in
189            crate::window::WindowIter::new(audio.len(), win_samples, hop_samples)
190                .include_partial()
191                .enumerate()
192        {
193            let slice = &audio[start_sample..(start_sample + win_samples).min(audio.len())];
194            let (logits, num_frames) = self.infer_window(slice, window_idx)?;
195            let start_t = start_sample as f32 / self.config.sample_rate as f32;
196            let end_t = (start_sample + win_samples) as f32 / self.config.sample_rate as f32;
197            let w = WindowOutput::new(start_t, end_t, logits, num_frames)?;
198            windows.push(w);
199        }
200
201        let agg = Aggregator::new(self.config.aggregation.clone());
202        agg.stitch(&windows)
203    }
204
205    fn max_local_speakers(&self) -> usize {
206        3
207    }
208
209    fn supports_overlap(&self) -> bool {
210        true
211    }
212}