modelexpress_common/providers/
huggingface.rs1use crate::{Utils, constants, providers::ModelProviderTrait};
5use anyhow::{Context, Result};
6use hf_hub::api::tokio::ApiBuilder;
7use std::env;
8use std::fs;
9use std::path::{Path, PathBuf};
10use tracing::{debug, info, warn};
11
12const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
13const HF_HUB_CACHE_ENV_VAR: &str = "HF_HUB_CACHE";
14const MODEL_EXPRESS_CACHE_ENV_VAR: &str = "MODEL_EXPRESS_CACHE_DIRECTORY";
15const HF_HUB_OFFLINE_ENV_VAR: &str = "HF_HUB_OFFLINE";
16
17fn is_offline_mode() -> bool {
20 env::var(HF_HUB_OFFLINE_ENV_VAR)
21 .map(|v| matches!(v.to_uppercase().as_str(), "1" | "ON" | "YES" | "TRUE"))
22 .unwrap_or(false)
23}
24
25fn get_cache_dir(cache_dir: Option<PathBuf>) -> PathBuf {
31 if let Some(dir) = cache_dir {
33 return dir;
34 }
35
36 if let Ok(cache_path) = env::var(MODEL_EXPRESS_CACHE_ENV_VAR) {
38 return PathBuf::from(cache_path);
39 }
40
41 if let Ok(cache_path) = env::var(HF_HUB_CACHE_ENV_VAR) {
43 return PathBuf::from(cache_path);
44 }
45
46 let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
48 PathBuf::from(home).join(constants::DEFAULT_HF_CACHE_PATH)
49}
50
51pub struct HuggingFaceProvider;
53
54impl HuggingFaceProvider {
55 fn is_subdirectory_file(filename: &str) -> bool {
59 Path::new(filename).components().count() > 1
60 }
61}
62
63#[async_trait::async_trait]
64impl ModelProviderTrait for HuggingFaceProvider {
65 async fn download_model(
68 &self,
69 model_name: &str,
70 cache_dir: Option<PathBuf>,
71 ignore_weights: bool,
72 ) -> Result<PathBuf> {
73 let cache_dir = get_cache_dir(cache_dir);
74 std::fs::create_dir_all(&cache_dir).map_err(|e| {
75 anyhow::anyhow!("Failed to create cache directory {:?}: {}", cache_dir, e)
76 })?;
77
78 if is_offline_mode() {
79 info!("HF_HUB_OFFLINE is set, using cached model for '{model_name}'");
80 return self.get_model_path(model_name, cache_dir).await;
81 }
82
83 let token = env::var(HF_TOKEN_ENV_VAR).ok();
84
85 info!("Using cache directory: {:?}", cache_dir);
86 let api = ApiBuilder::from_env()
93 .with_progress(true)
94 .with_token(token)
95 .high()
96 .with_cache_dir(cache_dir)
97 .build()?;
98 let model_name = model_name.to_string();
99
100 let repo = api.model(model_name.clone());
101
102 let info = repo.info().await.map_err(
103 |e| anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace. Is this a valid HuggingFace ID? Error: {e}"),
104 )?;
105 debug!("Got model info: {info:?}");
106
107 if info.siblings.is_empty() {
108 anyhow::bail!("Model '{model_name}' exists but contains no downloadable files.");
109 }
110
111 let mut p = PathBuf::new();
112 let mut files_downloaded = false;
113
114 for sib in info.siblings {
115 if HuggingFaceProvider::is_subdirectory_file(&sib.rfilename) {
116 continue;
117 }
118
119 if HuggingFaceProvider::is_ignored(&sib.rfilename)
120 || HuggingFaceProvider::is_image(Path::new(&sib.rfilename))
121 {
122 continue;
123 }
124
125 if ignore_weights && HuggingFaceProvider::is_weight_file(&sib.rfilename) {
126 continue;
127 }
128
129 match repo.get(&sib.rfilename).await {
130 Ok(path) => {
131 p = path;
132 files_downloaded = true;
133 }
134 Err(e) => {
135 return Err(anyhow::anyhow!(
136 "Failed to download file '{sib}' from model '{model_name}': {e}",
137 sib = sib.rfilename,
138 model_name = model_name,
139 e = e
140 ));
141 }
142 }
143 }
144
145 if !files_downloaded {
146 return Err(anyhow::anyhow!(
147 "No valid files found for model '{}'.",
148 model_name
149 ));
150 }
151
152 info!("Downloaded model files for {model_name}");
153
154 match p.parent() {
155 Some(p) => Ok(p.to_path_buf()),
156 None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
157 }
158 }
159
160 async fn delete_model(&self, model_name: &str) -> Result<()> {
163 info!("Deleting model from Hugging Face cache: {model_name}");
164 let token = env::var(HF_TOKEN_ENV_VAR).ok();
165 let api = ApiBuilder::new()
166 .with_token(token)
167 .build()
168 .context("Failed to create Hugging Face API client")?;
169 let model_name = model_name.to_string();
170
171 let repo = api.model(model_name.clone());
172
173 let info = match repo.info().await {
174 Ok(info) => info,
175 Err(_) => {
176 info!("Model '{model_name}' not found or already deleted");
178 return Ok(());
179 }
180 };
181
182 if info.siblings.is_empty() {
183 info!("Model '{model_name}' has no files to delete");
184 return Ok(());
185 }
186
187 let mut files_deleted: u32 = 0;
188 let mut deletion_errors = Vec::new();
189
190 for sib in &info.siblings {
191 if HuggingFaceProvider::is_subdirectory_file(&sib.rfilename) {
192 continue;
193 }
194
195 if HuggingFaceProvider::is_ignored(&sib.rfilename)
196 || HuggingFaceProvider::is_image(Path::new(&sib.rfilename))
197 {
198 continue;
199 }
200
201 if let Ok(cached_path) = repo.get(&sib.rfilename).await {
203 match std::fs::remove_file(&cached_path) {
205 Ok(_) => {
206 files_deleted = files_deleted.saturating_add(1);
207 info!("Deleted cached file: {}", cached_path.display());
208 }
209 Err(e) => {
210 let error_msg =
211 format!("Failed to delete cached file '{}'", cached_path.display());
212 deletion_errors.push(anyhow::anyhow!(e).context(error_msg));
213 }
214 }
215 }
216 }
217
218 if files_deleted > 0 && deletion_errors.is_empty() {
220 for sib in &info.siblings {
222 if let Ok(cached_path) = repo.get(&sib.rfilename).await
223 && let Some(model_dir) = cached_path.parent()
224 && let Ok(mut entries) = std::fs::read_dir(model_dir)
225 && entries.next().is_none()
226 {
227 if let Err(e) = std::fs::remove_dir(model_dir) {
228 info!("Could not remove empty model directory: {e}");
229 } else {
230 info!("Removed empty model directory: {}", model_dir.display());
231 }
232 break;
233 }
234 }
235 }
236
237 if !deletion_errors.is_empty() {
238 let mut compound_error =
239 anyhow::anyhow!("Failed to delete some files for model '{model_name}'");
240
241 for (i, error) in deletion_errors.into_iter().enumerate() {
242 compound_error =
243 compound_error.context(format!("Error {}: {:#}", i.saturating_add(1), error));
244 }
245
246 return Err(compound_error);
247 }
248
249 if files_deleted == 0 {
250 info!("No cached files found to delete for model '{model_name}'");
251 } else {
252 info!("Successfully deleted {files_deleted} cached files for model '{model_name}'");
253 }
254
255 Ok(())
256 }
257
258 async fn get_model_path(&self, model_name: &str, cache_dir: PathBuf) -> Result<PathBuf> {
261 let normalized_name = model_name.replace("/", "--");
262 let path = cache_dir
263 .join(format!["models--{normalized_name}"])
264 .join("snapshots");
265
266 if !path.exists() {
267 anyhow::bail!("Model snapshots for '{model_name}' not found in cache");
268 }
269
270 let mut files: Vec<fs::DirEntry> = fs::read_dir(path)?.filter_map(Result::ok).collect();
271 if files.is_empty() {
272 anyhow::bail!("Model snapshots for '{model_name}' is empty");
273 }
274
275 files.sort_by_key(|e| {
277 e.metadata()
278 .and_then(|m| m.created().or_else(|_| m.modified()))
279 .unwrap_or(std::time::SystemTime::UNIX_EPOCH)
280 });
281 files.reverse();
282
283 if is_offline_mode() {
285 return Ok(files[0].path());
286 }
287
288 let token = env::var(HF_TOKEN_ENV_VAR).ok();
290 let api = ApiBuilder::from_env().with_token(token).build()?;
291 let repo = api.model(model_name.to_string());
292 let info = repo.info().await.map_err(|e| {
293 anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace: {e}")
294 })?;
295
296 for file in &files {
297 if file.file_name().display().to_string() == info.sha {
298 return Ok(file.path());
299 }
300 }
301
302 warn!(
303 "Existing model snapshots do not match the latest commit hash '{0}'. \
304 Returning the best-effort, latest local model snapshot.",
305 info.sha
306 );
307
308 Ok(files[0].path())
309 }
310
311 fn provider_name(&self) -> &'static str {
312 "Hugging Face"
313 }
314}
315
316#[cfg(test)]
317#[allow(clippy::expect_used)]
318mod tests {
319 use super::*;
320 use serde_json::json;
321 use std::sync::Mutex;
322 use tempfile::TempDir;
323 use tokio::time::Duration;
324 use wiremock::matchers::{method, path_regex};
325 use wiremock::{Mock, MockServer, ResponseTemplate};
326
327 static ENV_MUTEX: Mutex<()> = Mutex::new(());
330
331 struct MockHFServer {
341 _server: MockServer,
343 pub cache_path: PathBuf,
345 }
346
347 impl MockHFServer {
348 async fn new() -> Self {
356 let temp_dir = TempDir::new().expect("Failed to create temporary directory");
357 let server = MockServer::start().await;
358
359 Mock::given(method("GET"))
362 .and(path_regex(r"^/api/models/test/model(?:/.*)?$"))
363 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
364 "id": "test/model",
365 "sha": "def5678",
366 "siblings": [
367 {"rfilename": "config.json"},
368 {"rfilename": "model.safetensors"},
369 {"rfilename": "tokenizer.json"},
370 {"rfilename": "README.md"},
371 {"rfilename": "subdir/model.safetensors"}
372 ]
373 })))
374 .mount(&server)
375 .await;
376
377 Mock::given(method("GET"))
380 .and(path_regex(r"^/test/model/resolve/(main|[^/]+)/(?:config\.json|tokenizer\.json|README\.md|model\.safetensors)$"))
381 .respond_with(
382 ResponseTemplate::new(206)
383 .insert_header("etag", "\"def5678\"")
384 .insert_header("x-repo-commit", "def5678")
385 .insert_header("accept-ranges", "bytes")
386 .insert_header("content-length", "64")
387 .insert_header("content-range", "bytes 0-63/64")
388 .set_body_bytes(vec![0u8; 64]),
389 )
390 .mount(&server)
391 .await;
392
393 unsafe {
394 std::env::set_var("HF_ENDPOINT", server.uri());
395 }
396
397 Self {
398 _server: server,
399 cache_path: temp_dir.path().to_path_buf(),
400 }
401 }
402 }
403
404 impl Drop for MockHFServer {
405 fn drop(&mut self) {
407 std::fs::remove_dir_all(&self.cache_path).unwrap_or_else(|e| {
408 warn!("Failed to remove temporary cache path: {e}");
409 });
410 }
411 }
412
413 #[test]
414 fn test_hugging_face_provider_name() {
415 let provider = HuggingFaceProvider;
416 assert_eq!(provider.provider_name(), "Hugging Face");
417 }
418
419 #[test]
420 fn test_provider_trait_object() {
421 let provider: Box<dyn ModelProviderTrait> = Box::new(HuggingFaceProvider);
422 assert_eq!(provider.provider_name(), "Hugging Face");
423 }
424
425 #[tokio::test]
426 async fn test_delete_model_trait() {
427 let provider = HuggingFaceProvider;
428 let result = provider.delete_model("nonexistent/model").await;
432 assert!(result.is_ok());
434 }
435
436 #[tokio::test]
437 async fn test_get_model_path_trait() {
438 let mock_server = MockHFServer::new().await;
439
440 let path = mock_server
442 .cache_path
443 .join("models--test--model")
444 .join("snapshots");
445
446 std::fs::create_dir_all(path.join("abc1234")).expect("Failed to create directory");
447 tokio::time::sleep(Duration::from_secs(1)).await;
448 std::fs::create_dir_all(path.join("def5678")).expect("Failed to create directory");
449
450 let provider = HuggingFaceProvider;
451 let result = provider
452 .get_model_path("test/model", mock_server.cache_path.clone())
453 .await;
454
455 assert!(result.is_ok());
456 assert_eq!(
457 result.expect("Failed to get model path"),
458 path.join("def5678")
459 );
460 }
461
462 #[tokio::test]
463 async fn test_download_ignore_weights() {
464 let mock_server = MockHFServer::new().await;
465 let provider = HuggingFaceProvider;
466 let result = provider
467 .download_model("test/model", Some(mock_server.cache_path.clone()), false)
468 .await
469 .expect("Failed to download model");
470
471 let files = fs::read_dir(result)
472 .expect("Failed to read directory")
473 .filter_map(Result::ok);
474
475 for file in files {
476 info!("File: {}", file.path().display());
477 assert!(!file.path().ends_with("safetensors"));
478 }
479 }
480
481 #[tokio::test]
482 async fn test_download_ignores_subdirectories() {
483 let mock_server = MockHFServer::new().await;
484 let provider = HuggingFaceProvider;
485
486 let result = provider
487 .download_model("test/model", Some(mock_server.cache_path.clone()), false)
488 .await
489 .expect("Failed to download model");
490
491 assert!(
492 !result.join("subdir").exists(),
493 "Expected files located in sub-directories to be ignored"
494 );
495 }
496
497 #[test]
498 fn test_is_offline_mode() {
499 let _guard = ENV_MUTEX.lock().expect("Failed to acquire env mutex");
500 unsafe {
501 env::set_var(HF_HUB_OFFLINE_ENV_VAR, "1");
502 assert!(is_offline_mode());
503
504 env::set_var(HF_HUB_OFFLINE_ENV_VAR, "0");
505 assert!(!is_offline_mode());
506
507 env::remove_var(HF_HUB_OFFLINE_ENV_VAR);
508 }
509 assert!(!is_offline_mode());
510 }
511
512 #[tokio::test]
513 #[allow(clippy::await_holding_lock)]
514 async fn test_download_model_offline_mode_with_cache() {
515 let _guard = ENV_MUTEX.lock().expect("Failed to acquire env mutex");
516 let temp_dir = TempDir::new().expect("Failed to create temporary directory");
517 let snapshots_path = temp_dir
518 .path()
519 .join("models--test--model")
520 .join("snapshots")
521 .join("abc1234");
522 std::fs::create_dir_all(&snapshots_path).expect("Failed to create directory");
523
524 unsafe {
525 env::set_var(HF_HUB_OFFLINE_ENV_VAR, "1");
526 }
527
528 let result = HuggingFaceProvider
529 .download_model("test/model", Some(temp_dir.path().into()), false)
530 .await;
531
532 unsafe {
533 env::remove_var(HF_HUB_OFFLINE_ENV_VAR);
534 }
535
536 assert!(result.is_ok());
537 assert!(result.expect("Expected path").ends_with("abc1234"));
538 }
539
540 #[tokio::test]
541 #[allow(clippy::await_holding_lock)]
542 async fn test_download_model_offline_mode_without_cache() {
543 let _guard = ENV_MUTEX.lock().expect("Failed to acquire env mutex");
544 let temp_dir = TempDir::new().expect("Failed to create temporary directory");
545
546 unsafe {
547 env::set_var(HF_HUB_OFFLINE_ENV_VAR, "1");
548 }
549
550 let result = HuggingFaceProvider
551 .download_model("nonexistent/model", Some(temp_dir.path().into()), false)
552 .await;
553
554 unsafe {
555 env::remove_var(HF_HUB_OFFLINE_ENV_VAR);
556 }
557
558 assert!(result.is_err());
559 assert!(
560 result
561 .expect_err("Expected error")
562 .to_string()
563 .contains("not found in cache")
564 );
565 }
566}