use std::env;
use std::path::{Path, PathBuf};
use hf_hub::Cache;
use modelexpress_client::{
Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
};
use modelexpress_common::download as mx;
use dynamo_runtime::config::environment_names::model as env_model;
fn get_cached_model_path(model_name: &str, ignore_weights: bool) -> Option<PathBuf> {
let cache = Cache::new(get_model_express_cache_dir());
let repo = cache.model(model_name.to_string());
let config_path = repo.get("config.json")?;
let has_tokenizer = repo.get("tokenizer.json").is_some()
|| repo.get("tokenizer_config.json").is_some()
|| repo.get("tiktoken.model").is_some()
|| has_tiktoken_file(config_path.parent()?);
if !has_tokenizer {
return None;
}
if !ignore_weights {
let has_weights = repo.get("model.safetensors").is_some()
|| repo.get("pytorch_model.bin").is_some()
|| repo.get("model.safetensors.index.json").is_some()
|| repo.get("pytorch_model.bin.index.json").is_some();
if !has_weights {
return None;
}
}
let snapshot_path = config_path.parent()?.to_path_buf();
tracing::info!("Found cached model '{model_name}' at {snapshot_path:?}, skipping download");
Some(snapshot_path)
}
fn has_tiktoken_file(dir: &Path) -> bool {
std::fs::read_dir(dir)
.into_iter()
.flatten()
.flatten()
.any(|e| e.path().extension().is_some_and(|ext| ext == "tiktoken"))
}
fn is_offline_mode() -> bool {
env::var(env_model::huggingface::HF_HUB_OFFLINE)
.map(|v| v == "1" || v.to_lowercase() == "true")
.unwrap_or(false)
}
pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let name = name.as_ref();
let model_name = name.display().to_string();
if is_offline_mode() {
if let Some(cached_path) = get_cached_model_path(&model_name, ignore_weights) {
tracing::info!(
"Offline mode: using cached model '{model_name}' without API validation"
);
return Ok(cached_path);
}
tracing::warn!(
"Offline mode enabled but model '{model_name}' not found in cache, attempting download anyway"
);
}
let mut config: MxClientConfig = MxClientConfig::default();
if let Ok(endpoint) = env::var(env_model::model_express::MODEL_EXPRESS_URL) {
config = config.with_endpoint(endpoint);
}
let result = match MxClient::new(config).await {
Ok(mut client) => {
tracing::info!("Successfully connected to ModelExpress server");
match client
.request_model_with_provider_and_fallback(
&model_name,
MxModelProvider::HuggingFace,
ignore_weights,
)
.await
{
Ok(()) => {
tracing::info!("Server download succeeded for model: {model_name}");
match client.get_model_path(&model_name).await {
Ok(path) => Ok(path),
Err(e) => {
tracing::warn!(
"Failed to resolve local model path after server download for '{model_name}': {e}. \
Falling back to direct download."
);
mx_download_direct(&model_name, ignore_weights).await
}
}
}
Err(e) => {
tracing::warn!(
"Server download failed for model '{model_name}': {e}. Falling back to direct download."
);
mx_download_direct(&model_name, ignore_weights).await
}
}
}
Err(e) => {
tracing::warn!("Cannot connect to ModelExpress server: {e}. Using direct download.");
mx_download_direct(&model_name, ignore_weights).await
}
};
match result {
Ok(path) => {
tracing::info!("ModelExpress download completed successfully for model: {model_name}");
Ok(path)
}
Err(e) => {
tracing::warn!("ModelExpress download failed for model '{model_name}': {e}");
Err(e)
}
}
}
async fn mx_download_direct(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
let cache_dir = get_model_express_cache_dir();
mx::download_model(
model_name,
MxModelProvider::HuggingFace,
Some(cache_dir),
ignore_weights,
)
.await
}
fn get_model_express_cache_dir() -> PathBuf {
if let Ok(cache_path) = env::var(env_model::huggingface::HF_HUB_CACHE) {
return PathBuf::from(cache_path);
}
if let Ok(hf_home) = env::var(env_model::huggingface::HF_HOME) {
return PathBuf::from(hf_home).join("hub");
}
if let Ok(cache_path) = env::var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH) {
return PathBuf::from(cache_path);
}
let home = env::var("HOME")
.or_else(|_| env::var("USERPROFILE"))
.unwrap_or_else(|_| ".".to_string());
PathBuf::from(home).join(".cache/huggingface/hub")
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_from_hf_with_model_express() {
let test_path = PathBuf::from("test-model");
let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
}
#[test]
fn test_get_model_express_cache_dir() {
let cache_dir = get_model_express_cache_dir();
assert!(!cache_dir.to_string_lossy().is_empty());
assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
}
#[serial_test::serial]
#[test]
fn test_get_model_express_cache_dir_with_hf_home() {
unsafe {
env::remove_var(env_model::huggingface::HF_HUB_CACHE);
env::remove_var(env_model::model_express::MODEL_EXPRESS_CACHE_PATH);
env::set_var(env_model::huggingface::HF_HOME, "/custom/cache/path");
let cache_dir = get_model_express_cache_dir();
assert_eq!(cache_dir, PathBuf::from("/custom/cache/path/hub"));
env::remove_var(env_model::huggingface::HF_HOME);
}
}
}