Skip to main content

omni_dev/voice/
models.rs

1//! Model storage convention and path resolution.
2//!
3//! Two distinct kinds of model are tracked by this module:
4//!
5//! - **Whisper ASR** (`tiny.en`), loaded by the `whisper-candle` backend.
6//! - **Wespeaker speaker embedding** (`resnet34_LM`), loaded by the
7//!   speaker-embedding subsystem added in #805 / ADR-0034.
8//!
9//! Both follow the same three-tier resolution priority:
10//!
11//! 1. Explicit `--model <path>` (Whisper) or `--speaker-model <path>`
12//!    (wespeaker) on the relevant CLI command.
13//! 2. `OMNI_DEV_VOICE_WHISPER_MODEL` / `OMNI_DEV_VOICE_SPEAKER_MODEL`
14//!    env var.
15//! 3. Default install location under the user's home directory.
16//!
17//! Sharing the helper means the install command writes to exactly the
18//! place the backend later reads from — bugs can't diverge between
19//! download-target and load-target.
20
21use std::path::{Path, PathBuf};
22
23use anyhow::{anyhow, Context, Result};
24
25use crate::voice::VoiceOpts;
26
27// ── Whisper constants (retained for backwards compatibility) ──────────────
28
29/// HuggingFace repository identifier for the `tiny.en` Whisper variant.
30pub const MODEL_ID: &str = "openai/whisper-tiny.en";
31
32/// Pinned HuggingFace revision. `refs/pr/15` adds the safetensors weights
33/// to `openai/whisper-tiny.en`; the candle spike in #813 validated this
34/// exact revision end-to-end.
35pub const REVISION: &str = "refs/pr/15";
36
37/// The three files the Whisper backend needs to load. Order matters for
38/// the install command's progress messages; the backend itself loads them
39/// via [`required_files_in`] independent of order.
40pub const REQUIRED_FILES: &[&str] = &["config.json", "tokenizer.json", "model.safetensors"];
41
42/// Default subdirectory name beneath `~/.omni-dev/voice/models/`.
43///
44/// Derived from [`MODEL_ID`] by stripping the `openai/` org prefix; keeps
45/// room for future variants (`whisper-base.en`, multilingual) as sibling
46/// dirs.
47pub const DEFAULT_VARIANT_DIR: &str = "whisper-tiny.en";
48
49// ── ModelSpec shape ──────────────────────────────────────────────────────
50
51/// Where the bytes of a model come from. Each variant carries the
52/// transport-specific metadata the install command needs to fetch the
53/// model exactly once and verify its integrity.
54#[derive(Debug, Clone, Copy)]
55pub enum ModelSource {
56    /// HuggingFace Hub — Whisper's distribution. The install command
57    /// uses `hf_hub::api::sync::Api` to download `required_files` at a
58    /// pinned revision.
59    HfHub {
60        /// HF repository identifier, e.g. `"openai/whisper-tiny.en"`.
61        repo_id: &'static str,
62        /// Pinned revision (branch, tag, or ref).
63        revision: &'static str,
64    },
65    /// A single signed GitHub release asset — wespeaker's distribution.
66    /// The install command downloads the asset, verifies SHA-256, and
67    /// atomically installs into `required_files[0]`.
68    HttpReleaseAsset {
69        /// Direct download URL.
70        url: &'static str,
71        /// Expected SHA-256 of the downloaded bytes (hex).
72        sha256: &'static str,
73        /// Expected size in bytes; informational, for progress messages.
74        bytes: u64,
75    },
76}
77
78/// Fully describes a model variant's storage, install transport, and CLI
79/// surface. Static lifetime: every field is `&'static str` (or
80/// `&'static [&'static str]`) so `ModelSpec` is `Copy` and `'static`.
81#[derive(Debug, Clone, Copy)]
82pub struct ModelSpec {
83    /// CLI-facing variant identifier: `"whisper-tiny.en"` or
84    /// `"speaker-wespeaker-en"`. Matches the `--variant` value the user
85    /// passes to `voice install-model`.
86    pub variant: &'static str,
87    /// Human label used in error messages: `"Whisper"` or `"Speaker"`.
88    pub kind_label: &'static str,
89    /// Subdirectory beneath `~/.omni-dev/voice/models/` where this
90    /// model's files live.
91    pub default_subdir: &'static str,
92    /// Files that must exist in the install directory for the model to
93    /// be considered installed.
94    pub required_files: &'static [&'static str],
95    /// Environment-variable override for the install directory.
96    pub env_var: &'static str,
97    /// Recommended `install-model` invocation, used verbatim in the
98    /// `ensure_model_present` error hint.
99    pub install_command: &'static str,
100    /// CLI flag that overrides the model path on consumer commands,
101    /// e.g. `"--model"` (Whisper) or `"--speaker-model"` (wespeaker).
102    pub model_flag: &'static str,
103    /// How to fetch the bytes.
104    pub source: ModelSource,
105}
106
107impl ModelSpec {
108    /// Default install directory: `~/.omni-dev/voice/models/<default_subdir>/`.
109    ///
110    /// `None` when the user's home directory cannot be located — same
111    /// failure mode as `dirs::home_dir()`.
112    pub fn default_dir(&self) -> Option<PathBuf> {
113        dirs::home_dir().map(|home| {
114            home.join(".omni-dev")
115                .join("voice")
116                .join("models")
117                .join(self.default_subdir)
118        })
119    }
120
121    /// Resolves the install directory for this spec.
122    ///
123    /// Priority: `override_path` → env var → default. The returned path
124    /// is *not* validated for existence; pair with [`Self::ensure_present`]
125    /// for fail-fast.
126    pub fn resolve_dir(&self, override_path: Option<&Path>) -> Result<PathBuf> {
127        if let Some(p) = override_path {
128            return Ok(p.to_path_buf());
129        }
130        if let Ok(env) = crate::utils::settings::get_env_var(self.env_var) {
131            if !env.is_empty() {
132                return Ok(PathBuf::from(env));
133            }
134        }
135        self.default_dir().ok_or_else(|| {
136            anyhow!(
137                "could not determine home directory; \
138                 pass {} <path> or set {}",
139                self.model_flag,
140                self.env_var
141            )
142        })
143    }
144
145    /// Returns the absolute path of each required file inside `dir`.
146    pub fn required_files_in(&self, dir: &Path) -> Vec<PathBuf> {
147        self.required_files.iter().map(|f| dir.join(f)).collect()
148    }
149
150    /// Verifies that `dir` contains every file in `self.required_files`.
151    ///
152    /// On failure, returns the install hint shaped for this spec (the
153    /// `install_command` / `model_flag` baked into the spec).
154    pub fn ensure_present(&self, dir: &Path) -> Result<()> {
155        for file in self.required_files {
156            let path = dir.join(file);
157            if !path.is_file() {
158                return Err(anyhow!(
159                    "no {} model found at {}; \
160                     run `{}` or pass {} <path>",
161                    self.kind_label,
162                    dir.display(),
163                    self.install_command,
164                    self.model_flag,
165                ))
166                .with_context(|| format!("missing required file: {}", path.display()));
167            }
168        }
169        Ok(())
170    }
171}
172
173// ── Registered specs ──────────────────────────────────────────────────────
174
175/// Whisper `tiny.en` — production ASR runtime per ADR-0033.
176pub const WHISPER_TINY_EN: ModelSpec = ModelSpec {
177    variant: "whisper-tiny.en",
178    kind_label: "Whisper",
179    default_subdir: DEFAULT_VARIANT_DIR,
180    required_files: REQUIRED_FILES,
181    env_var: "OMNI_DEV_VOICE_WHISPER_MODEL",
182    install_command: "omni-dev voice install-model",
183    model_flag: "--model",
184    source: ModelSource::HfHub {
185        repo_id: MODEL_ID,
186        revision: REVISION,
187    },
188};
189
190/// Wespeaker `voxceleb_resnet34_LM` — production speaker-embedding
191/// runtime per ADR-0034. Not yet wired to consumers; the speaker
192/// install variant lands in a follow-up commit.
193pub const SPEAKER_WESPEAKER_EN: ModelSpec = ModelSpec {
194    variant: "speaker-wespeaker-en",
195    kind_label: "Speaker",
196    default_subdir: "wespeaker-en-voxceleb-resnet34-LM",
197    required_files: &["wespeaker_en_voxceleb_resnet34_LM.onnx"],
198    env_var: "OMNI_DEV_VOICE_SPEAKER_MODEL",
199    install_command: "omni-dev voice install-model --variant speaker-wespeaker-en",
200    model_flag: "--speaker-model",
201    source: ModelSource::HttpReleaseAsset {
202        url: "https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_en_voxceleb_resnet34_LM.onnx",
203        sha256: "e9848563da86f263117134dfd7ad63c92355b37de492b55e325400c9d9c39012",
204        bytes: 26_530_550,
205    },
206};
207
208// ── Backwards-compatible Whisper helpers (thin shims) ────────────────────
209
210/// Returns the absolute path of each required model file inside `dir`.
211pub fn required_files_in(dir: &Path) -> Vec<PathBuf> {
212    WHISPER_TINY_EN.required_files_in(dir)
213}
214
215/// Computes the default install location: `~/.omni-dev/voice/models/whisper-tiny.en/`.
216///
217/// Returns `None` only when the user's home directory cannot be located
218/// (i.e. `dirs::home_dir()` returns `None`) — vanishingly rare in practice.
219pub fn default_whisper_model_dir() -> Option<PathBuf> {
220    WHISPER_TINY_EN.default_dir()
221}
222
223/// Resolves the Whisper model directory for the current invocation.
224///
225/// Priority: `opts.model` → `OMNI_DEV_VOICE_WHISPER_MODEL` → default.
226/// The returned path is *not* validated for existence; callers that need
227/// to fail-fast on missing files should pair this with [`ensure_model_present`].
228pub fn resolve_whisper_model_dir(opts: &VoiceOpts) -> Result<PathBuf> {
229    WHISPER_TINY_EN.resolve_dir(opts.model.as_deref())
230}
231
232/// Verifies that `dir` contains every file in [`REQUIRED_FILES`].
233///
234/// On failure, returns the install hint specified by issue #802:
235/// `"no Whisper model found at <path>; run `omni-dev voice install-model`
236/// or pass --model <path>"`.
237pub fn ensure_model_present(dir: &Path) -> Result<()> {
238    WHISPER_TINY_EN.ensure_present(dir)
239}
240
241#[cfg(test)]
242#[allow(clippy::unwrap_used, clippy::expect_used)]
243mod tests {
244    use super::*;
245    use std::sync::{Mutex, MutexGuard};
246
247    static ENV_GUARD: Mutex<()> = Mutex::new(());
248
249    fn env_guard() -> MutexGuard<'static, ()> {
250        match ENV_GUARD.lock() {
251            Ok(g) => g,
252            Err(poisoned) => poisoned.into_inner(),
253        }
254    }
255
256    #[test]
257    fn opts_model_takes_top_priority() {
258        let _g = env_guard();
259        std::env::set_var("OMNI_DEV_VOICE_WHISPER_MODEL", "/should/not/be/read");
260        let opts = VoiceOpts {
261            backend: None,
262            model: Some(PathBuf::from("/explicit/path")),
263        };
264        let resolved = resolve_whisper_model_dir(&opts).unwrap();
265        assert_eq!(resolved, PathBuf::from("/explicit/path"));
266        std::env::remove_var("OMNI_DEV_VOICE_WHISPER_MODEL");
267    }
268
269    #[test]
270    fn env_var_used_when_opts_absent() {
271        let _g = env_guard();
272        std::env::set_var("OMNI_DEV_VOICE_WHISPER_MODEL", "/from/env");
273        let resolved = resolve_whisper_model_dir(&VoiceOpts::default()).unwrap();
274        assert_eq!(resolved, PathBuf::from("/from/env"));
275        std::env::remove_var("OMNI_DEV_VOICE_WHISPER_MODEL");
276    }
277
278    #[test]
279    fn empty_env_var_falls_through_to_default() {
280        let _g = env_guard();
281        std::env::set_var("OMNI_DEV_VOICE_WHISPER_MODEL", "");
282        let resolved = resolve_whisper_model_dir(&VoiceOpts::default()).unwrap();
283        let expected = default_whisper_model_dir().unwrap();
284        assert_eq!(resolved, expected);
285        std::env::remove_var("OMNI_DEV_VOICE_WHISPER_MODEL");
286    }
287
288    #[test]
289    fn default_path_uses_omni_dev_voice_models_subdir() {
290        let dir = default_whisper_model_dir().unwrap();
291        assert!(dir.ends_with(".omni-dev/voice/models/whisper-tiny.en"));
292    }
293
294    #[test]
295    fn ensure_model_present_succeeds_when_all_files_exist() {
296        let tmp = tempfile::TempDir::new().unwrap();
297        for f in REQUIRED_FILES {
298            std::fs::write(tmp.path().join(f), b"placeholder").unwrap();
299        }
300        ensure_model_present(tmp.path()).unwrap();
301    }
302
303    #[test]
304    fn ensure_model_present_errors_with_hint_when_files_missing() {
305        let tmp = tempfile::TempDir::new().unwrap();
306        let err = ensure_model_present(tmp.path()).unwrap_err();
307        let msg = format!("{err:#}");
308        assert!(msg.contains("no Whisper model found"), "got: {msg}");
309        assert!(msg.contains("voice install-model"), "got: {msg}");
310        assert!(msg.contains("--model"), "got: {msg}");
311    }
312
313    #[test]
314    fn ensure_model_present_errors_when_any_file_missing() {
315        let tmp = tempfile::TempDir::new().unwrap();
316        // Write two of three required files; tokenizer.json missing.
317        std::fs::write(tmp.path().join("config.json"), b"x").unwrap();
318        std::fs::write(tmp.path().join("model.safetensors"), b"x").unwrap();
319        let err = ensure_model_present(tmp.path()).unwrap_err();
320        let msg = format!("{err:#}");
321        assert!(msg.contains("tokenizer.json"), "got: {msg}");
322    }
323
324    #[test]
325    fn required_files_in_returns_three_paths() {
326        let paths = required_files_in(Path::new("/x"));
327        assert_eq!(paths.len(), 3);
328        assert_eq!(paths[0], PathBuf::from("/x/config.json"));
329        assert_eq!(paths[1], PathBuf::from("/x/tokenizer.json"));
330        assert_eq!(paths[2], PathBuf::from("/x/model.safetensors"));
331    }
332
333    // ── ModelSpec-shaped API tests ──────────────────────────────────────
334
335    #[test]
336    fn speaker_spec_default_dir_ends_with_wespeaker_subdir() {
337        let dir = SPEAKER_WESPEAKER_EN.default_dir().unwrap();
338        assert!(dir.ends_with(".omni-dev/voice/models/wespeaker-en-voxceleb-resnet34-LM"));
339    }
340
341    #[test]
342    fn speaker_spec_resolve_dir_override_takes_priority() {
343        let _g = env_guard();
344        std::env::set_var("OMNI_DEV_VOICE_SPEAKER_MODEL", "/should/not/be/read");
345        let resolved = SPEAKER_WESPEAKER_EN
346            .resolve_dir(Some(Path::new("/explicit/path")))
347            .unwrap();
348        assert_eq!(resolved, PathBuf::from("/explicit/path"));
349        std::env::remove_var("OMNI_DEV_VOICE_SPEAKER_MODEL");
350    }
351
352    #[test]
353    fn speaker_spec_resolve_dir_env_var_used_when_override_absent() {
354        let _g = env_guard();
355        std::env::set_var("OMNI_DEV_VOICE_SPEAKER_MODEL", "/from/env");
356        let resolved = SPEAKER_WESPEAKER_EN.resolve_dir(None).unwrap();
357        assert_eq!(resolved, PathBuf::from("/from/env"));
358        std::env::remove_var("OMNI_DEV_VOICE_SPEAKER_MODEL");
359    }
360
361    #[test]
362    fn speaker_spec_ensure_present_errors_with_install_hint() {
363        let tmp = tempfile::TempDir::new().unwrap();
364        let err = SPEAKER_WESPEAKER_EN.ensure_present(tmp.path()).unwrap_err();
365        let msg = format!("{err:#}");
366        assert!(msg.contains("no Speaker model found"), "got: {msg}");
367        assert!(msg.contains("--variant speaker-wespeaker-en"), "got: {msg}");
368        assert!(msg.contains("--speaker-model"), "got: {msg}");
369        assert!(
370            msg.contains("wespeaker_en_voxceleb_resnet34_LM.onnx"),
371            "got: {msg}"
372        );
373    }
374
375    #[test]
376    fn speaker_spec_ensure_present_succeeds_when_file_exists() {
377        let tmp = tempfile::TempDir::new().unwrap();
378        std::fs::write(
379            tmp.path().join("wespeaker_en_voxceleb_resnet34_LM.onnx"),
380            b"placeholder",
381        )
382        .unwrap();
383        SPEAKER_WESPEAKER_EN.ensure_present(tmp.path()).unwrap();
384    }
385
386    #[test]
387    fn whisper_spec_required_files_matches_legacy_helper() {
388        let dir = Path::new("/x");
389        assert_eq!(
390            WHISPER_TINY_EN.required_files_in(dir),
391            required_files_in(dir)
392        );
393    }
394
395    #[test]
396    fn whisper_spec_source_carries_pinned_hf_metadata() {
397        match WHISPER_TINY_EN.source {
398            ModelSource::HfHub { repo_id, revision } => {
399                assert_eq!(repo_id, MODEL_ID);
400                assert_eq!(revision, REVISION);
401            }
402            ModelSource::HttpReleaseAsset { .. } => {
403                panic!("WHISPER_TINY_EN should be HfHub-sourced");
404            }
405        }
406    }
407
408    #[test]
409    fn speaker_spec_source_carries_pinned_release_metadata() {
410        match SPEAKER_WESPEAKER_EN.source {
411            ModelSource::HttpReleaseAsset { url, sha256, bytes } => {
412                assert!(url.contains("wespeaker_en_voxceleb_resnet34_LM.onnx"));
413                assert_eq!(
414                    sha256,
415                    "e9848563da86f263117134dfd7ad63c92355b37de492b55e325400c9d9c39012"
416                );
417                assert_eq!(bytes, 26_530_550);
418            }
419            ModelSource::HfHub { .. } => {
420                panic!("SPEAKER_WESPEAKER_EN should be HttpReleaseAsset-sourced");
421            }
422        }
423    }
424}