1use std::path::{Path, PathBuf};
22
23use anyhow::{anyhow, Context, Result};
24
25use crate::voice::VoiceOpts;
26
27pub const MODEL_ID: &str = "openai/whisper-tiny.en";
31
32pub const REVISION: &str = "refs/pr/15";
36
37pub const REQUIRED_FILES: &[&str] = &["config.json", "tokenizer.json", "model.safetensors"];
41
42pub const DEFAULT_VARIANT_DIR: &str = "whisper-tiny.en";
48
49#[derive(Debug, Clone, Copy)]
55pub enum ModelSource {
56 HfHub {
60 repo_id: &'static str,
62 revision: &'static str,
64 },
65 HttpReleaseAsset {
69 url: &'static str,
71 sha256: &'static str,
73 bytes: u64,
75 },
76}
77
78#[derive(Debug, Clone, Copy)]
82pub struct ModelSpec {
83 pub variant: &'static str,
87 pub kind_label: &'static str,
89 pub default_subdir: &'static str,
92 pub required_files: &'static [&'static str],
95 pub env_var: &'static str,
97 pub install_command: &'static str,
100 pub model_flag: &'static str,
103 pub source: ModelSource,
105}
106
107impl ModelSpec {
108 pub fn default_dir(&self) -> Option<PathBuf> {
113 dirs::home_dir().map(|home| {
114 home.join(".omni-dev")
115 .join("voice")
116 .join("models")
117 .join(self.default_subdir)
118 })
119 }
120
121 pub fn resolve_dir(&self, override_path: Option<&Path>) -> Result<PathBuf> {
127 if let Some(p) = override_path {
128 return Ok(p.to_path_buf());
129 }
130 if let Ok(env) = crate::utils::settings::get_env_var(self.env_var) {
131 if !env.is_empty() {
132 return Ok(PathBuf::from(env));
133 }
134 }
135 self.default_dir().ok_or_else(|| {
136 anyhow!(
137 "could not determine home directory; \
138 pass {} <path> or set {}",
139 self.model_flag,
140 self.env_var
141 )
142 })
143 }
144
145 pub fn required_files_in(&self, dir: &Path) -> Vec<PathBuf> {
147 self.required_files.iter().map(|f| dir.join(f)).collect()
148 }
149
150 pub fn ensure_present(&self, dir: &Path) -> Result<()> {
155 for file in self.required_files {
156 let path = dir.join(file);
157 if !path.is_file() {
158 return Err(anyhow!(
159 "no {} model found at {}; \
160 run `{}` or pass {} <path>",
161 self.kind_label,
162 dir.display(),
163 self.install_command,
164 self.model_flag,
165 ))
166 .with_context(|| format!("missing required file: {}", path.display()));
167 }
168 }
169 Ok(())
170 }
171}
172
173pub const WHISPER_TINY_EN: ModelSpec = ModelSpec {
177 variant: "whisper-tiny.en",
178 kind_label: "Whisper",
179 default_subdir: DEFAULT_VARIANT_DIR,
180 required_files: REQUIRED_FILES,
181 env_var: "OMNI_DEV_VOICE_WHISPER_MODEL",
182 install_command: "omni-dev voice install-model",
183 model_flag: "--model",
184 source: ModelSource::HfHub {
185 repo_id: MODEL_ID,
186 revision: REVISION,
187 },
188};
189
190pub const SPEAKER_WESPEAKER_EN: ModelSpec = ModelSpec {
194 variant: "speaker-wespeaker-en",
195 kind_label: "Speaker",
196 default_subdir: "wespeaker-en-voxceleb-resnet34-LM",
197 required_files: &["wespeaker_en_voxceleb_resnet34_LM.onnx"],
198 env_var: "OMNI_DEV_VOICE_SPEAKER_MODEL",
199 install_command: "omni-dev voice install-model --variant speaker-wespeaker-en",
200 model_flag: "--speaker-model",
201 source: ModelSource::HttpReleaseAsset {
202 url: "https://github.com/k2-fsa/sherpa-onnx/releases/download/speaker-recongition-models/wespeaker_en_voxceleb_resnet34_LM.onnx",
203 sha256: "e9848563da86f263117134dfd7ad63c92355b37de492b55e325400c9d9c39012",
204 bytes: 26_530_550,
205 },
206};
207
208pub fn required_files_in(dir: &Path) -> Vec<PathBuf> {
212 WHISPER_TINY_EN.required_files_in(dir)
213}
214
215pub fn default_whisper_model_dir() -> Option<PathBuf> {
220 WHISPER_TINY_EN.default_dir()
221}
222
223pub fn resolve_whisper_model_dir(opts: &VoiceOpts) -> Result<PathBuf> {
229 WHISPER_TINY_EN.resolve_dir(opts.model.as_deref())
230}
231
232pub fn ensure_model_present(dir: &Path) -> Result<()> {
238 WHISPER_TINY_EN.ensure_present(dir)
239}
240
241#[cfg(test)]
242#[allow(clippy::unwrap_used, clippy::expect_used)]
243mod tests {
244 use super::*;
245 use std::sync::{Mutex, MutexGuard};
246
247 static ENV_GUARD: Mutex<()> = Mutex::new(());
248
249 fn env_guard() -> MutexGuard<'static, ()> {
250 match ENV_GUARD.lock() {
251 Ok(g) => g,
252 Err(poisoned) => poisoned.into_inner(),
253 }
254 }
255
256 #[test]
257 fn opts_model_takes_top_priority() {
258 let _g = env_guard();
259 std::env::set_var("OMNI_DEV_VOICE_WHISPER_MODEL", "/should/not/be/read");
260 let opts = VoiceOpts {
261 backend: None,
262 model: Some(PathBuf::from("/explicit/path")),
263 };
264 let resolved = resolve_whisper_model_dir(&opts).unwrap();
265 assert_eq!(resolved, PathBuf::from("/explicit/path"));
266 std::env::remove_var("OMNI_DEV_VOICE_WHISPER_MODEL");
267 }
268
269 #[test]
270 fn env_var_used_when_opts_absent() {
271 let _g = env_guard();
272 std::env::set_var("OMNI_DEV_VOICE_WHISPER_MODEL", "/from/env");
273 let resolved = resolve_whisper_model_dir(&VoiceOpts::default()).unwrap();
274 assert_eq!(resolved, PathBuf::from("/from/env"));
275 std::env::remove_var("OMNI_DEV_VOICE_WHISPER_MODEL");
276 }
277
278 #[test]
279 fn empty_env_var_falls_through_to_default() {
280 let _g = env_guard();
281 std::env::set_var("OMNI_DEV_VOICE_WHISPER_MODEL", "");
282 let resolved = resolve_whisper_model_dir(&VoiceOpts::default()).unwrap();
283 let expected = default_whisper_model_dir().unwrap();
284 assert_eq!(resolved, expected);
285 std::env::remove_var("OMNI_DEV_VOICE_WHISPER_MODEL");
286 }
287
288 #[test]
289 fn default_path_uses_omni_dev_voice_models_subdir() {
290 let dir = default_whisper_model_dir().unwrap();
291 assert!(dir.ends_with(".omni-dev/voice/models/whisper-tiny.en"));
292 }
293
294 #[test]
295 fn ensure_model_present_succeeds_when_all_files_exist() {
296 let tmp = tempfile::TempDir::new().unwrap();
297 for f in REQUIRED_FILES {
298 std::fs::write(tmp.path().join(f), b"placeholder").unwrap();
299 }
300 ensure_model_present(tmp.path()).unwrap();
301 }
302
303 #[test]
304 fn ensure_model_present_errors_with_hint_when_files_missing() {
305 let tmp = tempfile::TempDir::new().unwrap();
306 let err = ensure_model_present(tmp.path()).unwrap_err();
307 let msg = format!("{err:#}");
308 assert!(msg.contains("no Whisper model found"), "got: {msg}");
309 assert!(msg.contains("voice install-model"), "got: {msg}");
310 assert!(msg.contains("--model"), "got: {msg}");
311 }
312
313 #[test]
314 fn ensure_model_present_errors_when_any_file_missing() {
315 let tmp = tempfile::TempDir::new().unwrap();
316 std::fs::write(tmp.path().join("config.json"), b"x").unwrap();
318 std::fs::write(tmp.path().join("model.safetensors"), b"x").unwrap();
319 let err = ensure_model_present(tmp.path()).unwrap_err();
320 let msg = format!("{err:#}");
321 assert!(msg.contains("tokenizer.json"), "got: {msg}");
322 }
323
324 #[test]
325 fn required_files_in_returns_three_paths() {
326 let paths = required_files_in(Path::new("/x"));
327 assert_eq!(paths.len(), 3);
328 assert_eq!(paths[0], PathBuf::from("/x/config.json"));
329 assert_eq!(paths[1], PathBuf::from("/x/tokenizer.json"));
330 assert_eq!(paths[2], PathBuf::from("/x/model.safetensors"));
331 }
332
333 #[test]
336 fn speaker_spec_default_dir_ends_with_wespeaker_subdir() {
337 let dir = SPEAKER_WESPEAKER_EN.default_dir().unwrap();
338 assert!(dir.ends_with(".omni-dev/voice/models/wespeaker-en-voxceleb-resnet34-LM"));
339 }
340
341 #[test]
342 fn speaker_spec_resolve_dir_override_takes_priority() {
343 let _g = env_guard();
344 std::env::set_var("OMNI_DEV_VOICE_SPEAKER_MODEL", "/should/not/be/read");
345 let resolved = SPEAKER_WESPEAKER_EN
346 .resolve_dir(Some(Path::new("/explicit/path")))
347 .unwrap();
348 assert_eq!(resolved, PathBuf::from("/explicit/path"));
349 std::env::remove_var("OMNI_DEV_VOICE_SPEAKER_MODEL");
350 }
351
352 #[test]
353 fn speaker_spec_resolve_dir_env_var_used_when_override_absent() {
354 let _g = env_guard();
355 std::env::set_var("OMNI_DEV_VOICE_SPEAKER_MODEL", "/from/env");
356 let resolved = SPEAKER_WESPEAKER_EN.resolve_dir(None).unwrap();
357 assert_eq!(resolved, PathBuf::from("/from/env"));
358 std::env::remove_var("OMNI_DEV_VOICE_SPEAKER_MODEL");
359 }
360
361 #[test]
362 fn speaker_spec_ensure_present_errors_with_install_hint() {
363 let tmp = tempfile::TempDir::new().unwrap();
364 let err = SPEAKER_WESPEAKER_EN.ensure_present(tmp.path()).unwrap_err();
365 let msg = format!("{err:#}");
366 assert!(msg.contains("no Speaker model found"), "got: {msg}");
367 assert!(msg.contains("--variant speaker-wespeaker-en"), "got: {msg}");
368 assert!(msg.contains("--speaker-model"), "got: {msg}");
369 assert!(
370 msg.contains("wespeaker_en_voxceleb_resnet34_LM.onnx"),
371 "got: {msg}"
372 );
373 }
374
375 #[test]
376 fn speaker_spec_ensure_present_succeeds_when_file_exists() {
377 let tmp = tempfile::TempDir::new().unwrap();
378 std::fs::write(
379 tmp.path().join("wespeaker_en_voxceleb_resnet34_LM.onnx"),
380 b"placeholder",
381 )
382 .unwrap();
383 SPEAKER_WESPEAKER_EN.ensure_present(tmp.path()).unwrap();
384 }
385
386 #[test]
387 fn whisper_spec_required_files_matches_legacy_helper() {
388 let dir = Path::new("/x");
389 assert_eq!(
390 WHISPER_TINY_EN.required_files_in(dir),
391 required_files_in(dir)
392 );
393 }
394
395 #[test]
396 fn whisper_spec_source_carries_pinned_hf_metadata() {
397 match WHISPER_TINY_EN.source {
398 ModelSource::HfHub { repo_id, revision } => {
399 assert_eq!(repo_id, MODEL_ID);
400 assert_eq!(revision, REVISION);
401 }
402 ModelSource::HttpReleaseAsset { .. } => {
403 panic!("WHISPER_TINY_EN should be HfHub-sourced");
404 }
405 }
406 }
407
408 #[test]
409 fn speaker_spec_source_carries_pinned_release_metadata() {
410 match SPEAKER_WESPEAKER_EN.source {
411 ModelSource::HttpReleaseAsset { url, sha256, bytes } => {
412 assert!(url.contains("wespeaker_en_voxceleb_resnet34_LM.onnx"));
413 assert_eq!(
414 sha256,
415 "e9848563da86f263117134dfd7ad63c92355b37de492b55e325400c9d9c39012"
416 );
417 assert_eq!(bytes, 26_530_550);
418 }
419 ModelSource::HfHub { .. } => {
420 panic!("SPEAKER_WESPEAKER_EN should be HttpReleaseAsset-sourced");
421 }
422 }
423 }
424}