candle_coreml/download/
git_lfs.rs1use anyhow::{Error as E, Result};
12use std::fs;
13use std::io::Read;
14use std::path::{Path, PathBuf};
15
16#[derive(Debug, Clone)]
18pub struct CleanDownloadConfig {
19 pub model_id: String,
21 pub target_dir: PathBuf,
23 pub verbose: bool,
25 pub keep_git_dir: bool,
27}
28
29impl CleanDownloadConfig {
30 pub fn for_hf_model(model_id: &str, cache_base: &Path) -> Self {
32 let model_cache_name = model_id.replace('/', "--");
33 let target_dir = cache_base.join(format!("clean-{model_cache_name}"));
34
35 Self {
36 model_id: model_id.to_string(),
37 target_dir,
38 verbose: false,
39 keep_git_dir: false,
40 }
41 }
42
43 pub fn with_verbose(mut self, verbose: bool) -> Self {
45 self.verbose = verbose;
46 self
47 }
48
49 pub fn with_keep_git(mut self, keep_git: bool) -> Self {
51 self.keep_git_dir = keep_git;
52 self
53 }
54}
55
56#[derive(Debug, Clone)]
58pub struct LfsPointer {
59 pub file_path: PathBuf,
60 pub version: String,
61 pub oid: String,
62 pub size: u64,
63}
64
65pub fn download_hf_model_clean(config: &CleanDownloadConfig) -> Result<PathBuf> {
67 if config.verbose {
68 println!("🚀 Starting clean git2+LFS download");
69 println!("📦 Model: {}", config.model_id);
70 println!("📁 Target: {}", config.target_dir.display());
71 }
72
73 let repo_path = clone_hf_repo_git2(config)?;
75
76 let lfs_pointers = scan_for_lfs_pointers(&repo_path, config.verbose)?;
78
79 if config.verbose {
80 println!("🔍 Found {} LFS pointer files", lfs_pointers.len());
81 }
82
83 download_lfs_content(&lfs_pointers, &config.model_id, config.verbose)?;
85
86 if !config.keep_git_dir {
88 let git_dir = repo_path.join(".git");
89 if git_dir.exists() {
90 if config.verbose {
91 println!("🗑️ Removing .git directory");
92 }
93 fs::remove_dir_all(&git_dir)
94 .map_err(|e| E::msg(format!("Failed to remove .git directory: {e}")))?;
95 }
96 }
97
98 if config.verbose {
99 println!("✅ Clean download completed successfully");
100 }
101
102 Ok(repo_path)
103}
104
105fn clone_hf_repo_git2(config: &CleanDownloadConfig) -> Result<PathBuf> {
107 let repo_url = format!("https://huggingface.co/{}", config.model_id);
108
109 if config.verbose {
110 println!("📥 Cloning repository: {repo_url}");
111 }
112
113 if config.target_dir.exists() {
115 if config.verbose {
116 println!("🗑️ Removing existing directory");
117 }
118 fs::remove_dir_all(&config.target_dir)
119 .map_err(|e| E::msg(format!("Failed to remove existing directory: {e}")))?;
120 }
121
122 if let Some(parent) = config.target_dir.parent() {
124 fs::create_dir_all(parent)
125 .map_err(|e| E::msg(format!("Failed to create parent directory: {e}")))?;
126 }
127
128 let mut builder = git2::build::RepoBuilder::new();
130
131 let mut fetch_options = git2::FetchOptions::new();
133 fetch_options.depth(1);
134 builder.fetch_options(fetch_options);
135
136 let _repo = builder
137 .clone(&repo_url, &config.target_dir)
138 .map_err(|e| E::msg(format!("Git clone failed: {e}")))?;
139
140 if config.verbose {
141 println!("✅ Repository cloned successfully");
142 }
143
144 Ok(config.target_dir.clone())
145}
146
147fn scan_for_lfs_pointers(repo_path: &Path, verbose: bool) -> Result<Vec<LfsPointer>> {
149 if verbose {
150 println!("🔍 Scanning for LFS pointer files...");
151 }
152
153 let mut lfs_pointers = Vec::new();
154 scan_directory_for_lfs(repo_path, repo_path, &mut lfs_pointers, verbose)?;
155
156 if verbose {
157 println!(" Found {} LFS pointer files", lfs_pointers.len());
158 for pointer in &lfs_pointers {
159 println!(
160 " • {} (size: {} bytes)",
161 pointer
162 .file_path
163 .strip_prefix(repo_path)
164 .unwrap_or(&pointer.file_path)
165 .display(),
166 pointer.size
167 );
168 }
169 }
170
171 Ok(lfs_pointers)
172}
173
174fn scan_directory_for_lfs(
176 dir: &Path,
177 repo_root: &Path,
178 lfs_pointers: &mut Vec<LfsPointer>,
179 _verbose: bool,
180) -> Result<()> {
181 let entries = fs::read_dir(dir)
182 .map_err(|e| E::msg(format!("Failed to read directory {}: {e}", dir.display())))?;
183
184 for entry in entries {
185 let entry = entry.map_err(|e| E::msg(format!("Failed to read directory entry: {e}")))?;
186 let path = entry.path();
187
188 if path.is_dir() {
189 if path.file_name() == Some(std::ffi::OsStr::new(".git")) {
191 continue;
192 }
193
194 scan_directory_for_lfs(&path, repo_root, lfs_pointers, _verbose)?;
196 } else if path.is_file() {
197 if let Ok(pointer) = check_lfs_pointer_file(&path, repo_root) {
199 lfs_pointers.push(pointer);
200 }
201 }
202 }
203
204 Ok(())
205}
206
207fn check_lfs_pointer_file(file_path: &Path, _repo_root: &Path) -> Result<LfsPointer> {
209 let mut file = fs::File::open(file_path)
211 .map_err(|e| E::msg(format!("Failed to open file {}: {e}", file_path.display())))?;
212
213 let mut buffer = [0; 1024];
214 let bytes_read = file
215 .read(&mut buffer)
216 .map_err(|e| E::msg(format!("Failed to read file {}: {e}", file_path.display())))?;
217
218 let content = std::str::from_utf8(&buffer[..bytes_read])
219 .map_err(|_| E::msg("File is not valid UTF-8"))?;
220
221 parse_lfs_pointer(content, file_path.to_path_buf())
223}
224
225fn parse_lfs_pointer(content: &str, file_path: PathBuf) -> Result<LfsPointer> {
227 let lines: Vec<&str> = content.lines().collect();
228
229 if lines.is_empty() {
230 return Err(E::msg("Empty file"));
231 }
232
233 if !lines[0].starts_with("version https://git-lfs.github.com/spec/v") {
235 return Err(E::msg("Not an LFS pointer file"));
236 }
237
238 let version = lines[0].to_string();
239 let mut oid = String::new();
240 let mut size = 0u64;
241
242 for line in &lines[1..] {
244 if let Some(stripped) = line.strip_prefix("oid sha256:") {
245 oid = stripped.to_string();
246 } else if let Some(stripped) = line.strip_prefix("size ") {
247 size = stripped
248 .parse()
249 .map_err(|e| E::msg(format!("Invalid size in LFS pointer: {e}")))?;
250 }
251 }
252
253 if oid.is_empty() || size == 0 {
254 return Err(E::msg("Invalid LFS pointer: missing oid or size"));
255 }
256
257 Ok(LfsPointer {
258 file_path,
259 version,
260 oid,
261 size,
262 })
263}
264
265fn download_lfs_content(lfs_pointers: &[LfsPointer], model_id: &str, verbose: bool) -> Result<()> {
267 if lfs_pointers.is_empty() {
268 if verbose {
269 println!("📄 No LFS files to download");
270 }
271 return Ok(());
272 }
273
274 if verbose {
275 println!(
276 "📥 Downloading {} LFS files using hf-hub...",
277 lfs_pointers.len()
278 );
279 }
280
281 let api = hf_hub::api::sync::Api::new()
283 .map_err(|e| E::msg(format!("Failed to create HF API: {e}")))?;
284 let repo = api.model(model_id.to_string());
285
286 for (i, pointer) in lfs_pointers.iter().enumerate() {
287 if verbose {
288 println!(
289 " 📥 [{}/{}] Downloading: {}",
290 i + 1,
291 lfs_pointers.len(),
292 pointer
293 .file_path
294 .file_name()
295 .unwrap_or_default()
296 .to_string_lossy()
297 );
298 }
299
300 let mut repo_root = pointer.file_path.parent();
303 while let Some(parent) = repo_root {
304 if parent.join(".git").exists() {
305 repo_root = Some(parent);
306 break;
307 }
308 repo_root = parent.parent();
309 }
310
311 let repo_root = repo_root.ok_or_else(|| E::msg("Cannot find repo root"))?;
312 let relative_path = pointer
313 .file_path
314 .strip_prefix(repo_root)
315 .map_err(|e| E::msg(format!("Failed to get relative path: {e}")))?;
316
317 let relative_path_str = relative_path.to_string_lossy();
318
319 match repo.get(&relative_path_str) {
321 Ok(downloaded_path) => {
322 fs::copy(&downloaded_path, &pointer.file_path)
324 .map_err(|e| E::msg(format!("Failed to replace pointer file: {e}")))?;
325
326 if verbose {
327 println!(
328 " ✅ Downloaded and replaced: {} ({} bytes)",
329 relative_path_str, pointer.size
330 );
331 }
332 }
333 Err(e) => {
334 if verbose {
335 println!(" ⚠️ Failed to download {relative_path_str}: {e}");
336 }
337 return Err(E::msg(format!(
338 "Failed to download LFS file {relative_path_str}: {e}"
339 )));
340 }
341 }
342 }
343
344 if verbose {
345 println!("✅ All LFS files downloaded successfully");
346 }
347
348 Ok(())
349}
350
351pub fn verify_download_completeness(
353 model_path: &Path,
354 expected_files: &[&str],
355 verbose: bool,
356) -> Result<()> {
357 if verbose {
358 println!("🔍 Verifying download completeness...");
359 }
360
361 for expected_file in expected_files {
362 let file_path = model_path.join(expected_file);
363 if !file_path.exists() {
364 return Err(E::msg(format!("Expected file not found: {expected_file}")));
365 }
366
367 if is_lfs_pointer_file(&file_path)? {
369 return Err(E::msg(format!(
370 "File {expected_file} is still an LFS pointer"
371 )));
372 }
373
374 if verbose {
375 let size = fs::metadata(&file_path)
376 .map_err(|e| E::msg(format!("Failed to get file metadata: {e}")))?
377 .len();
378 println!(" ✅ {expected_file} ({size} bytes)");
379 }
380 }
381
382 if verbose {
383 println!("✅ Download verification completed");
384 }
385
386 Ok(())
387}
388
389fn is_lfs_pointer_file(file_path: &Path) -> Result<bool> {
391 match check_lfs_pointer_file(file_path, file_path.parent().unwrap_or(file_path)) {
392 Ok(_) => Ok(true),
393 Err(_) => Ok(false),
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use std::env;
401
402 #[test]
403 fn test_lfs_pointer_parsing() {
404 let content = "version https://git-lfs.github.com/spec/v1\noid sha256:abc123\nsize 12345\n";
405 let result = parse_lfs_pointer(content, PathBuf::from("test.bin"));
406
407 assert!(result.is_ok());
408 let pointer = result.unwrap();
409 assert_eq!(pointer.oid, "abc123");
410 assert_eq!(pointer.size, 12345);
411 }
412
413 #[test]
414 fn test_invalid_lfs_pointer() {
415 let content = "This is not an LFS pointer file";
416 let result = parse_lfs_pointer(content, PathBuf::from("test.txt"));
417 assert!(result.is_err());
418 }
419
420 #[test]
421 fn test_config_creation() {
422 let temp_dir = env::temp_dir();
423 let config = CleanDownloadConfig::for_hf_model("test/model", &temp_dir);
424
425 assert_eq!(config.model_id, "test/model");
426 assert!(config
427 .target_dir
428 .to_string_lossy()
429 .contains("clean-test--model"));
430 assert!(!config.verbose);
431 assert!(!config.keep_git_dir);
432 }
433}