Skip to main content

cake_core/utils/
hf.rs

1//! HuggingFace Hub integration for automatic model downloading.
2
3use std::collections::HashSet;
4use std::path::PathBuf;
5
6use anyhow::Result;
7use hf_hub::api::sync::ApiBuilder;
8
9/// Returns true if the string looks like a HuggingFace repo ID (e.g., "Qwen/Qwen2.5-Coder-1.5B-Instruct").
10pub fn looks_like_hf_repo(model: &str) -> bool {
11    let parts: Vec<&str> = model.split('/').collect();
12    parts.len() == 2
13        && !model.starts_with('/')
14        && !model.starts_with('.')
15        && !model.starts_with('~')
16        && !parts[0].is_empty()
17        && !parts[1].is_empty()
18}
19
20/// Check the HuggingFace cache for an already-downloaded complete model.
21/// Returns the snapshot directory path if found with all shards present.
22fn find_cached_model(repo_id: &str) -> Option<PathBuf> {
23    let hf_cache = hf_cache_dir()?;
24
25    // HF cache dirs look like "models--org--model-name"
26    let cache_dir_name = format!("models--{}", repo_id.replace('/', "--"));
27    let model_dir = hf_cache.join(&cache_dir_name);
28    let snapshots_dir = model_dir.join("snapshots");
29
30    if !snapshots_dir.exists() {
31        return None;
32    }
33
34    // Check each snapshot (usually just one, pick the newest)
35    let mut best: Option<(PathBuf, std::time::SystemTime)> = None;
36
37    for entry in std::fs::read_dir(&snapshots_dir).ok()?.flatten() {
38        let snap_path = entry.path();
39        if !snap_path.is_dir() {
40            continue;
41        }
42
43        // Must have config.json
44        if !snap_path.join("config.json").exists() {
45            continue;
46        }
47
48        // Check model completeness
49        let is_complete = if snap_path.join("model.safetensors").exists() {
50            true
51        } else if let Ok(index_data) =
52            std::fs::read_to_string(snap_path.join("model.safetensors.index.json"))
53        {
54            if let Ok(index_json) = serde_json::from_str::<serde_json::Value>(&index_data) {
55                if let Some(weight_map) = index_json.get("weight_map").and_then(|v| v.as_object())
56                {
57                    let expected: HashSet<&str> =
58                        weight_map.values().filter_map(|v| v.as_str()).collect();
59                    expected.iter().all(|f| snap_path.join(f).exists())
60                } else {
61                    false
62                }
63            } else {
64                false
65            }
66        } else {
67            false
68        };
69
70        if is_complete {
71            let mtime = entry
72                .metadata()
73                .ok()
74                .and_then(|m| m.modified().ok())
75                .unwrap_or(std::time::SystemTime::UNIX_EPOCH);
76            if best.as_ref().map_or(true, |(_, t)| mtime > *t) {
77                best = Some((snap_path, mtime));
78            }
79        }
80    }
81
82    best.map(|(p, _)| p)
83}
84
85/// Return the HuggingFace hub cache directory if it exists.
86pub fn hf_cache_dir() -> Option<PathBuf> {
87    if let Ok(dir) = std::env::var("HF_HUB_CACHE") {
88        let p = PathBuf::from(dir);
89        if p.exists() {
90            return Some(p);
91        }
92    }
93    if let Ok(dir) = std::env::var("HF_HOME") {
94        let p = PathBuf::from(dir).join("hub");
95        if p.exists() {
96            return Some(p);
97        }
98    }
99    let home = dirs::home_dir()?;
100    let p = home.join(".cache/huggingface/hub");
101    if p.exists() {
102        Some(p)
103    } else {
104        None
105    }
106}
107
108/// Downloads all required model files from HuggingFace Hub.
109/// Returns the local directory path containing the downloaded files.
110/// Checks local cache first — returns instantly if model already downloaded.
111pub fn ensure_model_downloaded(repo_id: &str) -> Result<PathBuf> {
112    // Check local cache first
113    if let Some(cached_path) = find_cached_model(repo_id) {
114        log::info!(
115            "model '{}' found in cache at {}",
116            repo_id,
117            cached_path.display()
118        );
119        return Ok(cached_path);
120    }
121
122    log::info!(
123        "downloading model '{}' from HuggingFace Hub...",
124        repo_id
125    );
126
127    let mut builder = ApiBuilder::new().with_progress(true);
128
129    // Use explicit cache dir if HF_HUB_CACHE is set, to avoid hf_hub bugs
130    // with env var handling (lock files created at wrong path).
131    if let Ok(cache_dir) = std::env::var("HF_HUB_CACHE") {
132        builder = builder.with_cache_dir(PathBuf::from(cache_dir));
133    }
134
135    let api = builder.build()?;
136    let repo = api.model(repo_id.to_string());
137
138    // Download config.json first — validates repo access (fails fast on auth errors).
139    log::info!("downloading config.json ...");
140    repo.download("config.json")
141        .map_err(|e| anyhow!("failed to download config.json from '{}': {}", repo_id, e))?;
142
143    // Download tokenizer.
144    log::info!("downloading tokenizer.json ...");
145    repo.download("tokenizer.json")
146        .map_err(|e| anyhow!("failed to download tokenizer.json from '{}': {}", repo_id, e))?;
147
148    // Try sharded model first (model.safetensors.index.json), fall back to single file.
149    let snapshot_dir = if let Ok(index_path) = repo.download("model.safetensors.index.json") {
150        log::info!("found sharded model, parsing index...");
151
152        let index_data = std::fs::read(&index_path)?;
153        let index_json: serde_json::Value = serde_json::from_slice(&index_data)?;
154        let weight_map = index_json
155            .get("weight_map")
156            .and_then(|v| v.as_object())
157            .ok_or_else(|| anyhow!("no weight_map in model.safetensors.index.json"))?;
158
159        let mut shard_files = std::collections::HashSet::new();
160        for value in weight_map.values() {
161            if let Some(file) = value.as_str() {
162                shard_files.insert(file.to_string());
163            }
164        }
165
166        log::info!("downloading {} shard files...", shard_files.len());
167        for (i, shard) in shard_files.iter().enumerate() {
168            log::info!("[{}/{}] downloading {} ...", i + 1, shard_files.len(), shard);
169            repo.download(shard).map_err(|e| {
170                anyhow!(
171                    "failed to download shard '{}' from '{}': {}",
172                    shard,
173                    repo_id,
174                    e
175                )
176            })?;
177        }
178
179        index_path.parent().unwrap().to_path_buf()
180    } else {
181        log::info!("downloading model.safetensors ...");
182        let model_path = repo.download("model.safetensors").map_err(|e| {
183            anyhow!(
184                "failed to download model from '{}': no index.json and no model.safetensors found: {}",
185                repo_id,
186                e
187            )
188        })?;
189        model_path.parent().unwrap().to_path_buf()
190    };
191
192    log::info!("model files ready at {}", snapshot_dir.display());
193    Ok(snapshot_dir)
194}