1use std::io::IsTerminal;
11use std::io::Write;
12use std::path::{Path, PathBuf};
13
14use anyhow::{bail, Context, Result};
15use clap::Parser;
16
17use crate::voice::models::SPEAKER_WESPEAKER_EN;
18use crate::voice::{
19 cosine, create_default_transcriber, detect_format, render_jsonl, render_markdown, speaker_file,
20 EnrolledSpeaker, OutputFormat, TranscriptEvent, VecAudioInput, VoiceOpts, WespeakerEmbedder,
21 MIN_EMBED_SAMPLES,
22};
23
24const DEFAULT_CHUNK_SAMPLES: usize = 1024;
29
30pub const DEFAULT_SPEAKER_THRESHOLD: f32 = 0.5;
39
40#[derive(Parser)]
47pub struct TranscribeCommand {
48 pub wav: PathBuf,
51
52 #[arg(long)]
55 pub backend: Option<String>,
56
57 #[arg(long)]
61 pub model: Option<PathBuf>,
62
63 #[arg(long, value_enum)]
65 pub format: Option<OutputFormatArg>,
66
67 #[arg(long)]
71 pub speaker: Option<String>,
72
73 #[arg(long)]
76 pub threshold: Option<f32>,
77
78 #[arg(long)]
83 pub speaker_model: Option<PathBuf>,
84}
85
86#[derive(Clone, Copy, Debug, clap::ValueEnum)]
88#[value(rename_all = "lowercase")]
89pub enum OutputFormatArg {
90 Jsonl,
92 Md,
94}
95
96impl From<OutputFormatArg> for OutputFormat {
97 fn from(value: OutputFormatArg) -> Self {
98 match value {
99 OutputFormatArg::Jsonl => Self::Jsonl,
100 OutputFormatArg::Md => Self::Md,
101 }
102 }
103}
104
105impl TranscribeCommand {
106 pub fn execute(self) -> Result<()> {
113 let format = detect_format(
114 self.format.map(OutputFormat::from),
115 std::io::stdout().is_terminal(),
116 );
117 let mut out = std::io::stdout().lock();
118 self.run(&mut out, format)
119 }
120
121 fn run<W: Write>(self, w: &mut W, format: OutputFormat) -> Result<()> {
127 let speaker_filter = self
128 .speaker
129 .as_deref()
130 .map(|name| {
131 SpeakerFilter::load(
132 name,
133 self.speaker_model.as_deref(),
134 self.threshold.unwrap_or(DEFAULT_SPEAKER_THRESHOLD),
135 &self.wav,
136 )
137 })
138 .transpose()?;
139 let opts = VoiceOpts {
140 backend: self.backend,
141 model: self.model,
142 };
143 let transcriber = create_default_transcriber(&opts)?;
144 let input = VecAudioInput::from_wav_path(&self.wav, DEFAULT_CHUNK_SAMPLES)?;
145 let stream = transcriber.transcribe(Box::new(input))?;
146
147 let events: Vec<Result<TranscriptEvent>> = stream.collect();
151 let filtered: Vec<Result<TranscriptEvent>> = match &speaker_filter {
152 Some(f) => events
153 .into_iter()
154 .filter_map(|ev| f.transform(ev))
155 .collect(),
156 None => events,
157 };
158
159 match format {
160 OutputFormat::Jsonl => render_jsonl(filtered, w)?,
161 OutputFormat::Md => render_markdown(filtered, w)?,
162 }
163 w.flush()?;
164 Ok(())
165 }
166}
167
168struct SpeakerFilter {
171 name: String,
172 enrolled: EnrolledSpeaker,
173 embedder: WespeakerEmbedder,
174 pcm: Vec<i16>,
175 threshold: f32,
176}
177
178impl SpeakerFilter {
179 fn load(name: &str, speaker_model: Option<&Path>, threshold: f32, wav: &Path) -> Result<Self> {
180 let enrolled_path = speaker_file(name)?;
181 let enrolled = EnrolledSpeaker::load(&enrolled_path).with_context(|| {
182 format!(
183 "load enrolled speaker {} from {}",
184 name,
185 enrolled_path.display()
186 )
187 })?;
188 let dir = SPEAKER_WESPEAKER_EN.resolve_dir(speaker_model)?;
189 SPEAKER_WESPEAKER_EN.ensure_present(&dir)?;
190 let model_path = dir.join(SPEAKER_WESPEAKER_EN.required_files[0]);
191 let embedder = WespeakerEmbedder::new(&model_path)?;
192 let pcm = read_wav_pcm_16k_mono(wav)?;
193 Ok(Self {
194 name: name.to_string(),
195 enrolled,
196 embedder,
197 pcm,
198 threshold,
199 })
200 }
201
202 fn transform(&self, ev: Result<TranscriptEvent>) -> Option<Result<TranscriptEvent>> {
207 let ev = match ev {
208 Ok(ev) => ev,
209 err @ Err(_) => return Some(err),
210 };
211 match ev {
212 TranscriptEvent::Final {
213 event_id,
214 text,
215 start,
216 end,
217 confidence,
218 words,
219 speaker: _,
220 revisable,
221 } => {
222 let s = (start.as_secs_f64() * 16_000.0) as usize;
223 let e = (end.as_secs_f64() * 16_000.0) as usize;
224 let lo = s.min(self.pcm.len());
225 let hi = e.min(self.pcm.len());
226 let window = &self.pcm[lo..hi.max(lo)];
227 if window.len() < MIN_EMBED_SAMPLES {
228 return None;
230 }
231 let emb = match self.embedder.embed(window) {
232 Ok(v) => v,
233 Err(err) => return Some(Err(err)),
234 };
235 if cosine(&emb, &self.enrolled.vector) >= self.threshold {
236 Some(Ok(TranscriptEvent::Final {
237 event_id,
238 text,
239 start,
240 end,
241 confidence,
242 words,
243 speaker: Some(self.name.clone()),
244 revisable,
245 }))
246 } else {
247 None
248 }
249 }
250 other => Some(Ok(other)),
251 }
252 }
253}
254
255fn read_wav_pcm_16k_mono(path: &Path) -> Result<Vec<i16>> {
262 let mut reader = hound::WavReader::open(path)
263 .with_context(|| format!("open WAV at {} for speaker filter", path.display()))?;
264 let spec = reader.spec();
265 if spec.sample_rate != 16_000
266 || spec.channels != 1
267 || spec.bits_per_sample != 16
268 || spec.sample_format != hound::SampleFormat::Int
269 {
270 bail!(
271 "WAV at {} must be 16 kHz mono 16-bit PCM for --speaker filtering",
272 path.display()
273 );
274 }
275 reader
276 .samples::<i16>()
277 .collect::<Result<Vec<_>, _>>()
278 .with_context(|| format!("decode PCM samples from {}", path.display()))
279}
280
281#[cfg(test)]
282#[allow(clippy::unwrap_used, clippy::expect_used)]
283mod tests {
284 use super::*;
285
286 use clap::Parser;
287
288 #[derive(Parser)]
289 struct TestCli {
290 #[command(flatten)]
291 transcribe: TranscribeCommand,
292 }
293
294 #[test]
295 fn parses_required_wav_only() {
296 let cli = TestCli::try_parse_from(["test", "/tmp/x.wav"]).unwrap();
297 assert_eq!(cli.transcribe.wav.to_str().unwrap(), "/tmp/x.wav");
298 assert!(cli.transcribe.backend.is_none());
299 assert!(cli.transcribe.model.is_none());
300 assert!(cli.transcribe.format.is_none());
301 }
302
303 #[test]
304 fn parses_model_flag() {
305 let cli =
306 TestCli::try_parse_from(["test", "/tmp/x.wav", "--model", "/opt/whisper"]).unwrap();
307 assert_eq!(
308 cli.transcribe.model.as_deref().and_then(|p| p.to_str()),
309 Some("/opt/whisper")
310 );
311 }
312
313 #[test]
314 fn parses_all_flags() {
315 let cli = TestCli::try_parse_from([
316 "test",
317 "/tmp/x.wav",
318 "--backend",
319 "mock",
320 "--format",
321 "jsonl",
322 ])
323 .unwrap();
324 assert_eq!(cli.transcribe.backend.as_deref(), Some("mock"));
325 assert!(matches!(
326 cli.transcribe.format,
327 Some(OutputFormatArg::Jsonl)
328 ));
329 }
330
331 #[test]
332 fn parses_speaker_flag() {
333 let cli = TestCli::try_parse_from(["test", "/tmp/x.wav", "--speaker", "alice"]).unwrap();
334 assert_eq!(cli.transcribe.speaker.as_deref(), Some("alice"));
335 assert!(cli.transcribe.threshold.is_none());
338 }
339
340 #[test]
341 fn parses_threshold_flag() {
342 let cli = TestCli::try_parse_from(["test", "/tmp/x.wav", "--threshold", "0.65"]).unwrap();
343 assert!((cli.transcribe.threshold.unwrap() - 0.65).abs() < f32::EPSILON);
344 }
345
346 #[test]
347 fn parses_speaker_model_flag() {
348 let cli = TestCli::try_parse_from([
349 "test",
350 "/tmp/x.wav",
351 "--speaker-model",
352 "/opt/wespeaker.onnx",
353 ])
354 .unwrap();
355 assert_eq!(
356 cli.transcribe
357 .speaker_model
358 .as_deref()
359 .and_then(|p| p.to_str()),
360 Some("/opt/wespeaker.onnx")
361 );
362 }
363
364 #[test]
365 fn rejects_non_numeric_threshold() {
366 let result = TestCli::try_parse_from(["test", "/tmp/x.wav", "--threshold", "high"]);
367 assert!(result.is_err(), "non-numeric threshold should fail");
368 }
369
370 #[test]
371 fn default_speaker_threshold_is_half() {
372 assert!(
373 (DEFAULT_SPEAKER_THRESHOLD - 0.5).abs() < f32::EPSILON,
374 "default threshold must be 0.5 to match the spike-calibrated default"
375 );
376 }
377
378 #[test]
379 fn parses_md_format() {
380 let cli = TestCli::try_parse_from(["test", "/tmp/x.wav", "--format", "md"]).unwrap();
381 assert!(matches!(cli.transcribe.format, Some(OutputFormatArg::Md)));
382 }
383
384 #[test]
385 fn rejects_missing_wav() {
386 let result = TestCli::try_parse_from(["test"]);
387 assert!(result.is_err(), "wav argument is required");
388 }
389
390 #[test]
391 fn rejects_unknown_format() {
392 let result = TestCli::try_parse_from(["test", "/tmp/x.wav", "--format", "yaml"]);
393 assert!(result.is_err(), "only md/jsonl are valid formats");
394 }
395
396 #[test]
397 fn output_format_arg_maps_to_output_format() {
398 assert_eq!(
399 OutputFormat::from(OutputFormatArg::Jsonl),
400 OutputFormat::Jsonl
401 );
402 assert_eq!(OutputFormat::from(OutputFormatArg::Md), OutputFormat::Md);
403 }
404
405 struct AlwaysFailWriter;
412
413 impl Write for AlwaysFailWriter {
414 fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
415 Err(std::io::Error::other("forced write failure"))
416 }
417 fn flush(&mut self) -> std::io::Result<()> {
418 Ok(())
419 }
420 }
421
422 struct FlushFailWriter;
423
424 impl Write for FlushFailWriter {
425 fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
426 Ok(buf.len())
427 }
428 fn flush(&mut self) -> std::io::Result<()> {
429 Err(std::io::Error::other("forced flush failure"))
430 }
431 }
432
433 fn fixture_wav() -> PathBuf {
434 let manifest = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
435 manifest.join("tests/fixtures/voice/short_en.wav")
436 }
437
438 fn cmd(wav: PathBuf, backend: Option<&str>) -> TranscribeCommand {
439 TranscribeCommand {
440 wav,
441 backend: backend.map(str::to_string),
442 model: None,
443 format: None,
444 speaker: None,
445 threshold: None,
446 speaker_model: None,
447 }
448 }
449
450 fn write_test_wav(
451 path: &std::path::Path,
452 sample_rate: u32,
453 channels: u16,
454 bits: u16,
455 format: hound::SampleFormat,
456 ) {
457 let spec = hound::WavSpec {
458 channels,
459 sample_rate,
460 bits_per_sample: bits,
461 sample_format: format,
462 };
463 let mut writer = hound::WavWriter::create(path, spec).unwrap();
464 match format {
465 hound::SampleFormat::Int => {
466 for s in [0_i16, 1, 2, 3] {
467 writer.write_sample(s).unwrap();
468 }
469 }
470 hound::SampleFormat::Float => {
471 writer.write_sample(0.0_f32).unwrap();
472 }
473 }
474 writer.finalize().unwrap();
475 }
476
477 #[test]
478 fn read_wav_pcm_16k_mono_accepts_valid_wav() {
479 let tmp = tempfile::TempDir::new().unwrap();
480 let path = tmp.path().join("ok.wav");
481 write_test_wav(&path, 16_000, 1, 16, hound::SampleFormat::Int);
482 let pcm = read_wav_pcm_16k_mono(&path).unwrap();
483 assert_eq!(pcm, vec![0, 1, 2, 3]);
484 }
485
486 #[test]
487 fn read_wav_pcm_16k_mono_rejects_wrong_sample_rate() {
488 let tmp = tempfile::TempDir::new().unwrap();
489 let path = tmp.path().join("44k.wav");
490 write_test_wav(&path, 44_100, 1, 16, hound::SampleFormat::Int);
491 let err = read_wav_pcm_16k_mono(&path).unwrap_err();
492 assert!(
493 err.to_string().contains("must be 16 kHz mono 16-bit PCM"),
494 "got: {err}"
495 );
496 }
497
498 #[test]
499 fn read_wav_pcm_16k_mono_rejects_stereo() {
500 let tmp = tempfile::TempDir::new().unwrap();
501 let path = tmp.path().join("stereo.wav");
502 write_test_wav(&path, 16_000, 2, 16, hound::SampleFormat::Int);
503 let err = read_wav_pcm_16k_mono(&path).unwrap_err();
504 assert!(err.to_string().contains("16 kHz mono"), "got: {err}");
505 }
506
507 #[test]
508 fn read_wav_pcm_16k_mono_rejects_wrong_bit_depth() {
509 let tmp = tempfile::TempDir::new().unwrap();
510 let path = tmp.path().join("24bit.wav");
511 write_test_wav(&path, 16_000, 1, 24, hound::SampleFormat::Int);
512 let err = read_wav_pcm_16k_mono(&path).unwrap_err();
513 assert!(err.to_string().contains("16 kHz mono"), "got: {err}");
514 }
515
516 #[test]
517 fn read_wav_pcm_16k_mono_rejects_float_format() {
518 let tmp = tempfile::TempDir::new().unwrap();
519 let path = tmp.path().join("f32.wav");
520 write_test_wav(&path, 16_000, 1, 32, hound::SampleFormat::Float);
521 let err = read_wav_pcm_16k_mono(&path).unwrap_err();
522 assert!(err.to_string().contains("16 kHz mono"), "got: {err}");
523 }
524
525 #[test]
526 fn read_wav_pcm_16k_mono_missing_file_errors() {
527 let err = read_wav_pcm_16k_mono(std::path::Path::new("/nope/missing.wav")).unwrap_err();
528 assert!(err.to_string().contains("open WAV"), "got: {err}");
529 }
530
531 #[test]
532 fn run_propagates_unknown_backend_error() {
533 let mut buf: Vec<u8> = Vec::new();
534 let err = cmd(fixture_wav(), Some("nope"))
535 .run(&mut buf, OutputFormat::Jsonl)
536 .unwrap_err();
537 assert!(
538 err.to_string().contains("unknown voice backend"),
539 "got: {err}"
540 );
541 }
542
543 #[test]
544 fn run_propagates_missing_wav_error() {
545 let mut buf: Vec<u8> = Vec::new();
546 let err = cmd(PathBuf::from("/nonexistent/should/not/exist.wav"), None)
547 .run(&mut buf, OutputFormat::Jsonl)
548 .unwrap_err();
549 assert!(err.to_string().contains("Failed to open WAV"), "got: {err}");
550 }
551
552 #[test]
553 fn run_propagates_writer_error_jsonl() {
554 let err = cmd(fixture_wav(), None)
555 .run(&mut AlwaysFailWriter, OutputFormat::Jsonl)
556 .unwrap_err();
557 assert!(
558 err.to_string().contains("forced write failure"),
559 "got: {err}"
560 );
561 }
562
563 #[test]
564 fn run_propagates_writer_error_md() {
565 let err = cmd(fixture_wav(), None)
566 .run(&mut AlwaysFailWriter, OutputFormat::Md)
567 .unwrap_err();
568 assert!(
569 err.to_string().contains("forced write failure"),
570 "got: {err}"
571 );
572 }
573
574 #[test]
575 fn run_propagates_flush_error() {
576 let err = cmd(fixture_wav(), None)
584 .run(&mut FlushFailWriter, OutputFormat::Jsonl)
585 .unwrap_err();
586 assert!(
587 err.to_string().contains("forced flush failure"),
588 "got: {err}"
589 );
590 }
591}