use crate::config::Config;
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpeakerSegment {
pub speaker: String,
pub start: f64,
pub end: f64,
}
#[derive(Debug, Clone)]
pub struct DiarizationResult {
pub segments: Vec<SpeakerSegment>,
pub num_speakers: usize,
pub system_dominant_ratio: f32,
pub voice_dominant_ratio: f32,
pub degraded_capture: Option<DegradedCapture>,
pub from_stems: bool,
pub source_aware: bool,
pub speaker_embeddings: std::collections::HashMap<String, Vec<f32>>,
}
impl Default for DiarizationResult {
fn default() -> Self {
Self {
segments: Vec::new(),
num_speakers: 0,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::new(),
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
pub struct DegradedCapture {
pub failure_kind: FailureKind,
pub capture_backend: String,
pub capture_source: CaptureSource,
pub voice_active_ratio: Option<f32>,
pub system_active_ratio: Option<f32>,
pub observed_signal: ObservedSignal,
pub diagnostic_confidence: DiagnosticConfidence,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "kebab-case")]
pub enum FailureKind {
Silent,
Sparse,
Missing,
BackendUnavailable,
StreamError,
SourceStarved,
UnsupportedFormat,
MisconfiguredRoute,
PermissionDenied,
RouteUnavailable,
Other { code: String },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum CaptureSource {
Voice,
System,
Both,
Backend,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, schemars::JsonSchema)]
pub struct ObservedSignal {
pub frames_captured: usize,
pub max_rms: f32,
pub avg_rms: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum DiagnosticConfidence {
High,
Inferred,
}
type EnergyWindow = (f64, f32);
type StemEnergyWindows = (Vec<EnergyWindow>, Vec<EnergyWindow>);
const STEM_PROBE_SECS: usize = 5;
const STEM_PROBE_RMS_FLOOR: f32 = 0.001;
const PRIMARY_DEGRADED_MIN_DURATION_SECS: f64 = 60.0;
const DEGRADED_ML_FALLBACK_MIN_DURATION_SECS: f64 = 120.0;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum Confidence {
High,
Medium,
Low,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, schemars::JsonSchema)]
#[serde(rename_all = "lowercase")]
pub enum AttributionSource {
Deterministic,
Llm,
Enrollment,
Manual,
#[serde(rename = "ml-bleed-degraded")]
MlBleedDegraded,
#[serde(rename = "stem-recovery")]
StemRecovery,
}
#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema)]
pub struct SpeakerAttribution {
pub speaker_label: String,
pub name: String,
pub confidence: Confidence,
pub source: AttributionSource,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum DiarizationPurpose {
PrimaryMeeting,
Auxiliary,
}
#[derive(Debug, Clone, Copy)]
pub struct DiarizationContext<'a> {
pub purpose: DiarizationPurpose,
pub transcript_windows: Option<&'a [TranscriptWindow]>,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TranscriptWindow {
pub start_secs: f32,
pub end_secs: f32,
}
#[derive(Debug, Clone)]
pub enum DiarizationOutcome {
Result(DiarizationResult),
Skipped { reason: DegradedCapture },
NotConfigured,
}
impl DiarizationOutcome {
pub fn result_for_auxiliary_use(&self) -> Option<&DiarizationResult> {
match self {
DiarizationOutcome::Result(result) => Some(result),
DiarizationOutcome::Skipped { .. } | DiarizationOutcome::NotConfigured => None,
}
}
}
pub fn apply_confirmed_names(transcript: &str, attributions: &[SpeakerAttribution]) -> String {
let high_map: std::collections::HashMap<&str, &str> = attributions
.iter()
.filter(|a| a.confidence == Confidence::High)
.map(|a| (a.speaker_label.as_str(), a.name.as_str()))
.collect();
if high_map.is_empty() {
return transcript.to_string();
}
let mut output = String::new();
for line in transcript.lines() {
let mut replaced = false;
if let Some(rest) = line.strip_prefix('[') {
if let Some(bracket_end) = rest.find(']') {
let inside = &rest[..bracket_end];
if let Some(space_pos) = inside.find(' ') {
let label = &inside[..space_pos];
let text = rest[bracket_end + 1..].trim();
if let Some(name) = high_map.get(label) {
if !is_non_lexical_event_text(text) {
let after = &rest[bracket_end..];
output.push_str(&format!(
"[{} {}{}\n",
name,
&inside[space_pos + 1..],
after
));
replaced = true;
}
}
}
}
}
if !replaced {
output.push_str(line);
output.push('\n');
}
}
output
}
fn is_non_lexical_event_text(text: &str) -> bool {
let trimmed = text.trim();
trimmed.starts_with('[') && trimmed.ends_with(']')
}
pub const SEGMENTATION_MODEL: &str = "segmentation-3.0.onnx";
pub const SEGMENTATION_MODEL_URL: &str =
"https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/segmentation-3.0.onnx";
pub struct EmbeddingModelInfo {
pub filename: &'static str,
pub url: &'static str,
pub version: &'static str,
}
pub fn embedding_model_info(name: &str) -> Option<&'static EmbeddingModelInfo> {
static CAM_PP: EmbeddingModelInfo = EmbeddingModelInfo {
filename: "wespeaker_en_voxceleb_CAM++.onnx",
url: "https://github.com/thewh1teagle/pyannote-rs/releases/download/v0.1.0/wespeaker_en_voxceleb_CAM++.onnx",
version: "wespeaker_en_voxceleb_CAM++_v0.3",
};
static CAM_PP_LM: EmbeddingModelInfo = EmbeddingModelInfo {
filename: "voxceleb_CAM++_LM.onnx",
url: "https://huggingface.co/Wespeaker/wespeaker-voxceleb-campplus-LM/resolve/main/voxceleb_CAM%2B%2B_LM.onnx",
version: "wespeaker_voxceleb_CAM++_LM_v0.3",
};
match name {
"cam++" => Some(&CAM_PP),
"cam++-lm" => Some(&CAM_PP_LM),
_ => None,
}
}
pub const EMBEDDING_MODEL_NAMES: &[&str] = &["cam++", "cam++-lm"];
pub fn embedding_model_for_config(config: &Config) -> &'static EmbeddingModelInfo {
embedding_model_info(&config.diarization.embedding_model)
.unwrap_or_else(|| embedding_model_info("cam++").unwrap())
}
pub fn models_installed(config: &Config) -> bool {
let dir = &config.diarization.model_path;
let emb = embedding_model_for_config(config);
dir.join(SEGMENTATION_MODEL).exists() && dir.join(emb.filename).exists()
}
fn preprocess_audio(audio_path: &Path) -> (std::path::PathBuf, Option<std::path::PathBuf>) {
let temp_path = std::env::temp_dir().join("minutes-diarize-preprocessed.wav");
match std::process::Command::new("ffmpeg")
.args([
"-y",
"-i",
audio_path.to_str().unwrap_or(""),
"-ar",
"16000",
"-ac",
"1",
"-sample_fmt",
"s16",
temp_path.to_str().unwrap_or(""),
])
.stdout(std::process::Stdio::null())
.stderr(std::process::Stdio::null())
.status()
{
Ok(status) if status.success() => {
tracing::info!("audio preprocessed to 16kHz mono via ffmpeg");
(temp_path.clone(), Some(temp_path))
}
_ => {
tracing::debug!("ffmpeg not available for preprocessing, using original audio");
(audio_path.to_path_buf(), None)
}
}
}
pub fn audio_duration_secs(audio_path: &Path) -> Result<f64, String> {
use symphonia::core::io::MediaSourceStream;
use symphonia::core::probe::Hint;
let file = std::fs::File::open(audio_path)
.map_err(|error| format!("failed to open audio {}: {error}", audio_path.display()))?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = audio_path.extension().and_then(|ext| ext.to_str()) {
hint.with_extension(ext);
}
let probed = symphonia::default::get_probe()
.format(&hint, mss, &Default::default(), &Default::default())
.map_err(|error| format!("failed to probe audio {}: {error}", audio_path.display()))?;
let track = probed
.format
.default_track()
.ok_or_else(|| format!("audio {} has no default track", audio_path.display()))?;
let params = &track.codec_params;
let n_frames = params
.n_frames
.ok_or_else(|| format!("audio {} has no frame count", audio_path.display()))?;
let sample_rate = params
.sample_rate
.ok_or_else(|| format!("audio {} has no sample rate", audio_path.display()))?;
if sample_rate == 0 {
return Err(format!(
"audio {} has zero sample rate",
audio_path.display()
));
}
Ok(n_frames as f64 / sample_rate as f64)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct StemPaths {
pub voice: std::path::PathBuf,
pub system: std::path::PathBuf,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) enum SourceAwareDiarizationPlan {
FullStems(StemPaths),
SystemStemOnly(std::path::PathBuf),
SilentSystemStem(StemPaths),
}
pub(crate) fn stem_has_audio(path: &Path) -> bool {
let Ok(reader) = hound::WavReader::open(path) else {
return false;
};
let spec = reader.spec();
if spec.sample_rate == 0 || spec.channels == 0 {
return false;
}
match spec.sample_format {
hound::SampleFormat::Float => probe_stem_samples(
reader.into_samples::<f32>(),
spec.sample_rate,
spec.channels,
|sample| sample,
),
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample.clamp(1, 32);
let max_value = (1_i64 << (bits - 1)) as f32;
probe_stem_samples(
reader.into_samples::<i32>(),
spec.sample_rate,
spec.channels,
move |sample| sample as f32 / max_value,
)
}
}
}
fn stem_probe_observed_signal(path: &Path) -> ObservedSignal {
let Ok(reader) = hound::WavReader::open(path) else {
return ObservedSignal {
frames_captured: 0,
max_rms: 0.0,
avg_rms: 0.0,
};
};
let spec = reader.spec();
if spec.sample_rate == 0 || spec.channels == 0 {
return ObservedSignal {
frames_captured: 0,
max_rms: 0.0,
avg_rms: 0.0,
};
}
match spec.sample_format {
hound::SampleFormat::Float => probe_stem_observed_signal(
reader.into_samples::<f32>(),
spec.sample_rate,
spec.channels,
|sample| sample,
),
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample.clamp(1, 32);
let max_value = (1_i64 << (bits - 1)) as f32;
probe_stem_observed_signal(
reader.into_samples::<i32>(),
spec.sample_rate,
spec.channels,
move |sample| sample as f32 / max_value,
)
}
}
}
fn probe_stem_samples<T>(
samples: impl Iterator<Item = Result<T, hound::Error>>,
sample_rate: u32,
channels: u16,
normalize: impl Fn(T) -> f32,
) -> bool {
let channels = channels as usize;
let window_frames = sample_rate as usize;
if window_frames == 0 || channels == 0 {
return false;
}
let mut channel_index = 0usize;
let mut frame_sum = 0.0_f32;
let mut window_frames_read = 0usize;
let mut window_sum_sq = 0.0_f64;
for sample in samples {
let Ok(sample) = sample else {
continue;
};
frame_sum += normalize(sample);
channel_index += 1;
if channel_index < channels {
continue;
}
let mono = frame_sum / channels as f32;
window_sum_sq += (mono as f64) * (mono as f64);
window_frames_read += 1;
channel_index = 0;
frame_sum = 0.0;
if window_frames_read >= window_frames {
let rms = (window_sum_sq / window_frames_read as f64).sqrt() as f32;
if rms > STEM_PROBE_RMS_FLOOR {
return true;
}
window_frames_read = 0;
window_sum_sq = 0.0;
}
}
if window_frames_read > 0 {
let rms = (window_sum_sq / window_frames_read as f64).sqrt() as f32;
return rms > STEM_PROBE_RMS_FLOOR;
}
false
}
fn probe_stem_observed_signal<T>(
mut samples: impl Iterator<Item = Result<T, hound::Error>>,
sample_rate: u32,
channels: u16,
normalize: impl Fn(T) -> f32,
) -> ObservedSignal {
let channels = channels as usize;
let max_frames = sample_rate as usize * STEM_PROBE_SECS;
let max_samples = max_frames * channels;
if max_frames == 0 || channels == 0 {
return ObservedSignal {
frames_captured: 0,
max_rms: 0.0,
avg_rms: 0.0,
};
}
let mut samples_read = 0usize;
let mut frames_read = 0usize;
let mut channel_index = 0usize;
let mut frame_sum = 0.0_f32;
let mut sum_sq = 0.0_f64;
let mut max_abs = 0.0_f32;
while samples_read < max_samples && frames_read < max_frames {
let Some(sample) = samples.next() else {
break;
};
samples_read += 1;
let Ok(sample) = sample else {
continue;
};
frame_sum += normalize(sample);
channel_index += 1;
if channel_index < channels {
continue;
}
let mono = frame_sum / channels as f32;
max_abs = max_abs.max(mono.abs());
sum_sq += (mono as f64) * (mono as f64);
frames_read += 1;
channel_index = 0;
frame_sum = 0.0;
}
let avg_rms = if frames_read == 0 {
0.0
} else {
(sum_sq / frames_read as f64).sqrt() as f32
};
ObservedSignal {
frames_captured: frames_read,
max_rms: max_abs,
avg_rms,
}
}
pub(crate) fn discover_stem_plan(audio_path: &Path) -> Option<SourceAwareDiarizationPlan> {
let stem = audio_path.file_stem()?.to_str()?;
let dir = audio_path.parent()?;
let voice = dir.join(format!("{}.voice.wav", stem));
let system = dir.join(format!("{}.system.wav", stem));
let voice_ok = stem_has_audio(&voice);
let system_ok = stem_has_audio(&system);
match (voice_ok, system_ok) {
(true, true) => {
tracing::info!(
voice = %voice.display(),
system = %system.display(),
"discovered per-source audio stems"
);
Some(SourceAwareDiarizationPlan::FullStems(StemPaths {
voice,
system,
}))
}
(false, true) => {
tracing::warn!(
system = %system.display(),
voice = %voice.display(),
"voice stem missing or empty; falling back to system-stem-only diarization"
);
Some(SourceAwareDiarizationPlan::SystemStemOnly(system))
}
(true, false) => {
if system.exists() {
tracing::warn!(
voice = %voice.display(),
system = %system.display(),
"system stem exists but has no detected audio"
);
Some(SourceAwareDiarizationPlan::SilentSystemStem(StemPaths {
voice,
system,
}))
} else {
tracing::warn!(
voice = %voice.display(),
system = %system.display(),
"system stem missing; skipping source-aware diarization"
);
None
}
}
(false, false) => None,
}
}
pub fn discover_stems(audio_path: &Path) -> Option<StemPaths> {
match discover_stem_plan(audio_path) {
Some(SourceAwareDiarizationPlan::FullStems(stems)) => Some(stems),
_ => None,
}
}
fn compute_energy_windows(wav_path: &Path, window_secs: f64) -> Result<Vec<(f64, f32)>, String> {
let reader = hound::WavReader::open(wav_path)
.map_err(|e| format!("failed to open stem {}: {}", wav_path.display(), e))?;
let spec = reader.spec();
let sample_rate = spec.sample_rate as f64;
let window_samples = (sample_rate * window_secs) as usize;
if window_samples == 0 {
return Err("window too small".into());
}
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => reader
.into_samples::<f32>()
.filter_map(|s| s.ok())
.collect(),
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample;
let max_val = (1i64 << (bits - 1)) as f32;
reader
.into_samples::<i32>()
.filter_map(|s| s.ok())
.map(|s| s as f32 / max_val)
.collect()
}
};
let channels = spec.channels as usize;
let mono: Vec<f32> = if channels > 1 {
samples
.chunks(channels)
.map(|chunk| chunk.iter().sum::<f32>() / channels as f32)
.collect()
} else {
samples
};
let mut windows = Vec::new();
for (i, chunk) in mono.chunks(window_samples).enumerate() {
let sum_sq: f64 = chunk.iter().map(|&s| (s as f64) * (s as f64)).sum();
let rms = (sum_sq / chunk.len() as f64).sqrt() as f32;
let start = i as f64 * window_secs;
windows.push((start, rms));
}
Ok(windows)
}
fn read_stem_energy_windows(
stems: &StemPaths,
window_secs: f64,
) -> Result<StemEnergyWindows, String> {
let voice_energy = compute_energy_windows(&stems.voice, window_secs)
.map_err(|error| format!("failed to read voice stem: {error}"))?;
let system_energy = compute_energy_windows(&stems.system, window_secs)
.map_err(|error| format!("failed to read system stem: {error}"))?;
Ok((voice_energy, system_energy))
}
fn correlation_coefficient(xs: &[f32], ys: &[f32]) -> Option<f32> {
if xs.len() != ys.len() || xs.len() < 2 {
return None;
}
let n = xs.len() as f64;
let mean_x = xs.iter().map(|&x| x as f64).sum::<f64>() / n;
let mean_y = ys.iter().map(|&y| y as f64).sum::<f64>() / n;
let mut num = 0.0;
let mut den_x = 0.0;
let mut den_y = 0.0;
for (&x, &y) in xs.iter().zip(ys.iter()) {
let dx = x as f64 - mean_x;
let dy = y as f64 - mean_y;
num += dx * dy;
den_x += dx * dx;
den_y += dy * dy;
}
let denom = (den_x * den_y).sqrt();
if denom <= f64::EPSILON {
None
} else {
Some((num / denom) as f32)
}
}
fn merge_or_push_segment(segments: &mut Vec<SpeakerSegment>, speaker: &str, start: f64, end: f64) {
if let Some(last) = segments.last_mut() {
if last.speaker == speaker && (start - last.end).abs() < 0.01 {
last.end = end;
return;
}
}
segments.push(SpeakerSegment {
speaker: speaker.to_string(),
start,
end,
});
}
fn collapse_to_single_speaker_segments(
voice_energy: &[(f64, f32)],
system_energy: &[(f64, f32)],
window_secs: f64,
silence_threshold: f32,
speaker_label: &str,
) -> Vec<SpeakerSegment> {
let mut segments = Vec::new();
let window_count = voice_energy.len().min(system_energy.len());
for i in 0..window_count {
let (start, voice_rms) = voice_energy[i];
let (_, system_rms) = system_energy[i];
let end = start + window_secs;
let voice_active = voice_rms > silence_threshold;
let system_active = system_rms > silence_threshold;
if voice_active || system_active {
merge_or_push_segment(&mut segments, speaker_label, start, end);
}
}
segments
}
fn maybe_relabel_single_call_speaker_to_voice(
segments: &mut [SpeakerSegment],
voice_values: &[f32],
system_values: &[f32],
silence_threshold: f32,
stem_correlation_threshold: f32,
) {
if segments.len() != 1 || segments[0].speaker != "SPEAKER_1" {
return;
}
let active_voice_windows = voice_values
.iter()
.filter(|&&rms| rms > silence_threshold)
.count();
let active_voice_ratio = active_voice_windows as f32 / voice_values.len().max(1) as f32;
let correlated = correlation_coefficient(voice_values, system_values)
.is_some_and(|value| value >= stem_correlation_threshold);
if active_voice_ratio >= 0.6 && correlated {
segments[0].speaker = "SPEAKER_0".into();
}
}
fn stem_dominant_ratios(
voice_values: &[f32],
system_values: &[f32],
silence_threshold: f32,
) -> (f32, f32) {
let mut active = 0usize;
let mut voice_dominant = 0usize;
let mut system_dominant = 0usize;
for (voice_rms, system_rms) in voice_values.iter().zip(system_values.iter()) {
let voice_active = *voice_rms > silence_threshold;
let system_active = *system_rms > silence_threshold;
if !voice_active && !system_active {
continue;
}
active += 1;
if voice_active && !system_active {
voice_dominant += 1;
} else if system_active && !voice_active {
system_dominant += 1;
} else if voice_rms >= system_rms {
voice_dominant += 1;
} else {
system_dominant += 1;
}
}
if active == 0 {
return (0.0, 0.0);
}
(
system_dominant as f32 / active as f32,
voice_dominant as f32 / active as f32,
)
}
fn active_ratio(values: &[f32], silence_threshold: f32) -> Option<f32> {
if values.is_empty() {
return None;
}
let active = values
.iter()
.filter(|&&rms| rms > silence_threshold)
.count();
Some(active as f32 / values.len() as f32)
}
fn observed_signal(values: &[f32]) -> ObservedSignal {
let frames_captured = values.len();
let max_rms = values.iter().copied().fold(0.0_f32, f32::max);
let avg_rms = if values.is_empty() {
0.0
} else {
values.iter().sum::<f32>() / values.len() as f32
};
ObservedSignal {
frames_captured,
max_rms,
avg_rms,
}
}
fn stem_degraded_capture_evidence(
voice_values: &[f32],
system_values: &[f32],
silence_threshold: f32,
system_dominant_ratio: f32,
) -> Option<DegradedCapture> {
let system_signal = observed_signal(system_values);
let system_active_ratio = active_ratio(system_values, silence_threshold);
let failure_kind = if system_signal.max_rms <= STEM_PROBE_RMS_FLOOR {
FailureKind::Silent
} else if system_active_ratio.unwrap_or(0.0) < 0.02 || system_dominant_ratio < 0.02 {
FailureKind::Sparse
} else {
return None;
};
Some(DegradedCapture {
failure_kind,
capture_backend: "cpal".into(),
capture_source: CaptureSource::System,
voice_active_ratio: active_ratio(voice_values, silence_threshold),
system_active_ratio,
observed_signal: system_signal,
diagnostic_confidence: DiagnosticConfidence::Inferred,
})
}
fn diarization_from_energy_windows(
voice_energy: &[(f64, f32)],
system_energy: &[(f64, f32)],
window_secs: f64,
stem_correlation_threshold: f32,
) -> Option<DiarizationResult> {
let silence_threshold = 0.005_f32;
let voice_label = "SPEAKER_0";
let call_label = "SPEAKER_1";
let window_count = voice_energy.len().min(system_energy.len());
let voice_values: Vec<f32> = voice_energy
.iter()
.take(window_count)
.map(|(_, rms)| *rms)
.collect();
let system_values: Vec<f32> = system_energy
.iter()
.take(window_count)
.map(|(_, rms)| *rms)
.collect();
let active_windows = voice_values
.iter()
.zip(system_values.iter())
.filter(|(voice_rms, system_rms)| {
**voice_rms > silence_threshold || **system_rms > silence_threshold
})
.count();
let (system_dominant_ratio, voice_dominant_ratio) =
stem_dominant_ratios(&voice_values, &system_values, silence_threshold);
let degraded_capture = stem_degraded_capture_evidence(
&voice_values,
&system_values,
silence_threshold,
system_dominant_ratio,
);
let correlation = correlation_coefficient(&voice_values, &system_values);
if active_windows >= 3 && correlation.is_some_and(|value| value >= stem_correlation_threshold) {
let segments = collapse_to_single_speaker_segments(
voice_energy,
system_energy,
window_secs,
silence_threshold,
voice_label,
);
if segments.is_empty() {
return None;
}
tracing::info!(
active_windows,
correlation = correlation,
threshold = stem_correlation_threshold,
"stem energies strongly correlated — collapsing to one speaker"
);
return Some(DiarizationResult {
segments,
num_speakers: 1,
system_dominant_ratio,
voice_dominant_ratio,
degraded_capture: degraded_capture.clone(),
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
});
}
let mut segments: Vec<SpeakerSegment> = Vec::new();
for i in 0..window_count {
let (start, voice_rms) = voice_energy[i];
let (_, system_rms) = system_energy[i];
let end = start + window_secs;
let voice_active = voice_rms > silence_threshold;
let system_active = system_rms > silence_threshold;
let speaker = if voice_active && !system_active {
voice_label
} else if system_active && !voice_active {
call_label
} else if voice_active && system_active {
if voice_rms >= system_rms {
voice_label
} else {
call_label
}
} else {
continue;
};
merge_or_push_segment(&mut segments, speaker, start, end);
}
let num_speakers = segments
.iter()
.map(|s| s.speaker.as_str())
.collect::<std::collections::HashSet<_>>()
.len();
if num_speakers == 1 {
maybe_relabel_single_call_speaker_to_voice(
&mut segments,
&voice_values,
&system_values,
silence_threshold,
stem_correlation_threshold,
);
}
if segments.is_empty() {
None
} else {
let num_speakers = segments
.iter()
.map(|s| s.speaker.as_str())
.collect::<std::collections::HashSet<_>>()
.len();
Some(DiarizationResult {
segments,
num_speakers,
system_dominant_ratio,
voice_dominant_ratio,
degraded_capture,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
})
}
}
pub fn diarize_from_stems(stems: &StemPaths, config: &Config) -> Option<DiarizationResult> {
let window_secs = 1.0;
let (voice_energy, system_energy) = match read_stem_energy_windows(stems, window_secs) {
Ok(energies) => energies,
Err(error) => {
tracing::warn!(error = %error, "failed to read source-aware stems, falling back to ML diarization");
return None;
}
};
let stem_correlation_threshold = config.diarization.stem_correlation_threshold;
let Some(result) = diarization_from_energy_windows(
&voice_energy,
&system_energy,
window_secs,
stem_correlation_threshold,
) else {
tracing::warn!("stem-based diarization produced no segments (all silent), falling back");
return None;
};
tracing::info!(
speakers = result.num_speakers,
segments = result.segments.len(),
voice_stem = %stems.voice.display(),
system_stem = %stems.system.display(),
"stem-based diarization complete"
);
Some(result)
}
fn resolve_diarization_engine(config: &Config) -> Option<&str> {
match config.diarization.engine.as_str() {
"none" => None,
"auto" => {
if models_installed(config) {
tracing::info!("diarization models found — auto-enabling pyannote-rs");
Some("pyannote-rs")
} else {
tracing::debug!(
"diarization models not found — skipping (run `minutes setup --diarization` to enable)"
);
None
}
}
other => Some(other),
}
}
fn run_diarization_engine(
audio_path: &Path,
config: &Config,
resolved_engine: &str,
) -> Option<DiarizationResult> {
tracing::info!(
engine = %resolved_engine,
file = %audio_path.display(),
"running diarization"
);
let (effective_path, temp_file) = preprocess_audio(audio_path);
let effective_path_owned = effective_path.clone();
#[allow(unused_variables)] let config_clone = config.clone();
let engine_owned = resolved_engine.to_string();
let handle = std::thread::spawn(move || -> Result<DiarizationResult, String> {
let result = match engine_owned.as_str() {
#[cfg(feature = "diarize")]
"pyannote-rs" => diarize_with_pyannote_rs(&effective_path_owned, &config_clone),
#[cfg(not(feature = "diarize"))]
"pyannote-rs" => {
Err("pyannote-rs engine requires the 'diarize' feature. Rebuild with: cargo build --features diarize".into())
}
"pyannote" => diarize_with_pyannote(&effective_path_owned),
other => Err(format!("unknown diarization engine: {}", other).into()),
};
result.map_err(|e| e.to_string())
});
let result = match handle.join() {
Ok(r) => Some(r),
Err(_) => {
tracing::error!("diarization thread panicked");
None
}
};
if let Some(ref temp) = temp_file {
std::fs::remove_file(temp).ok();
}
match result {
Some(Ok(result)) => {
tracing::info!(
speakers = result.num_speakers,
segments = result.segments.len(),
"diarization complete"
);
Some(result)
}
Some(Err(e)) => {
tracing::error!(error = %e, "diarization failed, continuing without speaker labels");
None
}
None => None,
}
}
fn remap_diarization_labels(
result: &DiarizationResult,
starting_label: usize,
) -> DiarizationResult {
let mut label_map: std::collections::HashMap<String, String> = std::collections::HashMap::new();
let mut next_label = starting_label;
let mut remap_label = |raw: &str| {
label_map
.entry(raw.to_string())
.or_insert_with(|| {
let label = format!("SPEAKER_{}", next_label);
next_label += 1;
label
})
.clone()
};
let segments = result
.segments
.iter()
.map(|segment| SpeakerSegment {
speaker: remap_label(&segment.speaker),
start: segment.start,
end: segment.end,
})
.collect();
let mut embedding_keys: Vec<String> = result.speaker_embeddings.keys().cloned().collect();
embedding_keys.sort();
let mut speaker_embeddings = std::collections::HashMap::new();
for raw_label in embedding_keys {
let remapped_label = remap_label(&raw_label);
if let Some(embedding) = result.speaker_embeddings.get(&raw_label) {
speaker_embeddings.insert(remapped_label, embedding.clone());
}
}
DiarizationResult {
segments,
num_speakers: label_map.len(),
system_dominant_ratio: result.system_dominant_ratio,
voice_dominant_ratio: result.voice_dominant_ratio,
degraded_capture: result.degraded_capture.clone(),
from_stems: result.from_stems,
source_aware: result.source_aware,
speaker_embeddings,
}
}
fn merge_remote_diarization_into_stem_result(
stem_result: &DiarizationResult,
remote_result: &DiarizationResult,
) -> DiarizationResult {
let mut base_segments = stem_result.segments.clone();
base_segments.sort_by(|a, b| {
a.start
.partial_cmp(&b.start)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut remote_segments = remote_result.segments.clone();
remote_segments.sort_by(|a, b| {
a.start
.partial_cmp(&b.start)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut merged = Vec::new();
let mut remote_cursor = 0usize;
for segment in base_segments {
if segment.speaker != "SPEAKER_1" {
merge_or_push_segment(&mut merged, &segment.speaker, segment.start, segment.end);
continue;
}
while remote_cursor < remote_segments.len()
&& remote_segments[remote_cursor].end <= segment.start
{
remote_cursor += 1;
}
let mut idx = remote_cursor;
let mut cursor = segment.start;
while idx < remote_segments.len() && remote_segments[idx].start < segment.end {
let remote = &remote_segments[idx];
let start = segment.start.max(remote.start).max(cursor);
let end = segment.end.min(remote.end);
if start > cursor {
merge_or_push_segment(&mut merged, "SPEAKER_1", cursor, start);
}
if end > start {
merge_or_push_segment(&mut merged, &remote.speaker, start, end);
cursor = end;
}
idx += 1;
}
if cursor < segment.end {
merge_or_push_segment(&mut merged, "SPEAKER_1", cursor, segment.end);
}
}
let present_labels: std::collections::HashSet<String> = merged
.iter()
.map(|segment| segment.speaker.clone())
.collect();
let speaker_embeddings = remote_result
.speaker_embeddings
.iter()
.filter(|(label, _)| present_labels.contains(*label))
.map(|(label, embedding)| (label.clone(), embedding.clone()))
.collect();
DiarizationResult {
num_speakers: present_labels.len(),
segments: merged,
system_dominant_ratio: stem_result.system_dominant_ratio,
voice_dominant_ratio: stem_result.voice_dominant_ratio,
degraded_capture: stem_result.degraded_capture.clone(),
from_stems: false,
source_aware: true,
speaker_embeddings,
}
}
fn meaningful_speaker_count_excluding(result: &DiarizationResult, ignored: &[&str]) -> usize {
let mut speaker_durations: std::collections::HashMap<&str, f64> =
std::collections::HashMap::new();
for segment in &result.segments {
if ignored.contains(&segment.speaker.as_str()) {
continue;
}
let duration = (segment.end - segment.start).max(0.0);
if duration > 0.0 {
*speaker_durations
.entry(segment.speaker.as_str())
.or_insert(0.0) += duration;
}
}
speaker_durations
.values()
.filter(|&&duration| duration >= 0.5)
.count()
}
fn has_meaningful_remote_structure(result: &DiarizationResult) -> bool {
meaningful_speaker_count_excluding(result, &["SPEAKER_0"]) >= 1
}
fn has_meaningful_system_stem_labels(result: &DiarizationResult) -> bool {
meaningful_speaker_count_excluding(result, &["SPEAKER_0", "SPEAKER_1"]) >= 1
}
fn diarize_from_source_aware_stems(
stems: &StemPaths,
config: &Config,
resolved_engine: Option<&str>,
) -> Option<DiarizationResult> {
let window_secs = 1.0;
let (voice_energy, system_energy) = match read_stem_energy_windows(stems, window_secs) {
Ok(energies) => energies,
Err(error) => {
tracing::warn!(error = %error, "failed to read source-aware stems, falling back to ML diarization");
return None;
}
};
let stem_result = diarization_from_energy_windows(
&voice_energy,
&system_energy,
window_secs,
config.diarization.stem_correlation_threshold,
)?;
let local_only_collapse = stem_result.num_speakers == 1
&& !stem_result.segments.is_empty()
&& stem_result
.segments
.iter()
.all(|segment| segment.speaker == "SPEAKER_0");
let non_collapsed_stem_result =
diarization_from_energy_windows(&voice_energy, &system_energy, window_secs, 2.0);
let Some(resolved_engine) = resolved_engine else {
return Some(stem_result);
};
let Some(remote_result) = run_diarization_engine(&stems.system, config, resolved_engine) else {
tracing::warn!(
system_stem = %stems.system.display(),
"system-stem diarization failed, keeping stem-only attribution"
);
return Some(stem_result);
};
let remapped_remote = remap_diarization_labels(&remote_result, 2);
if !has_meaningful_remote_structure(&remapped_remote) {
tracing::info!(
remote_speakers = remapped_remote.num_speakers,
"system-stem diarization did not find stable remote structure, keeping stem-only attribution"
);
return Some(stem_result);
}
let merge_base = if local_only_collapse {
non_collapsed_stem_result.as_ref().unwrap_or(&stem_result)
} else {
&stem_result
};
let merged = merge_remote_diarization_into_stem_result(merge_base, &remapped_remote);
if !has_meaningful_system_stem_labels(&merged) {
tracing::info!(
stem_speakers = stem_result.num_speakers,
merged_speakers = merged.num_speakers,
"system-stem diarization did not contribute stable remote speaker labels, keeping stem-only attribution"
);
return Some(stem_result);
}
tracing::info!(
stem_speakers = stem_result.num_speakers,
merged_speakers = merged.num_speakers,
"hybrid source-aware diarization complete"
);
Some(merged)
}
fn diarize_system_stem_with_full_audio_fallback(
system_stem: &Path,
audio_path: &Path,
config: &Config,
resolved_engine: &str,
mut run_engine: impl FnMut(&Path, &Config, &str) -> Option<DiarizationResult>,
) -> Option<DiarizationResult> {
if let Some(result) = run_engine(system_stem, config, resolved_engine) {
return Some(result);
}
tracing::warn!(
system_stem = %system_stem.display(),
audio = %audio_path.display(),
"system-stem-only diarization failed, falling back to full-audio ML diarization"
);
run_engine(audio_path, config, resolved_engine)
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct RemoteSpeechRun {
start_secs: f32,
end_secs: f32,
}
impl RemoteSpeechRun {
fn duration_secs(self) -> f32 {
(self.end_secs - self.start_secs).max(0.0)
}
}
fn system_dominant_runs(segments: &[SpeakerSegment]) -> Vec<RemoteSpeechRun> {
let mut sorted = segments.to_vec();
sorted.sort_by(|a, b| {
a.start
.partial_cmp(&b.start)
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut runs: Vec<RemoteSpeechRun> = Vec::new();
for segment in sorted {
if segment.speaker == "SPEAKER_0" || segment.end <= segment.start {
continue;
}
let start_secs = segment.start as f32;
let end_secs = segment.end as f32;
if let Some(last) = runs.last_mut() {
if start_secs <= last.end_secs + 0.25 {
last.end_secs = last.end_secs.max(end_secs);
continue;
}
}
runs.push(RemoteSpeechRun {
start_secs,
end_secs,
});
}
runs
}
fn run_overlaps_any_window(run: RemoteSpeechRun, windows: &[TranscriptWindow]) -> bool {
windows.iter().any(|window| {
window.end_secs > run.start_secs - 1.0 && window.start_secs < run.end_secs + 1.0
})
}
pub(crate) fn has_sustained_remote_speech(
result: &DiarizationResult,
transcript_windows: Option<&[TranscriptWindow]>,
) -> bool {
let Some(windows) = transcript_windows else {
return false;
};
if windows.is_empty() {
return false;
}
let speech_runs: Vec<_> = system_dominant_runs(&result.segments)
.into_iter()
.filter(|run| run_overlaps_any_window(*run, windows))
.collect();
let long_runs = speech_runs
.iter()
.filter(|run| run.duration_secs() >= 1.5)
.count();
let total_secs: f32 = speech_runs.iter().map(|run| run.duration_secs()).sum();
long_runs >= 3 || total_secs >= 30.0
}
fn fallback_degraded_capture_from_result(result: &DiarizationResult) -> DegradedCapture {
DegradedCapture {
failure_kind: if result.system_dominant_ratio <= 0.001 {
FailureKind::Silent
} else {
FailureKind::Sparse
},
capture_backend: "cpal".into(),
capture_source: CaptureSource::System,
voice_active_ratio: Some(result.voice_dominant_ratio),
system_active_ratio: Some(result.system_dominant_ratio),
observed_signal: ObservedSignal {
frames_captured: result.segments.len(),
max_rms: 0.0,
avg_rms: 0.0,
},
diagnostic_confidence: DiagnosticConfidence::Inferred,
}
}
fn dominant_ratio_degraded(result: &DiarizationResult) -> bool {
result.from_stems
&& result.system_dominant_ratio < 0.10
&& result.voice_dominant_ratio > result.system_dominant_ratio
}
fn silent_system_stem_degraded_capture(system_stem: &Path) -> DegradedCapture {
DegradedCapture {
failure_kind: FailureKind::Silent,
capture_backend: "cpal".into(),
capture_source: CaptureSource::System,
voice_active_ratio: Some(1.0),
system_active_ratio: Some(0.0),
observed_signal: stem_probe_observed_signal(system_stem),
diagnostic_confidence: DiagnosticConfidence::Inferred,
}
}
fn degraded_capture_for_silent_system_stem(
audio_path: &Path,
system_stem: &Path,
ctx: DiarizationContext<'_>,
) -> Option<DegradedCapture> {
if ctx.purpose != DiarizationPurpose::PrimaryMeeting {
return None;
}
let duration_secs = audio_duration_secs(audio_path).unwrap_or(f64::INFINITY);
if duration_secs <= PRIMARY_DEGRADED_MIN_DURATION_SECS {
return None;
}
Some(silent_system_stem_degraded_capture(system_stem))
}
fn degraded_capture_for_primary_result(
audio_path: &Path,
result: &DiarizationResult,
ctx: DiarizationContext<'_>,
) -> Option<DegradedCapture> {
if ctx.purpose != DiarizationPurpose::PrimaryMeeting || !result.source_aware {
return None;
}
let degraded_reason = result.degraded_capture.clone().or_else(|| {
dominant_ratio_degraded(result).then(|| fallback_degraded_capture_from_result(result))
})?;
let duration_secs = audio_duration_secs(audio_path).unwrap_or(f64::INFINITY);
if duration_secs <= PRIMARY_DEGRADED_MIN_DURATION_SECS {
return None;
}
if has_sustained_remote_speech(result, ctx.transcript_windows) {
return None;
}
Some(degraded_reason)
}
fn should_attempt_degraded_ml_fallback(audio_path: &Path, ctx: DiarizationContext<'_>) -> bool {
if ctx.purpose != DiarizationPurpose::PrimaryMeeting {
return false;
}
let duration_secs = audio_duration_secs(audio_path).unwrap_or(f64::INFINITY);
duration_secs > DEGRADED_ML_FALLBACK_MIN_DURATION_SECS
}
fn degraded_voice_stem_ml_fallback_with_runner(
audio_path: &Path,
voice_stem: &Path,
config: &Config,
resolved_engine: Option<&str>,
reason: &DegradedCapture,
ctx: DiarizationContext<'_>,
mut run_engine: impl FnMut(&Path, &Config, &str) -> Option<DiarizationResult>,
) -> Option<DiarizationResult> {
let resolved_engine = resolved_engine?;
if !should_attempt_degraded_ml_fallback(audio_path, ctx) {
return None;
}
let mut result = run_engine(voice_stem, config, resolved_engine)?;
if result.num_speakers < 2 {
tracing::warn!(
failure_kind = ?reason.failure_kind,
voice_stem = %voice_stem.display(),
speakers = result.num_speakers,
segments = result.segments.len(),
"source-aware diarization degraded; voice-stem ML fallback did not recover multiple speakers"
);
return None;
}
result.degraded_capture = Some(reason.clone());
result.from_stems = false;
result.source_aware = false;
tracing::warn!(
failure_kind = ?reason.failure_kind,
voice_stem = %voice_stem.display(),
speakers = result.num_speakers,
segments = result.segments.len(),
"source-aware diarization degraded; recovered with low-confidence voice-stem ML diarization"
);
Some(result)
}
fn degraded_voice_stem_ml_fallback(
audio_path: &Path,
voice_stem: &Path,
config: &Config,
resolved_engine: Option<&str>,
reason: &DegradedCapture,
ctx: DiarizationContext<'_>,
) -> Option<DiarizationResult> {
degraded_voice_stem_ml_fallback_with_runner(
audio_path,
voice_stem,
config,
resolved_engine,
reason,
ctx,
run_diarization_engine,
)
}
pub fn diarize_with_context(
audio_path: &Path,
config: &Config,
ctx: DiarizationContext<'_>,
) -> DiarizationOutcome {
let engine = &config.diarization.engine;
if engine == "none" {
return DiarizationOutcome::NotConfigured;
}
let resolved_engine = resolve_diarization_engine(config);
if let Some(plan) = discover_stem_plan(audio_path) {
match plan {
SourceAwareDiarizationPlan::FullStems(stems) => {
if let Some(result) =
diarize_from_source_aware_stems(&stems, config, resolved_engine)
{
if let Some(reason) =
degraded_capture_for_primary_result(audio_path, &result, ctx)
{
if let Some(recovered) = degraded_voice_stem_ml_fallback(
audio_path,
&stems.voice,
config,
resolved_engine,
&reason,
ctx,
) {
return DiarizationOutcome::Result(recovered);
}
tracing::warn!(
failure_kind = ?reason.failure_kind,
system_dominant_ratio = result.system_dominant_ratio,
voice_dominant_ratio = result.voice_dominant_ratio,
"source-aware diarization degraded; leaving primary transcript unlabeled"
);
return DiarizationOutcome::Skipped { reason };
}
return DiarizationOutcome::Result(result);
}
if let Some(resolved_engine) = resolved_engine {
tracing::warn!(
system_stem = %stems.system.display(),
"source-aware stem diarization failed, falling back to system-stem ML diarization"
);
if let Some(result) =
run_diarization_engine(&stems.system, config, resolved_engine)
{
return DiarizationOutcome::Result(result);
}
}
tracing::warn!("source-aware stem diarization failed, falling back to ML engine");
}
SourceAwareDiarizationPlan::SystemStemOnly(system_stem) => {
if let Some(resolved_engine) = resolved_engine {
return match diarize_system_stem_with_full_audio_fallback(
&system_stem,
audio_path,
config,
resolved_engine,
run_diarization_engine,
) {
Some(result) => DiarizationOutcome::Result(result),
None => DiarizationOutcome::NotConfigured,
};
}
}
SourceAwareDiarizationPlan::SilentSystemStem(stems) => {
if let Some(reason) =
degraded_capture_for_silent_system_stem(audio_path, &stems.system, ctx)
{
if let Some(recovered) = degraded_voice_stem_ml_fallback(
audio_path,
&stems.voice,
config,
resolved_engine,
&reason,
ctx,
) {
return DiarizationOutcome::Result(recovered);
}
tracing::warn!(
failure_kind = ?reason.failure_kind,
voice = %stems.voice.display(),
system = %stems.system.display(),
"system stem is silent; leaving primary transcript unlabeled"
);
return DiarizationOutcome::Skipped { reason };
}
tracing::warn!(
voice = %stems.voice.display(),
system = %stems.system.display(),
"system stem is silent outside primary guard; skipping source-aware diarization"
);
}
}
}
let Some(resolved_engine) = resolved_engine else {
return DiarizationOutcome::NotConfigured;
};
match run_diarization_engine(audio_path, config, resolved_engine) {
Some(result) => DiarizationOutcome::Result(result),
None => DiarizationOutcome::NotConfigured,
}
}
pub fn diarize(audio_path: &Path, config: &Config) -> Option<DiarizationResult> {
match diarize_with_context(
audio_path,
config,
DiarizationContext {
purpose: DiarizationPurpose::Auxiliary,
transcript_windows: None,
},
) {
DiarizationOutcome::Result(result) => Some(result),
DiarizationOutcome::Skipped { .. } | DiarizationOutcome::NotConfigured => None,
}
}
pub fn apply_speakers(transcript: &str, result: &DiarizationResult) -> String {
let mut sorted_segments = result.segments.clone();
sorted_segments.sort_by(|a, b| {
a.start
.partial_cmp(&b.start)
.unwrap_or(std::cmp::Ordering::Equal)
});
enum OutputLine {
Attributed {
speaker: String,
ts_str: String,
text: String,
},
Raw(String),
}
let mut lines: Vec<OutputLine> = Vec::new();
let mut unknown_count = 0usize;
let mut matched_count = 0usize;
for line in transcript.lines() {
if let Some(rest) = line.strip_prefix('[') {
if let Some(bracket_end) = rest.find(']') {
let ts_str = &rest[..bracket_end];
let text = rest[bracket_end + 1..].trim();
if let Some(secs) = parse_timestamp(ts_str) {
let speaker =
find_speaker(secs, &sorted_segments, result.from_stems).to_string();
if speaker == "UNKNOWN" {
unknown_count += 1;
} else {
matched_count += 1;
}
lines.push(OutputLine::Attributed {
speaker,
ts_str: ts_str.to_string(),
text: text.to_string(),
});
continue;
}
}
}
lines.push(OutputLine::Raw(line.to_string()));
}
let dominant_speaker = dominant_speaker_label(&sorted_segments);
let first_attr = lines
.iter()
.position(|l| matches!(l, OutputLine::Attributed { .. }));
if let Some(first_idx) = first_attr {
let is_unknown = matches!(&lines[first_idx], OutputLine::Attributed { speaker, .. } if speaker == "UNKNOWN");
if is_unknown {
let next_speaker = lines[first_idx + 1..].iter().find_map(|l| match l {
OutputLine::Attributed { speaker, .. } if speaker != "UNKNOWN" => {
Some(speaker.clone())
}
_ => None,
});
if let Some(resolved) = next_speaker {
if let OutputLine::Attributed { speaker, .. } = &mut lines[first_idx] {
*speaker = resolved;
unknown_count = unknown_count.saturating_sub(1);
matched_count += 1;
}
}
}
}
let all_unknown = !lines.is_empty()
&& lines.iter().all(|line| match line {
OutputLine::Attributed { speaker, .. } => speaker == "UNKNOWN",
OutputLine::Raw(_) => true,
});
if all_unknown {
if let Some(dominant) = dominant_speaker {
for line in &mut lines {
if let OutputLine::Attributed { speaker, .. } = line {
if speaker == "UNKNOWN" {
*speaker = dominant.clone();
unknown_count = unknown_count.saturating_sub(1);
matched_count += 1;
}
}
}
}
}
let mut output = String::new();
for line in &lines {
match line {
OutputLine::Attributed {
speaker,
ts_str,
text,
} => {
output.push_str(&format!("[{} {}] {}\n", speaker, ts_str, text));
}
OutputLine::Raw(raw) => {
output.push_str(raw);
output.push('\n');
}
}
}
if unknown_count > 0 {
tracing::warn!(
unknown = unknown_count,
matched = matched_count,
"speaker attribution results — high unknown count may indicate timestamp/segment mismatch"
);
}
output
}
fn dominant_speaker_label(segments: &[SpeakerSegment]) -> Option<String> {
let mut durations: std::collections::HashMap<&str, f64> = std::collections::HashMap::new();
for seg in segments {
let dur = (seg.end - seg.start).max(0.0);
*durations.entry(seg.speaker.as_str()).or_insert(0.0) += dur;
}
let total: f64 = durations.values().sum();
if total <= f64::EPSILON {
return None;
}
let (label, duration) = durations
.into_iter()
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))?;
if duration / total >= 0.6 {
Some(label.to_string())
} else {
None
}
}
fn find_speaker(time_secs: f64, segments: &[SpeakerSegment], from_stems: bool) -> &str {
let idx = segments.partition_point(|seg| seg.end <= time_secs);
if idx < segments.len() && time_secs >= segments[idx].start && time_secs < segments[idx].end {
return &segments[idx].speaker;
}
let next_tolerance = if from_stems { 2.0 } else { 0.5 };
let prev_tolerance = 0.5;
if idx < segments.len() {
let gap = segments[idx].start - time_secs;
if gap >= 0.0 && gap <= next_tolerance {
return &segments[idx].speaker;
}
}
if idx > 0 {
let prev = &segments[idx - 1];
let gap = time_secs - prev.end;
if gap >= 0.0 && gap <= prev_tolerance {
return &prev.speaker;
}
}
"UNKNOWN"
}
fn parse_timestamp(ts: &str) -> Option<f64> {
let parts: Vec<&str> = ts.split(':').collect();
match parts.len() {
2 => {
let mins: f64 = parts[0].parse().ok()?;
let secs: f64 = parts[1].parse().ok()?;
Some(mins * 60.0 + secs)
}
3 => {
let hours: f64 = parts[0].parse().ok()?;
let mins: f64 = parts[1].parse().ok()?;
let secs: f64 = parts[2].parse().ok()?;
Some(hours * 3600.0 + mins * 60.0 + secs)
}
_ => None,
}
}
#[cfg(feature = "diarize")]
fn diarize_with_pyannote_rs(
audio_path: &Path,
config: &Config,
) -> Result<DiarizationResult, Box<dyn std::error::Error>> {
use pyannote_rs::EmbeddingExtractor;
let model_dir = &config.diarization.model_path;
let seg_model = model_dir.join(SEGMENTATION_MODEL);
let emb_info = embedding_model_for_config(config);
let emb_model = model_dir.join(emb_info.filename);
if !seg_model.exists() {
return Err(format!(
"Segmentation model not found at {}. Run `minutes setup --diarization` to download.",
seg_model.display()
)
.into());
}
if !emb_model.exists() {
return Err(format!(
"Embedding model not found at {}. Run `minutes setup --diarization` to download.",
emb_model.display()
)
.into());
}
let (mut f32_samples, mut i16_samples, sample_rate) = load_audio(audio_path)?;
tracing::info!(
f32_samples = f32_samples.len(),
i16_samples = i16_samples.len(),
sample_rate = sample_rate,
"audio loaded for diarization"
);
let mut speech_segments = segment_speech(&f32_samples, sample_rate, &seg_model)?;
if speech_segments.is_empty() {
let peak = f32_samples.iter().map(|s| s.abs()).fold(0.0f32, f32::max);
const TARGET_PEAK: f32 = 0.5;
if peak > 0.0 && peak < TARGET_PEAK {
let gain = TARGET_PEAK / peak;
tracing::info!(
peak = format!("{:.6}", peak),
gain = format!("{:.1}x", gain),
"no speech detected — retrying with normalized audio"
);
for s in &mut f32_samples {
*s = (*s * gain).clamp(-1.0, 1.0);
}
i16_samples = f32_samples
.iter()
.map(|&s| (s.clamp(-1.0, 1.0) * 32767.0) as i16)
.collect();
speech_segments = segment_speech(&f32_samples, sample_rate, &seg_model)?;
}
}
tracing::info!(
segments = speech_segments.len(),
"speech segmentation complete"
);
let mut extractor = EmbeddingExtractor::new(&emb_model)?;
let threshold = config.diarization.threshold;
let speech_segments = merge_short_segments(speech_segments, sample_rate);
tracing::info!(
segments = speech_segments.len(),
"speech segments after merge"
);
let mut speaker_templates: Vec<(Vec<f32>, usize)> = Vec::new();
let mut seg_speaker_ids: Vec<usize> = Vec::new();
let min_embed_samples = (sample_rate as f64 * 1.5) as usize;
for seg in &speech_segments {
let seg_i16 = &i16_samples[seg.start_sample..seg.end_sample];
if seg_i16.len() < min_embed_samples {
seg_speaker_ids.push(usize::MAX); continue;
}
let raw_embedding: Vec<f32> = extractor.compute(seg_i16)?.collect();
let embedding = l2_normalize(&raw_embedding);
let mut best_id = None;
let mut best_sim = threshold;
for (id, (template, _)) in speaker_templates.iter().enumerate() {
let sim = crate::voice::cosine_similarity(&embedding, template);
if sim > best_sim {
best_sim = sim;
best_id = Some(id);
}
}
let speaker_id = match best_id {
Some(id) => {
let (ref mut template, ref mut count) = speaker_templates[id];
let old_count = *count as f32;
let new_count = old_count + 1.0;
for (i, val) in embedding.iter().enumerate() {
template[i] = (template[i] * old_count + val) / new_count;
}
*count += 1;
id
}
None => {
let id = speaker_templates.len();
speaker_templates.push((embedding, 1));
id
}
};
seg_speaker_ids.push(speaker_id);
}
let merge_threshold = (threshold - 0.05).max(0.3);
let num_templates = speaker_templates.len();
let mut merge_map: Vec<usize> = (0..num_templates).collect();
for i in 0..num_templates {
for j in (i + 1)..num_templates {
if merge_map[j] != j {
continue; }
let ri = merge_map[i]; let sim =
crate::voice::cosine_similarity(&speaker_templates[ri].0, &speaker_templates[j].0);
if sim > merge_threshold {
tracing::info!(
from = j,
into = ri,
similarity = format!("{:.4}", sim),
"merging speaker clusters"
);
merge_map[j] = ri;
}
}
}
for i in 0..num_templates {
let mut root = merge_map[i];
let mut steps = 0;
while merge_map[root] != root && steps < num_templates {
root = merge_map[root];
steps += 1;
}
merge_map[i] = root;
}
let mut canonical_to_label: std::collections::HashMap<usize, String> =
std::collections::HashMap::new();
let mut next_label = 1usize;
for &canonical in &merge_map {
canonical_to_label.entry(canonical).or_insert_with(|| {
let label = format!("SPEAKER_{}", next_label);
next_label += 1;
label
});
}
let mut segments = Vec::new();
let resolved_labels: Vec<Option<String>> = seg_speaker_ids
.iter()
.map(|&raw_id| {
if raw_id == usize::MAX {
None
} else {
let canonical_id = merge_map[raw_id];
Some(canonical_to_label[&canonical_id].clone())
}
})
.collect();
let mut last_known_label: Option<String> = None;
let mut final_labels: Vec<String> = Vec::with_capacity(resolved_labels.len());
for label in &resolved_labels {
if let Some(l) = label {
last_known_label = Some(l.clone());
}
final_labels.push(last_known_label.clone().unwrap_or_else(|| "UNKNOWN".into()));
}
if let Some(first_known) = resolved_labels.iter().find_map(|l| l.as_ref()) {
for label in &mut final_labels {
if label == "UNKNOWN" {
*label = first_known.clone();
} else {
break;
}
}
}
for (idx, seg) in speech_segments.iter().enumerate() {
segments.push(SpeakerSegment {
speaker: final_labels[idx].clone(),
start: seg.start,
end: seg.end,
});
}
let mut speaker_embeddings: std::collections::HashMap<String, Vec<f32>> =
std::collections::HashMap::new();
let mut speaker_total_counts: std::collections::HashMap<String, usize> =
std::collections::HashMap::new();
for (raw_id, (template, count)) in speaker_templates.iter().enumerate() {
let canonical_id = merge_map[raw_id];
let label = canonical_to_label[&canonical_id].clone();
let entry = speaker_embeddings
.entry(label.clone())
.or_insert_with(|| vec![0.0f32; template.len()]);
for (i, val) in template.iter().enumerate() {
entry[i] += val * (*count as f32);
}
*speaker_total_counts.entry(label).or_insert(0) += count;
}
for (label, embedding) in speaker_embeddings.iter_mut() {
let total = *speaker_total_counts.get(label).unwrap_or(&1) as f32;
for val in embedding.iter_mut() {
*val /= total;
}
}
let num_speakers = speaker_embeddings.len();
tracing::info!(
raw_clusters = num_templates,
merged_speakers = num_speakers,
threshold = threshold,
merge_threshold = format!("{:.3}", merge_threshold),
"speaker clustering complete"
);
Ok(DiarizationResult {
segments,
num_speakers,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings,
})
}
#[cfg(feature = "diarize")]
#[derive(Clone)]
struct SpeechSegment {
start: f64,
end: f64,
start_sample: usize,
end_sample: usize,
}
#[cfg(feature = "diarize")]
fn l2_normalize(v: &[f32]) -> Vec<f32> {
let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm == 0.0 {
return v.to_vec();
}
v.iter().map(|x| x / norm).collect()
}
#[cfg(feature = "diarize")]
fn merge_short_segments(segments: Vec<SpeechSegment>, sample_rate: u32) -> Vec<SpeechSegment> {
if segments.is_empty() {
return segments;
}
let max_gap_samples = (sample_rate as f64 * 0.3) as usize; let min_dur_samples = (sample_rate as f64 * 0.5) as usize;
let max_short_gap_samples = (sample_rate as f64 * 1.0) as usize;
let mut merged: Vec<SpeechSegment> = Vec::new();
let mut current = segments[0].clone();
for seg in segments.iter().skip(1) {
let gap = seg.start_sample.saturating_sub(current.end_sample);
let current_dur = current.end_sample.saturating_sub(current.start_sample);
let should_merge = gap <= max_gap_samples
|| (current_dur < min_dur_samples && gap <= max_short_gap_samples);
if should_merge {
current.end = seg.end;
current.end_sample = seg.end_sample;
} else {
merged.push(current);
current = seg.clone();
}
}
merged.push(current);
tracing::debug!(
before = segments.len(),
after = merged.len(),
"merged adjacent speech segments"
);
merged
}
#[cfg(feature = "diarize")]
fn segment_speech(
samples: &[f32],
sample_rate: u32,
model_path: &Path,
) -> Result<Vec<SpeechSegment>, Box<dyn std::error::Error>> {
use ndarray::{Array1, ArrayViewD, Axis, IxDyn};
use ort::session::builder::GraphOptimizationLevel;
use ort::session::Session;
let mut session = Session::builder()?
.with_optimization_level(GraphOptimizationLevel::Level3)?
.with_intra_threads(1)?
.with_inter_threads(1)?
.commit_from_file(model_path)?;
let frame_size: usize = 270;
let frame_start: usize = 721;
let window_size = (sample_rate as usize) * 10;
let mut padded = samples.to_vec();
let remainder = padded.len() % window_size;
if remainder != 0 {
padded.extend(vec![0.0f32; window_size - remainder]);
}
let mut result = Vec::new();
let mut is_speeching = false;
let mut offset = frame_start;
let mut start_offset = 0usize;
for window_start in (0..padded.len()).step_by(window_size) {
let window_end = (window_start + window_size).min(padded.len());
let window = &padded[window_start..window_end];
let array = Array1::from_iter(window.iter().copied());
let array = array.view().insert_axis(Axis(0)).insert_axis(Axis(1));
let inputs = ort::inputs![ort::value::TensorRef::from_array_view(array.into_dyn())
.map_err(|e| format!("tensor prep: {e:?}"))?];
let ort_outs = session.run(inputs)?;
let ort_out = ort_outs
.get("output")
.ok_or("segmentation model missing 'output' tensor")?;
let ort_out = ort_out
.try_extract_tensor::<f32>()
.map_err(|e| format!("tensor extract: {e:?}"))?;
let (shape, data) = ort_out;
let shape_slice: Vec<usize> = (0..shape.len()).map(|i| shape[i] as usize).collect();
let view = ArrayViewD::<f32>::from_shape(IxDyn(&shape_slice), data)
.map_err(|e| format!("ndarray shape: {e}"))?;
for row in view.outer_iter() {
for sub_row in row.axis_iter(Axis(0)) {
let max_index = sub_row
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0);
if max_index != 0 {
if !is_speeching {
start_offset = offset;
is_speeching = true;
}
} else if is_speeching {
let start_secs = start_offset as f64 / sample_rate as f64;
let end_secs = offset as f64 / sample_rate as f64;
let si = start_offset.min(samples.len().saturating_sub(1));
let ei = offset.min(samples.len());
result.push(SpeechSegment {
start: start_secs,
end: end_secs,
start_sample: si,
end_sample: ei,
});
is_speeching = false;
}
offset += frame_size;
}
}
}
if is_speeching {
let start_secs = start_offset as f64 / sample_rate as f64;
let end_secs = offset as f64 / sample_rate as f64;
let si = start_offset.min(samples.len().saturating_sub(1));
let ei = samples.len();
result.push(SpeechSegment {
start: start_secs,
end: end_secs,
start_sample: si,
end_sample: ei,
});
}
Ok(result)
}
#[cfg(feature = "diarize")]
#[allow(clippy::type_complexity)]
fn load_audio(audio_path: &Path) -> Result<(Vec<f32>, Vec<i16>, u32), Box<dyn std::error::Error>> {
use symphonia::core::audio::SampleBuffer;
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
let file = std::fs::File::open(audio_path)?;
let mss = MediaSourceStream::new(Box::new(file), Default::default());
let mut hint = Hint::new();
if let Some(ext) = audio_path.extension().and_then(|e| e.to_str()) {
hint.with_extension(ext);
}
let probed = symphonia::default::get_probe().format(
&hint,
mss,
&FormatOptions::default(),
&MetadataOptions::default(),
)?;
let mut format = probed.format;
let track = format.default_track().ok_or("no audio track found")?;
let track_id = track.id;
let sample_rate = track.codec_params.sample_rate.ok_or("no sample rate")?;
let channels = track.codec_params.channels.map(|c| c.count()).unwrap_or(1);
let mut decoder =
symphonia::default::get_codecs().make(&track.codec_params, &DecoderOptions::default())?;
let mut all_samples: Vec<f32> = Vec::new();
loop {
let packet = match format.next_packet() {
Ok(p) => p,
Err(symphonia::core::errors::Error::IoError(ref e))
if e.kind() == std::io::ErrorKind::UnexpectedEof =>
{
break;
}
Err(e) => return Err(e.into()),
};
if packet.track_id() != track_id {
continue;
}
let decoded = decoder.decode(&packet)?;
let spec = *decoded.spec();
let num_frames = decoded.capacity();
let mut sample_buf = SampleBuffer::<f32>::new(num_frames as u64, spec);
sample_buf.copy_interleaved_ref(decoded);
let samples = sample_buf.samples();
if channels > 1 {
for chunk in samples.chunks(channels) {
let mono: f32 = chunk.iter().sum::<f32>() / channels as f32;
all_samples.push(mono);
}
} else {
all_samples.extend_from_slice(samples);
}
}
let i16_samples: Vec<i16> = all_samples
.iter()
.map(|&s| (s.clamp(-1.0, 1.0) * 32767.0) as i16)
.collect();
Ok((all_samples, i16_samples, sample_rate))
}
fn diarize_with_pyannote(
audio_path: &Path,
) -> Result<DiarizationResult, Box<dyn std::error::Error>> {
let python = find_python()?;
let script = r#"
import json, sys
try:
from pyannote.audio import Pipeline
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization-3.1",
use_auth_token=False)
diarization = pipeline(sys.argv[1])
segments = []
for turn, _, speaker in diarization.itertracks(yield_label=True):
segments.append({"speaker": speaker, "start": turn.start, "end": turn.end})
print(json.dumps(segments))
except ImportError:
print("ERROR: pyannote.audio not installed. Run: pip install pyannote.audio", file=sys.stderr)
sys.exit(1)
except Exception as e:
print(f"ERROR: {e}", file=sys.stderr)
sys.exit(1)
"#;
let output = std::process::Command::new(&python)
.args(["-c", script, audio_path.to_str().unwrap_or("")])
.output()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(format!("pyannote failed: {}", stderr).into());
}
let stdout = String::from_utf8_lossy(&output.stdout);
let segments: Vec<SpeakerSegment> = serde_json::from_str(&stdout)?;
let num_speakers = segments
.iter()
.map(|s| s.speaker.as_str())
.collect::<std::collections::HashSet<_>>()
.len();
Ok(DiarizationResult {
segments,
num_speakers,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::new(), })
}
fn find_python() -> Result<String, Box<dyn std::error::Error>> {
for candidate in &["python3", "python"] {
let result = std::process::Command::new(candidate)
.args(["--version"])
.output();
if let Ok(output) = result {
if output.status.success() {
return Ok(candidate.to_string());
}
}
}
Err("Python not found. Install Python 3 for speaker diarization.".into())
}
#[cfg(test)]
mod tests {
use super::*;
fn write_i16_wav(
path: &Path,
sample_rate: u32,
channels: u16,
frames: usize,
mut sample_for_frame: impl FnMut(usize, u16) -> i16,
) {
let spec = hound::WavSpec {
channels,
sample_rate,
bits_per_sample: 16,
sample_format: hound::SampleFormat::Int,
};
let mut writer = hound::WavWriter::create(path, spec).unwrap();
for frame in 0..frames {
for channel in 0..channels {
writer
.write_sample(sample_for_frame(frame, channel))
.unwrap();
}
}
writer.finalize().unwrap();
}
fn write_active_wav(path: &Path) {
write_i16_wav(path, 1_000, 1, 1_000, |_, _| 3_000);
}
#[test]
fn parse_timestamp_minutes_seconds() {
assert_eq!(parse_timestamp("0:00"), Some(0.0));
assert_eq!(parse_timestamp("1:30"), Some(90.0));
assert_eq!(parse_timestamp("10:05"), Some(605.0));
}
#[test]
fn parse_timestamp_hours() {
assert_eq!(parse_timestamp("1:00:00"), Some(3600.0));
}
#[test]
fn parse_timestamp_invalid() {
assert_eq!(parse_timestamp("abc"), None);
assert_eq!(parse_timestamp(""), None);
}
#[test]
fn stem_has_audio_rejects_large_zero_wav() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("zero.wav");
write_i16_wav(&path, 1_000, 2, 12_000, |_, _| 0);
assert!(!stem_has_audio(&path));
}
#[test]
fn stem_has_audio_accepts_sparse_nonzero_within_probe_window() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("sparse.wav");
write_i16_wav(&path, 1_000, 1, 12_000, |frame, _| {
if (500..1_500).contains(&frame) {
3_000
} else {
0
}
});
assert!(stem_has_audio(&path));
}
#[test]
fn stem_has_audio_detects_speech_after_a_quiet_opening() {
for (sample_rate, channels) in [(1_000, 1), (1_000, 2), (4_410, 1), (4_410, 2)] {
let dir = tempfile::tempdir().unwrap();
let path = dir
.path()
.join(format!("quiet-open-{sample_rate}-{channels}.wav"));
let quiet_frames = sample_rate as usize * (STEM_PROBE_SECS + 3);
let total_frames = quiet_frames + sample_rate as usize * 2;
write_i16_wav(&path, sample_rate, channels, total_frames, |frame, _| {
if frame >= quiet_frames {
12_000
} else {
0
}
});
assert!(
stem_has_audio(&path),
"speech after a quiet opening must be detected for {sample_rate} Hz/{channels} ch (#280)"
);
}
}
#[test]
fn stem_has_audio_detects_a_single_loud_sample() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("single-sample.wav");
let sample_rate = 1_000;
let frames = sample_rate as usize * 2;
write_i16_wav(&path, sample_rate, 1, frames, |frame, _| {
if frame + 1 == frames {
32_000
} else {
0
}
});
assert!(stem_has_audio(&path));
}
#[test]
fn audio_duration_secs_reads_wav_duration() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("duration.wav");
write_i16_wav(&path, 8_000, 1, 12_000, |_, _| 0);
let duration = audio_duration_secs(&path).unwrap();
assert!((duration - 1.5).abs() < 0.01, "duration={duration}");
}
#[test]
fn audio_duration_secs_errors_on_malformed_audio() {
let dir = tempfile::tempdir().unwrap();
let path = dir.path().join("malformed.wav");
std::fs::write(&path, b"not a wav").unwrap();
assert!(audio_duration_secs(&path).is_err());
}
#[test]
fn find_speaker_returns_correct_label() {
let segments = vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 5.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 5.0,
end: 10.0,
},
];
assert_eq!(find_speaker(2.5, &segments, false), "SPEAKER_0");
assert_eq!(find_speaker(7.0, &segments, false), "SPEAKER_1");
assert_eq!(find_speaker(15.0, &segments, false), "UNKNOWN");
}
#[test]
fn find_speaker_gap_fallback_prefers_next_speaker() {
let segments = vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.045,
end: 3.98,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 4.12,
end: 8.5,
},
];
assert_eq!(find_speaker(0.0, &segments, false), "SPEAKER_0");
assert_eq!(find_speaker(4.0, &segments, false), "SPEAKER_1");
assert_eq!(find_speaker(8.6, &segments, false), "SPEAKER_1");
assert_eq!(find_speaker(10.0, &segments, false), "UNKNOWN");
assert_eq!(find_speaker(15.0, &segments, false), "UNKNOWN");
}
#[test]
fn find_speaker_silence_stays_unknown() {
let segments = vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 5.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 10.0,
end: 15.0,
},
];
assert_eq!(find_speaker(7.0, &segments, false), "UNKNOWN");
}
#[test]
fn find_speaker_from_stems_allows_larger_forward_tolerance() {
let segments = vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 5.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 8.8,
end: 10.0,
},
];
assert_eq!(find_speaker(7.0, &segments, false), "UNKNOWN");
assert_eq!(find_speaker(7.0, &segments, true), "SPEAKER_1");
}
#[test]
fn apply_speakers_labels_transcript() {
let transcript = "[0:00] Hello everyone\n[0:05] Thanks for joining\n";
let result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 3.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 3.0,
end: 10.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::new(),
};
let labeled = apply_speakers(transcript, &result);
assert!(labeled.contains("[SPEAKER_0 0:00]"));
assert!(labeled.contains("[SPEAKER_1 0:05]"));
}
#[test]
fn apply_speakers_first_unknown_inherits_next_speaker() {
let transcript = "[0:00] Hello\n[0:03] How are you\n[0:07] Good thanks\n";
let result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 1.5,
end: 5.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 5.0,
end: 10.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::new(),
};
let labeled = apply_speakers(transcript, &result);
assert!(
labeled.contains("[SPEAKER_0 0:00]"),
"first UNKNOWN should inherit next speaker, got: {labeled}"
);
assert!(labeled.contains("[SPEAKER_0 0:03]"));
assert!(labeled.contains("[SPEAKER_1 0:07]"));
}
#[test]
fn apply_speakers_all_unknown_prefers_dominant_speaker() {
let transcript = "[0:00] Short intro line\n";
let result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 1.0,
end: 9.0,
},
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 9.0,
end: 10.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.75,
voice_dominant_ratio: 0.25,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
let labeled = apply_speakers(transcript, &result);
assert!(labeled.contains("[SPEAKER_1 0:00]"));
}
#[test]
fn dominant_speaker_requires_clear_majority() {
let segments = vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 5.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 5.0,
end: 9.0,
},
];
assert_eq!(dominant_speaker_label(&segments), None);
}
#[test]
fn stem_energy_correlation_collapses_to_single_speaker() {
let voice_energy = vec![(0.0, 0.12), (1.0, 0.20), (2.0, 0.18), (3.0, 0.11)];
let system_energy = vec![(0.0, 0.08), (1.0, 0.14), (2.0, 0.13), (3.0, 0.07)];
let result = diarization_from_energy_windows(&voice_energy, &system_energy, 1.0, 0.85)
.expect("correlated stems should still produce diarization");
assert_eq!(result.num_speakers, 1);
assert_eq!(result.segments.len(), 1);
assert_eq!(result.segments[0].speaker, "SPEAKER_0");
assert_eq!(result.segments[0].start, 0.0);
assert_eq!(result.segments[0].end, 4.0);
}
#[test]
fn stem_correlation_threshold_of_one_preserves_remote_label_on_open_speaker_bleed() {
let voice_energy = vec![(0.0, 0.08), (1.0, 0.14), (2.0, 0.12), (3.0, 0.06)];
let system_energy = vec![(0.0, 0.20), (1.0, 0.28), (2.0, 0.24), (3.0, 0.12)];
let collapsed = diarization_from_energy_windows(&voice_energy, &system_energy, 1.0, 0.85)
.expect("default threshold should produce a diarization result");
assert_eq!(collapsed.segments.len(), 1);
assert_eq!(collapsed.segments[0].speaker, "SPEAKER_0");
let preserved = diarization_from_energy_windows(&voice_energy, &system_energy, 1.0, 1.0)
.expect("threshold=1.0 must not suppress diarization, only the collapse");
assert_eq!(preserved.segments[0].speaker, "SPEAKER_1");
}
#[test]
fn stem_energy_distinguishes_two_sources_when_patterns_diverge() {
let voice_energy = vec![(0.0, 0.16), (1.0, 0.14), (2.0, 0.0), (3.0, 0.0)];
let system_energy = vec![(0.0, 0.0), (1.0, 0.0), (2.0, 0.18), (3.0, 0.15)];
let result = diarization_from_energy_windows(&voice_energy, &system_energy, 1.0, 0.85)
.expect("distinct stem patterns should produce diarization");
assert_eq!(result.num_speakers, 2);
assert_eq!(result.segments.len(), 2);
assert_eq!(result.segments[0].speaker, "SPEAKER_0");
assert_eq!(result.segments[1].speaker, "SPEAKER_1");
}
#[test]
fn has_sustained_remote_speech_requires_transcript_aligned_runs() {
let result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 2.0,
end: 4.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 12.0,
end: 14.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 22.0,
end: 24.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.75,
voice_dominant_ratio: 0.25,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
let windows = vec![
TranscriptWindow {
start_secs: 1.0,
end_secs: 5.0,
},
TranscriptWindow {
start_secs: 11.0,
end_secs: 15.0,
},
TranscriptWindow {
start_secs: 21.0,
end_secs: 25.0,
},
];
assert!(has_sustained_remote_speech(&result, Some(&windows)));
assert!(!has_sustained_remote_speech(&result, None));
assert!(!has_sustained_remote_speech(&result, Some(&[])));
}
#[test]
fn has_sustained_remote_speech_filters_chimes_before_thresholds() {
let result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 5.0,
end: 7.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 20.0,
end: 22.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 40.0,
end: 42.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 60.0,
end: 62.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.8,
voice_dominant_ratio: 0.2,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
let one_overlapping_window = vec![TranscriptWindow {
start_secs: 5.5,
end_secs: 6.5,
}];
assert!(!has_sustained_remote_speech(
&result,
Some(&one_overlapping_window)
));
}
#[test]
fn has_sustained_remote_speech_accepts_long_transcript_aligned_remote_audio() {
let result = DiarizationResult {
segments: vec![SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 10.0,
end: 42.0,
}],
num_speakers: 2,
system_dominant_ratio: 0.9,
voice_dominant_ratio: 0.1,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
let windows = vec![TranscriptWindow {
start_secs: 12.0,
end_secs: 20.0,
}];
assert!(has_sustained_remote_speech(&result, Some(&windows)));
}
#[test]
fn single_system_dominant_speaker_relabels_to_voice_when_mic_is_consistently_active() {
let voice_energy = vec![(0.0, 0.020), (1.0, 0.024), (2.0, 0.018), (3.0, 0.022)];
let system_energy = vec![(0.0, 0.050), (1.0, 0.060), (2.0, 0.045), (3.0, 0.055)];
let result = diarization_from_energy_windows(&voice_energy, &system_energy, 1.0, 0.85)
.expect("single dominant system speaker should still produce diarization");
assert_eq!(result.num_speakers, 1);
assert_eq!(result.segments.len(), 1);
assert_eq!(result.segments[0].speaker, "SPEAKER_0");
}
#[test]
fn single_system_dominant_speaker_stays_remote_when_mic_noise_is_uncorrelated() {
let voice_energy = vec![(0.0, 0.020), (1.0, 0.006), (2.0, 0.019), (3.0, 0.007)];
let system_energy = vec![(0.0, 0.050), (1.0, 0.048), (2.0, 0.047), (3.0, 0.051)];
let result = diarization_from_energy_windows(&voice_energy, &system_energy, 1.0, 0.85)
.expect("single dominant system speaker should still produce diarization");
assert_eq!(result.num_speakers, 1);
assert_eq!(result.segments.len(), 1);
assert_eq!(result.segments[0].speaker, "SPEAKER_1");
}
#[test]
fn remap_diarization_labels_rebases_remote_namespace() {
let result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "remote-alex".into(),
start: 0.0,
end: 1.0,
},
SpeakerSegment {
speaker: "remote-sam".into(),
start: 1.0,
end: 2.0,
},
SpeakerSegment {
speaker: "remote-alex".into(),
start: 2.0,
end: 3.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::from([
("remote-alex".to_string(), vec![0.1, 0.2]),
("remote-sam".to_string(), vec![0.3, 0.4]),
]),
};
let remapped = remap_diarization_labels(&result, 1);
assert_eq!(remapped.num_speakers, 2);
assert_eq!(remapped.segments[0].speaker, "SPEAKER_1");
assert_eq!(remapped.segments[1].speaker, "SPEAKER_2");
assert_eq!(remapped.segments[2].speaker, "SPEAKER_1");
assert!(remapped.speaker_embeddings.contains_key("SPEAKER_1"));
assert!(remapped.speaker_embeddings.contains_key("SPEAKER_2"));
}
#[test]
fn merge_remote_diarization_into_stem_result_keeps_local_and_splits_remote_windows() {
let stem_result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 2.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 2.0,
end: 6.0,
},
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 6.0,
end: 7.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 7.0,
end: 10.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.3,
voice_dominant_ratio: 0.7,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
let remote_result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_2".into(),
start: 2.1,
end: 3.6,
},
SpeakerSegment {
speaker: "SPEAKER_3".into(),
start: 3.6,
end: 5.8,
},
SpeakerSegment {
speaker: "SPEAKER_3".into(),
start: 7.2,
end: 8.4,
},
SpeakerSegment {
speaker: "SPEAKER_2".into(),
start: 8.4,
end: 9.9,
},
],
num_speakers: 2,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::from([
("SPEAKER_2".to_string(), vec![0.1]),
("SPEAKER_3".to_string(), vec![0.2]),
]),
};
let merged = merge_remote_diarization_into_stem_result(&stem_result, &remote_result);
assert_eq!(merged.num_speakers, 4);
assert!(!merged.from_stems);
assert!(merged.source_aware);
assert_eq!(
merged
.segments
.iter()
.map(|segment| (segment.speaker.as_str(), segment.start, segment.end))
.collect::<Vec<_>>(),
vec![
("SPEAKER_0", 0.0, 2.0),
("SPEAKER_1", 2.0, 2.1),
("SPEAKER_2", 2.1, 3.6),
("SPEAKER_3", 3.6, 5.8),
("SPEAKER_1", 5.8, 6.0),
("SPEAKER_0", 6.0, 7.0),
("SPEAKER_1", 7.0, 7.2),
("SPEAKER_3", 7.2, 8.4),
("SPEAKER_2", 8.4, 9.9),
("SPEAKER_1", 9.9, 10.0),
]
);
assert!(merged.speaker_embeddings.contains_key("SPEAKER_2"));
assert!(merged.speaker_embeddings.contains_key("SPEAKER_3"));
}
#[test]
fn has_meaningful_remote_structure_rejects_noise_but_accepts_one_remote_speaker() {
let weak_remote = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 2.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 2.0,
end: 2.4,
},
SpeakerSegment {
speaker: "SPEAKER_2".into(),
start: 2.4,
end: 2.8,
},
],
num_speakers: 3,
system_dominant_ratio: 0.7,
voice_dominant_ratio: 0.3,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
let single_remote = DiarizationResult {
segments: vec![SpeakerSegment {
speaker: "SPEAKER_2".into(),
start: 1.0,
end: 2.2,
}],
num_speakers: 1,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::new(),
};
let strong_remote = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 1.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 1.0,
end: 1.7,
},
SpeakerSegment {
speaker: "SPEAKER_2".into(),
start: 1.7,
end: 2.4,
},
],
num_speakers: 3,
system_dominant_ratio: 0.7,
voice_dominant_ratio: 0.3,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
assert!(!has_meaningful_remote_structure(&weak_remote));
assert!(has_meaningful_remote_structure(&single_remote));
assert!(has_meaningful_remote_structure(&strong_remote));
}
#[test]
fn merged_system_stem_label_is_useful_even_without_more_speakers() {
let stem_result = DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 2.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 2.0,
end: 5.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.6,
voice_dominant_ratio: 0.4,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
};
let remote_result = DiarizationResult {
segments: vec![SpeakerSegment {
speaker: "SPEAKER_2".into(),
start: 2.0,
end: 5.0,
}],
num_speakers: 1,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::from([(
"SPEAKER_2".to_string(),
vec![0.2],
)]),
};
let merged = merge_remote_diarization_into_stem_result(&stem_result, &remote_result);
assert_eq!(merged.num_speakers, 2);
assert!(has_meaningful_system_stem_labels(&merged));
assert_eq!(
merged
.segments
.iter()
.map(|segment| (segment.speaker.as_str(), segment.start, segment.end))
.collect::<Vec<_>>(),
vec![("SPEAKER_0", 0.0, 2.0), ("SPEAKER_2", 2.0, 5.0)]
);
}
#[test]
fn diarize_returns_none_when_disabled() {
let config = Config::default(); let result = diarize(Path::new("/fake.wav"), &config);
assert!(result.is_none());
}
#[test]
fn apply_confirmed_names_rewrites_high_confidence() {
let transcript = "[SPEAKER_1 0:00] Hello\n[SPEAKER_2 0:05] Hi there\n";
let attributions = vec![
SpeakerAttribution {
speaker_label: "SPEAKER_1".into(),
name: "Mat".into(),
confidence: Confidence::High,
source: AttributionSource::Manual,
},
SpeakerAttribution {
speaker_label: "SPEAKER_2".into(),
name: "Alex".into(),
confidence: Confidence::Medium,
source: AttributionSource::Deterministic,
},
];
let result = apply_confirmed_names(transcript, &attributions);
assert!(result.contains("[Mat 0:00]"));
assert!(result.contains("[SPEAKER_2 0:05]"));
}
#[test]
fn apply_confirmed_names_no_high_is_noop() {
let transcript = "[SPEAKER_1 0:00] Hello\n";
let result = apply_confirmed_names(
transcript,
&[SpeakerAttribution {
speaker_label: "SPEAKER_1".into(),
name: "Mat".into(),
confidence: Confidence::Medium,
source: AttributionSource::Deterministic,
}],
);
assert_eq!(result, transcript);
}
#[test]
fn apply_confirmed_names_keeps_non_speech_events_anonymous() {
let transcript =
"[SPEAKER_1 0:00] [beep]\n[SPEAKER_1 0:01] Hello there\n[SPEAKER_1 0:02] [typing]\n";
let result = apply_confirmed_names(
transcript,
&[SpeakerAttribution {
speaker_label: "SPEAKER_1".into(),
name: "Mat".into(),
confidence: Confidence::High,
source: AttributionSource::Manual,
}],
);
assert!(result.contains("[SPEAKER_1 0:00] [beep]"));
assert!(result.contains("[Mat 0:01] Hello there"));
assert!(result.contains("[SPEAKER_1 0:02] [typing]"));
}
#[test]
fn speaker_attribution_roundtrips_yaml() {
let attr = SpeakerAttribution {
speaker_label: "SPEAKER_2".into(),
name: "Sarah".into(),
confidence: Confidence::High,
source: AttributionSource::MlBleedDegraded,
};
let yaml = serde_yaml::to_string(&attr).unwrap();
assert!(yaml.contains("ml-bleed-degraded"));
let parsed: SpeakerAttribution = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(parsed.confidence, Confidence::High);
assert_eq!(parsed.source, AttributionSource::MlBleedDegraded);
let recovered: AttributionSource = serde_yaml::from_str("stem-recovery").unwrap();
assert_eq!(recovered, AttributionSource::StemRecovery);
}
#[test]
fn diarize_returns_none_for_unknown_engine() {
let mut config = Config::default();
config.diarization.engine = "nonexistent".into();
let result = diarize(Path::new("/fake.wav"), &config);
assert!(result.is_none());
}
#[test]
fn models_installed_returns_false_when_missing() {
let dir = tempfile::tempdir().unwrap();
let mut config = Config::default();
config.diarization.model_path = dir.path().join("missing-models");
assert!(!models_installed(&config));
}
#[test]
fn discover_stem_plan_prefers_full_stems_when_both_are_present() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.mov");
let voice = dir.path().join("call.voice.wav");
let system = dir.path().join("call.system.wav");
std::fs::write(&audio, b"mov").unwrap();
write_active_wav(&voice);
write_active_wav(&system);
let plan = discover_stem_plan(&audio);
assert_eq!(
plan,
Some(SourceAwareDiarizationPlan::FullStems(StemPaths {
voice,
system,
}))
);
}
#[test]
fn discover_stem_plan_uses_system_only_when_voice_stem_is_missing() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.mov");
let system = dir.path().join("call.system.wav");
std::fs::write(&audio, b"mov").unwrap();
write_active_wav(&system);
let plan = discover_stem_plan(&audio);
assert_eq!(
plan,
Some(SourceAwareDiarizationPlan::SystemStemOnly(system))
);
}
#[test]
fn system_stem_only_falls_back_to_full_audio_when_engine_fails() {
let config = Config::default();
let system_stem = Path::new("/tmp/call.system.wav");
let audio = Path::new("/tmp/call.mov");
let full_audio_result = DiarizationResult {
segments: vec![SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 1.0,
}],
num_speakers: 1,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: false,
source_aware: false,
speaker_embeddings: std::collections::HashMap::new(),
};
let mut attempted_paths = Vec::new();
let result = diarize_system_stem_with_full_audio_fallback(
system_stem,
audio,
&config,
"test-engine",
|path, _config, _engine| {
attempted_paths.push(path.to_path_buf());
if path == audio {
Some(full_audio_result.clone())
} else {
None
}
},
);
assert_eq!(
attempted_paths,
vec![system_stem.to_path_buf(), audio.to_path_buf()]
);
assert_eq!(result.unwrap().segments[0].speaker, "SPEAKER_0");
}
#[test]
fn discover_stem_plan_rejects_voice_only_partial_stems() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.mov");
let voice = dir.path().join("call.voice.wav");
std::fs::write(&audio, b"mov").unwrap();
write_active_wav(&voice);
let plan = discover_stem_plan(&audio);
assert_eq!(plan, None);
}
#[test]
fn discover_stem_plan_detects_existing_silent_system_stem() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.wav");
let voice = dir.path().join("call.voice.wav");
let system = dir.path().join("call.system.wav");
std::fs::write(&audio, b"wav").unwrap();
write_active_wav(&voice);
write_i16_wav(&system, 1_000, 1, 1_000, |_, _| 0);
let plan = discover_stem_plan(&audio);
assert_eq!(
plan,
Some(SourceAwareDiarizationPlan::SilentSystemStem(StemPaths {
voice,
system,
}))
);
}
#[test]
fn primary_sparse_stem_result_skips_without_unknown_spam_and_sets_health() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.wav");
let voice = dir.path().join("call.voice.wav");
let system = dir.path().join("call.system.wav");
let sample_rate = 1_000;
let frames = 61_000;
write_i16_wav(&audio, sample_rate, 1, frames, |_, _| 0);
write_i16_wav(&voice, sample_rate, 1, frames, |_, _| 3_000);
write_i16_wav(&system, sample_rate, 1, frames, |frame, _| {
if frame < sample_rate as usize {
3_000
} else {
0
}
});
let config = Config::default();
let transcript = "[0:00] First line\n[0:10] Second line\n";
let windows = vec![
TranscriptWindow {
start_secs: 0.0,
end_secs: 8.0,
},
TranscriptWindow {
start_secs: 10.0,
end_secs: 18.0,
},
];
let outcome = diarize_with_context(
&audio,
&config,
DiarizationContext {
purpose: DiarizationPurpose::PrimaryMeeting,
transcript_windows: Some(&windows),
},
);
let DiarizationOutcome::Skipped { reason } = outcome else {
panic!("expected degraded primary capture to skip");
};
assert_eq!(reason.failure_kind, FailureKind::Sparse);
assert_eq!(reason.capture_source, CaptureSource::System);
let health: crate::markdown::RecordingHealth = reason.into();
assert_eq!(health.capture_warnings.len(), 1);
assert_eq!(health.capture_warnings[0].kind, FailureKind::Sparse);
assert_eq!(health.capture_warnings[0].source, CaptureSource::System);
assert_eq!(
health.diarization_path,
Some(crate::markdown::DiarizationPath::None)
);
assert!(!transcript.contains("[UNKNOWN"));
assert!(!transcript.contains("[SPEAKER_0"));
}
#[test]
fn primary_zero_system_stem_skips_without_unknown_spam_and_sets_health() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.wav");
let voice = dir.path().join("call.voice.wav");
let system = dir.path().join("call.system.wav");
let sample_rate = 1_000;
let frames = 61_000;
write_i16_wav(&audio, sample_rate, 1, frames, |_, _| 0);
write_i16_wav(&voice, sample_rate, 1, frames, |_, _| 3_000);
write_i16_wav(&system, sample_rate, 1, frames, |_, _| 0);
let config = Config::default();
let transcript = "[0:00] First line\n[0:10] Second line\n";
let windows = vec![
TranscriptWindow {
start_secs: 0.0,
end_secs: 8.0,
},
TranscriptWindow {
start_secs: 10.0,
end_secs: 18.0,
},
];
let outcome = diarize_with_context(
&audio,
&config,
DiarizationContext {
purpose: DiarizationPurpose::PrimaryMeeting,
transcript_windows: Some(&windows),
},
);
let DiarizationOutcome::Skipped { reason } = outcome else {
panic!("expected zero-system primary capture to skip");
};
assert_eq!(reason.failure_kind, FailureKind::Silent);
assert_eq!(reason.capture_source, CaptureSource::System);
assert_eq!(reason.system_active_ratio, Some(0.0));
assert_eq!(reason.observed_signal.max_rms, 0.0);
let health: crate::markdown::RecordingHealth = reason.into();
assert_eq!(health.capture_warnings.len(), 1);
assert_eq!(health.capture_warnings[0].kind, FailureKind::Silent);
assert_eq!(health.capture_warnings[0].source, CaptureSource::System);
assert_eq!(
health.diarization_path,
Some(crate::markdown::DiarizationPath::None)
);
assert!(!transcript.contains("[UNKNOWN"));
assert!(!transcript.contains("[SPEAKER_0"));
}
#[test]
fn degraded_voice_stem_ml_fallback_marks_multi_speaker_result_backend_agnostic_and_degraded() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.wav");
let voice = dir.path().join("call.voice.wav");
let system = dir.path().join("call.system.wav");
let sample_rate = 1_000;
let frames = 121_000;
write_i16_wav(&audio, sample_rate, 1, frames, |_, _| 0);
write_i16_wav(&voice, sample_rate, 1, frames, |_, _| 3_000);
write_i16_wav(&system, sample_rate, 1, frames, |_, _| 0);
let reason = silent_system_stem_degraded_capture(&system);
let config = Config::default();
let windows = vec![TranscriptWindow {
start_secs: 0.0,
end_secs: 8.0,
}];
let mut attempted_paths = Vec::new();
let recovered = degraded_voice_stem_ml_fallback_with_runner(
&audio,
&voice,
&config,
Some("test-engine"),
&reason,
DiarizationContext {
purpose: DiarizationPurpose::PrimaryMeeting,
transcript_windows: Some(&windows),
},
|path, _config, _engine| {
attempted_paths.push(path.to_path_buf());
Some(DiarizationResult {
segments: vec![
SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 10.0,
},
SpeakerSegment {
speaker: "SPEAKER_1".into(),
start: 10.0,
end: 20.0,
},
],
num_speakers: 2,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
})
},
)
.expect("expected degraded capture to recover through voice-stem ML");
assert_eq!(attempted_paths, vec![voice]);
assert!(!recovered.from_stems);
assert!(!recovered.source_aware);
assert_eq!(recovered.num_speakers, 2);
assert_eq!(recovered.degraded_capture, Some(reason));
}
#[test]
fn degraded_voice_stem_ml_fallback_rejects_single_speaker_result() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.wav");
let voice = dir.path().join("call.voice.wav");
let system = dir.path().join("call.system.wav");
let sample_rate = 1_000;
let frames = 121_000;
write_i16_wav(&audio, sample_rate, 1, frames, |_, _| 0);
write_i16_wav(&voice, sample_rate, 1, frames, |_, _| 3_000);
write_i16_wav(&system, sample_rate, 1, frames, |_, _| 0);
let reason = silent_system_stem_degraded_capture(&system);
let config = Config::default();
let mut attempted = false;
let recovered = degraded_voice_stem_ml_fallback_with_runner(
&audio,
&voice,
&config,
Some("test-engine"),
&reason,
DiarizationContext {
purpose: DiarizationPurpose::PrimaryMeeting,
transcript_windows: None,
},
|_path, _config, _engine| {
attempted = true;
Some(DiarizationResult {
segments: vec![SpeakerSegment {
speaker: "SPEAKER_0".into(),
start: 0.0,
end: 20.0,
}],
num_speakers: 1,
system_dominant_ratio: 0.0,
voice_dominant_ratio: 0.0,
degraded_capture: None,
from_stems: true,
source_aware: true,
speaker_embeddings: std::collections::HashMap::new(),
})
},
);
assert!(attempted);
assert!(recovered.is_none());
}
#[test]
fn degraded_voice_stem_ml_fallback_respects_two_minute_floor() {
let dir = tempfile::tempdir().unwrap();
let audio = dir.path().join("call.wav");
let voice = dir.path().join("call.voice.wav");
let system = dir.path().join("call.system.wav");
let sample_rate = 1_000;
let frames = 90_000;
write_i16_wav(&audio, sample_rate, 1, frames, |_, _| 0);
write_i16_wav(&voice, sample_rate, 1, frames, |_, _| 3_000);
write_i16_wav(&system, sample_rate, 1, frames, |_, _| 0);
let reason = silent_system_stem_degraded_capture(&system);
let config = Config::default();
let mut attempted = false;
let recovered = degraded_voice_stem_ml_fallback_with_runner(
&audio,
&voice,
&config,
Some("test-engine"),
&reason,
DiarizationContext {
purpose: DiarizationPurpose::PrimaryMeeting,
transcript_windows: None,
},
|_path, _config, _engine| {
attempted = true;
Some(DiarizationResult::default())
},
);
assert!(recovered.is_none());
assert!(!attempted);
}
#[test]
fn config_recognizes_pyannote_rs_engine() {
let mut config = Config::default();
config.diarization.engine = "pyannote-rs".into();
assert_eq!(config.diarization.engine, "pyannote-rs");
assert_eq!(config.diarization.threshold, 0.4);
}
#[cfg(feature = "diarize")]
#[test]
fn l2_normalize_unit_vector() {
let v = vec![3.0f32, 4.0];
let n = l2_normalize(&v);
let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!(
(norm - 1.0).abs() < 1e-6,
"expected unit length, got {}",
norm
);
assert!((n[0] - 0.6).abs() < 1e-6);
assert!((n[1] - 0.8).abs() < 1e-6);
}
#[cfg(feature = "diarize")]
#[test]
fn l2_normalize_zero_vector() {
let v = vec![0.0f32; 5];
let n = l2_normalize(&v);
assert_eq!(n, v, "zero vector should be returned as-is");
}
#[cfg(feature = "diarize")]
#[test]
fn l2_normalize_single_element() {
let v = vec![7.0f32];
let n = l2_normalize(&v);
assert!((n[0] - 1.0).abs() < 1e-6);
}
#[cfg(feature = "diarize")]
#[test]
fn l2_normalize_negative_values() {
let v = vec![-3.0f32, 4.0];
let n = l2_normalize(&v);
let norm: f32 = n.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 1e-6);
assert!(n[0] < 0.0, "sign should be preserved");
}
#[cfg(feature = "diarize")]
fn make_seg(start_s: f64, end_s: f64, sr: u32) -> SpeechSegment {
SpeechSegment {
start: start_s,
end: end_s,
start_sample: (start_s * sr as f64) as usize,
end_sample: (end_s * sr as f64) as usize,
}
}
#[cfg(feature = "diarize")]
#[test]
fn merge_short_segments_empty_input() {
let result = merge_short_segments(vec![], 16000);
assert!(result.is_empty());
}
#[cfg(feature = "diarize")]
#[test]
fn merge_short_segments_single_segment() {
let segs = vec![make_seg(0.0, 2.0, 16000)];
let result = merge_short_segments(segs, 16000);
assert_eq!(result.len(), 1);
assert!((result[0].start - 0.0).abs() < 1e-6);
assert!((result[0].end - 2.0).abs() < 1e-6);
}
#[cfg(feature = "diarize")]
#[test]
fn merge_short_segments_merges_small_gaps() {
let segs = vec![make_seg(0.0, 1.0, 16000), make_seg(1.2, 2.0, 16000)];
let result = merge_short_segments(segs, 16000);
assert_eq!(result.len(), 1);
assert!((result[0].end - 2.0).abs() < 1e-6);
}
#[cfg(feature = "diarize")]
#[test]
fn merge_short_segments_preserves_large_gaps() {
let segs = vec![make_seg(0.0, 1.0, 16000), make_seg(3.0, 4.0, 16000)];
let result = merge_short_segments(segs, 16000);
assert_eq!(result.len(), 2);
}
#[cfg(feature = "diarize")]
#[test]
fn merge_short_segments_short_segment_respects_gap_ceiling() {
let segs = vec![make_seg(0.0, 0.3, 16000), make_seg(1.8, 3.0, 16000)];
let result = merge_short_segments(segs, 16000);
assert_eq!(
result.len(),
2,
"short segment should not absorb across >1s gap"
);
}
#[cfg(feature = "diarize")]
#[test]
fn merge_short_segments_short_segment_merges_within_ceiling() {
let segs = vec![make_seg(0.0, 0.3, 16000), make_seg(1.1, 2.0, 16000)];
let result = merge_short_segments(segs, 16000);
assert_eq!(
result.len(),
1,
"short segment should absorb within 1s ceiling"
);
}
#[cfg(feature = "diarize")]
#[test]
fn merge_short_segments_all_below_min_duration() {
let segs = vec![
make_seg(0.0, 0.1, 16000),
make_seg(0.2, 0.3, 16000),
make_seg(0.4, 0.5, 16000),
make_seg(3.5, 3.6, 16000),
];
let result = merge_short_segments(segs, 16000);
assert_eq!(
result.len(),
2,
"chain of short segments should merge, but not across 3s gap"
);
assert!((result[0].end - 0.5).abs() < 1e-6);
assert!((result[1].start - 3.5).abs() < 1e-6);
}
}