Skip to main content

modelexpress_common/providers/
huggingface.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{Utils, constants, providers::ModelProviderTrait};
5use anyhow::{Context, Result};
6use hf_hub::api::tokio::ApiBuilder;
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use tracing::{debug, info, warn};
11
12const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
13const HF_HUB_CACHE_ENV_VAR: &str = "HF_HUB_CACHE";
14const MODEL_EXPRESS_CACHE_ENV_VAR: &str = "MODEL_EXPRESS_CACHE_DIRECTORY";
15const HF_HUB_OFFLINE_ENV_VAR: &str = "HF_HUB_OFFLINE";
16
17/// Check if offline mode is enabled via HF_HUB_OFFLINE environment variable.
18/// The variable is considered enabled if its value is one of: "1", "ON", "YES", "TRUE" (case-insensitive).
19fn is_offline_mode() -> bool {
20    env::var(HF_HUB_OFFLINE_ENV_VAR)
21        .map(|v| matches!(v.to_uppercase().as_str(), "1" | "ON" | "YES" | "TRUE"))
22        .unwrap_or(false)
23}
24
25/// Get the cache directory for Hugging Face models
26/// Priority order:
27/// 1. Provided cache_dir parameter
28/// 2. HF_HUB_CACHE environment variable
29/// 3. Default location (~/.cache/huggingface/hub)
30fn get_cache_dir(cache_dir: Option<PathBuf>) -> PathBuf {
31    // Use provided cache directory if available
32    if let Some(dir) = cache_dir {
33        return dir;
34    }
35
36    // Try MODEL_EXPRESS_CACHE_DIRECTORY environment variable first
37    if let Ok(cache_path) = env::var(MODEL_EXPRESS_CACHE_ENV_VAR) {
38        return PathBuf::from(cache_path);
39    }
40
41    // Try environment variable
42    if let Ok(cache_path) = env::var(HF_HUB_CACHE_ENV_VAR) {
43        return PathBuf::from(cache_path);
44    }
45
46    // Fall back to default location
47    let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
48    PathBuf::from(home).join(constants::DEFAULT_HF_CACHE_PATH)
49}
50
51/// Hugging Face model provider implementation
52pub struct HuggingFaceProvider;
53
54impl HuggingFaceProvider {
55    /// Determine whether the provided filename refers to a file that lives in a sub-directory.
56    /// Hugging Face repositories can contain nested folders, but those are never files
57    /// we use to run the model, so Model Express ignores them.
58    fn is_subdirectory_file(filename: &str) -> bool {
59        Path::new(filename).components().count() > 1
60    }
61}
62
63#[async_trait::async_trait]
64impl ModelProviderTrait for HuggingFaceProvider {
65    /// Attempt to download a model from Hugging Face.
66    /// Returns the directory it is in.
67    async fn download_model(
68        &self,
69        model_name: &str,
70        cache_dir: Option<PathBuf>,
71        ignore_weights: bool,
72    ) -> Result<PathBuf> {
73        let cache_dir = get_cache_dir(cache_dir);
74        std::fs::create_dir_all(&cache_dir).map_err(|e| {
75            anyhow::anyhow!("Failed to create cache directory {:?}: {}", cache_dir, e)
76        })?;
77
78        if is_offline_mode() {
79            info!("HF_HUB_OFFLINE is set, using cached model for '{model_name}'");
80            return self.get_model_path(model_name, cache_dir).await;
81        }
82
83        let token = env::var(HF_TOKEN_ENV_VAR).ok();
84
85        info!("Using cache directory: {:?}", cache_dir);
86        // High CPU download
87        //
88        // This may cause issues on regular desktops as it will saturate
89        // CPUs by multiplexing the downloads.
90        // However in data-center focused environments with model express
91        // this may help saturate the bandwidth (>500MB/s) better.
92        let api = ApiBuilder::from_env()
93            .with_progress(true)
94            .with_token(token)
95            .high()
96            .with_cache_dir(cache_dir)
97            .build()?;
98        let model_name = model_name.to_string();
99
100        let repo = api.model(model_name.clone());
101
102        let info = repo.info().await.map_err(
103            |e| anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace. Is this a valid HuggingFace ID? Error: {e}"),
104        )?;
105        debug!("Got model info: {info:?}");
106
107        if info.siblings.is_empty() {
108            anyhow::bail!("Model '{model_name}' exists but contains no downloadable files.");
109        }
110
111        let mut p = PathBuf::new();
112        let mut files_downloaded = false;
113
114        for sib in info.siblings {
115            if HuggingFaceProvider::is_subdirectory_file(&sib.rfilename) {
116                continue;
117            }
118
119            if HuggingFaceProvider::is_ignored(&sib.rfilename)
120                || HuggingFaceProvider::is_image(Path::new(&sib.rfilename))
121            {
122                continue;
123            }
124
125            if ignore_weights && HuggingFaceProvider::is_weight_file(&sib.rfilename) {
126                continue;
127            }
128
129            match repo.get(&sib.rfilename).await {
130                Ok(path) => {
131                    p = path;
132                    files_downloaded = true;
133                }
134                Err(e) => {
135                    return Err(anyhow::anyhow!(
136                        "Failed to download file '{sib}' from model '{model_name}': {e}",
137                        sib = sib.rfilename,
138                        model_name = model_name,
139                        e = e
140                    ));
141                }
142            }
143        }
144
145        if !files_downloaded {
146            return Err(anyhow::anyhow!(
147                "No valid files found for model '{}'.",
148                model_name
149            ));
150        }
151
152        info!("Downloaded model files for {model_name}");
153
154        match p.parent() {
155            Some(p) => Ok(p.to_path_buf()),
156            None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
157        }
158    }
159
160    /// Attempt to delete a model from Hugging Face cache
161    /// Returns Ok(()) if the model was successfully deleted or didn't exist
162    async fn delete_model(&self, model_name: &str) -> Result<()> {
163        info!("Deleting model from Hugging Face cache: {model_name}");
164        let token = env::var(HF_TOKEN_ENV_VAR).ok();
165        let api = ApiBuilder::new()
166            .with_token(token)
167            .build()
168            .context("Failed to create Hugging Face API client")?;
169        let model_name = model_name.to_string();
170
171        let repo = api.model(model_name.clone());
172
173        let info = match repo.info().await {
174            Ok(info) => info,
175            Err(_) => {
176                // If we can't get model info, assume it doesn't exist or is already deleted
177                info!("Model '{model_name}' not found or already deleted");
178                return Ok(());
179            }
180        };
181
182        if info.siblings.is_empty() {
183            info!("Model '{model_name}' has no files to delete");
184            return Ok(());
185        }
186
187        let mut files_deleted: u32 = 0;
188        let mut deletion_errors = Vec::new();
189
190        for sib in &info.siblings {
191            if HuggingFaceProvider::is_subdirectory_file(&sib.rfilename) {
192                continue;
193            }
194
195            if HuggingFaceProvider::is_ignored(&sib.rfilename)
196                || HuggingFaceProvider::is_image(Path::new(&sib.rfilename))
197            {
198                continue;
199            }
200
201            // Try to get the file path from cache first
202            if let Ok(cached_path) = repo.get(&sib.rfilename).await {
203                // Delete the cached file
204                match std::fs::remove_file(&cached_path) {
205                    Ok(_) => {
206                        files_deleted = files_deleted.saturating_add(1);
207                        info!("Deleted cached file: {}", cached_path.display());
208                    }
209                    Err(e) => {
210                        let error_msg =
211                            format!("Failed to delete cached file '{}'", cached_path.display());
212                        deletion_errors.push(anyhow::anyhow!(e).context(error_msg));
213                    }
214                }
215            }
216        }
217
218        // Try to remove the empty model directory if all files were deleted
219        if files_deleted > 0 && deletion_errors.is_empty() {
220            // Get any file path to find the model directory
221            for sib in &info.siblings {
222                if let Ok(cached_path) = repo.get(&sib.rfilename).await
223                    && let Some(model_dir) = cached_path.parent()
224                    && let Ok(mut entries) = std::fs::read_dir(model_dir)
225                    && entries.next().is_none()
226                {
227                    if let Err(e) = std::fs::remove_dir(model_dir) {
228                        info!("Could not remove empty model directory: {e}");
229                    } else {
230                        info!("Removed empty model directory: {}", model_dir.display());
231                    }
232                    break;
233                }
234            }
235        }
236
237        if !deletion_errors.is_empty() {
238            let mut compound_error =
239                anyhow::anyhow!("Failed to delete some files for model '{model_name}'");
240
241            for (i, error) in deletion_errors.into_iter().enumerate() {
242                compound_error =
243                    compound_error.context(format!("Error {}: {:#}", i.saturating_add(1), error));
244            }
245
246            return Err(compound_error);
247        }
248
249        if files_deleted == 0 {
250            info!("No cached files found to delete for model '{model_name}'");
251        } else {
252            info!("Successfully deleted {files_deleted} cached files for model '{model_name}'");
253        }
254
255        Ok(())
256    }
257
258    /// Get the full path to the latest model snapshot if it exists.
259    /// Returns the path if found, or an error if not found.
260    async fn get_model_path(&self, model_name: &str, cache_dir: PathBuf) -> Result<PathBuf> {
261        let normalized_name = model_name.replace("/", "--");
262        let path = cache_dir
263            .join(format!["models--{normalized_name}"])
264            .join("snapshots");
265
266        if !path.exists() {
267            anyhow::bail!("Model snapshots for '{model_name}' not found in cache");
268        }
269
270        let mut files: Vec<fs::DirEntry> = fs::read_dir(path)?.filter_map(Result::ok).collect();
271        if files.is_empty() {
272            anyhow::bail!("Model snapshots for '{model_name}' is empty");
273        }
274
275        // Sort by creation/modification time to get the most recent snapshot
276        files.sort_by_key(|e| {
277            e.metadata()
278                .and_then(|m| m.created().or_else(|_| m.modified()))
279                .unwrap_or(std::time::SystemTime::UNIX_EPOCH)
280        });
281        files.reverse();
282
283        // In offline mode, skip network validation and return the latest local snapshot
284        if is_offline_mode() {
285            return Ok(files[0].path());
286        }
287
288        // Check against the latest commit hash from HF
289        let token = env::var(HF_TOKEN_ENV_VAR).ok();
290        let api = ApiBuilder::from_env().with_token(token).build()?;
291        let repo = api.model(model_name.to_string());
292        let info = repo.info().await.map_err(|e| {
293            anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace: {e}")
294        })?;
295
296        for file in &files {
297            if file.file_name().display().to_string() == info.sha {
298                return Ok(file.path());
299            }
300        }
301
302        warn!(
303            "Existing model snapshots do not match the latest commit hash '{0}'. \
304            Returning the best-effort, latest local model snapshot.",
305            info.sha
306        );
307
308        Ok(files[0].path())
309    }
310
311    fn provider_name(&self) -> &'static str {
312        "Hugging Face"
313    }
314}
315
316#[cfg(test)]
317#[allow(clippy::expect_used)]
318mod tests {
319    use super::*;
320    use serde_json::json;
321    use std::sync::Mutex;
322    use tempfile::TempDir;
323    use tokio::time::Duration;
324    use wiremock::matchers::{method, path_regex};
325    use wiremock::{Mock, MockServer, ResponseTemplate};
326
327    /// Mutex to serialize access to HF_HUB_OFFLINE environment variable across tests.
328    /// This prevents race conditions when tests run in parallel.
329    static ENV_MUTEX: Mutex<()> = Mutex::new(());
330
331    /// Minimal mock of the Hugging Face Hub used by tests.
332    ///
333    /// This server stubs:
334    /// - the model info endpoint (`/api/models/<repo>`), returning a fixed `sha` and file list
335    /// - the file resolve endpoints (`/<repo>/resolve/<rev>/<filename>`) for each sibling
336    ///
337    /// The hf_hub client writes files into `cache_path` when the resolve endpoints return
338    /// successful responses with the headers it expects (ETag, commit, range). This allows
339    /// us to simulate a real model download without external network access.
340    struct MockHFServer {
341        /// WireMock instance; keeps the server alive for the lifetime of the test
342        _server: MockServer,
343        /// Temporary HF cache root that tests pass to `ApiBuilder::with_cache_dir`
344        pub cache_path: PathBuf,
345    }
346
347    impl MockHFServer {
348        /// Start a WireMock server and configure stubs compatible with hf_hub's download flow.
349        ///
350        /// Notes on headers and status codes expected by hf_hub:
351        /// - `etag`: used for dedup and cache validation
352        /// - `x-repo-commit`: identifies the snapshot commit (must match `info.sha`)
353        /// - Range download: GETs may be partial; we return 206 with `accept-ranges`,
354        ///   `content-length` and `content-range` to keep the client happy across versions.
355        async fn new() -> Self {
356            let temp_dir = TempDir::new().expect("Failed to create temporary directory");
357            let server = MockServer::start().await;
358
359            // Return the desired sha we want get_model_path to pick
360            // Matches GET /api/models/test/model (and subpaths).
361            Mock::given(method("GET"))
362                .and(path_regex(r"^/api/models/test/model(?:/.*)?$"))
363                .respond_with(ResponseTemplate::new(200).set_body_json(json!({
364                     "id": "test/model",
365                     "sha": "def5678",
366                     "siblings": [
367                         {"rfilename": "config.json"},
368                         {"rfilename": "model.safetensors"},
369                         {"rfilename": "tokenizer.json"},
370                         {"rfilename": "README.md"},
371                         {"rfilename": "subdir/model.safetensors"}
372                     ]
373                })))
374                .mount(&server)
375                .await;
376
377            // Mock resolved file contents so hf_hub can populate the cache
378            // Matches GET /test/model/resolve/<rev>/(config.json|tokenizer.json|README.md|model.safetensors)
379            Mock::given(method("GET"))
380                .and(path_regex(r"^/test/model/resolve/(main|[^/]+)/(?:config\.json|tokenizer\.json|README\.md|model\.safetensors)$"))
381                .respond_with(
382                    ResponseTemplate::new(206)
383                        .insert_header("etag", "\"def5678\"")
384                        .insert_header("x-repo-commit", "def5678")
385                        .insert_header("accept-ranges", "bytes")
386                        .insert_header("content-length", "64")
387                        .insert_header("content-range", "bytes 0-63/64")
388                        .set_body_bytes(vec![0u8; 64]),
389                )
390                .mount(&server)
391                .await;
392
393            unsafe {
394                std::env::set_var("HF_ENDPOINT", server.uri());
395            }
396
397            Self {
398                _server: server,
399                cache_path: temp_dir.path().to_path_buf(),
400            }
401        }
402    }
403
404    impl Drop for MockHFServer {
405        /// Ensure the temporary cache path is removed even if a test fails.
406        fn drop(&mut self) {
407            std::fs::remove_dir_all(&self.cache_path).unwrap_or_else(|e| {
408                warn!("Failed to remove temporary cache path: {e}");
409            });
410        }
411    }
412
413    #[test]
414    fn test_hugging_face_provider_name() {
415        let provider = HuggingFaceProvider;
416        assert_eq!(provider.provider_name(), "Hugging Face");
417    }
418
419    #[test]
420    fn test_provider_trait_object() {
421        let provider: Box<dyn ModelProviderTrait> = Box::new(HuggingFaceProvider);
422        assert_eq!(provider.provider_name(), "Hugging Face");
423    }
424
425    #[tokio::test]
426    async fn test_delete_model_trait() {
427        let provider = HuggingFaceProvider;
428        // Test that the delete method exists and can be called
429        // Note: This won't actually delete anything since we're not providing a real model
430        // but it tests the trait implementation
431        let result = provider.delete_model("nonexistent/model").await;
432        // Should succeed (return Ok(())) even if model doesn't exist
433        assert!(result.is_ok());
434    }
435
436    #[tokio::test]
437    async fn test_get_model_path_trait() {
438        let mock_server = MockHFServer::new().await;
439
440        // Construct a temporary cache dir with a model snapshots
441        let path = mock_server
442            .cache_path
443            .join("models--test--model")
444            .join("snapshots");
445
446        std::fs::create_dir_all(path.join("abc1234")).expect("Failed to create directory");
447        tokio::time::sleep(Duration::from_secs(1)).await;
448        std::fs::create_dir_all(path.join("def5678")).expect("Failed to create directory");
449
450        let provider = HuggingFaceProvider;
451        let result = provider
452            .get_model_path("test/model", mock_server.cache_path.clone())
453            .await;
454
455        assert!(result.is_ok());
456        assert_eq!(
457            result.expect("Failed to get model path"),
458            path.join("def5678")
459        );
460    }
461
462    #[tokio::test]
463    async fn test_download_ignore_weights() {
464        let mock_server = MockHFServer::new().await;
465        let provider = HuggingFaceProvider;
466        let result = provider
467            .download_model("test/model", Some(mock_server.cache_path.clone()), false)
468            .await
469            .expect("Failed to download model");
470
471        let files = fs::read_dir(result)
472            .expect("Failed to read directory")
473            .filter_map(Result::ok);
474
475        for file in files {
476            info!("File: {}", file.path().display());
477            assert!(!file.path().ends_with("safetensors"));
478        }
479    }
480
481    #[tokio::test]
482    async fn test_download_ignores_subdirectories() {
483        let mock_server = MockHFServer::new().await;
484        let provider = HuggingFaceProvider;
485
486        let result = provider
487            .download_model("test/model", Some(mock_server.cache_path.clone()), false)
488            .await
489            .expect("Failed to download model");
490
491        assert!(
492            !result.join("subdir").exists(),
493            "Expected files located in sub-directories to be ignored"
494        );
495    }
496
497    #[test]
498    fn test_is_offline_mode() {
499        let _guard = ENV_MUTEX.lock().expect("Failed to acquire env mutex");
500        unsafe {
501            env::set_var(HF_HUB_OFFLINE_ENV_VAR, "1");
502            assert!(is_offline_mode());
503
504            env::set_var(HF_HUB_OFFLINE_ENV_VAR, "0");
505            assert!(!is_offline_mode());
506
507            env::remove_var(HF_HUB_OFFLINE_ENV_VAR);
508        }
509        assert!(!is_offline_mode());
510    }
511
512    #[tokio::test]
513    #[allow(clippy::await_holding_lock)]
514    async fn test_download_model_offline_mode_with_cache() {
515        let _guard = ENV_MUTEX.lock().expect("Failed to acquire env mutex");
516        let temp_dir = TempDir::new().expect("Failed to create temporary directory");
517        let snapshots_path = temp_dir
518            .path()
519            .join("models--test--model")
520            .join("snapshots")
521            .join("abc1234");
522        std::fs::create_dir_all(&snapshots_path).expect("Failed to create directory");
523
524        unsafe {
525            env::set_var(HF_HUB_OFFLINE_ENV_VAR, "1");
526        }
527
528        let result = HuggingFaceProvider
529            .download_model("test/model", Some(temp_dir.path().into()), false)
530            .await;
531
532        unsafe {
533            env::remove_var(HF_HUB_OFFLINE_ENV_VAR);
534        }
535
536        assert!(result.is_ok());
537        assert!(result.expect("Expected path").ends_with("abc1234"));
538    }
539
540    #[tokio::test]
541    #[allow(clippy::await_holding_lock)]
542    async fn test_download_model_offline_mode_without_cache() {
543        let _guard = ENV_MUTEX.lock().expect("Failed to acquire env mutex");
544        let temp_dir = TempDir::new().expect("Failed to create temporary directory");
545
546        unsafe {
547            env::set_var(HF_HUB_OFFLINE_ENV_VAR, "1");
548        }
549
550        let result = HuggingFaceProvider
551            .download_model("nonexistent/model", Some(temp_dir.path().into()), false)
552            .await;
553
554        unsafe {
555            env::remove_var(HF_HUB_OFFLINE_ENV_VAR);
556        }
557
558        assert!(result.is_err());
559        assert!(
560            result
561                .expect_err("Expected error")
562                .to_string()
563                .contains("not found in cache")
564        );
565    }
566}