polyvoice/segmentation/
powerset.rs1use crate::segmentation::aggregator::{AggregationConfig, Aggregator, WindowOutput};
8use crate::segmentation::{MIN_AUDIO_SAMPLES, RawSegment, SegmentationError, Segmenter};
9use ort::session::Session;
10use std::path::{Path, PathBuf};
11use std::sync::Mutex;
12
13#[derive(Debug, Clone)]
15pub struct PowersetConfig {
16 pub window_secs: f32,
18 pub hop_secs: f32,
20 pub sample_rate: u32,
22 pub aggregation: AggregationConfig,
24}
25
26impl Default for PowersetConfig {
27 fn default() -> Self {
28 Self {
29 window_secs: 10.0,
30 hop_secs: 1.0,
31 sample_rate: 16000,
32 aggregation: AggregationConfig::default(),
33 }
34 }
35}
36
37pub struct PowersetSegmenter {
39 session: Mutex<Session>,
40 input_name: String,
41 config: PowersetConfig,
42 model_path: PathBuf,
43}
44
45impl PowersetSegmenter {
46 pub fn new(model_path: impl AsRef<Path>) -> Result<Self, SegmentationError> {
51 Self::with_config(model_path, PowersetConfig::default())
52 }
53
54 pub fn with_config(
59 model_path: impl AsRef<Path>,
60 config: PowersetConfig,
61 ) -> Result<Self, SegmentationError> {
62 let path = model_path.as_ref().to_path_buf();
63 crate::onnx::validate_onnx_header(&path).map_err(|e| SegmentationError::ModelIo {
64 path: path.clone(),
65 detail: e.to_string(),
66 })?;
67 let mut builder = Session::builder().map_err(|e| SegmentationError::ModelIo {
68 path: path.clone(),
69 detail: format!("session builder failed: {e}"),
70 })?;
71 #[cfg(all(feature = "coreml", target_os = "macos", target_arch = "aarch64"))]
72 {
73 let coreml = ort::execution_providers::CoreMLExecutionProvider::default();
74 builder = builder
75 .with_execution_providers([coreml.build()])
76 .map_err(|e| SegmentationError::ModelIo {
77 path: path.clone(),
78 detail: format!("coreml ep: {e}"),
79 })?;
80 }
81 let session = builder
82 .commit_from_file(&path)
83 .map_err(|e| SegmentationError::ModelIo {
84 path: path.clone(),
85 detail: format!("commit_from_file failed: {e}"),
86 })?;
87 let input_name = session
88 .inputs()
89 .first()
90 .map(|i| i.name().to_owned())
91 .unwrap_or_else(|| "waveform".to_owned());
92 Ok(Self {
93 session: Mutex::new(session),
94 input_name,
95 config,
96 model_path: path,
97 })
98 }
99
100 pub fn config(&self) -> &PowersetConfig {
104 &self.config
105 }
106
107 pub fn model_path(&self) -> &Path {
111 &self.model_path
112 }
113
114 fn window_samples(&self) -> usize {
115 (self.config.window_secs * self.config.sample_rate as f32) as usize
116 }
117
118 fn hop_samples(&self) -> usize {
119 (self.config.hop_secs * self.config.sample_rate as f32) as usize
120 }
121
122 fn infer_window(
125 &self,
126 window: &[f32],
127 window_idx: usize,
128 ) -> Result<(Vec<f32>, usize), SegmentationError> {
129 let win_samples = self.window_samples();
130 let mut buf = vec![0.0_f32; win_samples];
132 let n = window.len().min(win_samples);
133 buf[..n].copy_from_slice(&window[..n]);
134
135 let input_tensor = ort::value::TensorRef::from_array_view((
139 [1_usize, 1_usize, win_samples],
140 buf.as_slice(),
141 ))
142 .map_err(|e| SegmentationError::InferenceFailed {
143 window_idx,
144 detail: format!("input tensor: {e}"),
145 })?;
146
147 let mut guard = self.session.lock().unwrap_or_else(|e| e.into_inner());
148 let outputs = guard
149 .run(ort::inputs![self.input_name.as_str() => input_tensor])
150 .map_err(|e| SegmentationError::InferenceFailed {
151 window_idx,
152 detail: format!("session.run: {e}"),
153 })?;
154
155 let (shape, data) = outputs[0].try_extract_tensor::<f32>().map_err(|e| {
158 SegmentationError::InferenceFailed {
159 window_idx,
160 detail: format!("try_extract_tensor: {e}"),
161 }
162 })?;
163
164 let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
166 if shape_vec.len() != 3 || shape_vec[0] != 1 || shape_vec[2] != 7 {
167 return Err(SegmentationError::InvalidOutputShape {
168 actual_shape: shape_vec,
169 });
170 }
171 let num_frames = shape_vec[1];
172 Ok((data.to_vec(), num_frames))
173 }
174}
175
176impl Segmenter for PowersetSegmenter {
177 fn segment(&self, audio: &[f32]) -> Result<Vec<RawSegment>, SegmentationError> {
178 if audio.len() < MIN_AUDIO_SAMPLES {
179 return Err(SegmentationError::AudioTooShort {
180 actual_secs: audio.len() as f32 / self.config.sample_rate as f32,
181 min_secs: MIN_AUDIO_SAMPLES as f32 / self.config.sample_rate as f32,
182 });
183 }
184
185 let win_samples = self.window_samples();
186 let hop_samples = self.hop_samples();
187 let mut windows: Vec<WindowOutput> = Vec::new();
188 for (window_idx, (start_sample, _end_sample)) in
189 crate::window::WindowIter::new(audio.len(), win_samples, hop_samples)
190 .include_partial()
191 .enumerate()
192 {
193 let slice = &audio[start_sample..(start_sample + win_samples).min(audio.len())];
194 let (logits, num_frames) = self.infer_window(slice, window_idx)?;
195 let start_t = start_sample as f32 / self.config.sample_rate as f32;
196 let end_t = (start_sample + win_samples) as f32 / self.config.sample_rate as f32;
197 let w = WindowOutput::new(start_t, end_t, logits, num_frames)?;
198 windows.push(w);
199 }
200
201 let agg = Aggregator::new(self.config.aggregation.clone());
202 agg.stitch(&windows)
203 }
204
205 fn max_local_speakers(&self) -> usize {
206 3
207 }
208
209 fn supports_overlap(&self) -> bool {
210 true
211 }
212}