1use std::io::Write;
13use std::path::{Path, PathBuf};
14use std::time::Instant;
15
16use anyhow::{anyhow, bail, Context, Result};
17use clap::{Parser, ValueEnum};
18use hf_hub::{api::sync::Api, Repo, RepoType};
19use sha2::{Digest, Sha256};
20
21use crate::voice::models::{ModelSource, ModelSpec, SPEAKER_WESPEAKER_EN, WHISPER_TINY_EN};
22
23#[derive(Copy, Clone, Debug, Default, PartialEq, Eq, ValueEnum)]
29pub enum Variant {
30 #[default]
32 #[value(name = "whisper-tiny.en")]
33 WhisperTinyEn,
34 #[value(name = "speaker-wespeaker-en")]
36 SpeakerWespeakerEn,
37}
38
39impl Variant {
40 pub fn spec(self) -> &'static ModelSpec {
42 match self {
43 Self::WhisperTinyEn => &WHISPER_TINY_EN,
44 Self::SpeakerWespeakerEn => &SPEAKER_WESPEAKER_EN,
45 }
46 }
47}
48
49#[derive(Parser)]
57pub struct InstallModelCommand {
58 #[arg(long)]
61 pub dest: Option<PathBuf>,
62
63 #[arg(long)]
65 pub force: bool,
66
67 #[arg(long, value_enum, default_value_t = Variant::WhisperTinyEn)]
69 pub variant: Variant,
70}
71
72impl InstallModelCommand {
73 pub fn execute(self) -> Result<()> {
77 let mut err = std::io::stderr().lock();
78 self.run(&mut err)
79 }
80
81 fn run<W: Write>(self, w: &mut W) -> Result<()> {
84 let spec = self.variant.spec();
85 let dest = match self.dest {
86 Some(p) => p,
87 None => spec
88 .default_dir()
89 .ok_or_else(|| anyhow!("could not determine home directory; pass --dest <path>"))?,
90 };
91
92 if !self.force && all_present(spec, &dest) {
93 writeln!(w, "model already installed at {}", dest.display())?;
94 return Ok(());
95 }
96
97 match spec.source {
98 ModelSource::HfHub { repo_id, revision } => {
99 download_hf_hub(spec, repo_id, revision, &dest, w)
100 }
101 ModelSource::HttpReleaseAsset { url, sha256, bytes } => {
102 download_release_asset(spec, url, sha256, bytes, &dest, w)
103 }
104 }
105 }
106}
107
108fn all_present(spec: &ModelSpec, dir: &Path) -> bool {
109 spec.required_files_in(dir)
110 .iter()
111 .all(|p| p.is_file() && p.metadata().is_ok_and(|m| m.len() > 0))
112}
113
114fn download_hf_hub<W: Write>(
115 spec: &ModelSpec,
116 repo_id: &str,
117 revision: &str,
118 dest: &Path,
119 w: &mut W,
120) -> Result<()> {
121 writeln!(
122 w,
123 "Installing {repo_id} (revision {revision}) -> {}",
124 dest.display()
125 )?;
126 std::fs::create_dir_all(dest)
127 .with_context(|| format!("create install directory at {}", dest.display()))?;
128
129 let api = Api::new().context("initialise HuggingFace Hub client")?;
130 let repo = api.repo(Repo::with_revision(
131 repo_id.to_string(),
132 RepoType::Model,
133 revision.to_string(),
134 ));
135
136 for file in spec.required_files {
137 let start = Instant::now();
138 write!(w, " fetching {file}... ")?;
139 w.flush()?;
140 let downloaded = repo.get(file).with_context(|| {
141 format!(
142 "download {file} from {repo_id} (revision {revision}). \
143 Check your network or set HTTPS_PROXY"
144 )
145 })?;
146 let target = dest.join(file);
147 atomic_install_copy(&downloaded, &target).with_context(|| {
148 format!(
149 "install {file} into {} (atomic rename failed)",
150 target.display()
151 )
152 })?;
153 let bytes = std::fs::metadata(&target).map_or(0, |m| m.len());
154 writeln!(
155 w,
156 "done ({bytes} bytes in {:.1}s)",
157 start.elapsed().as_secs_f64()
158 )?;
159 }
160
161 writeln!(
162 w,
163 "{} model installed at {}",
164 spec.kind_label,
165 dest.display()
166 )?;
167 Ok(())
168}
169
170fn download_release_asset<W: Write>(
171 spec: &ModelSpec,
172 url: &str,
173 expected_sha256: &str,
174 expected_bytes: u64,
175 dest: &Path,
176 w: &mut W,
177) -> Result<()> {
178 if spec.required_files.len() != 1 {
183 bail!(
184 "HttpReleaseAsset source expects exactly one required_file, \
185 got {} for variant {}",
186 spec.required_files.len(),
187 spec.variant
188 );
189 }
190 let file_name = spec.required_files[0];
191 let target = dest.join(file_name);
192
193 writeln!(
194 w,
195 "Installing {file_name} ({expected_bytes} B) -> {}",
196 dest.display()
197 )?;
198 std::fs::create_dir_all(dest)
199 .with_context(|| format!("create install directory at {}", dest.display()))?;
200
201 let start = Instant::now();
202 write!(w, " fetching {url}... ")?;
203 w.flush()?;
204
205 let resp = ureq::get(url)
206 .call()
207 .with_context(|| format!("HTTP GET {url}"))?;
208 let status = resp.status();
209 if !status.is_success() {
210 bail!(
211 "HTTP {} fetching {url}: {}",
212 status.as_u16(),
213 status.canonical_reason().unwrap_or("Unknown"),
214 );
215 }
216 let bytes = resp
217 .into_body()
218 .read_to_vec()
219 .with_context(|| format!("read response body for {url}"))?;
220
221 let actual_sha = {
222 let mut hasher = Sha256::new();
223 hasher.update(&bytes);
224 let digest = hasher.finalize();
225 let mut hex = String::with_capacity(digest.len() * 2);
226 for byte in digest {
227 use std::fmt::Write as _;
228 let _ = write!(&mut hex, "{byte:02x}");
230 }
231 hex
232 };
233 if !actual_sha.eq_ignore_ascii_case(expected_sha256) {
234 bail!("SHA-256 mismatch for {file_name}: expected {expected_sha256}, got {actual_sha}");
235 }
236
237 atomic_install_bytes(&bytes, &target).with_context(|| {
238 format!(
239 "install {file_name} into {} (atomic rename failed)",
240 target.display()
241 )
242 })?;
243 writeln!(
244 w,
245 "done ({} bytes in {:.1}s; sha256 verified)",
246 bytes.len(),
247 start.elapsed().as_secs_f64()
248 )?;
249 writeln!(
250 w,
251 "{} model installed at {}",
252 spec.kind_label,
253 dest.display()
254 )?;
255 Ok(())
256}
257
258fn atomic_install_bytes(bytes: &[u8], to: &Path) -> Result<()> {
261 if let Some(parent) = to.parent() {
262 std::fs::create_dir_all(parent)
263 .with_context(|| format!("create parent dir {}", parent.display()))?;
264 }
265 let tmp = part_sibling(to)?;
266 std::fs::write(&tmp, bytes)
267 .with_context(|| format!("write {} bytes -> {}", bytes.len(), tmp.display()))?;
268 std::fs::rename(&tmp, to)
269 .with_context(|| format!("rename {} -> {}", tmp.display(), to.display()))?;
270 Ok(())
271}
272
273fn atomic_install_copy(from: &Path, to: &Path) -> Result<()> {
276 if let Some(parent) = to.parent() {
277 std::fs::create_dir_all(parent)
278 .with_context(|| format!("create parent dir {}", parent.display()))?;
279 }
280 let tmp = part_sibling(to)?;
281 std::fs::copy(from, &tmp)
282 .with_context(|| format!("copy {} -> {}", from.display(), tmp.display()))?;
283 std::fs::rename(&tmp, to)
284 .with_context(|| format!("rename {} -> {}", tmp.display(), to.display()))?;
285 Ok(())
286}
287
288fn part_sibling(to: &Path) -> Result<PathBuf> {
289 let file_name = to
290 .file_name()
291 .ok_or_else(|| anyhow!("destination path has no file name: {}", to.display()))?;
292 let mut tmp_name = std::ffi::OsString::from(".");
293 tmp_name.push(file_name);
294 tmp_name.push(".part");
295 Ok(to.with_file_name(tmp_name))
296}
297
298#[cfg(test)]
299#[allow(clippy::unwrap_used, clippy::expect_used)]
300mod tests {
301 use super::*;
302 use crate::voice::models::REQUIRED_FILES;
303 use std::sync::{Mutex, MutexGuard};
304
305 static ENV_GUARD: Mutex<()> = Mutex::new(());
308
309 fn env_guard() -> MutexGuard<'static, ()> {
310 match ENV_GUARD.lock() {
311 Ok(g) => g,
312 Err(poisoned) => poisoned.into_inner(),
313 }
314 }
315
316 fn stage_complete_whisper_model(dir: &Path) {
317 std::fs::create_dir_all(dir).unwrap();
318 for f in REQUIRED_FILES {
319 std::fs::write(dir.join(f), b"placeholder").unwrap();
320 }
321 }
322
323 fn stage_complete_speaker_model(dir: &Path) {
324 std::fs::create_dir_all(dir).unwrap();
325 for f in SPEAKER_WESPEAKER_EN.required_files {
326 std::fs::write(dir.join(f), b"placeholder").unwrap();
327 }
328 }
329
330 #[test]
331 fn idempotent_when_all_files_present() {
332 let tmp = tempfile::TempDir::new().unwrap();
333 stage_complete_whisper_model(tmp.path());
334
335 let cmd = InstallModelCommand {
336 dest: Some(tmp.path().to_path_buf()),
337 force: false,
338 variant: Variant::WhisperTinyEn,
339 };
340 let mut out: Vec<u8> = Vec::new();
341 cmd.run(&mut out).unwrap();
342 let msg = String::from_utf8(out).unwrap();
343 assert!(msg.contains("already installed"), "got: {msg}");
344 }
345
346 #[test]
347 fn idempotent_when_speaker_model_present() {
348 let tmp = tempfile::TempDir::new().unwrap();
349 stage_complete_speaker_model(tmp.path());
350
351 let cmd = InstallModelCommand {
352 dest: Some(tmp.path().to_path_buf()),
353 force: false,
354 variant: Variant::SpeakerWespeakerEn,
355 };
356 let mut out: Vec<u8> = Vec::new();
357 cmd.run(&mut out).unwrap();
358 let msg = String::from_utf8(out).unwrap();
359 assert!(msg.contains("already installed"), "got: {msg}");
360 }
361
362 #[test]
363 fn idempotent_skip_treats_zero_byte_file_as_missing() {
364 let tmp = tempfile::TempDir::new().unwrap();
365 std::fs::create_dir_all(tmp.path()).unwrap();
366 for f in REQUIRED_FILES {
367 std::fs::write(tmp.path().join(f), b"").unwrap();
370 }
371 assert!(!all_present(&WHISPER_TINY_EN, tmp.path()));
372 }
373
374 #[test]
375 fn atomic_install_copy_replaces_target() {
376 let tmp = tempfile::TempDir::new().unwrap();
377 let src = tmp.path().join("src");
378 let dst = tmp.path().join("dst");
379 std::fs::write(&src, b"hello").unwrap();
380 std::fs::write(&dst, b"old").unwrap();
381 atomic_install_copy(&src, &dst).unwrap();
382 let got = std::fs::read(&dst).unwrap();
383 assert_eq!(got, b"hello");
384 let leftover = std::fs::read_dir(tmp.path())
386 .unwrap()
387 .filter_map(Result::ok)
388 .any(|e| e.file_name().to_string_lossy().ends_with(".part"));
389 assert!(!leftover, "atomic_install_copy must not leave .part files");
390 }
391
392 #[test]
393 fn atomic_install_bytes_writes_and_renames() {
394 let tmp = tempfile::TempDir::new().unwrap();
395 let dst = tmp.path().join("out");
396 atomic_install_bytes(b"hello", &dst).unwrap();
397 assert_eq!(std::fs::read(&dst).unwrap(), b"hello");
398 let leftover = std::fs::read_dir(tmp.path())
399 .unwrap()
400 .filter_map(Result::ok)
401 .any(|e| e.file_name().to_string_lossy().ends_with(".part"));
402 assert!(!leftover, "atomic_install_bytes must not leave .part files");
403 }
404
405 #[test]
406 fn parses_no_args() {
407 #[derive(Parser)]
408 struct T {
409 #[command(flatten)]
410 c: InstallModelCommand,
411 }
412 let t = T::try_parse_from(["test"]).unwrap();
413 assert!(t.c.dest.is_none());
414 assert!(!t.c.force);
415 assert_eq!(t.c.variant, Variant::WhisperTinyEn);
416 }
417
418 #[test]
419 fn parses_dest_and_force() {
420 #[derive(Parser)]
421 struct T {
422 #[command(flatten)]
423 c: InstallModelCommand,
424 }
425 let t = T::try_parse_from(["test", "--dest", "/opt/x", "--force"]).unwrap();
426 assert_eq!(t.c.dest.as_deref(), Some(Path::new("/opt/x")));
427 assert!(t.c.force);
428 }
429
430 #[test]
431 fn parses_speaker_variant() {
432 #[derive(Parser)]
433 struct T {
434 #[command(flatten)]
435 c: InstallModelCommand,
436 }
437 let t = T::try_parse_from(["test", "--variant", "speaker-wespeaker-en"]).unwrap();
438 assert_eq!(t.c.variant, Variant::SpeakerWespeakerEn);
439 }
440
441 #[test]
442 fn parses_whisper_variant_explicit() {
443 #[derive(Parser)]
444 struct T {
445 #[command(flatten)]
446 c: InstallModelCommand,
447 }
448 let t = T::try_parse_from(["test", "--variant", "whisper-tiny.en"]).unwrap();
449 assert_eq!(t.c.variant, Variant::WhisperTinyEn);
450 }
451
452 #[test]
453 fn rejects_unknown_variant() {
454 #[derive(Parser)]
455 struct T {
456 #[command(flatten)]
457 c: InstallModelCommand,
458 }
459 let err = T::try_parse_from(["test", "--variant", "klingon"]);
460 assert!(err.is_err(), "unknown variant should fail to parse");
461 }
462
463 #[test]
464 fn run_with_dest_none_resolves_default_install_dir_from_home() {
465 let _g = env_guard();
471 let tmp = tempfile::TempDir::new().unwrap();
472 let prev_home = std::env::var_os("HOME");
473 std::env::set_var("HOME", tmp.path());
474
475 let default_dir = WHISPER_TINY_EN.default_dir().unwrap();
476 stage_complete_whisper_model(&default_dir);
477
478 let cmd = InstallModelCommand {
479 dest: None,
480 force: false,
481 variant: Variant::WhisperTinyEn,
482 };
483 let mut out: Vec<u8> = Vec::new();
484 let result = cmd.run(&mut out);
485
486 match prev_home {
487 Some(v) => std::env::set_var("HOME", v),
488 None => std::env::remove_var("HOME"),
489 }
490
491 result.unwrap();
492 let msg = String::from_utf8(out).unwrap();
493 assert!(msg.contains("already installed"), "got: {msg}");
494 assert!(
495 msg.contains("whisper-tiny.en"),
496 "expected resolved default dir in message, got: {msg}"
497 );
498 }
499
500 #[test]
501 fn run_speaker_variant_with_dest_none_resolves_default() {
502 let _g = env_guard();
503 let tmp = tempfile::TempDir::new().unwrap();
504 let prev_home = std::env::var_os("HOME");
505 std::env::set_var("HOME", tmp.path());
506
507 let default_dir = SPEAKER_WESPEAKER_EN.default_dir().unwrap();
508 stage_complete_speaker_model(&default_dir);
509
510 let cmd = InstallModelCommand {
511 dest: None,
512 force: false,
513 variant: Variant::SpeakerWespeakerEn,
514 };
515 let mut out: Vec<u8> = Vec::new();
516 let result = cmd.run(&mut out);
517
518 match prev_home {
519 Some(v) => std::env::set_var("HOME", v),
520 None => std::env::remove_var("HOME"),
521 }
522
523 result.unwrap();
524 let msg = String::from_utf8(out).unwrap();
525 assert!(msg.contains("already installed"), "got: {msg}");
526 assert!(
527 msg.contains("wespeaker-en-voxceleb-resnet34-LM"),
528 "expected resolved default dir in message, got: {msg}"
529 );
530 }
531
532 #[test]
533 fn variant_spec_returns_correct_spec() {
534 assert_eq!(
535 Variant::WhisperTinyEn.spec().variant,
536 WHISPER_TINY_EN.variant
537 );
538 assert_eq!(
539 Variant::SpeakerWespeakerEn.spec().variant,
540 SPEAKER_WESPEAKER_EN.variant
541 );
542 }
543}