Skip to main content

model_runtime/
download.rs

1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3
4use crate::{ModelRuntimeError, Result};
5use hf_hub::api::sync::ApiBuilder;
6use hf_hub::{Repo, RepoType};
7
8use crate::{HuggingFaceModelSpec, ModelFileRequest};
9
10#[derive(Debug, Clone)]
11/// Data type for downloaded model.
12pub struct DownloadedModel {
13    /// The spec value.
14    pub spec: HuggingFaceModelSpec,
15    /// The files value.
16    pub files: BTreeMap<String, PathBuf>,
17}
18
19impl DownloadedModel {
20    /// Returns model dir.
21    pub fn model_dir(&self) -> Option<&Path> {
22        self.files.values().next().and_then(|path| path.parent())
23    }
24}
25
26#[derive(Debug, Clone)]
27/// Data type for hugging face downloader.
28pub struct HuggingFaceDownloader {
29    cache_dir: Option<PathBuf>,
30    token: Option<String>,
31    progress: bool,
32    max_retries: usize,
33}
34
35impl Default for HuggingFaceDownloader {
36    fn default() -> Self {
37        Self {
38            cache_dir: None,
39            token: None,
40            progress: true,
41            max_retries: 0,
42        }
43    }
44}
45
46impl HuggingFaceDownloader {
47    /// Creates a new value.
48    pub fn new() -> Self {
49        Self::default()
50    }
51
52    /// Returns cache dir.
53    pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
54        self.cache_dir = Some(path.into());
55        self
56    }
57
58    /// Returns token.
59    pub fn token(mut self, value: impl Into<String>) -> Self {
60        self.token = Some(value.into());
61        self
62    }
63
64    /// Returns progress.
65    pub fn progress(mut self, value: bool) -> Self {
66        self.progress = value;
67        self
68    }
69
70    /// Returns max retries.
71    pub fn max_retries(mut self, value: usize) -> Self {
72        self.max_retries = value;
73        self
74    }
75
76    /// Returns download.
77    pub fn download(&self, spec: &HuggingFaceModelSpec) -> Result<DownloadedModel> {
78        if spec.files.is_empty() {
79            return Err(ModelRuntimeError::InvalidArgument(
80                "at least one model file must be requested".to_string(),
81            ));
82        }
83
84        let mut builder = ApiBuilder::from_env()
85            .with_progress(self.progress)
86            .with_retries(self.max_retries)
87            .with_user_agent("video-analysis", env!("CARGO_PKG_VERSION"));
88        if let Some(cache_dir) = &self.cache_dir {
89            builder = builder.with_cache_dir(cache_dir.clone());
90        }
91        builder = builder.with_token(self.token.clone());
92
93        let api = builder
94            .build()
95            .map_err(|err| ModelRuntimeError::Source(format!("huggingface api error: {err}")))?;
96        let repo = api.repo(Repo::with_revision(
97            spec.repo_id.clone(),
98            RepoType::Model,
99            spec.revision.clone(),
100        ));
101
102        let mut files = BTreeMap::new();
103        for request in &spec.files {
104            match request {
105                ModelFileRequest::Required(path) => {
106                    let local = repo.get(path).map_err(|err| {
107                        ModelRuntimeError::Source(format!(
108                            "failed to download `{path}` from `{}`: {err}",
109                            spec.repo_id
110                        ))
111                    })?;
112                    files.insert(path.clone(), local);
113                }
114                ModelFileRequest::Optional(path) => {
115                    if let Ok(local) = repo.get(path) {
116                        files.insert(path.clone(), local);
117                    }
118                }
119                ModelFileRequest::FirstAvailable(paths) => {
120                    let mut last_error = None;
121                    let mut found = None;
122                    for path in paths {
123                        match repo.get(path) {
124                            Ok(local) => {
125                                found = Some((path.clone(), local));
126                                break;
127                            }
128                            Err(err) => last_error = Some(err.to_string()),
129                        }
130                    }
131                    if let Some((path, local)) = found {
132                        files.insert(path, local);
133                    } else {
134                        return Err(ModelRuntimeError::Source(format!(
135                            "none of the alternative files [{}] could be downloaded from `{}`{}",
136                            paths.join(", "),
137                            spec.repo_id,
138                            last_error
139                                .map(|err| format!("; last error: {err}"))
140                                .unwrap_or_default()
141                        )));
142                    }
143                }
144            }
145        }
146
147        Ok(DownloadedModel {
148            spec: spec.clone(),
149            files,
150        })
151    }
152}
153
154/// Minimal downloader seam for bundle resolution tests and alternate materializers.
155pub trait ModelDownloader {
156    /// Downloads or otherwise stages the requested model files.
157    fn download_model(&self, spec: &HuggingFaceModelSpec) -> Result<DownloadedModel>;
158}
159
160impl ModelDownloader for HuggingFaceDownloader {
161    fn download_model(&self, spec: &HuggingFaceModelSpec) -> Result<DownloadedModel> {
162        self.download(spec)
163    }
164}