use std::env;
use std::path::{Path, PathBuf};
use hf_hub::Cache;
use modelexpress_client::{
Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
};
use modelexpress_common::download as mx;
use dynamo_runtime::config::environment_names::model as env_model;
fn get_cached_model_path(model_name: &str, ignore_weights: bool) -> Option<PathBuf> {
let cache = Cache::new(get_model_express_cache_dir());
let repo = cache.model(model_name.to_string());
let config_path = repo.get("config.json")?;
let has_tokenizer = repo.get("tokenizer.json").is_some()
|| repo.get("tiktoken.model").is_some()
|| has_tiktoken_file(config_path.parent()?);
if !has_tokenizer {
return None;
}
if !ignore_weights {
let has_weights = repo.get("model.safetensors").is_some()
|| repo.get("pytorch_model.bin").is_some()
|| repo
.get("model.safetensors.index.json")
.is_some_and(|p| shard_files_present(&p))
|| repo
.get("pytorch_model.bin.index.json")
.is_some_and(|p| shard_files_present(&p));
if !has_weights {
return None;
}
}
let snapshot_path = config_path.parent()?.to_path_buf();
tracing::info!("Found cached model '{model_name}' at {snapshot_path:?}, skipping download");
Some(snapshot_path)
}
fn has_tiktoken_file(dir: &Path) -> bool {
std::fs::read_dir(dir)
.into_iter()
.flatten()
.flatten()
.any(|e| e.path().extension().is_some_and(|ext| ext == "tiktoken"))
}
fn shard_files_present(index_path: &Path) -> bool {
let Some(snapshot_dir) = index_path.parent() else {
return false;
};
let Ok(contents) = std::fs::read_to_string(index_path) else {
return false;
};
let Ok(value) = serde_json::from_str::<serde_json::Value>(&contents) else {
return false;
};
let Some(weight_map) = value.get("weight_map").and_then(|v| v.as_object()) else {
return false;
};
let shards: std::collections::HashSet<&str> =
weight_map.values().filter_map(|v| v.as_str()).collect();
if shards.is_empty() {
return false;
}
shards.iter().all(|s| snapshot_dir.join(s).exists())
}
fn is_offline_mode() -> bool {
env::var(env_model::huggingface::HF_HUB_OFFLINE)
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false)
}
fn is_no_shared_storage() -> bool {
env::var(env_model::model_express::MODEL_EXPRESS_NO_SHARED_STORAGE)
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false)
}
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref();
let model_name = name.display().to_string();
if let Some(cached_path) = get_cached_model_path(&model_name, ignore_weights) {
return Ok(cached_path);
}
if is_offline_mode() {
tracing::warn!(
"Offline mode enabled but model '{model_name}' not found in cache, attempting download anyway"
);
}
let mut config: MxClientConfig = MxClientConfig::default();
if let Ok(endpoint) = env::var(env_model::model_express::MODEL_EXPRESS_URL) {
config = config.with_endpoint(endpoint);
}
if is_no_shared_storage() {
config.cache.shared_storage = false;
}
let result = match MxClient::new(config).await {
Ok(mut client) => {
tracing::info!("Successfully connected to ModelExpress server");
match client
.request_model_with_provider_and_fallback(
&model_name,
MxModelProvider::HuggingFace,
ignore_weights,
)
.await
{
Ok(()) => {
tracing::info!("Server download succeeded for model: {model_name}");
match client
.get_model_path(&model_name, MxModelProvider::HuggingFace)
.await
{
Ok(path) => Ok(path),
Err(e) => {
tracing::warn!(
"Failed to resolve local model path after server download for '{model_name}': {e}. \
Falling back to direct download."
);
mx_download_direct(&model_name, ignore_weights).await
}
}
}
Err(e) => {
tracing::warn!(
"Server download failed for model '{model_name}': {e}. Falling back to direct download."
);
mx_download_direct(&model_name, ignore_weights).await
}
}
}
Err(e) => {
tracing::warn!("Cannot connect to ModelExpress server: {e}. Using direct download.");
mx_download_direct(&model_name, ignore_weights).await
}
};
match result {
Ok(path) => {
tracing::info!("ModelExpress download completed successfully for model: {model_name}");
Ok(path)
}
Err(e) => {
tracing::warn!("ModelExpress download failed for model '{model_name}': {e}");
Err(e)
}
}
}
async fn mx_download_direct(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let cache_dir = get_model_express_cache_dir();
mx::download_model(
model_name,
MxModelProvider::HuggingFace,
Some(cache_dir),
ignore_weights,
)
.await
}
fn get_model_express_cache_dir() -> PathBuf {
if let Ok(cache_path) = env::var(env_model::huggingface::HF_HUB_CACHE) {
return PathBuf::from(cache_path);
}
if let Ok(hf_home) = env::var(env_model::huggingface::HF_HOME) {
return PathBuf::from(hf_home).join("hub");
}
if let Ok(cache_path) = env::var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH) {
return PathBuf::from(cache_path);
}
let home = env::var("HOME")
.or_else(|_| env::var("USERPROFILE"))
.unwrap_or_else(|_| ".".to_string());
PathBuf::from(home).join(".cache/huggingface/hub")
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[tokio::test]
async fn test_from_hf_with_model_express() {
let test_path = PathBuf::from("test-model");
let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
}
#[test]
fn test_get_model_express_cache_dir() {
let cache_dir = get_model_express_cache_dir();
assert!(!cache_dir.to_string_lossy().is_empty());
assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
}
#[serial_test::serial]
#[test]
fn test_get_model_express_cache_dir_with_hf_home() {
unsafe {
env::remove_var(env_model::huggingface::HF_HUB_CACHE);
env::remove_var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH);
env::set_var(env_model::huggingface::HF_HOME, "/custom/cache/path");
let cache_dir = get_model_express_cache_dir();
assert_eq!(cache_dir, PathBuf::from("/custom/cache/path/hub"));
env::remove_var(env_model::huggingface::HF_HOME);
}
}
fn build_hf_cache(cache_root: &Path, model_name: &str, files: &[&str]) -> PathBuf {
let repo_dir = cache_root.join(format!("models--{}", model_name.replace('/', "--")));
let snapshot_hash = "0000000000000000000000000000000000000000";
let snapshot_dir = repo_dir.join("snapshots").join(snapshot_hash);
let refs_dir = repo_dir.join("refs");
fs::create_dir_all(&snapshot_dir).unwrap();
fs::create_dir_all(&refs_dir).unwrap();
fs::write(refs_dir.join("main"), snapshot_hash).unwrap();
for f in files {
fs::write(snapshot_dir.join(f), "{}").unwrap();
}
snapshot_dir
}
struct EnvGuard {
hub_cache: Option<String>,
hub_offline: Option<String>,
hf_home: Option<String>,
mx_cache_path: Option<String>,
}
impl EnvGuard {
fn with_hub_cache(path: &Path) -> Self {
let guard = Self {
hub_cache: env::var(env_model::huggingface::HF_HUB_CACHE).ok(),
hub_offline: env::var(env_model::huggingface::HF_HUB_OFFLINE).ok(),
hf_home: env::var(env_model::huggingface::HF_HOME).ok(),
mx_cache_path: env::var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH).ok(),
};
unsafe {
env::set_var(env_model::huggingface::HF_HUB_CACHE, path.to_str().unwrap());
env::remove_var(env_model::huggingface::HF_HOME);
env::remove_var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH);
env::remove_var(env_model::huggingface::HF_HUB_OFFLINE);
}
guard
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
unsafe {
restore(env_model::huggingface::HF_HUB_CACHE, &self.hub_cache);
restore(env_model::huggingface::HF_HUB_OFFLINE, &self.hub_offline);
restore(env_model::huggingface::HF_HOME, &self.hf_home);
restore(
env_model::model_express::MODEL_EXPRESS_CACHE_PATH,
&self.mx_cache_path,
);
}
}
}
unsafe fn restore(key: &str, value: &Option<String>) {
unsafe {
match value {
Some(v) => env::set_var(key, v),
None => env::remove_var(key),
}
}
}
#[serial_test::serial]
#[test]
fn test_cached_path_metadata_only_satisfies_ignore_weights_true() {
let temp = TempDir::new().unwrap();
let model = "test-org/metadata-only";
let snapshot = build_hf_cache(temp.path(), model, &["config.json", "tokenizer.json"]);
let _guard = EnvGuard::with_hub_cache(temp.path());
let with_weights = get_cached_model_path(model, false);
let no_weights = get_cached_model_path(model, true);
assert!(
with_weights.is_none(),
"metadata-only cache must NOT satisfy ignore_weights=false"
);
assert_eq!(
no_weights.as_deref(),
Some(snapshot.as_path()),
"metadata-only cache must satisfy ignore_weights=true"
);
}
#[serial_test::serial]
#[test]
fn test_cached_path_full_cache_satisfies_both_modes() {
let temp = TempDir::new().unwrap();
let model = "test-org/full-cache";
let snapshot = build_hf_cache(
temp.path(),
model,
&["config.json", "tokenizer.json", "model.safetensors"],
);
let _guard = EnvGuard::with_hub_cache(temp.path());
let with_weights = get_cached_model_path(model, false);
let no_weights = get_cached_model_path(model, true);
assert_eq!(with_weights.as_deref(), Some(snapshot.as_path()));
assert_eq!(no_weights.as_deref(), Some(snapshot.as_path()));
}
#[serial_test::serial]
#[test]
fn test_cached_path_sharded_requires_all_shard_files() {
let temp = TempDir::new().unwrap();
let model = "test-org/sharded";
let snapshot = build_hf_cache(temp.path(), model, &["config.json", "tokenizer.json"]);
fs::write(
snapshot.join("model.safetensors.index.json"),
r#"{"weight_map": {"a.weight": "model-00001-of-00002.safetensors", "b.weight": "model-00002-of-00002.safetensors"}}"#,
)
.unwrap();
let _guard = EnvGuard::with_hub_cache(temp.path());
let incomplete = get_cached_model_path(model, false);
assert!(
incomplete.is_none(),
"sharded cache without shard files must NOT satisfy ignore_weights=false"
);
fs::write(snapshot.join("model-00001-of-00002.safetensors"), "").unwrap();
fs::write(snapshot.join("model-00002-of-00002.safetensors"), "").unwrap();
let complete = get_cached_model_path(model, false);
assert_eq!(complete.as_deref(), Some(snapshot.as_path()));
}
#[serial_test::serial]
#[test]
fn test_cached_path_rejects_tokenizer_config_without_real_tokenizer() {
let temp = TempDir::new().unwrap();
let model = "test-org/tokenizer-config-only";
build_hf_cache(
temp.path(),
model,
&["config.json", "tokenizer_config.json"],
);
let _guard = EnvGuard::with_hub_cache(temp.path());
assert!(
get_cached_model_path(model, true).is_none(),
"tokenizer_config.json alone must NOT satisfy ignore_weights=true",
);
assert!(
get_cached_model_path(model, false).is_none(),
"tokenizer_config.json alone must NOT satisfy ignore_weights=false",
);
}
#[serial_test::serial]
#[tokio::test]
async fn test_from_hf_cache_first_in_online_mode() {
let temp = TempDir::new().unwrap();
let model = "test-org/cache-first-online";
let snapshot = build_hf_cache(
temp.path(),
model,
&["config.json", "tokenizer.json", "model.safetensors"],
);
let _guard = EnvGuard::with_hub_cache(temp.path());
let result = from_hf(PathBuf::from(model), false).await;
assert_eq!(
result.ok().as_deref(),
Some(snapshot.as_path()),
"from_hf must return cached path in online mode without network"
);
}
}