1use std::env;
5use std::path::{Path, PathBuf};
6
7use modelexpress_client::{
8 Client as MxClient, ClientConfig as MxClientConfig, ModelProvider as MxModelProvider,
9};
10use modelexpress_common::download as mx;
11
12const MODEL_EXPRESS_ENDPOINT_ENV_VAR: &str = "MODEL_EXPRESS_URL";
14
15pub async fn from_hf(name: impl AsRef<Path>, ignore_weights: bool) -> anyhow::Result<PathBuf> {
20 let name = name.as_ref();
21 let model_name = name.display().to_string();
22
23 let mut config: MxClientConfig = MxClientConfig::default();
24 if let Ok(endpoint) = env::var(MODEL_EXPRESS_ENDPOINT_ENV_VAR) {
25 config = config.with_endpoint(endpoint);
26 }
27
28 let result = match MxClient::new(config).await {
29 Ok(mut client) => {
30 tracing::info!("Successfully connected to ModelExpress server");
31 match client
32 .request_model_with_provider_and_fallback(
33 &model_name,
34 MxModelProvider::HuggingFace,
35 ignore_weights,
36 )
37 .await
38 {
39 Ok(()) => {
40 tracing::info!("Server download succeeded for model: {model_name}");
41 match client.get_model_path(&model_name).await {
42 Ok(path) => Ok(path),
43 Err(e) => {
44 tracing::warn!(
45 "Failed to resolve local model path after server download for '{model_name}': {e}. \
46 Falling back to direct download."
47 );
48 mx_download_direct(&model_name, ignore_weights).await
49 }
50 }
51 }
52 Err(e) => {
53 tracing::warn!(
54 "Server download failed for model '{model_name}': {e}. Falling back to direct download."
55 );
56 mx_download_direct(&model_name, ignore_weights).await
57 }
58 }
59 }
60 Err(e) => {
61 tracing::warn!("Cannot connect to ModelExpress server: {e}. Using direct download.");
62 mx_download_direct(&model_name, ignore_weights).await
63 }
64 };
65
66 match result {
67 Ok(path) => {
68 tracing::info!("ModelExpress download completed successfully for model: {model_name}");
69 Ok(path)
70 }
71 Err(e) => {
72 tracing::warn!("ModelExpress download failed for model '{model_name}': {e}");
73 Err(e)
74 }
75 }
76}
77
78async fn mx_download_direct(model_name: &str, ignore_weights: bool) -> anyhow::Result<PathBuf> {
80 let cache_dir = get_model_express_cache_dir();
81 mx::download_model(
82 model_name,
83 MxModelProvider::HuggingFace,
84 Some(cache_dir),
85 ignore_weights,
86 )
87 .await
88}
89
90fn get_model_express_cache_dir() -> PathBuf {
93 if let Ok(cache_path) = env::var("HF_HUB_CACHE") {
94 return PathBuf::from(cache_path);
95 }
96
97 if let Ok(cache_path) = env::var("MODEL_EXPRESS_CACHE_PATH") {
98 return PathBuf::from(cache_path);
99 }
100 let home = env::var("HOME")
101 .or_else(|_| env::var("USERPROFILE"))
102 .unwrap_or_else(|_| ".".to_string());
103
104 PathBuf::from(home).join(".cache/huggingface/hub")
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 #[tokio::test]
112 async fn test_from_hf_with_model_express() {
113 let test_path = PathBuf::from("test-model");
114 let _result: anyhow::Result<PathBuf> = from_hf(test_path, false).await;
115 }
116
117 #[test]
118 fn test_get_model_express_cache_dir() {
119 let cache_dir = get_model_express_cache_dir();
120 assert!(!cache_dir.to_string_lossy().is_empty());
121 assert!(cache_dir.is_absolute() || cache_dir.starts_with("."));
122 }
123}