Skip to main content

omni_dev/cli/voice/
install_model.rs

1//! `omni-dev voice install-model` — one-time fetch of model artefacts.
2//!
3//! Supports two variants: `whisper-tiny.en` for the `whisper-candle` ASR
4//! backend, and `speaker-wespeaker-en` for the speaker-embedding runtime
5//! added in #805 / ADR-0034. Files land in the conventional install
6//! locations beneath `~/.omni-dev/voice/models/`.
7//!
8//! Bumps the model-download cost to install time rather than transcribe/
9//! enrol time, so network failures surface explicitly when the user opts
10//! in to installing rather than silently on first use.
11
12use std::io::Write;
13use std::path::{Path, PathBuf};
14use std::time::Instant;
15
16use anyhow::{anyhow, bail, Context, Result};
17use clap::{Parser, ValueEnum};
18use hf_hub::{api::sync::Api, Repo, RepoType};
19use sha2::{Digest, Sha256};
20
21use crate::voice::models::{ModelSource, ModelSpec, SPEAKER_WESPEAKER_EN, WHISPER_TINY_EN};
22
23/// Which model variant to install.
24///
25/// `--variant` defaults to `whisper-tiny.en` so bare
26/// `voice install-model` continues to install the ASR model — the
27/// pre-#805 behaviour.
28#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, ValueEnum)]
29pub enum Variant {
30    /// OpenAI Whisper `tiny.en` (ADR-0033).
31    #[default]
32    #[value(name = "whisper-tiny.en")]
33    WhisperTinyEn,
34    /// Wespeaker `resnet34_LM` English-only speaker embedding (ADR-0034).
35    #[value(name = "speaker-wespeaker-en")]
36    SpeakerWespeakerEn,
37}
38
39impl Variant {
40    /// Returns the [`ModelSpec`] for this variant.
41    pub fn spec(self) -> &'static ModelSpec {
42        match self {
43            Self::WhisperTinyEn => &WHISPER_TINY_EN,
44            Self::SpeakerWespeakerEn => &SPEAKER_WESPEAKER_EN,
45        }
46    }
47}
48
49/// Downloads the model files for a chosen variant into the conventional
50/// install location at `~/.omni-dev/voice/models/<variant-subdir>/` (or
51/// `--dest` to override).
52///
53/// Idempotent: if every required file is already present and non-empty,
54/// the command prints a "model already installed" line and exits 0. Pass
55/// `--force` to re-download anyway.
56#[derive(Parser)]
57pub struct InstallModelCommand {
58    /// Override the install directory. Defaults to the variant's
59    /// canonical location under `~/.omni-dev/voice/models/`.
60    #[arg(long)]
61    pub dest: Option<PathBuf>,
62
63    /// Re-download even if all required files are already present.
64    #[arg(long)]
65    pub force: bool,
66
67    /// Which model variant to install. Defaults to `whisper-tiny.en`.
68    #[arg(long, value_enum, default_value_t = Variant::WhisperTinyEn)]
69    pub variant: Variant,
70}
71
72impl InstallModelCommand {
73    /// Entry point. Writes user-facing progress to stderr so stdout stays
74    /// reserved for machine-readable output (parity with `voice
75    /// transcribe`'s JSONL pipe-detection convention).
76    pub fn execute(self) -> Result<()> {
77        let mut err = std::io::stderr().lock();
78        self.run(&mut err)
79    }
80
81    /// Writer-generic core, parameterised over stderr so tests can drive
82    /// the success/idempotency paths without touching the global stream.
83    fn run<W: Write>(self, w: &mut W) -> Result<()> {
84        let spec = self.variant.spec();
85        let dest = match self.dest {
86            Some(p) => p,
87            None => spec
88                .default_dir()
89                .ok_or_else(|| anyhow!("could not determine home directory; pass --dest <path>"))?,
90        };
91
92        if !self.force && all_present(spec, &dest) {
93            writeln!(w, "model already installed at {}", dest.display())?;
94            return Ok(());
95        }
96
97        match spec.source {
98            ModelSource::HfHub { repo_id, revision } => {
99                download_hf_hub(spec, repo_id, revision, &dest, w)
100            }
101            ModelSource::HttpReleaseAsset { url, sha256, bytes } => {
102                download_release_asset(spec, url, sha256, bytes, &dest, w)
103            }
104        }
105    }
106}
107
108fn all_present(spec: &ModelSpec, dir: &Path) -> bool {
109    spec.required_files_in(dir)
110        .iter()
111        .all(|p| p.is_file() && p.metadata().is_ok_and(|m| m.len() > 0))
112}
113
114fn download_hf_hub<W: Write>(
115    spec: &ModelSpec,
116    repo_id: &str,
117    revision: &str,
118    dest: &Path,
119    w: &mut W,
120) -> Result<()> {
121    writeln!(
122        w,
123        "Installing {repo_id} (revision {revision}) -> {}",
124        dest.display()
125    )?;
126    std::fs::create_dir_all(dest)
127        .with_context(|| format!("create install directory at {}", dest.display()))?;
128
129    let api = Api::new().context("initialise HuggingFace Hub client")?;
130    let repo = api.repo(Repo::with_revision(
131        repo_id.to_string(),
132        RepoType::Model,
133        revision.to_string(),
134    ));
135
136    for file in spec.required_files {
137        let start = Instant::now();
138        write!(w, "  fetching {file}... ")?;
139        w.flush()?;
140        let downloaded = repo.get(file).with_context(|| {
141            format!(
142                "download {file} from {repo_id} (revision {revision}). \
143                 Check your network or set HTTPS_PROXY"
144            )
145        })?;
146        let target = dest.join(file);
147        atomic_install_copy(&downloaded, &target).with_context(|| {
148            format!(
149                "install {file} into {} (atomic rename failed)",
150                target.display()
151            )
152        })?;
153        let bytes = std::fs::metadata(&target).map_or(0, |m| m.len());
154        writeln!(
155            w,
156            "done ({bytes} bytes in {:.1}s)",
157            start.elapsed().as_secs_f64()
158        )?;
159    }
160
161    writeln!(
162        w,
163        "{} model installed at {}",
164        spec.kind_label,
165        dest.display()
166    )?;
167    Ok(())
168}
169
170fn download_release_asset<W: Write>(
171    spec: &ModelSpec,
172    url: &str,
173    expected_sha256: &str,
174    expected_bytes: u64,
175    dest: &Path,
176    w: &mut W,
177) -> Result<()> {
178    // Wespeaker (and any future single-asset release-driven model) ships
179    // exactly one file. The check is defensive: if a future spec mis-
180    // declares N!=1 with HttpReleaseAsset, fail loudly rather than
181    // silently install only the first.
182    if spec.required_files.len() != 1 {
183        bail!(
184            "HttpReleaseAsset source expects exactly one required_file, \
185             got {} for variant {}",
186            spec.required_files.len(),
187            spec.variant
188        );
189    }
190    let file_name = spec.required_files[0];
191    let target = dest.join(file_name);
192
193    writeln!(
194        w,
195        "Installing {file_name} ({expected_bytes} B) -> {}",
196        dest.display()
197    )?;
198    std::fs::create_dir_all(dest)
199        .with_context(|| format!("create install directory at {}", dest.display()))?;
200
201    let start = Instant::now();
202    write!(w, "  fetching {url}... ")?;
203    w.flush()?;
204
205    let resp = ureq::get(url)
206        .call()
207        .with_context(|| format!("HTTP GET {url}"))?;
208    let status = resp.status();
209    if !status.is_success() {
210        bail!(
211            "HTTP {} fetching {url}: {}",
212            status.as_u16(),
213            status.canonical_reason().unwrap_or("Unknown"),
214        );
215    }
216    let bytes = resp
217        .into_body()
218        .read_to_vec()
219        .with_context(|| format!("read response body for {url}"))?;
220
221    let actual_sha = {
222        let mut hasher = Sha256::new();
223        hasher.update(&bytes);
224        let digest = hasher.finalize();
225        let mut hex = String::with_capacity(digest.len() * 2);
226        for byte in digest {
227            use std::fmt::Write as _;
228            // `write!` into a `String` is infallible.
229            let _ = write!(&mut hex, "{byte:02x}");
230        }
231        hex
232    };
233    if !actual_sha.eq_ignore_ascii_case(expected_sha256) {
234        bail!("SHA-256 mismatch for {file_name}: expected {expected_sha256}, got {actual_sha}");
235    }
236
237    atomic_install_bytes(&bytes, &target).with_context(|| {
238        format!(
239            "install {file_name} into {} (atomic rename failed)",
240            target.display()
241        )
242    })?;
243    writeln!(
244        w,
245        "done ({} bytes in {:.1}s; sha256 verified)",
246        bytes.len(),
247        start.elapsed().as_secs_f64()
248    )?;
249    writeln!(
250        w,
251        "{} model installed at {}",
252        spec.kind_label,
253        dest.display()
254    )?;
255    Ok(())
256}
257
258/// Writes `bytes` to a `.part` sibling of `to`, then atomically renames
259/// so a partial download never leaves a half-written file at `to`.
260fn atomic_install_bytes(bytes: &[u8], to: &Path) -> Result<()> {
261    if let Some(parent) = to.parent() {
262        std::fs::create_dir_all(parent)
263            .with_context(|| format!("create parent dir {}", parent.display()))?;
264    }
265    let tmp = part_sibling(to)?;
266    std::fs::write(&tmp, bytes)
267        .with_context(|| format!("write {} bytes -> {}", bytes.len(), tmp.display()))?;
268    std::fs::rename(&tmp, to)
269        .with_context(|| format!("rename {} -> {}", tmp.display(), to.display()))?;
270    Ok(())
271}
272
273/// Copies `from` into `to` via a temp file sibling + rename so a partial
274/// download never leaves a half-written file at the destination.
275fn atomic_install_copy(from: &Path, to: &Path) -> Result<()> {
276    if let Some(parent) = to.parent() {
277        std::fs::create_dir_all(parent)
278            .with_context(|| format!("create parent dir {}", parent.display()))?;
279    }
280    let tmp = part_sibling(to)?;
281    std::fs::copy(from, &tmp)
282        .with_context(|| format!("copy {} -> {}", from.display(), tmp.display()))?;
283    std::fs::rename(&tmp, to)
284        .with_context(|| format!("rename {} -> {}", tmp.display(), to.display()))?;
285    Ok(())
286}
287
288fn part_sibling(to: &Path) -> Result<PathBuf> {
289    let file_name = to
290        .file_name()
291        .ok_or_else(|| anyhow!("destination path has no file name: {}", to.display()))?;
292    let mut tmp_name = std::ffi::OsString::from(".");
293    tmp_name.push(file_name);
294    tmp_name.push(".part");
295    Ok(to.with_file_name(tmp_name))
296}
297
298#[cfg(test)]
299#[allow(clippy::unwrap_used, clippy::expect_used)]
300mod tests {
301    use super::*;
302    use crate::voice::models::REQUIRED_FILES;
303    use std::sync::{Mutex, MutexGuard};
304
305    // HOME-mutating tests share this guard so they don't race the
306    // env-mutating tests in `voice::models`.
307    static ENV_GUARD: Mutex<()> = Mutex::new(());
308
309    fn env_guard() -> MutexGuard<'static, ()> {
310        match ENV_GUARD.lock() {
311            Ok(g) => g,
312            Err(poisoned) => poisoned.into_inner(),
313        }
314    }
315
316    fn stage_complete_whisper_model(dir: &Path) {
317        std::fs::create_dir_all(dir).unwrap();
318        for f in REQUIRED_FILES {
319            std::fs::write(dir.join(f), b"placeholder").unwrap();
320        }
321    }
322
323    fn stage_complete_speaker_model(dir: &Path) {
324        std::fs::create_dir_all(dir).unwrap();
325        for f in SPEAKER_WESPEAKER_EN.required_files {
326            std::fs::write(dir.join(f), b"placeholder").unwrap();
327        }
328    }
329
330    #[test]
331    fn idempotent_when_all_files_present() {
332        let tmp = tempfile::TempDir::new().unwrap();
333        stage_complete_whisper_model(tmp.path());
334
335        let cmd = InstallModelCommand {
336            dest: Some(tmp.path().to_path_buf()),
337            force: false,
338            variant: Variant::WhisperTinyEn,
339        };
340        let mut out: Vec<u8> = Vec::new();
341        cmd.run(&mut out).unwrap();
342        let msg = String::from_utf8(out).unwrap();
343        assert!(msg.contains("already installed"), "got: {msg}");
344    }
345
346    #[test]
347    fn idempotent_when_speaker_model_present() {
348        let tmp = tempfile::TempDir::new().unwrap();
349        stage_complete_speaker_model(tmp.path());
350
351        let cmd = InstallModelCommand {
352            dest: Some(tmp.path().to_path_buf()),
353            force: false,
354            variant: Variant::SpeakerWespeakerEn,
355        };
356        let mut out: Vec<u8> = Vec::new();
357        cmd.run(&mut out).unwrap();
358        let msg = String::from_utf8(out).unwrap();
359        assert!(msg.contains("already installed"), "got: {msg}");
360    }
361
362    #[test]
363    fn idempotent_skip_treats_zero_byte_file_as_missing() {
364        let tmp = tempfile::TempDir::new().unwrap();
365        std::fs::create_dir_all(tmp.path()).unwrap();
366        for f in REQUIRED_FILES {
367            // Zero-byte file would normally pass `is_file()` but is_not
368            // a valid model artifact; the idempotency check must reject it.
369            std::fs::write(tmp.path().join(f), b"").unwrap();
370        }
371        assert!(!all_present(&WHISPER_TINY_EN, tmp.path()));
372    }
373
374    #[test]
375    fn atomic_install_copy_replaces_target() {
376        let tmp = tempfile::TempDir::new().unwrap();
377        let src = tmp.path().join("src");
378        let dst = tmp.path().join("dst");
379        std::fs::write(&src, b"hello").unwrap();
380        std::fs::write(&dst, b"old").unwrap();
381        atomic_install_copy(&src, &dst).unwrap();
382        let got = std::fs::read(&dst).unwrap();
383        assert_eq!(got, b"hello");
384        // No leftover temp file.
385        let leftover = std::fs::read_dir(tmp.path())
386            .unwrap()
387            .filter_map(Result::ok)
388            .any(|e| e.file_name().to_string_lossy().ends_with(".part"));
389        assert!(!leftover, "atomic_install_copy must not leave .part files");
390    }
391
392    #[test]
393    fn atomic_install_bytes_writes_and_renames() {
394        let tmp = tempfile::TempDir::new().unwrap();
395        let dst = tmp.path().join("out");
396        atomic_install_bytes(b"hello", &dst).unwrap();
397        assert_eq!(std::fs::read(&dst).unwrap(), b"hello");
398        let leftover = std::fs::read_dir(tmp.path())
399            .unwrap()
400            .filter_map(Result::ok)
401            .any(|e| e.file_name().to_string_lossy().ends_with(".part"));
402        assert!(!leftover, "atomic_install_bytes must not leave .part files");
403    }
404
405    #[test]
406    fn parses_no_args() {
407        #[derive(Parser)]
408        struct T {
409            #[command(flatten)]
410            c: InstallModelCommand,
411        }
412        let t = T::try_parse_from(["test"]).unwrap();
413        assert!(t.c.dest.is_none());
414        assert!(!t.c.force);
415        assert_eq!(t.c.variant, Variant::WhisperTinyEn);
416    }
417
418    #[test]
419    fn parses_dest_and_force() {
420        #[derive(Parser)]
421        struct T {
422            #[command(flatten)]
423            c: InstallModelCommand,
424        }
425        let t = T::try_parse_from(["test", "--dest", "/opt/x", "--force"]).unwrap();
426        assert_eq!(t.c.dest.as_deref(), Some(Path::new("/opt/x")));
427        assert!(t.c.force);
428    }
429
430    #[test]
431    fn parses_speaker_variant() {
432        #[derive(Parser)]
433        struct T {
434            #[command(flatten)]
435            c: InstallModelCommand,
436        }
437        let t = T::try_parse_from(["test", "--variant", "speaker-wespeaker-en"]).unwrap();
438        assert_eq!(t.c.variant, Variant::SpeakerWespeakerEn);
439    }
440
441    #[test]
442    fn parses_whisper_variant_explicit() {
443        #[derive(Parser)]
444        struct T {
445            #[command(flatten)]
446            c: InstallModelCommand,
447        }
448        let t = T::try_parse_from(["test", "--variant", "whisper-tiny.en"]).unwrap();
449        assert_eq!(t.c.variant, Variant::WhisperTinyEn);
450    }
451
452    #[test]
453    fn rejects_unknown_variant() {
454        #[derive(Parser)]
455        struct T {
456            #[command(flatten)]
457            c: InstallModelCommand,
458        }
459        let err = T::try_parse_from(["test", "--variant", "klingon"]);
460        assert!(err.is_err(), "unknown variant should fail to parse");
461    }
462
463    #[test]
464    fn run_with_dest_none_resolves_default_install_dir_from_home() {
465        // Covers the `match self.dest { None => spec.default_dir()… }`
466        // arm — the priority-3 path that the explicit-dest tests skip.
467        // We stage the model files at the default location *under a
468        // tempdir HOME* so the idempotent branch returns Ok and we
469        // never touch the network or the real user's home.
470        let _g = env_guard();
471        let tmp = tempfile::TempDir::new().unwrap();
472        let prev_home = std::env::var_os("HOME");
473        std::env::set_var("HOME", tmp.path());
474
475        let default_dir = WHISPER_TINY_EN.default_dir().unwrap();
476        stage_complete_whisper_model(&default_dir);
477
478        let cmd = InstallModelCommand {
479            dest: None,
480            force: false,
481            variant: Variant::WhisperTinyEn,
482        };
483        let mut out: Vec<u8> = Vec::new();
484        let result = cmd.run(&mut out);
485
486        match prev_home {
487            Some(v) => std::env::set_var("HOME", v),
488            None => std::env::remove_var("HOME"),
489        }
490
491        result.unwrap();
492        let msg = String::from_utf8(out).unwrap();
493        assert!(msg.contains("already installed"), "got: {msg}");
494        assert!(
495            msg.contains("whisper-tiny.en"),
496            "expected resolved default dir in message, got: {msg}"
497        );
498    }
499
500    #[test]
501    fn run_speaker_variant_with_dest_none_resolves_default() {
502        let _g = env_guard();
503        let tmp = tempfile::TempDir::new().unwrap();
504        let prev_home = std::env::var_os("HOME");
505        std::env::set_var("HOME", tmp.path());
506
507        let default_dir = SPEAKER_WESPEAKER_EN.default_dir().unwrap();
508        stage_complete_speaker_model(&default_dir);
509
510        let cmd = InstallModelCommand {
511            dest: None,
512            force: false,
513            variant: Variant::SpeakerWespeakerEn,
514        };
515        let mut out: Vec<u8> = Vec::new();
516        let result = cmd.run(&mut out);
517
518        match prev_home {
519            Some(v) => std::env::set_var("HOME", v),
520            None => std::env::remove_var("HOME"),
521        }
522
523        result.unwrap();
524        let msg = String::from_utf8(out).unwrap();
525        assert!(msg.contains("already installed"), "got: {msg}");
526        assert!(
527            msg.contains("wespeaker-en-voxceleb-resnet34-LM"),
528            "expected resolved default dir in message, got: {msg}"
529        );
530    }
531
532    #[test]
533    fn variant_spec_returns_correct_spec() {
534        assert_eq!(
535            Variant::WhisperTinyEn.spec().variant,
536            WHISPER_TINY_EN.variant
537        );
538        assert_eq!(
539            Variant::SpeakerWespeakerEn.spec().variant,
540            SPEAKER_WESPEAKER_EN.variant
541        );
542    }
543}