dynamo_llm/
hub.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::env;
5use std::path::{Path, PathBuf};
6
7use modelexpress_client::{
8    Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
9};
10use modelexpress_common::download as mx;
11
12/// Example: export MODEL_EXPRESS_URL=http://localhost:8001
13const MODEL_EXPRESS_ENDPOINT_ENV_VAR: &str = "MODEL_EXPRESS_URL";
14
15/// Download a model using ModelExpress client. The client first requests for the model
16/// from the server and fallbacks to direct download in case of server failure.
17/// If ignore_weights is true, model weight files will be skipped
18/// Returns the path to the model files
19pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
20    let name = name.as_ref();
21    let model_name = name.display().to_string();
22
23    let mut config: MxClientConfig = MxClientConfig::default();
24    if let Ok(endpoint) = env::var(MODEL_EXPRESS_ENDPOINT_ENV_VAR) {
25        config = config.with_endpoint(endpoint);
26    }
27
28    let result = match MxClient::new(config).await {
29        Ok(mut client) => {
30            tracing::info!("Successfully connected to ModelExpress server");
31            match client
32                .request_model_with_provider_and_fallback(
33                    &model_name,
34                    MxModelProvider::HuggingFace,
35                    ignore_weights,
36                )
37                .await
38            {
39                Ok(()) => {
40                    tracing::info!("Server download succeeded for model: {model_name}");
41                    match client.get_model_path(&model_name).await {
42                        Ok(path) => Ok(path),
43                        Err(e) => {
44                            tracing::warn!(
45                                "Failed to resolve local model path after server download for '{model_name}': {e}. \
46                                Falling back to direct download."
47                            );
48                            mx_download_direct(&model_name, ignore_weights).await
49                        }
50                    }
51                }
52                Err(e) => {
53                    tracing::warn!(
54                        "Server download failed for model '{model_name}': {e}. Falling back to direct download."
55                    );
56                    mx_download_direct(&model_name, ignore_weights).await
57                }
58            }
59        }
60        Err(e) => {
61            tracing::warn!("Cannot connect to ModelExpress server: {e}. Using direct download.");
62            mx_download_direct(&model_name, ignore_weights).await
63        }
64    };
65
66    match result {
67        Ok(path) => {
68            tracing::info!("ModelExpress download completed successfully for model: {model_name}");
69            Ok(path)
70        }
71        Err(e) => {
72            tracing::warn!("ModelExpress download failed for model '{model_name}': {e}");
73            Err(e)
74        }
75    }
76}
77
78// Direct download using the ModelExpress client.
79async fn mx_download_direct(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
80    let cache_dir = get_model_express_cache_dir();
81    mx::download_model(
82        model_name,
83        MxModelProvider::HuggingFace,
84        Some(cache_dir),
85        ignore_weights,
86    )
87    .await
88}
89
90// TODO: remove in the future. This is a temporary workaround to find common
91// cache directory between client and server.
92fn get_model_express_cache_dir() -> PathBuf {
93    if let Ok(cache_path) = env::var("HF_HUB_CACHE") {
94        return PathBuf::from(cache_path);
95    }
96
97    if let Ok(cache_path) = env::var("MODEL_EXPRESS_CACHE_PATH") {
98        return PathBuf::from(cache_path);
99    }
100    let home = env::var("HOME")
101        .or_else(|_| env::var("USERPROFILE"))
102        .unwrap_or_else(|_| ".".to_string());
103
104    PathBuf::from(home).join(".cache/huggingface/hub")
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    #[tokio::test]
112    async fn test_from_hf_with_model_express() {
113        let test_path = PathBuf::from("test-model");
114        let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
115    }
116
117    #[test]
118    fn test_get_model_express_cache_dir() {
119        let cache_dir = get_model_express_cache_dir();
120        assert!(!cache_dir.to_string_lossy().is_empty());
121        assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
122    }
123}