Skip to main content

cake_core/utils/
models.rs

1//! Scan local directories for cached/downloaded models and report their status.
2
3use std::collections::HashSet;
4use std::path::{Path, PathBuf};
5
6use anyhow::Result;
7
8/// Status of a locally cached model.
9#[derive(Debug, Clone, PartialEq, Eq)]
10pub enum ModelStatus {
11    /// All safetensors shards present — model is ready for master or standalone use.
12    Complete,
13    /// Only a subset of shards present — typically a worker split or partial push.
14    Partial {
15        have: usize,
16        total: usize,
17    },
18}
19
20impl std::fmt::Display for ModelStatus {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            ModelStatus::Complete => write!(f, "complete"),
24            ModelStatus::Partial { have, total } => {
25                write!(f, "partial ({}/{} shards)", have, total)
26            }
27        }
28    }
29}
30
31/// A discovered local model.
32#[derive(Debug, Clone)]
33pub struct LocalModel {
34    /// Human-readable model name (e.g. "Qwen/Qwen2.5-Coder-1.5B-Instruct").
35    pub name: String,
36    /// Absolute path to the model directory.
37    pub path: PathBuf,
38    /// Where this model was found.
39    pub source: ModelSource,
40    /// Completeness status.
41    pub status: ModelStatus,
42    /// Total size of model files on disk (bytes).
43    pub size_bytes: u64,
44}
45
46/// Where a model was discovered.
47#[derive(Debug, Clone, PartialEq, Eq)]
48pub enum ModelSource {
49    /// Downloaded via `hf-hub` (lives in `~/.cache/huggingface/hub/`).
50    HuggingFaceCache,
51    /// Received from a master during zero-config setup (lives in `~/.cache/cake/<hash>/`).
52    ClusterCache { cluster_hash: String },
53    /// A user-provided local directory.
54    Local,
55}
56
57impl std::fmt::Display for ModelSource {
58    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
59        match self {
60            ModelSource::HuggingFaceCache => write!(f, "huggingface"),
61            ModelSource::ClusterCache { cluster_hash } => {
62                write!(f, "cluster ({})", cluster_hash)
63            }
64            ModelSource::Local => write!(f, "local"),
65        }
66    }
67}
68
69/// Scan all known locations for models and return a list of discovered models.
70pub fn list_models() -> Result<Vec<LocalModel>> {
71    let mut models = Vec::new();
72
73    // 1. Scan HuggingFace cache
74    if let Some(hf_cache) = hf_cache_dir() {
75        scan_hf_cache(&hf_cache, &mut models)?;
76    }
77
78    // 2. Scan zero-config cluster cache
79    if let Some(cake_cache) = cake_cache_dir() {
80        scan_cake_cache(&cake_cache, &mut models)?;
81    }
82
83    // Sort: complete first, then by name
84    models.sort_by(|a, b| {
85        let status_ord = |s: &ModelStatus| match s {
86            ModelStatus::Complete => 0,
87            ModelStatus::Partial { .. } => 1,
88        };
89        status_ord(&a.status)
90            .cmp(&status_ord(&b.status))
91            .then_with(|| a.name.cmp(&b.name))
92    });
93
94    Ok(models)
95}
96
97/// Check a single directory and return its model status, or None if it's not a model dir.
98fn check_model_dir(dir: &Path) -> Option<(ModelStatus, u64)> {
99    // Must have config.json to be considered a model
100    if !dir.join("config.json").exists() {
101        return None;
102    }
103
104    let mut total_size: u64 = 0;
105
106    // Count config + tokenizer sizes
107    for name in &["config.json", "tokenizer.json"] {
108        let p = dir.join(name);
109        if p.exists() {
110            total_size += std::fs::metadata(&p).map(|m| m.len()).unwrap_or(0);
111        }
112    }
113
114    let index_path = dir.join("model.safetensors.index.json");
115    if index_path.exists() {
116        total_size += std::fs::metadata(&index_path)
117            .map(|m| m.len())
118            .unwrap_or(0);
119
120        // Sharded model — check which shards are present
121        let index_data = std::fs::read_to_string(&index_path).ok()?;
122        let index_json: serde_json::Value = serde_json::from_str(&index_data).ok()?;
123        let weight_map = index_json.get("weight_map")?.as_object()?;
124
125        let mut expected_shards: HashSet<String> = HashSet::new();
126        for value in weight_map.values() {
127            if let Some(file) = value.as_str() {
128                expected_shards.insert(file.to_string());
129            }
130        }
131
132        let total = expected_shards.len();
133        let mut have = 0;
134        for shard in &expected_shards {
135            let shard_path = dir.join(shard);
136            if shard_path.exists() {
137                have += 1;
138                total_size += std::fs::metadata(&shard_path)
139                    .map(|m| m.len())
140                    .unwrap_or(0);
141            }
142        }
143
144        let status = if have == total {
145            ModelStatus::Complete
146        } else {
147            ModelStatus::Partial { have, total }
148        };
149
150        Some((status, total_size))
151    } else {
152        // Single safetensors file
153        let single = dir.join("model.safetensors");
154        if single.exists() {
155            total_size += std::fs::metadata(&single)
156                .map(|m| m.len())
157                .unwrap_or(0);
158            Some((ModelStatus::Complete, total_size))
159        } else {
160            // Has config but no weights at all
161            Some((ModelStatus::Partial { have: 0, total: 1 }, total_size))
162        }
163    }
164}
165
166/// Return the HuggingFace hub cache directory if it exists.
167fn hf_cache_dir() -> Option<PathBuf> {
168    super::hf::hf_cache_dir()
169}
170
171/// Return the Cake cluster cache directory if it exists.
172fn cake_cache_dir() -> Option<PathBuf> {
173    let cache = dirs::cache_dir()
174        .unwrap_or_else(|| PathBuf::from("/tmp"))
175        .join("cake");
176    if cache.exists() {
177        Some(cache)
178    } else {
179        None
180    }
181}
182
183/// Scan the HuggingFace hub cache for model snapshots.
184fn scan_hf_cache(hf_cache: &Path, models: &mut Vec<LocalModel>) -> Result<()> {
185    let entries = match std::fs::read_dir(hf_cache) {
186        Ok(e) => e,
187        Err(_) => return Ok(()),
188    };
189
190    for entry in entries.flatten() {
191        let dir_name = entry.file_name().to_string_lossy().to_string();
192        // HF cache dirs look like "models--org--model-name"
193        if !dir_name.starts_with("models--") {
194            continue;
195        }
196
197        let snapshots_dir = entry.path().join("snapshots");
198        if !snapshots_dir.exists() {
199            continue;
200        }
201
202        // Parse model name from dir: "models--Qwen--Qwen2.5-Coder-1.5B-Instruct" → "Qwen/Qwen2.5-Coder-1.5B-Instruct"
203        let model_name = dir_name
204            .strip_prefix("models--")
205            .unwrap_or(&dir_name)
206            .replacen("--", "/", 1);
207
208        // Check each snapshot (usually just one)
209        let snapshot_entries = match std::fs::read_dir(&snapshots_dir) {
210            Ok(e) => e,
211            Err(_) => continue,
212        };
213
214        for snap_entry in snapshot_entries.flatten() {
215            let snap_path = snap_entry.path();
216            if !snap_path.is_dir() {
217                continue;
218            }
219
220            if let Some((status, size_bytes)) = check_model_dir(&snap_path) {
221                models.push(LocalModel {
222                    name: model_name.clone(),
223                    path: snap_path,
224                    source: ModelSource::HuggingFaceCache,
225                    status,
226                    size_bytes,
227                });
228            }
229        }
230    }
231
232    Ok(())
233}
234
235/// Scan the Cake cluster cache for worker-received models.
236fn scan_cake_cache(cake_cache: &Path, models: &mut Vec<LocalModel>) -> Result<()> {
237    let entries = match std::fs::read_dir(cake_cache) {
238        Ok(e) => e,
239        Err(_) => return Ok(()),
240    };
241
242    for entry in entries.flatten() {
243        let dir = entry.path();
244        if !dir.is_dir() {
245            continue;
246        }
247
248        let cluster_hash = entry.file_name().to_string_lossy().to_string();
249
250        if let Some((status, size_bytes)) = check_model_dir(&dir) {
251            // Try to extract a model name from config.json
252            let name = read_model_name_from_config(&dir)
253                .unwrap_or_else(|| format!("cluster:{}", &cluster_hash));
254
255            models.push(LocalModel {
256                name,
257                path: dir,
258                source: ModelSource::ClusterCache { cluster_hash },
259                status,
260                size_bytes,
261            });
262        }
263    }
264
265    Ok(())
266}
267
268/// Try to read a model name from config.json (e.g. _name_or_path field).
269fn read_model_name_from_config(dir: &Path) -> Option<String> {
270    let config = std::fs::read_to_string(dir.join("config.json")).ok()?;
271    let json: serde_json::Value = serde_json::from_str(&config).ok()?;
272
273    // Try common fields that identify the model
274    json.get("_name_or_path")
275        .and_then(|v| v.as_str())
276        .map(|s| s.to_string())
277        .or_else(|| {
278            json.get("model_type")
279                .and_then(|v| v.as_str())
280                .map(|s| s.to_string())
281        })
282}