use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SpeakerId(pub u32);
#[derive(Debug, Clone, PartialEq)]
pub struct SpeakerIdRemap {
mapping: Vec<(SpeakerId, SpeakerId)>,
}
impl SpeakerIdRemap {
pub fn from_mapping(mapping: Vec<(SpeakerId, SpeakerId)>) -> Self {
Self { mapping }
}
pub fn remap(&self, id: SpeakerId) -> SpeakerId {
self.mapping
.iter()
.find(|(old, _)| *old == id)
.map(|(_, new)| *new)
.unwrap_or(id)
}
pub fn is_empty(&self) -> bool {
self.mapping.is_empty()
}
pub fn len(&self) -> usize {
self.mapping.len()
}
}
pub fn remap_segments(segments: &mut [Segment], remap: &SpeakerIdRemap) {
for seg in segments.iter_mut() {
if let Some(spk) = seg.speaker {
seg.speaker = Some(remap.remap(spk));
}
}
}
pub fn remap_turns(turns: &mut [SpeakerTurn], remap: &SpeakerIdRemap) {
for turn in turns.iter_mut() {
turn.speaker = remap.remap(turn.speaker);
}
}
impl fmt::Display for SpeakerId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "SPEAKER_{:02}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Profile {
Mobile,
Balanced,
Custom,
}
impl Profile {
pub const fn embedding_dim(self) -> usize {
match self {
Profile::Mobile => 512, Profile::Balanced => 256, Profile::Custom => 0,
}
}
pub const fn default_threshold(self) -> f32 {
match self {
Profile::Mobile => 0.55,
Profile::Balanced => 0.45,
Profile::Custom => 0.5,
}
}
pub const fn manifest_id(self) -> &'static str {
match self {
Profile::Mobile => "mobile",
Profile::Balanced => "balanced",
Profile::Custom => "custom",
}
}
}
impl std::str::FromStr for Profile {
type Err = ProfileParseError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_ascii_lowercase().as_str() {
"mobile" => Ok(Profile::Mobile),
"balanced" => Ok(Profile::Balanced),
"custom" => Ok(Profile::Custom),
other => Err(ProfileParseError(other.to_owned())),
}
}
}
#[derive(Debug, Clone)]
pub struct ProfileParseError(pub String);
impl std::fmt::Display for ProfileParseError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"unknown profile '{}': expected mobile|balanced|custom",
self.0
)
}
}
impl std::error::Error for ProfileParseError {}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct SampleRate(u32);
impl SampleRate {
pub fn new(rate: u32) -> Option<Self> {
(8000..=192000).contains(&rate).then_some(Self(rate))
}
pub fn get(&self) -> u32 {
self.0
}
}
impl Default for SampleRate {
fn default() -> Self {
Self(16000)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Confidence(f32);
impl Confidence {
pub fn new(v: f32) -> Option<Self> {
(0.0..=1.0).contains(&v).then_some(Self(v))
}
pub fn get(&self) -> f32 {
self.0
}
}
impl Default for Confidence {
fn default() -> Self {
Self(1.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Seconds(f32);
impl Seconds {
pub fn new(v: f32) -> Option<Self> {
(v >= 0.0).then_some(Self(v))
}
pub fn get(&self) -> f32 {
self.0
}
}
impl Default for Seconds {
fn default() -> Self {
Self(0.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct TimeRange {
pub start: f64,
pub end: f64,
}
impl TimeRange {
pub fn duration(&self) -> f64 {
debug_assert!(
self.end >= self.start,
"TimeRange invariant violated: end < start"
);
self.end - self.start
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct Segment {
pub time: TimeRange,
pub speaker: Option<SpeakerId>,
pub confidence: Option<f32>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct SpeakerTurn {
pub speaker: SpeakerId,
pub time: TimeRange,
#[serde(skip_serializing_if = "Option::is_none")]
pub text: Option<String>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct WordAlignment {
pub word: String,
pub time: TimeRange,
pub speaker: Option<SpeakerId>,
pub confidence: f32,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct DiarizationResult {
pub segments: Vec<Segment>,
pub turns: Vec<SpeakerTurn>,
pub num_speakers: usize,
}
#[derive(Debug, Clone, Copy)]
pub struct ClusterConfig {
pub threshold: f32,
pub max_speakers: usize,
}
impl Default for ClusterConfig {
fn default() -> Self {
Self {
threshold: 0.45,
max_speakers: 64,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct WindowConfig {
pub window_secs: f32,
pub hop_secs: f32,
pub sample_rate: SampleRate,
}
impl Default for WindowConfig {
fn default() -> Self {
Self {
window_secs: 1.5,
hop_secs: 0.75,
sample_rate: SampleRate(16000),
}
}
}
impl WindowConfig {
pub fn window_samples(&self) -> usize {
(self.window_secs * self.sample_rate.get() as f32) as usize
}
pub fn hop_samples(&self) -> usize {
(self.hop_secs * self.sample_rate.get() as f32) as usize
}
}
#[derive(Debug, Clone, Copy)]
pub struct SpeechFilterConfig {
pub min_speech_secs: f32,
pub max_gap_secs: f32,
}
impl Default for SpeechFilterConfig {
fn default() -> Self {
Self {
min_speech_secs: 0.25,
max_gap_secs: 0.5,
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct DiarizationConfig {
pub cluster: ClusterConfig,
pub window: WindowConfig,
pub speech_filter: SpeechFilterConfig,
pub max_duration_secs: f32,
}
impl Default for DiarizationConfig {
fn default() -> Self {
Self {
cluster: ClusterConfig::default(),
window: WindowConfig::default(),
speech_filter: SpeechFilterConfig::default(),
max_duration_secs: 3600.0,
}
}
}
impl DiarizationConfig {
pub fn window_samples(&self) -> usize {
self.window.window_samples()
}
pub fn hop_samples(&self) -> usize {
self.window.hop_samples()
}
}
#[cfg(test)]
mod profile_tests {
use super::*;
#[test]
fn mobile_profile_uses_cam_pp_dim() {
assert_eq!(Profile::Mobile.embedding_dim(), 512);
}
#[test]
fn balanced_profile_uses_resnet34_dim() {
assert_eq!(Profile::Balanced.embedding_dim(), 256);
}
#[test]
fn custom_profile_dim_is_unresolved() {
assert_eq!(Profile::Custom.embedding_dim(), 0);
}
#[test]
fn default_thresholds_match_spec() {
assert!((Profile::Mobile.default_threshold() - 0.55).abs() < 1e-6);
assert!((Profile::Balanced.default_threshold() - 0.45).abs() < 1e-6);
assert!((Profile::Custom.default_threshold() - 0.5).abs() < 1e-6);
}
#[test]
fn manifest_id_for_each_variant() {
assert_eq!(Profile::Mobile.manifest_id(), "mobile");
assert_eq!(Profile::Balanced.manifest_id(), "balanced");
assert_eq!(Profile::Custom.manifest_id(), "custom");
}
#[test]
fn from_str_parses_kebab_and_lowercase() {
assert_eq!("mobile".parse::<Profile>().unwrap(), Profile::Mobile);
assert_eq!("Mobile".parse::<Profile>().unwrap(), Profile::Mobile);
assert_eq!("balanced".parse::<Profile>().unwrap(), Profile::Balanced);
assert!("nope".parse::<Profile>().is_err());
}
}