#[cfg(feature = "sortformer")]
use hound;
#[cfg(feature = "sortformer")]
use parakeet_rs::sortformer::{DiarizationConfig, Sortformer};
#[cfg(feature = "sortformer")]
use parakeet_rs::{TimestampMode, Transcriber};
#[cfg(feature = "sortformer")]
use std::env;
#[cfg(feature = "sortformer")]
use std::time::Instant;
#[allow(unreachable_code)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
#[cfg(not(feature = "sortformer"))]
{
eprintln!("Error: This example requires the 'sortformer' feature.");
eprintln!(
"Please run with: cargo run --example diarization --features sortformer <audio.wav>"
);
return Err("sortformer feature not enabled".into());
}
#[cfg(feature = "sortformer")]
{
let start_time = Instant::now();
let args: Vec<String> = env::args().collect();
let audio_path = args.get(1)
.expect("Please specify audio file: cargo run --example diarization --features sortformer <audio.wav>");
println!("{}", "=".repeat(80));
println!("Step 1/3: Loading audio...");
let mut reader = hound::WavReader::open(audio_path)?;
let spec = reader.spec();
let audio: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => reader.samples::<f32>().collect::<Result<Vec<_>, _>>()?,
hound::SampleFormat::Int => reader
.samples::<i16>()
.map(|s| s.map(|s| s as f32 / 32768.0))
.collect::<Result<Vec<_>, _>>()?,
};
let duration = audio.len() as f32 / spec.sample_rate as f32 / spec.channels as f32;
println!(
"Loaded {} samples ({} Hz, {} channels, {:.1}s)",
audio.len(),
spec.sample_rate,
spec.channels,
duration
);
println!("{}", "=".repeat(80));
println!("Step 2/3: Performing speaker diarization with Sortformer v2 (streaming)...");
let mut sortformer = Sortformer::with_config(
"diar_streaming_sortformer_4spk-v2.onnx",
None, DiarizationConfig::callhome(),
)?;
println!(
" Config: chunk_len={}, fifo_len={}, spkcache_len={}, right_context={}",
sortformer.chunk_len, sortformer.fifo_len, sortformer.spkcache_len, sortformer.right_context
);
println!(" Latency: {:.2}s", sortformer.latency());
let speaker_segments =
sortformer.diarize(audio.clone(), spec.sample_rate, spec.channels)?;
println!(
"Found {} speaker segments from Sortformer",
speaker_segments.len()
);
println!("\nRaw diarization segments:");
for seg in &speaker_segments {
println!(
" [{:06.2}s - {:06.2}s] Speaker {}",
seg.start as f64 / 16_000.0,
seg.end as f64 / 16_000.0,
seg.speaker_id
);
}
println!("\n{}", "=".repeat(80));
println!("Step 3/3: Transcribing with Parakeet-TDT and attributing speakers...\n");
let mut parakeet = parakeet_rs::ParakeetTDT::from_pretrained("./tdt", None)?;
if let Ok(result) = parakeet.transcribe_samples(
audio,
spec.sample_rate,
spec.channels,
Some(TimestampMode::Sentences),
) {
for segment in &result.tokens {
let speaker = speaker_segments
.iter()
.filter_map(|s| {
let s_start = s.start as f32 / 16_000.0;
let s_end = s.end as f32 / 16_000.0;
let overlap_start = segment.start.max(s_start);
let overlap_end = segment.end.min(s_end);
let overlap = (overlap_end - overlap_start).max(0.0);
if overlap > 0.0 {
Some((s.speaker_id, overlap))
} else {
None
}
})
.max_by(|a, b| a.1.partial_cmp(&b.1).unwrap())
.map(|(id, _)| format!("Speaker {}", id))
.unwrap_or_else(|| "UNKNOWN".to_string());
println!(
"[{:.2}s - {:.2}s] {}: {}",
segment.start, segment.end, speaker, segment.text
);
}
}
println!("\n{}", "=".repeat(80));
let elapsed = start_time.elapsed();
println!(
"\n✓ Diarization and transcription completed in {:.2}s",
elapsed.as_secs_f32()
);
println!("• UNKNOWN: Segments where no speaker was detected by Sortformer");
println!("• Config: callhome v2 (onset=0.641, offset=0.561, min_on=0.511, min_off=0.296)");
Ok(())
}
#[cfg(not(feature = "sortformer"))]
unreachable!()
}