aic_model_downloader/
lib.rs1use sha2::{Digest, Sha256};
2use std::{
3 fs::{self, File},
4 io::Read,
5 path::{Path, PathBuf},
6};
7use thiserror::Error;
8
9mod manifest;
10use manifest::Manifest;
11
12const MODEL_BASE_URL: &str = "https://artifacts.ai-coustics.io/";
13
14#[derive(Debug, Error)]
15pub enum Error {
16 #[error("I/O error: {0}")]
17 Io(String),
18 #[error("Failed to download manifest: {0}")]
19 ManifestDownload(String),
20 #[error("Failed to parse manifest: {0}")]
21 ManifestParse(String),
22 #[error("Model `{0}` not found in manifest")]
23 ModelNotFound(String),
24 #[error("Model `{model}` missing compatible version v{compatible_version}")]
25 IncompatibleModel {
26 model: String,
27 compatible_version: u32,
28 },
29 #[error("Failed to download model file: {0}")]
30 ModelDownload(String),
31 #[error("Checksum mismatch for downloaded model")]
32 ChecksumMismatch,
33}
34
35pub fn download<P: AsRef<Path>>(
41 model_id: &str,
42 model_version: u32,
43 download_dir: P,
44) -> Result<PathBuf, Error> {
45 let manifest = Manifest::download()?;
46 let model = manifest.metadata_for_model(model_id, model_version)?;
47
48 let download_dir = download_dir.as_ref();
49 fs::create_dir_all(download_dir).map_err(|err| Error::Io(err.to_string()))?;
50
51 let destination = download_dir.join(&model.file_name);
52 if destination.exists() && checksum_matches(&destination, &model.checksum)? {
53 return Ok(destination);
54 }
55
56 let url = format!("{MODEL_BASE_URL}{}", model.url_path);
57 let bytes = download_bytes(&url)?;
58
59 let temp_path = destination.with_extension("download");
60 fs::write(&temp_path, &bytes).map_err(|err| Error::Io(err.to_string()))?;
61
62 if !checksum_matches(&temp_path, &model.checksum)? {
63 let _ = fs::remove_file(&temp_path);
64 return Err(Error::ChecksumMismatch);
65 }
66
67 fs::rename(&temp_path, &destination).map_err(|err| Error::Io(err.to_string()))?;
68
69 Ok(destination)
70}
71
72fn download_bytes(url: &str) -> Result<Vec<u8>, Error> {
73 let response = ureq::get(url)
74 .call()
75 .map_err(|err| Error::ModelDownload(err.to_string()))?;
76
77 response
78 .into_body()
79 .into_with_config()
80 .read_to_vec()
81 .map_err(|err| Error::ModelDownload(err.to_string()))
82}
83
84fn checksum_matches(path: &Path, expected: &str) -> Result<bool, Error> {
85 let mut file = File::open(path).map_err(|err| Error::Io(err.to_string()))?;
86 let mut hasher = Sha256::new();
87 let mut buffer = [0u8; 8192];
88
89 loop {
90 let read = file
91 .read(&mut buffer)
92 .map_err(|err| Error::Io(err.to_string()))?;
93 if read == 0 {
94 break;
95 }
96 hasher.update(&buffer[..read]);
97 }
98
99 let checksum = hasher
100 .finalize()
101 .iter()
102 .map(|byte| format!("{byte:02x}"))
103 .collect::<String>();
104 Ok(checksum.eq_ignore_ascii_case(expected))
105}