use crate::segmentation::aggregator::{AggregationConfig, Aggregator, WindowOutput};
use crate::segmentation::{MIN_AUDIO_SAMPLES, RawSegment, SegmentationError, Segmenter};
use ort::session::Session;
use std::path::{Path, PathBuf};
use std::sync::Mutex;
#[derive(Debug, Clone)]
pub struct PowersetConfig {
pub window_secs: f32,
pub hop_secs: f32,
pub sample_rate: u32,
pub aggregation: AggregationConfig,
}
impl Default for PowersetConfig {
fn default() -> Self {
Self {
window_secs: 10.0,
hop_secs: 0.5,
sample_rate: 16000,
aggregation: AggregationConfig::default(),
}
}
}
pub struct PowersetSegmenter {
session: Mutex<Session>,
input_name: String,
config: PowersetConfig,
model_path: PathBuf,
}
impl PowersetSegmenter {
pub fn new(model_path: impl AsRef<Path>) -> Result<Self, SegmentationError> {
Self::with_config(model_path, PowersetConfig::default())
}
pub fn with_config(
model_path: impl AsRef<Path>,
config: PowersetConfig,
) -> Result<Self, SegmentationError> {
let path = model_path.as_ref().to_path_buf();
crate::onnx::validate_onnx_header(&path).map_err(|e| SegmentationError::ModelIo {
path: path.clone(),
detail: e.to_string(),
})?;
let session = Session::builder()
.map_err(|e| SegmentationError::ModelIo {
path: path.clone(),
detail: format!("session builder failed: {e}"),
})?
.commit_from_file(&path)
.map_err(|e| SegmentationError::ModelIo {
path: path.clone(),
detail: format!("commit_from_file failed: {e}"),
})?;
let input_name = session
.inputs()
.first()
.map(|i| i.name().to_owned())
.unwrap_or_else(|| "waveform".to_owned());
Ok(Self {
session: Mutex::new(session),
input_name,
config,
model_path: path,
})
}
pub fn config(&self) -> &PowersetConfig {
&self.config
}
pub fn model_path(&self) -> &Path {
&self.model_path
}
fn window_samples(&self) -> usize {
(self.config.window_secs * self.config.sample_rate as f32) as usize
}
fn hop_samples(&self) -> usize {
(self.config.hop_secs * self.config.sample_rate as f32) as usize
}
fn infer_window(
&self,
window: &[f32],
window_idx: usize,
) -> Result<(Vec<f32>, usize), SegmentationError> {
let win_samples = self.window_samples();
let mut buf = vec![0.0_f32; win_samples];
let n = window.len().min(win_samples);
buf[..n].copy_from_slice(&window[..n]);
let input_tensor = ort::value::TensorRef::from_array_view((
[1_usize, 1_usize, win_samples],
buf.as_slice(),
))
.map_err(|e| SegmentationError::InferenceFailed {
window_idx,
detail: format!("input tensor: {e}"),
})?;
let mut guard = self.session.lock().unwrap_or_else(|e| e.into_inner());
let outputs = guard
.run(ort::inputs![self.input_name.as_str() => input_tensor])
.map_err(|e| SegmentationError::InferenceFailed {
window_idx,
detail: format!("session.run: {e}"),
})?;
let (shape, data) = outputs[0].try_extract_tensor::<f32>().map_err(|e| {
SegmentationError::InferenceFailed {
window_idx,
detail: format!("try_extract_tensor: {e}"),
}
})?;
let shape_vec: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
if shape_vec.len() != 3 || shape_vec[0] != 1 || shape_vec[2] != 7 {
return Err(SegmentationError::InvalidOutputShape {
actual_shape: shape_vec,
});
}
let num_frames = shape_vec[1];
Ok((data.to_vec(), num_frames))
}
}
impl Segmenter for PowersetSegmenter {
fn segment(&self, audio: &[f32]) -> Result<Vec<RawSegment>, SegmentationError> {
if audio.len() < MIN_AUDIO_SAMPLES {
return Err(SegmentationError::AudioTooShort {
actual_secs: audio.len() as f32 / self.config.sample_rate as f32,
min_secs: MIN_AUDIO_SAMPLES as f32 / self.config.sample_rate as f32,
});
}
let win_samples = self.window_samples();
let hop_samples = self.hop_samples();
let mut windows: Vec<WindowOutput> = Vec::new();
for (window_idx, (start_sample, _end_sample)) in
crate::window::WindowIter::new(audio.len(), win_samples, hop_samples)
.include_partial()
.enumerate()
{
let slice = &audio[start_sample..(start_sample + win_samples).min(audio.len())];
let (logits, num_frames) = self.infer_window(slice, window_idx)?;
let start_t = start_sample as f32 / self.config.sample_rate as f32;
let end_t = (start_sample + win_samples) as f32 / self.config.sample_rate as f32;
let w = WindowOutput::new(start_t, end_t, logits, num_frames)?;
windows.push(w);
}
let agg = Aggregator::new(self.config.aggregation.clone());
agg.stitch(&windows)
}
fn max_local_speakers(&self) -> usize {
3
}
fn supports_overlap(&self) -> bool {
true
}
}