model_runtime/
download.rs1use std::collections::BTreeMap;
2use std::path::{Path, PathBuf};
3
4use crate::{ModelRuntimeError, Result};
5use hf_hub::api::sync::ApiBuilder;
6use hf_hub::{Repo, RepoType};
7
8use crate::{HuggingFaceModelSpec, ModelFileRequest};
9
10#[derive(Debug, Clone)]
11pub struct DownloadedModel {
13 pub spec: HuggingFaceModelSpec,
15 pub files: BTreeMap<String, PathBuf>,
17}
18
19impl DownloadedModel {
20 pub fn model_dir(&self) -> Option<&Path> {
22 self.files.values().next().and_then(|path| path.parent())
23 }
24}
25
26#[derive(Debug, Clone)]
27pub struct HuggingFaceDownloader {
29 cache_dir: Option<PathBuf>,
30 token: Option<String>,
31 progress: bool,
32 max_retries: usize,
33}
34
35impl Default for HuggingFaceDownloader {
36 fn default() -> Self {
37 Self {
38 cache_dir: None,
39 token: None,
40 progress: true,
41 max_retries: 0,
42 }
43 }
44}
45
46impl HuggingFaceDownloader {
47 pub fn new() -> Self {
49 Self::default()
50 }
51
52 pub fn cache_dir(mut self, path: impl Into<PathBuf>) -> Self {
54 self.cache_dir = Some(path.into());
55 self
56 }
57
58 pub fn token(mut self, value: impl Into<String>) -> Self {
60 self.token = Some(value.into());
61 self
62 }
63
64 pub fn progress(mut self, value: bool) -> Self {
66 self.progress = value;
67 self
68 }
69
70 pub fn max_retries(mut self, value: usize) -> Self {
72 self.max_retries = value;
73 self
74 }
75
76 pub fn download(&self, spec: &HuggingFaceModelSpec) -> Result<DownloadedModel> {
78 if spec.files.is_empty() {
79 return Err(ModelRuntimeError::InvalidArgument(
80 "at least one model file must be requested".to_string(),
81 ));
82 }
83
84 let mut builder = ApiBuilder::from_env()
85 .with_progress(self.progress)
86 .with_retries(self.max_retries)
87 .with_user_agent("video-analysis", env!("CARGO_PKG_VERSION"));
88 if let Some(cache_dir) = &self.cache_dir {
89 builder = builder.with_cache_dir(cache_dir.clone());
90 }
91 builder = builder.with_token(self.token.clone());
92
93 let api = builder
94 .build()
95 .map_err(|err| ModelRuntimeError::Source(format!("huggingface api error: {err}")))?;
96 let repo = api.repo(Repo::with_revision(
97 spec.repo_id.clone(),
98 RepoType::Model,
99 spec.revision.clone(),
100 ));
101
102 let mut files = BTreeMap::new();
103 for request in &spec.files {
104 match request {
105 ModelFileRequest::Required(path) => {
106 let local = repo.get(path).map_err(|err| {
107 ModelRuntimeError::Source(format!(
108 "failed to download `{path}` from `{}`: {err}",
109 spec.repo_id
110 ))
111 })?;
112 files.insert(path.clone(), local);
113 }
114 ModelFileRequest::Optional(path) => {
115 if let Ok(local) = repo.get(path) {
116 files.insert(path.clone(), local);
117 }
118 }
119 ModelFileRequest::FirstAvailable(paths) => {
120 let mut last_error = None;
121 let mut found = None;
122 for path in paths {
123 match repo.get(path) {
124 Ok(local) => {
125 found = Some((path.clone(), local));
126 break;
127 }
128 Err(err) => last_error = Some(err.to_string()),
129 }
130 }
131 if let Some((path, local)) = found {
132 files.insert(path, local);
133 } else {
134 return Err(ModelRuntimeError::Source(format!(
135 "none of the alternative files [{}] could be downloaded from `{}`{}",
136 paths.join(", "),
137 spec.repo_id,
138 last_error
139 .map(|err| format!("; last error: {err}"))
140 .unwrap_or_default()
141 )));
142 }
143 }
144 }
145 }
146
147 Ok(DownloadedModel {
148 spec: spec.clone(),
149 files,
150 })
151 }
152}
153
154pub trait ModelDownloader {
156 fn download_model(&self, spec: &HuggingFaceModelSpec) -> Result<DownloadedModel>;
158}
159
160impl ModelDownloader for HuggingFaceDownloader {
161 fn download_model(&self, spec: &HuggingFaceModelSpec) -> Result<DownloadedModel> {
162 self.download(spec)
163 }
164}