Skip to main content

entrenar/hf_pipeline/fetcher/
hf_fetcher.rs

1//! HuggingFace model fetcher implementation.
2//!
3//! Downloads models from HuggingFace Hub with authentication and caching.
4
5use crate::hf_pipeline::error::{FetchError, Result};
6use std::path::PathBuf;
7
8use super::options::FetchOptions;
9use super::types::{ModelArtifact, WeightFormat};
10
11/// HuggingFace model fetcher
12pub struct HfModelFetcher {
13    /// Authentication token
14    pub(crate) token: Option<String>,
15    /// Cache directory
16    pub(crate) cache_dir: PathBuf,
17    /// API base URL (for future HTTP client integration)
18    #[allow(dead_code)]
19    pub(crate) api_base: String,
20}
21
22impl HfModelFetcher {
23    /// Create new fetcher using HF_TOKEN environment variable
24    ///
25    /// # Errors
26    ///
27    /// Does not error on missing token (allows anonymous pulls).
28    pub fn new() -> Result<Self> {
29        let token = Self::resolve_token();
30        let cache_dir = Self::default_cache_dir();
31
32        Ok(Self { token, cache_dir, api_base: "https://huggingface.co".into() })
33    }
34
35    /// Create fetcher with explicit token
36    #[must_use]
37    pub fn with_token(token: impl Into<String>) -> Self {
38        Self {
39            token: Some(token.into()),
40            cache_dir: Self::default_cache_dir(),
41            api_base: "https://huggingface.co".into(),
42        }
43    }
44
45    /// Set cache directory
46    #[must_use]
47    pub fn cache_dir(mut self, dir: impl Into<PathBuf>) -> Self {
48        self.cache_dir = dir.into();
49        self
50    }
51
52    /// Resolve token from multiple sources
53    ///
54    /// Priority:
55    /// 1. HF_TOKEN environment variable
56    /// 2. ~/.huggingface/token file
57    #[must_use]
58    pub fn resolve_token() -> Option<String> {
59        // Try environment variable first
60        if let Ok(token) = std::env::var("HF_TOKEN") {
61            if !token.is_empty() {
62                return Some(token);
63            }
64        }
65
66        // Try ~/.huggingface/token file
67        if let Some(home) = dirs::home_dir() {
68            let token_path = home.join(".huggingface").join("token");
69            if let Ok(token) = std::fs::read_to_string(token_path) {
70                let token = token.trim().to_string();
71                if !token.is_empty() {
72                    return Some(token);
73                }
74            }
75        }
76
77        None
78    }
79
80    /// Get default cache directory
81    pub(crate) fn default_cache_dir() -> PathBuf {
82        dirs::cache_dir().unwrap_or_else(|| PathBuf::from(".cache")).join("huggingface").join("hub")
83    }
84
85    /// Check if client has authentication
86    #[must_use]
87    pub fn is_authenticated(&self) -> bool {
88        self.token.is_some()
89    }
90
91    /// Parse and validate repository ID
92    pub(crate) fn parse_repo_id(repo_id: &str) -> Result<(&str, &str)> {
93        let parts: Vec<&str> = repo_id.split('/').collect();
94        if parts.len() != 2 || parts[0].is_empty() || parts[1].is_empty() {
95            return Err(FetchError::InvalidRepoId { repo_id: repo_id.to_string() });
96        }
97        Ok((parts[0], parts[1]))
98    }
99
100    /// Resolve the list of files to download, falling back to defaults if empty.
101    fn resolve_files(options: &FetchOptions) -> Vec<String> {
102        if options.files.is_empty() {
103            vec!["model.safetensors".to_string(), "config.json".to_string()]
104        } else {
105            options.files.clone()
106        }
107    }
108
109    /// Check that no file uses an unsafe format (e.g. pickle) unless explicitly allowed.
110    fn check_security(files: &[String], allow_pickle: bool) -> Result<()> {
111        for file in files {
112            if let Some(format) = WeightFormat::from_filename(file) {
113                if !format.is_safe() && !allow_pickle {
114                    return Err(FetchError::PickleSecurityRisk);
115                }
116            }
117        }
118        Ok(())
119    }
120
121    /// Build the hf-hub sync API client with optional authentication.
122    fn build_api(&self, cache_path: &std::path::Path) -> Result<hf_hub::api::sync::Api> {
123        let mut api_builder =
124            hf_hub::api::sync::ApiBuilder::new().with_cache_dir(cache_path.to_path_buf());
125
126        if let Some(token) = &self.token {
127            api_builder = api_builder.with_token(Some(token.clone()));
128        }
129
130        api_builder.build().map_err(|e| FetchError::ConfigParseError {
131            message: format!("Failed to initialize HF API: {e}"),
132        })
133    }
134
135    /// Download a single file from a repo, copying it into the cache directory.
136    fn download_file(
137        repo: &hf_hub::api::sync::ApiRepo,
138        api: &hf_hub::api::sync::Api,
139        repo_id: &str,
140        revision: &str,
141        file: &str,
142        cache_path: &std::path::Path,
143    ) -> Result<()> {
144        let download_result = if revision == "main" {
145            repo.get(file)
146        } else {
147            let revision_repo = api.repo(hf_hub::Repo::with_revision(
148                repo_id.to_string(),
149                hf_hub::RepoType::Model,
150                revision.to_string(),
151            ));
152            revision_repo.get(file)
153        };
154
155        match download_result {
156            Ok(path) => {
157                let dest = cache_path.join(file);
158                if path != dest {
159                    if let Some(parent) = dest.parent() {
160                        std::fs::create_dir_all(parent)?;
161                    }
162                    if path.exists() && !dest.exists() {
163                        std::fs::copy(&path, &dest)?;
164                    }
165                }
166                Ok(())
167            }
168            Err(hf_hub::api::sync::ApiError::RequestError(e)) => {
169                if e.to_string().contains("404") {
170                    Err(FetchError::FileNotFound {
171                        repo: repo_id.to_string(),
172                        file: file.to_string(),
173                    })
174                } else {
175                    Err(FetchError::ConfigParseError { message: format!("Download failed: {e}") })
176                }
177            }
178            Err(e) => {
179                Err(FetchError::ConfigParseError { message: format!("Download failed: {e}") })
180            }
181        }
182    }
183
184    /// Download a model from HuggingFace Hub
185    ///
186    /// # Arguments
187    ///
188    /// * `repo_id` - Repository ID in "org/name" format
189    /// * `options` - Fetch options
190    ///
191    /// # Errors
192    ///
193    /// Returns error if download fails, repo not found, or security check fails.
194    pub fn download_model(&self, repo_id: &str, options: FetchOptions) -> Result<ModelArtifact> {
195        Self::parse_repo_id(repo_id)?;
196
197        let files = Self::resolve_files(&options);
198        Self::check_security(&files, options.allow_pytorch_pickle)?;
199
200        // Create local cache path
201        let cache_path = options
202            .cache_dir
203            .clone()
204            .unwrap_or_else(|| self.cache_dir.clone())
205            .join(repo_id.replace('/', "--"))
206            .join(&options.revision);
207        std::fs::create_dir_all(&cache_path)?;
208
209        // Detect format from files
210        let format = files
211            .iter()
212            .find_map(|f| WeightFormat::from_filename(f))
213            .unwrap_or(WeightFormat::SafeTensors);
214
215        let api = self.build_api(&cache_path)?;
216        let repo = api.model(repo_id.to_string());
217
218        for file in &files {
219            Self::download_file(&repo, &api, repo_id, &options.revision, file, &cache_path)?;
220        }
221
222        Ok(ModelArtifact {
223            path: cache_path,
224            format,
225            architecture: None,
226            sha256: options.verify_sha256,
227        })
228    }
229
230    /// Estimate memory required to load a model
231    #[must_use]
232    pub fn estimate_memory(param_count: u64, dtype_bytes: u8) -> u64 {
233        param_count * u64::from(dtype_bytes)
234    }
235}
236
237impl Default for HfModelFetcher {
238    fn default() -> Self {
239        Self::new().expect("Failed to create HfModelFetcher")
240    }
241}