1use std::collections::HashSet;
4use std::path::PathBuf;
5
6use anyhow::Result;
7use hf_hub::api::sync::ApiBuilder;
8
9pub fn looks_like_hf_repo(model: &str) -> bool {
11 let parts: Vec<&str> = model.split('/').collect();
12 parts.len() == 2
13 && !model.starts_with('/')
14 && !model.starts_with('.')
15 && !model.starts_with('~')
16 && !parts[0].is_empty()
17 && !parts[1].is_empty()
18}
19
20fn find_cached_model(repo_id: &str) -> Option<PathBuf> {
23 let hf_cache = hf_cache_dir()?;
24
25 let cache_dir_name = format!("models--{}", repo_id.replace('/', "--"));
27 let model_dir = hf_cache.join(&cache_dir_name);
28 let snapshots_dir = model_dir.join("snapshots");
29
30 if !snapshots_dir.exists() {
31 return None;
32 }
33
34 let mut best: Option<(PathBuf, std::time::SystemTime)> = None;
36
37 for entry in std::fs::read_dir(&snapshots_dir).ok()?.flatten() {
38 let snap_path = entry.path();
39 if !snap_path.is_dir() {
40 continue;
41 }
42
43 if !snap_path.join("config.json").exists() {
45 continue;
46 }
47
48 let is_complete = if snap_path.join("model.safetensors").exists() {
50 true
51 } else if let Ok(index_data) =
52 std::fs::read_to_string(snap_path.join("model.safetensors.index.json"))
53 {
54 if let Ok(index_json) = serde_json::from_str::<serde_json::Value>(&index_data) {
55 if let Some(weight_map) = index_json.get("weight_map").and_then(|v| v.as_object())
56 {
57 let expected: HashSet<&str> =
58 weight_map.values().filter_map(|v| v.as_str()).collect();
59 expected.iter().all(|f| snap_path.join(f).exists())
60 } else {
61 false
62 }
63 } else {
64 false
65 }
66 } else {
67 false
68 };
69
70 if is_complete {
71 let mtime = entry
72 .metadata()
73 .ok()
74 .and_then(|m| m.modified().ok())
75 .unwrap_or(std::time::SystemTime::UNIX_EPOCH);
76 if best.as_ref().map_or(true, |(_, t)| mtime > *t) {
77 best = Some((snap_path, mtime));
78 }
79 }
80 }
81
82 best.map(|(p, _)| p)
83}
84
85pub fn hf_cache_dir() -> Option<PathBuf> {
87 if let Ok(dir) = std::env::var("HF_HUB_CACHE") {
88 let p = PathBuf::from(dir);
89 if p.exists() {
90 return Some(p);
91 }
92 }
93 if let Ok(dir) = std::env::var("HF_HOME") {
94 let p = PathBuf::from(dir).join("hub");
95 if p.exists() {
96 return Some(p);
97 }
98 }
99 let home = dirs::home_dir()?;
100 let p = home.join(".cache/huggingface/hub");
101 if p.exists() {
102 Some(p)
103 } else {
104 None
105 }
106}
107
108pub fn ensure_model_downloaded(repo_id: &str) -> Result<PathBuf> {
112 if let Some(cached_path) = find_cached_model(repo_id) {
114 log::info!(
115 "model '{}' found in cache at {}",
116 repo_id,
117 cached_path.display()
118 );
119 return Ok(cached_path);
120 }
121
122 log::info!(
123 "downloading model '{}' from HuggingFace Hub...",
124 repo_id
125 );
126
127 let mut builder = ApiBuilder::new().with_progress(true);
128
129 if let Ok(cache_dir) = std::env::var("HF_HUB_CACHE") {
132 builder = builder.with_cache_dir(PathBuf::from(cache_dir));
133 }
134
135 let api = builder.build()?;
136 let repo = api.model(repo_id.to_string());
137
138 log::info!("downloading config.json ...");
140 repo.download("config.json")
141 .map_err(|e| anyhow!("failed to download config.json from '{}': {}", repo_id, e))?;
142
143 log::info!("downloading tokenizer.json ...");
145 repo.download("tokenizer.json")
146 .map_err(|e| anyhow!("failed to download tokenizer.json from '{}': {}", repo_id, e))?;
147
148 let snapshot_dir = if let Ok(index_path) = repo.download("model.safetensors.index.json") {
150 log::info!("found sharded model, parsing index...");
151
152 let index_data = std::fs::read(&index_path)?;
153 let index_json: serde_json::Value = serde_json::from_slice(&index_data)?;
154 let weight_map = index_json
155 .get("weight_map")
156 .and_then(|v| v.as_object())
157 .ok_or_else(|| anyhow!("no weight_map in model.safetensors.index.json"))?;
158
159 let mut shard_files = std::collections::HashSet::new();
160 for value in weight_map.values() {
161 if let Some(file) = value.as_str() {
162 shard_files.insert(file.to_string());
163 }
164 }
165
166 log::info!("downloading {} shard files...", shard_files.len());
167 for (i, shard) in shard_files.iter().enumerate() {
168 log::info!("[{}/{}] downloading {} ...", i + 1, shard_files.len(), shard);
169 repo.download(shard).map_err(|e| {
170 anyhow!(
171 "failed to download shard '{}' from '{}': {}",
172 shard,
173 repo_id,
174 e
175 )
176 })?;
177 }
178
179 index_path.parent().unwrap().to_path_buf()
180 } else {
181 log::info!("downloading model.safetensors ...");
182 let model_path = repo.download("model.safetensors").map_err(|e| {
183 anyhow!(
184 "failed to download model from '{}': no index.json and no model.safetensors found: {}",
185 repo_id,
186 e
187 )
188 })?;
189 model_path.parent().unwrap().to_path_buf()
190 };
191
192 log::info!("model files ready at {}", snapshot_dir.display());
193 Ok(snapshot_dir)
194}