Skip to main content

omni_dev/cli/voice/
enroll.rs

1//! `omni-dev voice enroll` — capture a microphone sample, compute the
2//! speaker embedding, and persist to `~/.omni-dev/voice/speakers/<name>.json`.
3//!
4//! Stops on the first of: `--idle-after` seconds of trailing silence,
5//! `--max-secs` elapsed since start, or Ctrl-C. Refuses to overwrite an
6//! existing enrolment unless `--force` is set.
7
8use std::path::PathBuf;
9use std::sync::atomic::Ordering;
10use std::thread;
11use std::time::Duration;
12
13use anyhow::{bail, Context, Result};
14use chrono::Utc;
15use clap::Parser;
16
17use crate::voice::capture::{
18    install_ctrl_c_handler, run_capture, CaptureOpts, CaptureSummary, TerminationReason,
19};
20use crate::voice::models::SPEAKER_WESPEAKER_EN;
21use crate::voice::{
22    captures_dir, speaker_file, CpalAudioSource, EnrolledSpeaker, WespeakerEmbedder,
23};
24
25/// Default idle-silence threshold in seconds (issue #805 spec).
26pub const DEFAULT_IDLE_AFTER_SECS: u32 = 2;
27
28/// Default maximum capture duration in seconds (issue #805 spec — 30 s cap).
29pub const DEFAULT_MAX_SECS: u32 = 30;
30
31/// Captures audio from a microphone, computes a speaker embedding, and
32/// persists it to `~/.omni-dev/voice/speakers/<name>.json`.
33///
34/// Stops on the first of: `--idle-after` seconds of trailing silence,
35/// `--max-secs` elapsed (default 30), or Ctrl-C. Refuses to overwrite an
36/// existing enrolment unless `--force` is set.
37#[derive(Parser)]
38pub struct EnrollCommand {
39    /// Identifier under which to store the embedding (the JSON filename
40    /// stem). Defaults to `default`.
41    #[arg(long, default_value = "default")]
42    pub name: String,
43
44    /// Stop after this many seconds of trailing silence.
45    #[arg(long, default_value_t = DEFAULT_IDLE_AFTER_SECS)]
46    pub idle_after: u32,
47
48    /// Hard upper bound on capture duration in seconds. Capture stops as
49    /// soon as this many seconds have elapsed, even if speech continues.
50    /// `0` disables the cap (only idle/Ctrl-C will stop the capture).
51    #[arg(long, default_value_t = DEFAULT_MAX_SECS)]
52    pub max_secs: u32,
53
54    /// Audio input device name. Defaults to the system default input.
55    #[arg(long)]
56    pub device: Option<String>,
57
58    /// Path to the wespeaker ONNX model. Overrides the default at
59    /// `~/.omni-dev/voice/models/wespeaker-en-voxceleb-resnet34-LM/` and
60    /// the `OMNI_DEV_VOICE_SPEAKER_MODEL` env var.
61    #[arg(long)]
62    pub speaker_model: Option<PathBuf>,
63
64    /// Overwrite an existing `<name>.json` enrolment instead of refusing.
65    #[arg(long)]
66    pub force: bool,
67}
68
69impl EnrollCommand {
70    /// Executes the enroll command.
71    pub fn execute(self) -> Result<()> {
72        let speaker_model = resolve_speaker_model_path(self.speaker_model.as_deref())?;
73        let dest = speaker_file(&self.name)?;
74        if dest.is_file() && !self.force {
75            bail!(
76                "speaker {} already enrolled at {}; pass --force to overwrite",
77                self.name,
78                dest.display()
79            );
80        }
81
82        // Capture to a tempfile WAV inside the captures directory, named
83        // with a leading dot so it's clear from `ls` that it's transient.
84        let captures = captures_dir()?;
85        std::fs::create_dir_all(&captures)
86            .with_context(|| format!("create captures dir {}", captures.display()))?;
87        let timestamp = Utc::now().format("%Y%m%dT%H%M%SZ").to_string();
88        let tmp_wav = captures.join(format!(".enroll-{timestamp}.wav"));
89
90        let stop = install_ctrl_c_handler()?;
91        // Optional max-duration cap: a background watchdog flips the same
92        // stop signal Ctrl-C uses. Thread leaks at process exit, which is
93        // fine for a one-shot CLI.
94        if self.max_secs > 0 {
95            let stop = stop.clone();
96            let deadline_secs = u64::from(self.max_secs);
97            thread::spawn(move || {
98                thread::sleep(Duration::from_secs(deadline_secs));
99                stop.store(true, Ordering::Relaxed);
100            });
101        }
102
103        let source = CpalAudioSource::new(self.device.as_deref())?;
104        eprintln!(
105            "Recording enrolment for {} to {} (idle-after: {}s, max: {}s, Ctrl-C to stop)…",
106            self.name,
107            tmp_wav.display(),
108            self.idle_after,
109            self.max_secs,
110        );
111        let opts = CaptureOpts::new(&tmp_wav, self.idle_after);
112        let summary = run_capture(source, opts, stop)?;
113        print_capture_summary(&summary);
114
115        // Decode the captured WAV, embed, persist.
116        let result = embed_and_save(&self.name, &speaker_model, &tmp_wav, &dest);
117
118        // Always try to delete the tempfile — even if the embed step
119        // failed, the WAV is no longer useful.
120        let _ = std::fs::remove_file(&tmp_wav);
121
122        result?;
123        eprintln!(
124            "Enrolled speaker {} ({} dim) -> {}",
125            self.name,
126            embed_dim_hint(),
127            dest.display()
128        );
129        Ok(())
130    }
131}
132
133fn embed_and_save(
134    name: &str,
135    speaker_model: &std::path::Path,
136    tmp_wav: &std::path::Path,
137    dest: &std::path::Path,
138) -> Result<()> {
139    let pcm = read_wav_16k_mono_i16(tmp_wav)?;
140    let embedder = WespeakerEmbedder::new(speaker_model)?;
141    let vector = embedder.embed(&pcm)?;
142    let enrolled = EnrolledSpeaker {
143        name: name.to_string(),
144        model: SPEAKER_WESPEAKER_EN.variant.to_string(),
145        dim: vector.len(),
146        vector,
147        samples_used: 1,
148        enrolled_at: Utc::now(),
149    };
150    enrolled.save(dest)
151}
152
153fn read_wav_16k_mono_i16(path: &std::path::Path) -> Result<Vec<i16>> {
154    let mut reader = hound::WavReader::open(path)
155        .with_context(|| format!("open enrolment WAV at {}", path.display()))?;
156    let spec = reader.spec();
157    if spec.sample_rate != 16_000 || spec.channels != 1 {
158        bail!(
159            "enrolment WAV at {} must be 16 kHz mono (got {} Hz, {} channels)",
160            path.display(),
161            spec.sample_rate,
162            spec.channels
163        );
164    }
165    let samples: Vec<i16> = reader
166        .samples::<i16>()
167        .collect::<Result<Vec<_>, _>>()
168        .context("decode enrolment WAV samples")?;
169    Ok(samples)
170}
171
172fn resolve_speaker_model_path(override_path: Option<&std::path::Path>) -> Result<PathBuf> {
173    let dir = SPEAKER_WESPEAKER_EN.resolve_dir(override_path)?;
174    SPEAKER_WESPEAKER_EN.ensure_present(&dir)?;
175    Ok(dir.join(SPEAKER_WESPEAKER_EN.required_files[0]))
176}
177
178fn print_capture_summary(summary: &CaptureSummary) {
179    eprintln!("{}", format_capture_summary(summary));
180}
181
182fn format_capture_summary(summary: &CaptureSummary) -> String {
183    let reason = match summary.terminated_by {
184        TerminationReason::Idle => "silence threshold reached",
185        TerminationReason::SourceExhausted => "audio source ended",
186        TerminationReason::Signal => "Ctrl-C or max-secs deadline",
187    };
188    let seconds = samples_to_seconds(summary.samples_written);
189    format!(
190        "Captured {seconds:.2}s ({} samples; {} trimmed; stopped: {reason})",
191        summary.samples_written, summary.trimmed_samples,
192    )
193}
194
195fn samples_to_seconds(samples: u64) -> f64 {
196    samples as f64 / f64::from(crate::voice::wav::TARGET_SAMPLE_RATE)
197}
198
199const fn embed_dim_hint() -> usize {
200    // The actual dim comes from the model output and is verified at
201    // save() time; this hint is just for the stderr summary.
202    256
203}
204
205#[cfg(test)]
206#[allow(clippy::unwrap_used, clippy::expect_used)]
207mod tests {
208    use super::*;
209    use clap::Parser;
210
211    #[derive(Parser)]
212    struct TestCli {
213        #[command(flatten)]
214        enroll: EnrollCommand,
215    }
216
217    #[test]
218    fn parses_defaults() {
219        let cli = TestCli::try_parse_from(["test"]).unwrap();
220        assert_eq!(cli.enroll.name, "default");
221        assert_eq!(cli.enroll.idle_after, DEFAULT_IDLE_AFTER_SECS);
222        assert_eq!(cli.enroll.max_secs, DEFAULT_MAX_SECS);
223        assert!(cli.enroll.device.is_none());
224        assert!(cli.enroll.speaker_model.is_none());
225        assert!(!cli.enroll.force);
226    }
227
228    #[test]
229    fn parses_all_flags() {
230        let cli = TestCli::try_parse_from([
231            "test",
232            "--name",
233            "jky",
234            "--idle-after",
235            "3",
236            "--max-secs",
237            "20",
238            "--device",
239            "Built-in Mic",
240            "--speaker-model",
241            "/opt/wespeaker.onnx",
242            "--force",
243        ])
244        .unwrap();
245        assert_eq!(cli.enroll.name, "jky");
246        assert_eq!(cli.enroll.idle_after, 3);
247        assert_eq!(cli.enroll.max_secs, 20);
248        assert_eq!(cli.enroll.device.as_deref(), Some("Built-in Mic"));
249        assert_eq!(
250            cli.enroll.speaker_model.as_deref().and_then(|p| p.to_str()),
251            Some("/opt/wespeaker.onnx")
252        );
253        assert!(cli.enroll.force);
254    }
255
256    #[test]
257    fn parses_max_secs_zero_disables_cap() {
258        let cli = TestCli::try_parse_from(["test", "--max-secs", "0"]).unwrap();
259        assert_eq!(cli.enroll.max_secs, 0);
260    }
261
262    #[test]
263    fn rejects_negative_idle_after() {
264        let r = TestCli::try_parse_from(["test", "--idle-after", "-1"]);
265        assert!(r.is_err());
266    }
267
268    #[test]
269    fn rejects_negative_max_secs() {
270        let r = TestCli::try_parse_from(["test", "--max-secs", "-1"]);
271        assert!(r.is_err());
272    }
273
274    #[test]
275    fn resolve_speaker_model_path_errors_with_install_hint_when_dir_empty() {
276        let tmp = tempfile::TempDir::new().unwrap();
277        let Err(err) = resolve_speaker_model_path(Some(tmp.path())) else {
278            panic!("empty model dir should fail the ensure_present check");
279        };
280        let msg = format!("{err:#}");
281        assert!(msg.contains("no Speaker model found"), "got: {msg}");
282        assert!(msg.contains("--variant speaker-wespeaker-en"), "got: {msg}");
283    }
284
285    #[test]
286    fn resolve_speaker_model_path_returns_onnx_file_when_present() {
287        let tmp = tempfile::TempDir::new().unwrap();
288        let onnx = tmp.path().join(SPEAKER_WESPEAKER_EN.required_files[0]);
289        std::fs::write(&onnx, b"placeholder").unwrap();
290        let resolved = resolve_speaker_model_path(Some(tmp.path())).unwrap();
291        assert_eq!(resolved, onnx);
292    }
293
294    #[test]
295    fn read_wav_16k_mono_i16_rejects_wrong_format() {
296        let tmp = tempfile::TempDir::new().unwrap();
297        let path = tmp.path().join("bad.wav");
298        let spec = hound::WavSpec {
299            channels: 2,
300            sample_rate: 44_100,
301            bits_per_sample: 16,
302            sample_format: hound::SampleFormat::Int,
303        };
304        let mut writer = hound::WavWriter::create(&path, spec).unwrap();
305        writer.write_sample(0_i16).unwrap();
306        writer.write_sample(0_i16).unwrap();
307        writer.finalize().unwrap();
308        let Err(err) = read_wav_16k_mono_i16(&path) else {
309            panic!("stereo @ 44.1k should fail");
310        };
311        let msg = format!("{err:#}");
312        assert!(msg.contains("must be 16 kHz mono"), "got: {msg}");
313    }
314
315    fn summary(reason: TerminationReason, written: u64, trimmed: u64) -> CaptureSummary {
316        CaptureSummary {
317            output: std::path::PathBuf::from("/tmp/out.wav"),
318            samples_written: written,
319            trimmed_samples: trimmed,
320            terminated_by: reason,
321        }
322    }
323
324    #[test]
325    fn format_capture_summary_idle_termination_mentions_silence() {
326        let s = format_capture_summary(&summary(TerminationReason::Idle, 16_000, 3_200));
327        assert!(s.contains("silence threshold reached"));
328        assert!(
329            s.contains("1.00s"),
330            "16000 samples @ 16kHz = 1.00s; got: {s}"
331        );
332        assert!(s.contains("16000 samples"));
333        assert!(s.contains("3200 trimmed"));
334    }
335
336    #[test]
337    fn format_capture_summary_signal_termination_mentions_ctrl_c_or_deadline() {
338        let s = format_capture_summary(&summary(TerminationReason::Signal, 48_000, 0));
339        assert!(s.contains("Ctrl-C or max-secs deadline"));
340        assert!(
341            s.contains("3.00s"),
342            "48000 samples @ 16kHz = 3.00s; got: {s}"
343        );
344    }
345
346    #[test]
347    fn format_capture_summary_source_exhausted_mentions_source() {
348        let s = format_capture_summary(&summary(TerminationReason::SourceExhausted, 8_000, 0));
349        assert!(s.contains("audio source ended"));
350        assert!(s.contains("0.50s"));
351    }
352
353    #[test]
354    fn samples_to_seconds_round_trips_at_16k() {
355        assert!((samples_to_seconds(0) - 0.0).abs() < f64::EPSILON);
356        assert!((samples_to_seconds(16_000) - 1.0).abs() < f64::EPSILON);
357        assert!((samples_to_seconds(8_000) - 0.5).abs() < f64::EPSILON);
358    }
359
360    #[test]
361    fn embed_dim_hint_is_256() {
362        assert_eq!(embed_dim_hint(), 256);
363    }
364
365    #[test]
366    fn read_wav_16k_mono_i16_decodes_ok_when_format_matches() {
367        let tmp = tempfile::TempDir::new().unwrap();
368        let path = tmp.path().join("ok.wav");
369        let spec = hound::WavSpec {
370            channels: 1,
371            sample_rate: 16_000,
372            bits_per_sample: 16,
373            sample_format: hound::SampleFormat::Int,
374        };
375        let mut writer = hound::WavWriter::create(&path, spec).unwrap();
376        for s in [100_i16, 200, 300, 400] {
377            writer.write_sample(s).unwrap();
378        }
379        writer.finalize().unwrap();
380        let samples = read_wav_16k_mono_i16(&path).unwrap();
381        assert_eq!(samples, vec![100, 200, 300, 400]);
382    }
383}