use crate::{
Utils,
cache::{ModelInfo, ProviderCache, directory_size},
constants,
models::ModelProvider,
providers::ModelProviderTrait,
};
use anyhow::{Context, Result};
use hf_hub::Cache;
use hf_hub::api::tokio::{ApiBuilder, ApiError};
use std::collections::HashSet;
use std::env;
use std::fs;
use std::path::{Path, PathBuf};
use tracing::{debug, info, warn};
const HF_TOKEN_ENV_VAR: &str = "HF_TOKEN";
const HF_HUB_CACHE_ENV_VAR: &str = "HF_HUB_CACHE";
const MODEL_EXPRESS_CACHE_ENV_VAR: &str = "MODEL_EXPRESS_CACHE_DIRECTORY";
const HF_HUB_OFFLINE_ENV_VAR: &str = "HF_HUB_OFFLINE";
fn is_offline_mode() -> bool {
env::var(HF_HUB_OFFLINE_ENV_VAR)
.map(|v| matches!(v.to_uppercase().as_str(), "1" | "ON" | "YES" | "TRUE"))
.unwrap_or(false)
}
fn get_cache_dir(cache_dir: Option<PathBuf>) -> PathBuf {
if let Some(dir) = cache_dir {
return dir;
}
if let Ok(cache_path) = env::var(MODEL_EXPRESS_CACHE_ENV_VAR) {
return PathBuf::from(cache_path);
}
if let Ok(cache_path) = env::var(HF_HUB_CACHE_ENV_VAR) {
return PathBuf::from(cache_path);
}
let home = Utils::get_home_dir().unwrap_or_else(|_| ".".to_string());
PathBuf::from(home).join(constants::DEFAULT_HF_CACHE_PATH)
}
pub struct HuggingFaceProvider;
pub(crate) struct HuggingFaceProviderCache;
impl HuggingFaceProviderCache {
fn repo_root(cache_root: &Path, model_name: &str) -> PathBuf {
cache_root.join(format!("models--{}", model_name.replace('/', "--")))
}
fn snapshots_dir(cache_root: &Path, model_name: &str) -> PathBuf {
Self::repo_root(cache_root, model_name).join("snapshots")
}
fn folder_name_to_model_id(folder_name: &str) -> String {
if let Some(stripped) = folder_name.strip_prefix("models--") {
stripped.replace("--", "/")
} else {
folder_name.to_string()
}
}
fn latest_local_snapshot_path(cache_root: &Path, model_name: &str) -> Result<PathBuf> {
let path = Self::snapshots_dir(cache_root, model_name);
if !path.exists() {
anyhow::bail!("Model snapshots for '{model_name}' not found in cache");
}
let mut files: Vec<fs::DirEntry> = fs::read_dir(path)?.filter_map(Result::ok).collect();
if files.is_empty() {
anyhow::bail!("Model snapshots for '{model_name}' is empty");
}
files.sort_by_key(|entry| {
entry
.metadata()
.and_then(|metadata| metadata.created().or_else(|_| metadata.modified()))
.unwrap_or(std::time::SystemTime::UNIX_EPOCH)
});
files.reverse();
Ok(files[0].path())
}
}
impl ProviderCache for HuggingFaceProviderCache {
fn clear_model(&self, cache_root: &Path, model_name: &str) -> Result<()> {
let model_path = Self::repo_root(cache_root, model_name);
if model_path.exists() {
fs::remove_dir_all(&model_path)
.with_context(|| format!("Failed to remove model: {model_path:?}"))?;
info!(
"Cleared model: {} ({:?})",
model_name,
ModelProvider::HuggingFace
);
} else {
warn!(
"Model not found in cache: {} ({:?})",
model_name,
ModelProvider::HuggingFace
);
}
Ok(())
}
fn resolve_model_path(
&self,
cache_root: &Path,
model_name: &str,
revision: Option<&str>,
) -> Result<PathBuf> {
match revision {
Some(revision) => Ok(Self::snapshots_dir(cache_root, model_name).join(revision)),
None => Self::latest_local_snapshot_path(cache_root, model_name),
}
}
fn list_models(&self, cache_root: &Path) -> Result<Vec<ModelInfo>> {
let mut models = Vec::new();
if !cache_root.exists() {
return Ok(models);
}
for entry in fs::read_dir(cache_root)? {
let entry = entry?;
let path = entry.path();
if !path.is_dir() {
continue;
}
let Some(folder_name) = path.file_name().and_then(|name| name.to_str()) else {
continue;
};
if !folder_name.starts_with("models--") {
continue;
}
models.push(ModelInfo {
provider: ModelProvider::HuggingFace,
name: Self::folder_name_to_model_id(folder_name),
size: directory_size(&path)?,
path,
});
}
Ok(models)
}
}
impl HuggingFaceProvider {
fn is_subdirectory_file(filename: &str) -> bool {
Path::new(filename).components().count() > 1
}
}
#[async_trait::async_trait]
impl ModelProviderTrait for HuggingFaceProvider {
async fn download_model(
&self,
model_name: &str,
cache_dir: Option<PathBuf>,
ignore_weights: bool,
) -> Result<PathBuf> {
let cache_dir = get_cache_dir(cache_dir);
std::fs::create_dir_all(&cache_dir).map_err(|e| {
anyhow::anyhow!("Failed to create cache directory {:?}: {}", cache_dir, e)
})?;
if is_offline_mode() {
info!("HF_HUB_OFFLINE is set, using cached model for '{model_name}'");
return self.get_model_path(model_name, cache_dir).await;
}
let token = env::var(HF_TOKEN_ENV_VAR).ok();
info!("Using cache directory: {:?}", cache_dir);
let api = ApiBuilder::from_env()
.with_progress(true)
.with_token(token)
.high()
.with_cache_dir(cache_dir)
.build()?;
let model_name = model_name.to_string();
let repo = api.model(model_name.clone());
let info = repo.info().await.map_err(
|e| anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace. Is this a valid HuggingFace ID? Error: {e}"),
)?;
debug!("Got model info: {info:?}");
if info.siblings.is_empty() {
anyhow::bail!("Model '{model_name}' exists but contains no downloadable files.");
}
let mut p = PathBuf::new();
let mut files_downloaded = false;
for sib in info.siblings {
if HuggingFaceProvider::is_subdirectory_file(&sib.rfilename) {
continue;
}
if HuggingFaceProvider::is_ignored(&sib.rfilename)
|| HuggingFaceProvider::is_image(Path::new(&sib.rfilename))
{
continue;
}
if ignore_weights && HuggingFaceProvider::is_weight_file(&sib.rfilename) {
continue;
}
match repo.get(&sib.rfilename).await {
Ok(path) => {
p = path;
files_downloaded = true;
}
Err(e) => {
if let ApiError::RequestError(req_err) = &e
&& req_err.status().is_some_and(|s| s.as_u16() == 416)
{
warn!(
"Skipping empty file '{}' from model '{}': {}",
sib.rfilename, model_name, e
);
continue;
}
return Err(anyhow::anyhow!(
"Failed to download file '{sib}' from model '{model_name}': {e}",
sib = sib.rfilename,
model_name = model_name,
e = e
));
}
}
}
if !files_downloaded {
return Err(anyhow::anyhow!(
"No valid files found for model '{}'.",
model_name
));
}
info!("Downloaded model files for {model_name}");
match p.parent() {
Some(p) => Ok(p.to_path_buf()),
None => Err(anyhow::anyhow!("Invalid HF cache path: {}", p.display())),
}
}
async fn delete_model(&self, model_name: &str, cache_dir: PathBuf) -> Result<()> {
info!("Deleting model from Hugging Face cache: {model_name}");
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::from_env()
.with_token(token)
.with_cache_dir(cache_dir.clone())
.build()
.context("Failed to create Hugging Face API client")?;
let model_name = model_name.to_string();
let repo = api.model(model_name.clone());
let cache_repo = Cache::new(cache_dir).model(model_name.clone());
let info = match repo.info().await {
Ok(info) => info,
Err(_) => {
info!("Model '{model_name}' not found or already deleted");
return Ok(());
}
};
if info.siblings.is_empty() {
info!("Model '{model_name}' has no files to delete");
return Ok(());
}
let mut files_deleted: u32 = 0;
let mut deletion_errors = Vec::new();
let mut model_dirs: HashSet<PathBuf> = HashSet::new();
for sib in &info.siblings {
if HuggingFaceProvider::is_subdirectory_file(&sib.rfilename) {
continue;
}
if HuggingFaceProvider::is_ignored(&sib.rfilename)
|| HuggingFaceProvider::is_image(Path::new(&sib.rfilename))
{
continue;
}
if let Some(cached_path) = cache_repo.get(&sib.rfilename) {
match std::fs::remove_file(&cached_path) {
Ok(_) => {
files_deleted = files_deleted.saturating_add(1);
info!("Deleted cached file: {}", cached_path.display());
if let Some(model_dir) = cached_path.parent() {
model_dirs.insert(model_dir.to_path_buf());
}
}
Err(e) => {
let error_msg =
format!("Failed to delete cached file '{}'", cached_path.display());
deletion_errors.push(anyhow::anyhow!(e).context(error_msg));
}
}
}
}
if files_deleted > 0 && deletion_errors.is_empty() {
for model_dir in model_dirs {
if let Ok(mut entries) = std::fs::read_dir(&model_dir)
&& entries.next().is_none()
{
if let Err(e) = std::fs::remove_dir(&model_dir) {
info!("Could not remove empty model directory: {e}");
} else {
info!("Removed empty model directory: {}", model_dir.display());
}
}
}
}
if !deletion_errors.is_empty() {
let mut compound_error =
anyhow::anyhow!("Failed to delete some files for model '{model_name}'");
for (i, error) in deletion_errors.into_iter().enumerate() {
compound_error =
compound_error.context(format!("Error {}: {:#}", i.saturating_add(1), error));
}
return Err(compound_error);
}
if files_deleted == 0 {
info!("No cached files found to delete for model '{model_name}'");
} else {
info!("Successfully deleted {files_deleted} cached files for model '{model_name}'");
}
Ok(())
}
async fn get_model_path(&self, model_name: &str, cache_dir: PathBuf) -> Result<PathBuf> {
let latest_local_snapshot =
HuggingFaceProviderCache.resolve_model_path(&cache_dir, model_name, None)?;
if is_offline_mode() {
return Ok(latest_local_snapshot);
}
let token = env::var(HF_TOKEN_ENV_VAR).ok();
let api = ApiBuilder::from_env().with_token(token).build()?;
let repo = api.model(model_name.to_string());
let info = repo.info().await.map_err(|e| {
anyhow::anyhow!("Failed to fetch model '{model_name}' from HuggingFace: {e}")
})?;
let latest_remote_snapshot =
HuggingFaceProviderCache.resolve_model_path(&cache_dir, model_name, Some(&info.sha))?;
if latest_remote_snapshot.exists() {
return Ok(latest_remote_snapshot);
}
warn!(
"Existing model snapshots do not match the latest commit hash '{0}'. \
Returning the best-effort, latest local model snapshot.",
info.sha
);
Ok(latest_local_snapshot)
}
fn provider_name(&self) -> &'static str {
"Hugging Face"
}
}
#[cfg(test)]
#[allow(clippy::expect_used)]
mod tests {
use super::*;
use crate::test_support::{EnvVarGuard, acquire_env_mutex};
use serde_json::json;
use std::sync::MutexGuard;
use tempfile::TempDir;
use tokio::time::Duration;
use wiremock::matchers::{method, path_regex};
use wiremock::{Mock, MockServer, ResponseTemplate};
struct MockHFServer<'a> {
_server: MockServer,
pub cache_path: PathBuf,
_hf_endpoint_guard: EnvVarGuard<'a>,
}
impl<'a> MockHFServer<'a> {
async fn new(env_lock: &'a MutexGuard<'static, ()>) -> Self {
let temp_dir = TempDir::new().expect("Failed to create temporary directory");
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path_regex(r"^/api/models/test/model(?:/.*)?$"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": "test/model",
"sha": "def5678",
"siblings": [
{"rfilename": "config.json"},
{"rfilename": "model.safetensors"},
{"rfilename": "tokenizer.json"},
{"rfilename": "README.md"},
{"rfilename": "subdir/model.safetensors"}
]
})))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(r"^/test/model/resolve/(main|[^/]+)/(?:config\.json|tokenizer\.json|README\.md|model\.safetensors)$"))
.respond_with(
ResponseTemplate::new(206)
.insert_header("etag", "\"def5678\"")
.insert_header("x-repo-commit", "def5678")
.insert_header("accept-ranges", "bytes")
.insert_header("content-length", "64")
.insert_header("content-range", "bytes 0-63/64")
.set_body_bytes(vec![0u8; 64]),
)
.mount(&server)
.await;
let hf_endpoint_guard = EnvVarGuard::set(env_lock, "HF_ENDPOINT", &server.uri());
Self {
_server: server,
cache_path: temp_dir.path().to_path_buf(),
_hf_endpoint_guard: hf_endpoint_guard,
}
}
}
impl Drop for MockHFServer<'_> {
fn drop(&mut self) {
std::fs::remove_dir_all(&self.cache_path).unwrap_or_else(|e| {
warn!("Failed to remove temporary cache path: {e}");
});
}
}
#[test]
fn test_hugging_face_provider_name() {
let provider = HuggingFaceProvider;
assert_eq!(provider.provider_name(), "Hugging Face");
}
#[test]
fn test_provider_trait_object() {
let provider: Box<dyn ModelProviderTrait> = Box::new(HuggingFaceProvider);
assert_eq!(provider.provider_name(), "Hugging Face");
}
#[tokio::test]
async fn test_delete_model_trait() {
let provider = HuggingFaceProvider;
let cache_dir = TempDir::new().expect("Failed to create temporary cache directory");
let result = provider
.delete_model("nonexistent/model", cache_dir.path().to_path_buf())
.await;
assert!(result.is_ok());
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_delete_model_prefers_explicit_cache_dir_over_env() {
let env_lock = acquire_env_mutex();
let mock_server = MockHFServer::new(&env_lock).await;
let provider = HuggingFaceProvider;
let explicit_cache = TempDir::new().expect("Failed to create explicit cache directory");
let env_cache = TempDir::new().expect("Failed to create env cache directory");
let explicit_snapshot = provider
.download_model(
"test/model",
Some(explicit_cache.path().to_path_buf()),
false,
)
.await
.expect("Failed to seed explicit cache");
let explicit_config = explicit_snapshot.join("config.json");
assert!(
explicit_config.exists(),
"Expected explicit cache to contain model file before delete"
);
let env_snapshot = provider
.download_model("test/model", Some(env_cache.path().to_path_buf()), false)
.await
.expect("Failed to seed env cache");
let env_config = env_snapshot.join("config.json");
assert!(
env_config.exists(),
"Expected env cache to contain model file before delete"
);
let env_cache_path = env_cache.path().to_str().expect("Expected env cache path");
let _model_express_cache_guard =
EnvVarGuard::set(&env_lock, MODEL_EXPRESS_CACHE_ENV_VAR, env_cache_path);
let _hf_hub_cache_guard = EnvVarGuard::set(&env_lock, HF_HUB_CACHE_ENV_VAR, env_cache_path);
let delete_result = provider
.delete_model("test/model", explicit_cache.path().to_path_buf())
.await;
assert!(
delete_result.is_ok(),
"Delete request should succeed: {delete_result:?}"
);
assert!(
!explicit_config.exists(),
"Expected explicit cache file to be deleted when explicit cache dir is provided"
);
assert!(
env_config.exists(),
"Expected env cache file to remain untouched when explicit cache dir is provided"
);
drop(mock_server);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_delete_model_does_not_download_uncached_files() {
let env_lock = acquire_env_mutex();
let temp_cache = TempDir::new().expect("Failed to create temporary cache directory");
let server = MockServer::start().await;
let provider = HuggingFaceProvider;
let model_name = "modelexpress-tests/delete-no-download";
Mock::given(method("GET"))
.and(path_regex(
r"^/api/models/modelexpress-tests/delete-no-download(?:/.*)?$",
))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": model_name,
"sha": "def5678",
"siblings": [
{"rfilename": "config.json"}
]
})))
.expect(1)
.named("delete_model should query repo info exactly once")
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(
r"^/modelexpress-tests/delete-no-download/resolve/(main|[^/]+)/config\.json$",
))
.respond_with(
ResponseTemplate::new(206)
.insert_header("etag", "\"def5678\"")
.insert_header("x-repo-commit", "def5678")
.insert_header("accept-ranges", "bytes")
.insert_header("content-length", "64")
.insert_header("content-range", "bytes 0-63/64")
.set_body_bytes(vec![0u8; 64]),
)
.expect(0)
.named("delete_model should not call Hugging Face resolve endpoint")
.mount(&server)
.await;
let temp_cache_path = temp_cache.path().to_str().expect("Expected cache path");
let _model_express_cache_guard =
EnvVarGuard::set(&env_lock, MODEL_EXPRESS_CACHE_ENV_VAR, temp_cache_path);
let _hf_endpoint_guard = EnvVarGuard::set(&env_lock, "HF_ENDPOINT", &server.uri());
let result = provider
.delete_model(model_name, temp_cache.path().to_path_buf())
.await;
assert!(result.is_ok(), "Delete should succeed when cache is empty");
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_get_model_path_trait() {
let env_lock = acquire_env_mutex();
let mock_server = MockHFServer::new(&env_lock).await;
let path = mock_server
.cache_path
.join("models--test--model")
.join("snapshots");
std::fs::create_dir_all(path.join("abc1234")).expect("Failed to create directory");
tokio::time::sleep(Duration::from_secs(1)).await;
std::fs::create_dir_all(path.join("def5678")).expect("Failed to create directory");
let provider = HuggingFaceProvider;
let result = provider
.get_model_path("test/model", mock_server.cache_path.clone())
.await;
assert!(result.is_ok());
assert_eq!(
result.expect("Failed to get model path"),
path.join("def5678")
);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_download_ignore_weights() {
let env_lock = acquire_env_mutex();
let mock_server = MockHFServer::new(&env_lock).await;
let provider = HuggingFaceProvider;
let result = provider
.download_model("test/model", Some(mock_server.cache_path.clone()), false)
.await
.expect("Failed to download model");
let files = fs::read_dir(result)
.expect("Failed to read directory")
.filter_map(Result::ok);
for file in files {
info!("File: {}", file.path().display());
assert!(!file.path().ends_with("safetensors"));
}
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_download_ignores_subdirectories() {
let env_lock = acquire_env_mutex();
let mock_server = MockHFServer::new(&env_lock).await;
let provider = HuggingFaceProvider;
let result = provider
.download_model("test/model", Some(mock_server.cache_path.clone()), false)
.await
.expect("Failed to download model");
assert!(
!result.join("subdir").exists(),
"Expected files located in sub-directories to be ignored"
);
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_download_ignores_dotfiles() {
let env_lock = acquire_env_mutex();
let temp_dir = TempDir::new().expect("Failed to create temporary directory");
let server = MockServer::start().await;
Mock::given(method("GET"))
.and(path_regex(r"^/api/models/test/model(?:/.*)?$"))
.respond_with(ResponseTemplate::new(200).set_body_json(json!({
"id": "test/model",
"sha": "def5678",
"siblings": [
{"rfilename": "config.json"},
{"rfilename": ".gitkeep"},
{"rfilename": ".gitignore"},
{"rfilename": ".hidden"}
]
})))
.mount(&server)
.await;
Mock::given(method("GET"))
.and(path_regex(
r"^/test/model/resolve/(main|[^/]+)/config\.json$",
))
.respond_with(
ResponseTemplate::new(206)
.insert_header("etag", "\"def5678\"")
.insert_header("x-repo-commit", "def5678")
.insert_header("accept-ranges", "bytes")
.insert_header("content-length", "64")
.insert_header("content-range", "bytes 0-63/64")
.set_body_bytes(vec![0u8; 64]),
)
.mount(&server)
.await;
let _hf_endpoint_guard = EnvVarGuard::set(&env_lock, "HF_ENDPOINT", &server.uri());
let provider = HuggingFaceProvider;
let result = provider
.download_model("test/model", Some(temp_dir.path().to_path_buf()), false)
.await;
assert!(
result.is_ok(),
"Download should succeed with dotfiles ignored"
);
}
#[test]
fn test_is_offline_mode() {
let env_lock = acquire_env_mutex();
{
let _offline_guard = EnvVarGuard::set(&env_lock, HF_HUB_OFFLINE_ENV_VAR, "1");
assert!(is_offline_mode());
}
{
let _offline_guard = EnvVarGuard::set(&env_lock, HF_HUB_OFFLINE_ENV_VAR, "0");
assert!(!is_offline_mode());
}
{
let _offline_guard = EnvVarGuard::remove(&env_lock, HF_HUB_OFFLINE_ENV_VAR);
assert!(!is_offline_mode());
}
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_download_model_offline_mode_with_cache() {
let env_lock = acquire_env_mutex();
let temp_dir = TempDir::new().expect("Failed to create temporary directory");
let snapshots_path = temp_dir
.path()
.join("models--test--model")
.join("snapshots")
.join("abc1234");
std::fs::create_dir_all(&snapshots_path).expect("Failed to create directory");
let _offline_guard = EnvVarGuard::set(&env_lock, HF_HUB_OFFLINE_ENV_VAR, "1");
let result = HuggingFaceProvider
.download_model("test/model", Some(temp_dir.path().into()), false)
.await;
assert!(result.is_ok());
assert!(result.expect("Expected path").ends_with("abc1234"));
}
#[tokio::test]
#[allow(clippy::await_holding_lock)]
async fn test_download_model_offline_mode_without_cache() {
let env_lock = acquire_env_mutex();
let temp_dir = TempDir::new().expect("Failed to create temporary directory");
let _offline_guard = EnvVarGuard::set(&env_lock, HF_HUB_OFFLINE_ENV_VAR, "1");
let result = HuggingFaceProvider
.download_model("nonexistent/model", Some(temp_dir.path().into()), false)
.await;
assert!(result.is_err());
assert!(
result
.expect_err("Expected error")
.to_string()
.contains("not found in cache")
);
}
}