entrenar/hf_pipeline/fetcher/
hf_fetcher.rs1use crate::hf_pipeline::error::{FetchError, Result};
6use std::path::PathBuf;
7
8use super::options::FetchOptions;
9use super::types::{ModelArtifact, WeightFormat};
10
11pub struct HfModelFetcher {
13 pub(crate) token: Option<String>,
15 pub(crate) cache_dir: PathBuf,
17 #[allow(dead_code)]
19 pub(crate) api_base: String,
20}
21
22impl HfModelFetcher {
23 pub fn new() -> Result<Self> {
29 let token = Self::resolve_token();
30 let cache_dir = Self::default_cache_dir();
31
32 Ok(Self { token, cache_dir, api_base: "https://huggingface.co".into() })
33 }
34
35 #[must_use]
37 pub fn with_token(token: impl Into<String>) -> Self {
38 Self {
39 token: Some(token.into()),
40 cache_dir: Self::default_cache_dir(),
41 api_base: "https://huggingface.co".into(),
42 }
43 }
44
45 #[must_use]
47 pub fn cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
48 self.cache_dir = dir.into();
49 self
50 }
51
52 #[must_use]
58 pub fn resolve_token() -> Option<String> {
59 if let Ok(token) = std::env::var("HF_TOKEN") {
61 if !token.is_empty() {
62 return Some(token);
63 }
64 }
65
66 if let Some(home) = dirs::home_dir() {
68 let token_path = home.join(".huggingface").join("token");
69 if let Ok(token) = std::fs::read_to_string(token_path) {
70 let token = token.trim().to_string();
71 if !token.is_empty() {
72 return Some(token);
73 }
74 }
75 }
76
77 None
78 }
79
80 pub(crate) fn default_cache_dir() -> PathBuf {
82 dirs::cache_dir().unwrap_or_else(|| PathBuf::from(".cache")).join("huggingface").join("hub")
83 }
84
85 #[must_use]
87 pub fn is_authenticated(&self) -> bool {
88 self.token.is_some()
89 }
90
91 pub(crate) fn parse_repo_id(repo_id: &str) -> Result<(&str, &str)> {
93 let parts: Vec<&str> = repo_id.split('/').collect();
94 if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
95 return Err(FetchError::InvalidRepoId { repo_id: repo_id.to_string() });
96 }
97 Ok((parts[0], parts[1]))
98 }
99
100 fn resolve_files(options: &FetchOptions) -> Vec<String> {
102 if options.files.is_empty() {
103 vec!["model.safetensors".to_string(), "config.json".to_string()]
104 } else {
105 options.files.clone()
106 }
107 }
108
109 fn check_security(files: &[String], allow_pickle: bool) -> Result<()> {
111 for file in files {
112 if let Some(format) = WeightFormat::from_filename(file) {
113 if !format.is_safe() && !allow_pickle {
114 return Err(FetchError::PickleSecurityRisk);
115 }
116 }
117 }
118 Ok(())
119 }
120
121 fn build_api(&self, cache_path: &std::path::Path) -> Result<hf_hub::api::sync::Api> {
123 let mut api_builder =
124 hf_hub::api::sync::ApiBuilder::new().with_cache_dir(cache_path.to_path_buf());
125
126 if let Some(token) = &self.token {
127 api_builder = api_builder.with_token(Some(token.clone()));
128 }
129
130 api_builder.build().map_err(|e| FetchError::ConfigParseError {
131 message: format!("Failed to initialize HF API: {e}"),
132 })
133 }
134
135 fn download_file(
137 repo: &hf_hub::api::sync::ApiRepo,
138 api: &hf_hub::api::sync::Api,
139 repo_id: &str,
140 revision: &str,
141 file: &str,
142 cache_path: &std::path::Path,
143 ) -> Result<()> {
144 let download_result = if revision == "main" {
145 repo.get(file)
146 } else {
147 let revision_repo = api.repo(hf_hub::Repo::with_revision(
148 repo_id.to_string(),
149 hf_hub::RepoType::Model,
150 revision.to_string(),
151 ));
152 revision_repo.get(file)
153 };
154
155 match download_result {
156 Ok(path) => {
157 let dest = cache_path.join(file);
158 if path != dest {
159 if let Some(parent) = dest.parent() {
160 std::fs::create_dir_all(parent)?;
161 }
162 if path.exists() && !dest.exists() {
163 std::fs::copy(&path, &dest)?;
164 }
165 }
166 Ok(())
167 }
168 Err(hf_hub::api::sync::ApiError::RequestError(e)) => {
169 if e.to_string().contains("404") {
170 Err(FetchError::FileNotFound {
171 repo: repo_id.to_string(),
172 file: file.to_string(),
173 })
174 } else {
175 Err(FetchError::ConfigParseError { message: format!("Download failed: {e}") })
176 }
177 }
178 Err(e) => {
179 Err(FetchError::ConfigParseError { message: format!("Download failed: {e}") })
180 }
181 }
182 }
183
184 pub fn download_model(&self, repo_id: &str, options: FetchOptions) -> Result<ModelArtifact> {
195 Self::parse_repo_id(repo_id)?;
196
197 let files = Self::resolve_files(&options);
198 Self::check_security(&files, options.allow_pytorch_pickle)?;
199
200 let cache_path = options
202 .cache_dir
203 .clone()
204 .unwrap_or_else(|| self.cache_dir.clone())
205 .join(repo_id.replace('/', "--"))
206 .join(&options.revision);
207 std::fs::create_dir_all(&cache_path)?;
208
209 let format = files
211 .iter()
212 .find_map(|f| WeightFormat::from_filename(f))
213 .unwrap_or(WeightFormat::SafeTensors);
214
215 let api = self.build_api(&cache_path)?;
216 let repo = api.model(repo_id.to_string());
217
218 for file in &files {
219 Self::download_file(&repo, &api, repo_id, &options.revision, file, &cache_path)?;
220 }
221
222 Ok(ModelArtifact {
223 path: cache_path,
224 format,
225 architecture: None,
226 sha256: options.verify_sha256,
227 })
228 }
229
230 #[must_use]
232 pub fn estimate_memory(param_count: u64, dtype_bytes: u8) -> u64 {
233 param_count * u64::from(dtype_bytes)
234 }
235}
236
237impl Default for HfModelFetcher {
238 fn default() -> Self {
239 Self::new().expect("Failed to create HfModelFetcher")
240 }
241}