cake_core/utils/
models.rs1use std::collections::HashSet;
4use std::path::{Path, PathBuf};
5
6use anyhow::Result;
7
8#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum ModelStatus {
11 Complete,
13 Partial {
15 have: usize,
16 total: usize,
17 },
18}
19
20impl std::fmt::Display for ModelStatus {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 ModelStatus::Complete => write!(f, "complete"),
24 ModelStatus::Partial { have, total } => {
25 write!(f, "partial ({}/{} shards)", have, total)
26 }
27 }
28 }
29}
30
31#[derive(Debug, Clone)]
33pub struct LocalModel {
34 pub name: String,
36 pub path: PathBuf,
38 pub source: ModelSource,
40 pub status: ModelStatus,
42 pub size_bytes: u64,
44}
45
46#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum ModelSource {
49 HuggingFaceCache,
51 ClusterCache { cluster_hash: String },
53 Local,
55}
56
57impl std::fmt::Display for ModelSource {
58 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59 match self {
60 ModelSource::HuggingFaceCache => write!(f, "huggingface"),
61 ModelSource::ClusterCache { cluster_hash } => {
62 write!(f, "cluster ({})", cluster_hash)
63 }
64 ModelSource::Local => write!(f, "local"),
65 }
66 }
67}
68
69pub fn list_models() -> Result<Vec<LocalModel>> {
71 let mut models = Vec::new();
72
73 if let Some(hf_cache) = hf_cache_dir() {
75 scan_hf_cache(&hf_cache, &mut models)?;
76 }
77
78 if let Some(cake_cache) = cake_cache_dir() {
80 scan_cake_cache(&cake_cache, &mut models)?;
81 }
82
83 models.sort_by(|a, b| {
85 let status_ord = |s: &ModelStatus| match s {
86 ModelStatus::Complete => 0,
87 ModelStatus::Partial { .. } => 1,
88 };
89 status_ord(&a.status)
90 .cmp(&status_ord(&b.status))
91 .then_with(|| a.name.cmp(&b.name))
92 });
93
94 Ok(models)
95}
96
97fn check_model_dir(dir: &Path) -> Option<(ModelStatus, u64)> {
99 if !dir.join("config.json").exists() {
101 return None;
102 }
103
104 let mut total_size: u64 = 0;
105
106 for name in &["config.json", "tokenizer.json"] {
108 let p = dir.join(name);
109 if p.exists() {
110 total_size += std::fs::metadata(&p).map(|m| m.len()).unwrap_or(0);
111 }
112 }
113
114 let index_path = dir.join("model.safetensors.index.json");
115 if index_path.exists() {
116 total_size += std::fs::metadata(&index_path)
117 .map(|m| m.len())
118 .unwrap_or(0);
119
120 let index_data = std::fs::read_to_string(&index_path).ok()?;
122 let index_json: serde_json::Value = serde_json::from_str(&index_data).ok()?;
123 let weight_map = index_json.get("weight_map")?.as_object()?;
124
125 let mut expected_shards: HashSet<String> = HashSet::new();
126 for value in weight_map.values() {
127 if let Some(file) = value.as_str() {
128 expected_shards.insert(file.to_string());
129 }
130 }
131
132 let total = expected_shards.len();
133 let mut have = 0;
134 for shard in &expected_shards {
135 let shard_path = dir.join(shard);
136 if shard_path.exists() {
137 have += 1;
138 total_size += std::fs::metadata(&shard_path)
139 .map(|m| m.len())
140 .unwrap_or(0);
141 }
142 }
143
144 let status = if have == total {
145 ModelStatus::Complete
146 } else {
147 ModelStatus::Partial { have, total }
148 };
149
150 Some((status, total_size))
151 } else {
152 let single = dir.join("model.safetensors");
154 if single.exists() {
155 total_size += std::fs::metadata(&single)
156 .map(|m| m.len())
157 .unwrap_or(0);
158 Some((ModelStatus::Complete, total_size))
159 } else {
160 Some((ModelStatus::Partial { have: 0, total: 1 }, total_size))
162 }
163 }
164}
165
166fn hf_cache_dir() -> Option<PathBuf> {
168 super::hf::hf_cache_dir()
169}
170
171fn cake_cache_dir() -> Option<PathBuf> {
173 let cache = dirs::cache_dir()
174 .unwrap_or_else(|| PathBuf::from("/tmp"))
175 .join("cake");
176 if cache.exists() {
177 Some(cache)
178 } else {
179 None
180 }
181}
182
183fn scan_hf_cache(hf_cache: &Path, models: &mut Vec<LocalModel>) -> Result<()> {
185 let entries = match std::fs::read_dir(hf_cache) {
186 Ok(e) => e,
187 Err(_) => return Ok(()),
188 };
189
190 for entry in entries.flatten() {
191 let dir_name = entry.file_name().to_string_lossy().to_string();
192 if !dir_name.starts_with("models--") {
194 continue;
195 }
196
197 let snapshots_dir = entry.path().join("snapshots");
198 if !snapshots_dir.exists() {
199 continue;
200 }
201
202 let model_name = dir_name
204 .strip_prefix("models--")
205 .unwrap_or(&dir_name)
206 .replacen("--", "/", 1);
207
208 let snapshot_entries = match std::fs::read_dir(&snapshots_dir) {
210 Ok(e) => e,
211 Err(_) => continue,
212 };
213
214 for snap_entry in snapshot_entries.flatten() {
215 let snap_path = snap_entry.path();
216 if !snap_path.is_dir() {
217 continue;
218 }
219
220 if let Some((status, size_bytes)) = check_model_dir(&snap_path) {
221 models.push(LocalModel {
222 name: model_name.clone(),
223 path: snap_path,
224 source: ModelSource::HuggingFaceCache,
225 status,
226 size_bytes,
227 });
228 }
229 }
230 }
231
232 Ok(())
233}
234
235fn scan_cake_cache(cake_cache: &Path, models: &mut Vec<LocalModel>) -> Result<()> {
237 let entries = match std::fs::read_dir(cake_cache) {
238 Ok(e) => e,
239 Err(_) => return Ok(()),
240 };
241
242 for entry in entries.flatten() {
243 let dir = entry.path();
244 if !dir.is_dir() {
245 continue;
246 }
247
248 let cluster_hash = entry.file_name().to_string_lossy().to_string();
249
250 if let Some((status, size_bytes)) = check_model_dir(&dir) {
251 let name = read_model_name_from_config(&dir)
253 .unwrap_or_else(|| format!("cluster:{}", &cluster_hash));
254
255 models.push(LocalModel {
256 name,
257 path: dir,
258 source: ModelSource::ClusterCache { cluster_hash },
259 status,
260 size_bytes,
261 });
262 }
263 }
264
265 Ok(())
266}
267
268fn read_model_name_from_config(dir: &Path) -> Option<String> {
270 let config = std::fs::read_to_string(dir.join("config.json")).ok()?;
271 let json: serde_json::Value = serde_json::from_str(&config).ok()?;
272
273 json.get("_name_or_path")
275 .and_then(|v| v.as_str())
276 .map(|s| s.to_string())
277 .or_else(|| {
278 json.get("model_type")
279 .and_then(|v| v.as_str())
280 .map(|s| s.to_string())
281 })
282}