Skip to main content

aic_model_downloader/
lib.rs

1use sha2::{Digest, Sha256};
2use std::{
3    fs::{self, File},
4    io::Read,
5    path::{Path, PathBuf},
6};
7use thiserror::Error;
8
9mod manifest;
10use manifest::Manifest;
11
12const MODEL_BASE_URL: &str = "https://artifacts.ai-coustics.io/";
13
14#[derive(Debug, Error)]
15pub enum Error {
16    #[error("I/O error: {0}")]
17    Io(String),
18    #[error("Failed to download manifest: {0}")]
19    ManifestDownload(String),
20    #[error("Failed to parse manifest: {0}")]
21    ManifestParse(String),
22    #[error("Model `{0}` not found in manifest")]
23    ModelNotFound(String),
24    #[error("Model `{model}` missing compatible version v{compatible_version}")]
25    IncompatibleModel {
26        model: String,
27        compatible_version: u32,
28    },
29    #[error("Failed to download model file: {0}")]
30    ModelDownload(String),
31    #[error("Checksum mismatch for downloaded model")]
32    ChecksumMismatch,
33}
34
35/// Downloads a model file compatible with the provided model version.
36///
37/// The function fetches the model manifest, checks whether the requested model
38/// exists in a version compatible with the given `model_version`, and downloads
39/// the model file into the provided directory.
40pub fn download<P: AsRef<Path>>(
41    model_id: &str,
42    model_version: u32,
43    download_dir: P,
44) -> Result<PathBuf, Error> {
45    let manifest = Manifest::download()?;
46    let model = manifest.metadata_for_model(model_id, model_version)?;
47
48    let download_dir = download_dir.as_ref();
49    fs::create_dir_all(download_dir).map_err(|err| Error::Io(err.to_string()))?;
50
51    let destination = download_dir.join(&model.file_name);
52    if destination.exists() && checksum_matches(&destination, &model.checksum)? {
53        return Ok(destination);
54    }
55
56    let url = format!("{MODEL_BASE_URL}{}", model.url_path);
57    let bytes = download_bytes(&url)?;
58
59    let temp_path = destination.with_extension("download");
60    fs::write(&temp_path, &bytes).map_err(|err| Error::Io(err.to_string()))?;
61
62    if !checksum_matches(&temp_path, &model.checksum)? {
63        let _ = fs::remove_file(&temp_path);
64        return Err(Error::ChecksumMismatch);
65    }
66
67    fs::rename(&temp_path, &destination).map_err(|err| Error::Io(err.to_string()))?;
68
69    Ok(destination)
70}
71
72fn download_bytes(url: &str) -> Result<Vec<u8>, Error> {
73    let response = ureq::get(url)
74        .call()
75        .map_err(|err| Error::ModelDownload(err.to_string()))?;
76
77    response
78        .into_body()
79        .into_with_config()
80        .read_to_vec()
81        .map_err(|err| Error::ModelDownload(err.to_string()))
82}
83
84fn checksum_matches(path: &Path, expected: &str) -> Result<bool, Error> {
85    let mut file = File::open(path).map_err(|err| Error::Io(err.to_string()))?;
86    let mut hasher = Sha256::new();
87    let mut buffer = [0u8; 8192];
88
89    loop {
90        let read = file
91            .read(&mut buffer)
92            .map_err(|err| Error::Io(err.to_string()))?;
93        if read == 0 {
94            break;
95        }
96        hasher.update(&buffer[..read]);
97    }
98
99    let checksum = hasher
100        .finalize()
101        .iter()
102        .map(|byte| format!("{byte:02x}"))
103        .collect::<String>();
104    Ok(checksum.eq_ignore_ascii_case(expected))
105}