mod aggregator;
mod decoder;
mod hungarian;
#[cfg(all(feature = "onnx", feature = "segmentation"))]
mod powerset;
pub use aggregator::{AggregationConfig, Aggregator, WindowOutput};
pub use decoder::{FrameLabel, PowersetClass, PowersetDecoder};
#[cfg(all(feature = "onnx", feature = "segmentation"))]
pub use powerset::{PowersetConfig, PowersetSegmenter};
use crate::types::{Confidence, TimeRange};
#[derive(Debug, Clone, PartialEq)]
pub struct RawSegment {
pub time: TimeRange,
pub local_speaker_idx: u8,
pub is_overlap: bool,
pub confidence: Confidence,
}
pub trait Segmenter: Send + Sync {
fn segment(&self, audio: &[f32]) -> Result<Vec<RawSegment>, SegmentationError>;
fn max_local_speakers(&self) -> usize;
fn supports_overlap(&self) -> bool;
}
pub const MIN_AUDIO_SAMPLES: usize = 1600;
#[derive(Debug, thiserror::Error)]
pub enum SegmentationError {
#[error("audio too short: {actual_secs:.3}s < {min_secs:.3}s required")]
AudioTooShort { actual_secs: f32, min_secs: f32 },
#[error("ONNX inference failed at window {window_idx}: {detail}")]
InferenceFailed { window_idx: usize, detail: String },
#[error(
"powerset decoder produced invalid output shape: expected (_, 7), got {actual_shape:?}"
)]
InvalidOutputShape { actual_shape: Vec<usize> },
#[error("speaker permutation matching failed across windows {prev_idx}->{next_idx}: {detail}")]
PermutationFailed {
prev_idx: usize,
next_idx: usize,
detail: String,
},
#[error("model file io error on {path}: {detail}")]
ModelIo {
path: std::path::PathBuf,
detail: String,
},
}
#[cfg(test)]
mod trait_tests {
use super::*;
use crate::types::{Confidence, TimeRange};
struct ConstantSegmenter {
segments: Vec<RawSegment>,
max_speakers: usize,
overlap: bool,
}
impl Segmenter for ConstantSegmenter {
fn segment(&self, _audio: &[f32]) -> Result<Vec<RawSegment>, SegmentationError> {
Ok(self.segments.clone())
}
fn max_local_speakers(&self) -> usize {
self.max_speakers
}
fn supports_overlap(&self) -> bool {
self.overlap
}
}
#[test]
fn raw_segment_roundtrip() {
let s = RawSegment {
time: TimeRange {
start: 0.5,
end: 1.5,
},
local_speaker_idx: 1,
is_overlap: true,
confidence: Confidence::new(0.85).unwrap(),
};
assert_eq!(s.local_speaker_idx, 1);
assert!(s.is_overlap);
assert!((s.confidence.get() - 0.85).abs() < 1e-6);
}
#[test]
fn segmenter_trait_object_is_dyn_compatible() {
let cs = ConstantSegmenter {
segments: vec![],
max_speakers: 3,
overlap: true,
};
let _boxed: Box<dyn Segmenter> = Box::new(cs);
}
#[test]
fn segmenter_segment_returns_owned_vec() {
let cs = ConstantSegmenter {
segments: vec![RawSegment {
time: TimeRange {
start: 0.0,
end: 1.0,
},
local_speaker_idx: 0,
is_overlap: false,
confidence: Confidence::new(1.0).unwrap(),
}],
max_speakers: 3,
overlap: true,
};
let out = cs.segment(&[]).unwrap();
assert_eq!(out.len(), 1);
}
#[test]
fn error_audio_too_short_displays_required_thresholds() {
let err = SegmentationError::AudioTooShort {
actual_secs: 0.05,
min_secs: 0.1,
};
let msg = format!("{err}");
assert!(msg.contains("0.05"));
assert!(msg.contains("0.1"));
}
}