Skip to main content

omni_dev/cli/voice/
transcribe.rs

1//! `omni-dev voice transcribe` — feed a 16 kHz mono WAV file through the
2//! configured [`crate::voice::Transcriber`] and emit JSONL events to stdout
3//! (markdown when stdout is a tty).
4//!
5//! WAV validation is delegated to
6//! [`crate::voice::VecAudioInput::from_wav_path`] — non-16 kHz, non-mono,
7//! non-16-bit-PCM files error with a descriptive message pointing at
8//! `voice capture` as the source of normalised audio.
9
10use std::io::IsTerminal;
11use std::io::Write;
12use std::path::{Path, PathBuf};
13
14use anyhow::{bail, Context, Result};
15use clap::Parser;
16
17use crate::voice::models::SPEAKER_WESPEAKER_EN;
18use crate::voice::{
19    cosine, create_default_transcriber, detect_format, render_jsonl, render_markdown, speaker_file,
20    EnrolledSpeaker, OutputFormat, TranscriptEvent, VecAudioInput, VoiceOpts, WespeakerEmbedder,
21    MIN_EMBED_SAMPLES,
22};
23
24/// Default chunk size handed to [`VecAudioInput`]. Doesn't affect the
25/// mock backend's output; chosen for parity with the streaming pipeline
26/// (#806) where ~64 ms chunks at 16 kHz keep latency low without
27/// thrashing the inference loop.
28const DEFAULT_CHUNK_SAMPLES: usize = 1024;
29
30/// Default cosine-similarity threshold for `--speaker` filtering.
31///
32/// Calibrated against
33/// [tests/fixtures/voice/two_speakers.wav](../../../tests/fixtures/voice/two_speakers.wav)
34/// in [SPIKE.md on `issue-805-spike-tract-speaker`]: within-speaker mean
35/// ≈ 0.91, cross-speaker mean ≈ 0.07. The 0.5 default sits ~0.4 above
36/// the cross-speaker max and ~0.4 below the within-speaker min, leaving
37/// comfortable margin on both sides.
38pub const DEFAULT_SPEAKER_THRESHOLD: f32 = 0.5;
39
40/// Transcribes a 16 kHz mono WAV file to JSONL or markdown.
41///
42/// Output format defaults to `md` on a tty and `jsonl` when stdout is
43/// piped; pass `--format` to override. The transcriber backend is chosen
44/// by `--backend`, then `OMNI_DEV_VOICE_BACKEND`, then the default
45/// (`"mock"` until a real ASR backend lands — see ADR-0032).
46#[derive(Parser)]
47pub struct TranscribeCommand {
48    /// Path to a 16 kHz mono 16-bit PCM WAV file. Use `voice capture` to
49    /// produce one — `transcribe` does not resample.
50    pub wav: PathBuf,
51
52    /// Transcriber backend (`mock`, `whisper-candle`). Defaults to `mock`;
53    /// see ADR-0033 for the `whisper-candle` runtime choice.
54    #[arg(long)]
55    pub backend: Option<String>,
56
57    /// Path to a backend-specific model directory. For `whisper-candle`,
58    /// this overrides `OMNI_DEV_VOICE_WHISPER_MODEL` and the default at
59    /// `~/.omni-dev/voice/models/whisper-tiny.en/`. Ignored by `mock`.
60    #[arg(long)]
61    pub model: Option<PathBuf>,
62
63    /// Output format. Defaults to `md` on a tty, `jsonl` when piped.
64    #[arg(long, value_enum)]
65    pub format: Option<OutputFormatArg>,
66
67    /// Enrolled speaker to filter on. Drops any `Final` event whose
68    /// segment doesn't match the enrolled embedding by cosine
69    /// similarity at or above `--threshold`.
70    #[arg(long)]
71    pub speaker: Option<String>,
72
73    /// Cosine-similarity threshold for `--speaker`. Defaults to 0.5;
74    /// see [`DEFAULT_SPEAKER_THRESHOLD`].
75    #[arg(long)]
76    pub threshold: Option<f32>,
77
78    /// Path to the wespeaker ONNX model. Overrides the default at
79    /// `~/.omni-dev/voice/models/wespeaker-en-voxceleb-resnet34-LM/` and
80    /// `OMNI_DEV_VOICE_SPEAKER_MODEL`. Ignored unless `--speaker` is
81    /// set.
82    #[arg(long)]
83    pub speaker_model: Option<PathBuf>,
84}
85
86/// `clap` value enum matching [`OutputFormat`].
87#[derive(Clone, Copy, Debug, clap::ValueEnum)]
88#[value(rename_all = "lowercase")]
89pub enum OutputFormatArg {
90    /// JSON Lines — one event per line, machine-readable.
91    Jsonl,
92    /// Markdown — human-readable transcript view.
93    Md,
94}
95
96impl From<OutputFormatArg> for OutputFormat {
97    fn from(value: OutputFormatArg) -> Self {
98        match value {
99            OutputFormatArg::Jsonl => Self::Jsonl,
100            OutputFormatArg::Md => Self::Md,
101        }
102    }
103}
104
105impl TranscribeCommand {
106    /// Executes the transcribe command.
107    ///
108    /// Thin shim around `Self::run`: locks stdout and resolves the
109    /// effective format from `--format` plus tty auto-detection, then
110    /// delegates to the writer-generic helper. The split keeps stdout-
111    /// locking and tty-detection out of the testable business logic.
112    pub fn execute(self) -> Result<()> {
113        let format = detect_format(
114            self.format.map(OutputFormat::from),
115            std::io::stdout().is_terminal(),
116        );
117        let mut out = std::io::stdout().lock();
118        self.run(&mut out, format)
119    }
120
121    /// Runs the transcribe pipeline against an arbitrary writer.
122    ///
123    /// Decoupled from stdout so unit tests can drive the error paths
124    /// (writer failures, flush failures, backend-construction failures)
125    /// without spawning a subprocess.
126    fn run<W: Write>(self, w: &mut W, format: OutputFormat) -> Result<()> {
127        let speaker_filter = self
128            .speaker
129            .as_deref()
130            .map(|name| {
131                SpeakerFilter::load(
132                    name,
133                    self.speaker_model.as_deref(),
134                    self.threshold.unwrap_or(DEFAULT_SPEAKER_THRESHOLD),
135                    &self.wav,
136                )
137            })
138            .transpose()?;
139        let opts = VoiceOpts {
140            backend: self.backend,
141            model: self.model,
142        };
143        let transcriber = create_default_transcriber(&opts)?;
144        let input = VecAudioInput::from_wav_path(&self.wav, DEFAULT_CHUNK_SAMPLES)?;
145        let stream = transcriber.transcribe(Box::new(input))?;
146
147        // Collect the (small, batch) stream so we can fold the speaker
148        // filter over it without juggling lifetime gymnastics on the
149        // boxed event iterator.
150        let events: Vec<Result<TranscriptEvent>> = stream.collect();
151        let filtered: Vec<Result<TranscriptEvent>> = match &speaker_filter {
152            Some(f) => events
153                .into_iter()
154                .filter_map(|ev| f.transform(ev))
155                .collect(),
156            None => events,
157        };
158
159        match format {
160            OutputFormat::Jsonl => render_jsonl(filtered, w)?,
161            OutputFormat::Md => render_markdown(filtered, w)?,
162        }
163        w.flush()?;
164        Ok(())
165    }
166}
167
168/// Wraps the enrolled-speaker embedding + embedder + source PCM needed
169/// to filter the `Final` event stream on a single speaker.
170struct SpeakerFilter {
171    name: String,
172    enrolled: EnrolledSpeaker,
173    embedder: WespeakerEmbedder,
174    pcm: Vec<i16>,
175    threshold: f32,
176}
177
178impl SpeakerFilter {
179    fn load(name: &str, speaker_model: Option<&Path>, threshold: f32, wav: &Path) -> Result<Self> {
180        let enrolled_path = speaker_file(name)?;
181        let enrolled = EnrolledSpeaker::load(&enrolled_path).with_context(|| {
182            format!(
183                "load enrolled speaker {} from {}",
184                name,
185                enrolled_path.display()
186            )
187        })?;
188        let dir = SPEAKER_WESPEAKER_EN.resolve_dir(speaker_model)?;
189        SPEAKER_WESPEAKER_EN.ensure_present(&dir)?;
190        let model_path = dir.join(SPEAKER_WESPEAKER_EN.required_files[0]);
191        let embedder = WespeakerEmbedder::new(&model_path)?;
192        let pcm = read_wav_pcm_16k_mono(wav)?;
193        Ok(Self {
194            name: name.to_string(),
195            enrolled,
196            embedder,
197            pcm,
198            threshold,
199        })
200    }
201
202    /// Filters a single event. Returns `Some(event)` to keep it (with
203    /// `speaker` set on `Final`) or `None` to drop it. `Partial` and
204    /// `Endpoint` events always pass through unchanged. Errors pass
205    /// through so downstream rendering can fail loudly.
206    fn transform(&self, ev: Result<TranscriptEvent>) -> Option<Result<TranscriptEvent>> {
207        let ev = match ev {
208            Ok(ev) => ev,
209            err @ Err(_) => return Some(err),
210        };
211        match ev {
212            TranscriptEvent::Final {
213                event_id,
214                text,
215                start,
216                end,
217                confidence,
218                words,
219                speaker: _,
220                revisable,
221            } => {
222                let s = (start.as_secs_f64() * 16_000.0) as usize;
223                let e = (end.as_secs_f64() * 16_000.0) as usize;
224                let lo = s.min(self.pcm.len());
225                let hi = e.min(self.pcm.len());
226                let window = &self.pcm[lo..hi.max(lo)];
227                if window.len() < MIN_EMBED_SAMPLES {
228                    // Too short for a stable embedding; conservatively drop.
229                    return None;
230                }
231                let emb = match self.embedder.embed(window) {
232                    Ok(v) => v,
233                    Err(err) => return Some(Err(err)),
234                };
235                if cosine(&emb, &self.enrolled.vector) >= self.threshold {
236                    Some(Ok(TranscriptEvent::Final {
237                        event_id,
238                        text,
239                        start,
240                        end,
241                        confidence,
242                        words,
243                        speaker: Some(self.name.clone()),
244                        revisable,
245                    }))
246                } else {
247                    None
248                }
249            }
250            other => Some(Ok(other)),
251        }
252    }
253}
254
255/// Reads a 16 kHz mono 16-bit signed PCM WAV from `path`, returning the
256/// raw samples for re-windowing by [`SpeakerFilter::transform`].
257///
258/// Delegates format validation to the same invariants
259/// [`VecAudioInput::from_wav_path`] enforces; the two paths read the
260/// file independently because the transcriber moves its input.
261fn read_wav_pcm_16k_mono(path: &Path) -> Result<Vec<i16>> {
262    let mut reader = hound::WavReader::open(path)
263        .with_context(|| format!("open WAV at {} for speaker filter", path.display()))?;
264    let spec = reader.spec();
265    if spec.sample_rate != 16_000
266        || spec.channels != 1
267        || spec.bits_per_sample != 16
268        || spec.sample_format != hound::SampleFormat::Int
269    {
270        bail!(
271            "WAV at {} must be 16 kHz mono 16-bit PCM for --speaker filtering",
272            path.display()
273        );
274    }
275    reader
276        .samples::<i16>()
277        .collect::<Result<Vec<_>, _>>()
278        .with_context(|| format!("decode PCM samples from {}", path.display()))
279}
280
281#[cfg(test)]
282#[allow(clippy::unwrap_used, clippy::expect_used)]
283mod tests {
284    use super::*;
285
286    use clap::Parser;
287
288    #[derive(Parser)]
289    struct TestCli {
290        #[command(flatten)]
291        transcribe: TranscribeCommand,
292    }
293
294    #[test]
295    fn parses_required_wav_only() {
296        let cli = TestCli::try_parse_from(["test", "/tmp/x.wav"]).unwrap();
297        assert_eq!(cli.transcribe.wav.to_str().unwrap(), "/tmp/x.wav");
298        assert!(cli.transcribe.backend.is_none());
299        assert!(cli.transcribe.model.is_none());
300        assert!(cli.transcribe.format.is_none());
301    }
302
303    #[test]
304    fn parses_model_flag() {
305        let cli =
306            TestCli::try_parse_from(["test", "/tmp/x.wav", "--model", "/opt/whisper"]).unwrap();
307        assert_eq!(
308            cli.transcribe.model.as_deref().and_then(|p| p.to_str()),
309            Some("/opt/whisper")
310        );
311    }
312
313    #[test]
314    fn parses_all_flags() {
315        let cli = TestCli::try_parse_from([
316            "test",
317            "/tmp/x.wav",
318            "--backend",
319            "mock",
320            "--format",
321            "jsonl",
322        ])
323        .unwrap();
324        assert_eq!(cli.transcribe.backend.as_deref(), Some("mock"));
325        assert!(matches!(
326            cli.transcribe.format,
327            Some(OutputFormatArg::Jsonl)
328        ));
329    }
330
331    #[test]
332    fn parses_speaker_flag() {
333        let cli = TestCli::try_parse_from(["test", "/tmp/x.wav", "--speaker", "alice"]).unwrap();
334        assert_eq!(cli.transcribe.speaker.as_deref(), Some("alice"));
335        // Threshold defaults to None at parse time; the run path applies
336        // DEFAULT_SPEAKER_THRESHOLD when speaker is set and threshold is None.
337        assert!(cli.transcribe.threshold.is_none());
338    }
339
340    #[test]
341    fn parses_threshold_flag() {
342        let cli = TestCli::try_parse_from(["test", "/tmp/x.wav", "--threshold", "0.65"]).unwrap();
343        assert!((cli.transcribe.threshold.unwrap() - 0.65).abs() < f32::EPSILON);
344    }
345
346    #[test]
347    fn parses_speaker_model_flag() {
348        let cli = TestCli::try_parse_from([
349            "test",
350            "/tmp/x.wav",
351            "--speaker-model",
352            "/opt/wespeaker.onnx",
353        ])
354        .unwrap();
355        assert_eq!(
356            cli.transcribe
357                .speaker_model
358                .as_deref()
359                .and_then(|p| p.to_str()),
360            Some("/opt/wespeaker.onnx")
361        );
362    }
363
364    #[test]
365    fn rejects_non_numeric_threshold() {
366        let result = TestCli::try_parse_from(["test", "/tmp/x.wav", "--threshold", "high"]);
367        assert!(result.is_err(), "non-numeric threshold should fail");
368    }
369
370    #[test]
371    fn default_speaker_threshold_is_half() {
372        assert!(
373            (DEFAULT_SPEAKER_THRESHOLD - 0.5).abs() < f32::EPSILON,
374            "default threshold must be 0.5 to match the spike-calibrated default"
375        );
376    }
377
378    #[test]
379    fn parses_md_format() {
380        let cli = TestCli::try_parse_from(["test", "/tmp/x.wav", "--format", "md"]).unwrap();
381        assert!(matches!(cli.transcribe.format, Some(OutputFormatArg::Md)));
382    }
383
384    #[test]
385    fn rejects_missing_wav() {
386        let result = TestCli::try_parse_from(["test"]);
387        assert!(result.is_err(), "wav argument is required");
388    }
389
390    #[test]
391    fn rejects_unknown_format() {
392        let result = TestCli::try_parse_from(["test", "/tmp/x.wav", "--format", "yaml"]);
393        assert!(result.is_err(), "only md/jsonl are valid formats");
394    }
395
396    #[test]
397    fn output_format_arg_maps_to_output_format() {
398        assert_eq!(
399            OutputFormat::from(OutputFormatArg::Jsonl),
400            OutputFormat::Jsonl
401        );
402        assert_eq!(OutputFormat::from(OutputFormatArg::Md), OutputFormat::Md);
403    }
404
405    // ── In-process error-path tests for `TranscribeCommand::run` ──
406    //
407    // These exercise the `?` propagation in `run` directly, bypassing the
408    // subprocess machinery so each error site (backend factory, WAV load,
409    // render write, render flush) is hit deterministically.
410
411    struct AlwaysFailWriter;
412
413    impl Write for AlwaysFailWriter {
414        fn write(&mut self, _buf: &[u8]) -> std::io::Result<usize> {
415            Err(std::io::Error::other("forced write failure"))
416        }
417        fn flush(&mut self) -> std::io::Result<()> {
418            Ok(())
419        }
420    }
421
422    struct FlushFailWriter;
423
424    impl Write for FlushFailWriter {
425        fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
426            Ok(buf.len())
427        }
428        fn flush(&mut self) -> std::io::Result<()> {
429            Err(std::io::Error::other("forced flush failure"))
430        }
431    }
432
433    fn fixture_wav() -> PathBuf {
434        let manifest = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
435        manifest.join("tests/fixtures/voice/short_en.wav")
436    }
437
438    fn cmd(wav: PathBuf, backend: Option<&str>) -> TranscribeCommand {
439        TranscribeCommand {
440            wav,
441            backend: backend.map(str::to_string),
442            model: None,
443            format: None,
444            speaker: None,
445            threshold: None,
446            speaker_model: None,
447        }
448    }
449
450    fn write_test_wav(
451        path: &std::path::Path,
452        sample_rate: u32,
453        channels: u16,
454        bits: u16,
455        format: hound::SampleFormat,
456    ) {
457        let spec = hound::WavSpec {
458            channels,
459            sample_rate,
460            bits_per_sample: bits,
461            sample_format: format,
462        };
463        let mut writer = hound::WavWriter::create(path, spec).unwrap();
464        match format {
465            hound::SampleFormat::Int => {
466                for s in [0_i16, 1, 2, 3] {
467                    writer.write_sample(s).unwrap();
468                }
469            }
470            hound::SampleFormat::Float => {
471                writer.write_sample(0.0_f32).unwrap();
472            }
473        }
474        writer.finalize().unwrap();
475    }
476
477    #[test]
478    fn read_wav_pcm_16k_mono_accepts_valid_wav() {
479        let tmp = tempfile::TempDir::new().unwrap();
480        let path = tmp.path().join("ok.wav");
481        write_test_wav(&path, 16_000, 1, 16, hound::SampleFormat::Int);
482        let pcm = read_wav_pcm_16k_mono(&path).unwrap();
483        assert_eq!(pcm, vec![0, 1, 2, 3]);
484    }
485
486    #[test]
487    fn read_wav_pcm_16k_mono_rejects_wrong_sample_rate() {
488        let tmp = tempfile::TempDir::new().unwrap();
489        let path = tmp.path().join("44k.wav");
490        write_test_wav(&path, 44_100, 1, 16, hound::SampleFormat::Int);
491        let err = read_wav_pcm_16k_mono(&path).unwrap_err();
492        assert!(
493            err.to_string().contains("must be 16 kHz mono 16-bit PCM"),
494            "got: {err}"
495        );
496    }
497
498    #[test]
499    fn read_wav_pcm_16k_mono_rejects_stereo() {
500        let tmp = tempfile::TempDir::new().unwrap();
501        let path = tmp.path().join("stereo.wav");
502        write_test_wav(&path, 16_000, 2, 16, hound::SampleFormat::Int);
503        let err = read_wav_pcm_16k_mono(&path).unwrap_err();
504        assert!(err.to_string().contains("16 kHz mono"), "got: {err}");
505    }
506
507    #[test]
508    fn read_wav_pcm_16k_mono_rejects_wrong_bit_depth() {
509        let tmp = tempfile::TempDir::new().unwrap();
510        let path = tmp.path().join("24bit.wav");
511        write_test_wav(&path, 16_000, 1, 24, hound::SampleFormat::Int);
512        let err = read_wav_pcm_16k_mono(&path).unwrap_err();
513        assert!(err.to_string().contains("16 kHz mono"), "got: {err}");
514    }
515
516    #[test]
517    fn read_wav_pcm_16k_mono_rejects_float_format() {
518        let tmp = tempfile::TempDir::new().unwrap();
519        let path = tmp.path().join("f32.wav");
520        write_test_wav(&path, 16_000, 1, 32, hound::SampleFormat::Float);
521        let err = read_wav_pcm_16k_mono(&path).unwrap_err();
522        assert!(err.to_string().contains("16 kHz mono"), "got: {err}");
523    }
524
525    #[test]
526    fn read_wav_pcm_16k_mono_missing_file_errors() {
527        let err = read_wav_pcm_16k_mono(std::path::Path::new("/nope/missing.wav")).unwrap_err();
528        assert!(err.to_string().contains("open WAV"), "got: {err}");
529    }
530
531    #[test]
532    fn run_propagates_unknown_backend_error() {
533        let mut buf: Vec<u8> = Vec::new();
534        let err = cmd(fixture_wav(), Some("nope"))
535            .run(&mut buf, OutputFormat::Jsonl)
536            .unwrap_err();
537        assert!(
538            err.to_string().contains("unknown voice backend"),
539            "got: {err}"
540        );
541    }
542
543    #[test]
544    fn run_propagates_missing_wav_error() {
545        let mut buf: Vec<u8> = Vec::new();
546        let err = cmd(PathBuf::from("/nonexistent/should/not/exist.wav"), None)
547            .run(&mut buf, OutputFormat::Jsonl)
548            .unwrap_err();
549        assert!(err.to_string().contains("Failed to open WAV"), "got: {err}");
550    }
551
552    #[test]
553    fn run_propagates_writer_error_jsonl() {
554        let err = cmd(fixture_wav(), None)
555            .run(&mut AlwaysFailWriter, OutputFormat::Jsonl)
556            .unwrap_err();
557        assert!(
558            err.to_string().contains("forced write failure"),
559            "got: {err}"
560        );
561    }
562
563    #[test]
564    fn run_propagates_writer_error_md() {
565        let err = cmd(fixture_wav(), None)
566            .run(&mut AlwaysFailWriter, OutputFormat::Md)
567            .unwrap_err();
568        assert!(
569            err.to_string().contains("forced write failure"),
570            "got: {err}"
571        );
572    }
573
574    #[test]
575    fn run_propagates_flush_error() {
576        // FlushFailWriter accepts writes — including the per-event flushes
577        // inside render_jsonl — so this primarily exercises the first
578        // mid-stream flush `?`. The final outer `w.flush()?` in `run` is
579        // only reachable if the inner flushes all succeed, which by design
580        // of FlushFailWriter they don't. The behaviour we're locking in
581        // is "flush errors propagate, full stop" — not which `?` site
582        // catches them first.
583        let err = cmd(fixture_wav(), None)
584            .run(&mut FlushFailWriter, OutputFormat::Jsonl)
585            .unwrap_err();
586        assert!(
587            err.to_string().contains("forced flush failure"),
588            "got: {err}"
589        );
590    }
591}