candle_coreml/download/
git_lfs.rs

1//! Clean Git2 + HF Hub LFS Downloader
2//!
3//! This module provides a robust approach to downloading HuggingFace models:
4//! 1. Use git2 to clone the repository (gets structure + LFS pointers)
5//! 2. Detect LFS pointer files in the cloned repo
6//! 3. Use hf-hub to download actual LFS file content
7//! 4. Replace pointer files with real content
8//!
9//! This eliminates the need for external git tools while properly handling LFS files.
10
11use anyhow::{Error as E, Result};
12use std::fs;
13use std::io::Read;
14use std::path::{Path, PathBuf};
15
16/// Configuration for the clean git+LFS downloader
17#[derive(Debug, Clone)]
18pub struct CleanDownloadConfig {
19    /// HuggingFace model ID (e.g., "microsoft/DialoGPT-medium")
20    pub model_id: String,
21    /// Target directory for the complete download
22    pub target_dir: PathBuf,
23    /// Whether to enable verbose logging
24    pub verbose: bool,
25    /// Whether to keep the .git directory after download
26    pub keep_git_dir: bool,
27}
28
29impl CleanDownloadConfig {
30    /// Create config for downloading a HF model to cache
31    pub fn for_hf_model(model_id: &str, cache_base: &Path) -> Self {
32        let model_cache_name = model_id.replace('/', "--");
33        let target_dir = cache_base.join(format!("clean-{model_cache_name}"));
34
35        Self {
36            model_id: model_id.to_string(),
37            target_dir,
38            verbose: false,
39            keep_git_dir: false,
40        }
41    }
42
43    /// Enable verbose logging
44    pub fn with_verbose(mut self, verbose: bool) -> Self {
45        self.verbose = verbose;
46        self
47    }
48
49    /// Keep .git directory after download
50    pub fn with_keep_git(mut self, keep_git: bool) -> Self {
51        self.keep_git_dir = keep_git;
52        self
53    }
54}
55
56/// LFS pointer file information
57#[derive(Debug, Clone)]
58pub struct LfsPointer {
59    pub file_path: PathBuf,
60    pub version: String,
61    pub oid: String,
62    pub size: u64,
63}
64
65/// Download a HuggingFace model using the clean git2 + LFS approach
66pub fn download_hf_model_clean(config: &CleanDownloadConfig) -> Result<PathBuf> {
67    if config.verbose {
68        println!("🚀 Starting clean git2+LFS download");
69        println!("📦 Model: {}", config.model_id);
70        println!("📁 Target: {}", config.target_dir.display());
71    }
72
73    // Step 1: Clone the git repository using git2
74    let repo_path = clone_hf_repo_git2(config)?;
75
76    // Step 2: Scan for LFS pointer files
77    let lfs_pointers = scan_for_lfs_pointers(&repo_path, config.verbose)?;
78
79    if config.verbose {
80        println!("🔍 Found {} LFS pointer files", lfs_pointers.len());
81    }
82
83    // Step 3: Download actual LFS content using hf-hub
84    download_lfs_content(&lfs_pointers, &config.model_id, config.verbose)?;
85
86    // Step 4: Cleanup .git directory if requested
87    if !config.keep_git_dir {
88        let git_dir = repo_path.join(".git");
89        if git_dir.exists() {
90            if config.verbose {
91                println!("🗑️  Removing .git directory");
92            }
93            fs::remove_dir_all(&git_dir)
94                .map_err(|e| E::msg(format!("Failed to remove .git directory: {e}")))?;
95        }
96    }
97
98    if config.verbose {
99        println!("✅ Clean download completed successfully");
100    }
101
102    Ok(repo_path)
103}
104
105/// Clone HuggingFace repository using git2
106fn clone_hf_repo_git2(config: &CleanDownloadConfig) -> Result<PathBuf> {
107    let repo_url = format!("https://huggingface.co/{}", config.model_id);
108
109    if config.verbose {
110        println!("📥 Cloning repository: {repo_url}");
111    }
112
113    // Remove existing directory if it exists
114    if config.target_dir.exists() {
115        if config.verbose {
116            println!("🗑️  Removing existing directory");
117        }
118        fs::remove_dir_all(&config.target_dir)
119            .map_err(|e| E::msg(format!("Failed to remove existing directory: {e}")))?;
120    }
121
122    // Create parent directory
123    if let Some(parent) = config.target_dir.parent() {
124        fs::create_dir_all(parent)
125            .map_err(|e| E::msg(format!("Failed to create parent directory: {e}")))?;
126    }
127
128    // Clone with git2
129    let mut builder = git2::build::RepoBuilder::new();
130
131    // Use shallow clone for efficiency
132    let mut fetch_options = git2::FetchOptions::new();
133    fetch_options.depth(1);
134    builder.fetch_options(fetch_options);
135
136    let _repo = builder
137        .clone(&repo_url, &config.target_dir)
138        .map_err(|e| E::msg(format!("Git clone failed: {e}")))?;
139
140    if config.verbose {
141        println!("✅ Repository cloned successfully");
142    }
143
144    Ok(config.target_dir.clone())
145}
146
147/// Scan the cloned repository for LFS pointer files
148fn scan_for_lfs_pointers(repo_path: &Path, verbose: bool) -> Result<Vec<LfsPointer>> {
149    if verbose {
150        println!("🔍 Scanning for LFS pointer files...");
151    }
152
153    let mut lfs_pointers = Vec::new();
154    scan_directory_for_lfs(repo_path, repo_path, &mut lfs_pointers, verbose)?;
155
156    if verbose {
157        println!("  Found {} LFS pointer files", lfs_pointers.len());
158        for pointer in &lfs_pointers {
159            println!(
160                "    • {} (size: {} bytes)",
161                pointer
162                    .file_path
163                    .strip_prefix(repo_path)
164                    .unwrap_or(&pointer.file_path)
165                    .display(),
166                pointer.size
167            );
168        }
169    }
170
171    Ok(lfs_pointers)
172}
173
174/// Recursively scan a directory for LFS pointer files
175fn scan_directory_for_lfs(
176    dir: &Path,
177    repo_root: &Path,
178    lfs_pointers: &mut Vec<LfsPointer>,
179    _verbose: bool,
180) -> Result<()> {
181    let entries = fs::read_dir(dir)
182        .map_err(|e| E::msg(format!("Failed to read directory {}: {e}", dir.display())))?;
183
184    for entry in entries {
185        let entry = entry.map_err(|e| E::msg(format!("Failed to read directory entry: {e}")))?;
186        let path = entry.path();
187
188        if path.is_dir() {
189            // Skip .git directory
190            if path.file_name() == Some(std::ffi::OsStr::new(".git")) {
191                continue;
192            }
193
194            // Recursively scan subdirectories
195            scan_directory_for_lfs(&path, repo_root, lfs_pointers, _verbose)?;
196        } else if path.is_file() {
197            // Check if this file is an LFS pointer
198            if let Ok(pointer) = check_lfs_pointer_file(&path, repo_root) {
199                lfs_pointers.push(pointer);
200            }
201        }
202    }
203
204    Ok(())
205}
206
207/// Check if a file is an LFS pointer file
208fn check_lfs_pointer_file(file_path: &Path, _repo_root: &Path) -> Result<LfsPointer> {
209    // Read first 1024 bytes (LFS pointers must be < 1024 bytes)
210    let mut file = fs::File::open(file_path)
211        .map_err(|e| E::msg(format!("Failed to open file {}: {e}", file_path.display())))?;
212
213    let mut buffer = [0; 1024];
214    let bytes_read = file
215        .read(&mut buffer)
216        .map_err(|e| E::msg(format!("Failed to read file {}: {e}", file_path.display())))?;
217
218    let content = std::str::from_utf8(&buffer[..bytes_read])
219        .map_err(|_| E::msg("File is not valid UTF-8"))?;
220
221    // Parse LFS pointer format
222    parse_lfs_pointer(content, file_path.to_path_buf())
223}
224
225/// Parse LFS pointer file content
226fn parse_lfs_pointer(content: &str, file_path: PathBuf) -> Result<LfsPointer> {
227    let lines: Vec<&str> = content.lines().collect();
228
229    if lines.is_empty() {
230        return Err(E::msg("Empty file"));
231    }
232
233    // First line must be version
234    if !lines[0].starts_with("version https://git-lfs.github.com/spec/v") {
235        return Err(E::msg("Not an LFS pointer file"));
236    }
237
238    let version = lines[0].to_string();
239    let mut oid = String::new();
240    let mut size = 0u64;
241
242    // Parse remaining lines
243    for line in &lines[1..] {
244        if let Some(stripped) = line.strip_prefix("oid sha256:") {
245            oid = stripped.to_string();
246        } else if let Some(stripped) = line.strip_prefix("size ") {
247            size = stripped
248                .parse()
249                .map_err(|e| E::msg(format!("Invalid size in LFS pointer: {e}")))?;
250        }
251    }
252
253    if oid.is_empty() || size == 0 {
254        return Err(E::msg("Invalid LFS pointer: missing oid or size"));
255    }
256
257    Ok(LfsPointer {
258        file_path,
259        version,
260        oid,
261        size,
262    })
263}
264
265/// Download actual LFS content using hf-hub and replace pointer files
266fn download_lfs_content(lfs_pointers: &[LfsPointer], model_id: &str, verbose: bool) -> Result<()> {
267    if lfs_pointers.is_empty() {
268        if verbose {
269            println!("📄 No LFS files to download");
270        }
271        return Ok(());
272    }
273
274    if verbose {
275        println!(
276            "📥 Downloading {} LFS files using hf-hub...",
277            lfs_pointers.len()
278        );
279    }
280
281    // Setup HuggingFace API
282    let api = hf_hub::api::sync::Api::new()
283        .map_err(|e| E::msg(format!("Failed to create HF API: {e}")))?;
284    let repo = api.model(model_id.to_string());
285
286    for (i, pointer) in lfs_pointers.iter().enumerate() {
287        if verbose {
288            println!(
289                "  📥 [{}/{}] Downloading: {}",
290                i + 1,
291                lfs_pointers.len(),
292                pointer
293                    .file_path
294                    .file_name()
295                    .unwrap_or_default()
296                    .to_string_lossy()
297            );
298        }
299
300        // Get relative path from repo root for hf-hub API
301        // We need to find the repo root by looking for .git directory
302        let mut repo_root = pointer.file_path.parent();
303        while let Some(parent) = repo_root {
304            if parent.join(".git").exists() {
305                repo_root = Some(parent);
306                break;
307            }
308            repo_root = parent.parent();
309        }
310
311        let repo_root = repo_root.ok_or_else(|| E::msg("Cannot find repo root"))?;
312        let relative_path = pointer
313            .file_path
314            .strip_prefix(repo_root)
315            .map_err(|e| E::msg(format!("Failed to get relative path: {e}")))?;
316
317        let relative_path_str = relative_path.to_string_lossy();
318
319        // Download the actual file content using hf-hub
320        match repo.get(&relative_path_str) {
321            Ok(downloaded_path) => {
322                // Copy the downloaded content over the pointer file
323                fs::copy(&downloaded_path, &pointer.file_path)
324                    .map_err(|e| E::msg(format!("Failed to replace pointer file: {e}")))?;
325
326                if verbose {
327                    println!(
328                        "    ✅ Downloaded and replaced: {} ({} bytes)",
329                        relative_path_str, pointer.size
330                    );
331                }
332            }
333            Err(e) => {
334                if verbose {
335                    println!("    ⚠️  Failed to download {relative_path_str}: {e}");
336                }
337                return Err(E::msg(format!(
338                    "Failed to download LFS file {relative_path_str}: {e}"
339                )));
340            }
341        }
342    }
343
344    if verbose {
345        println!("✅ All LFS files downloaded successfully");
346    }
347
348    Ok(())
349}
350
351/// Verify that the downloaded model is complete
352pub fn verify_download_completeness(
353    model_path: &Path,
354    expected_files: &[&str],
355    verbose: bool,
356) -> Result<()> {
357    if verbose {
358        println!("🔍 Verifying download completeness...");
359    }
360
361    for expected_file in expected_files {
362        let file_path = model_path.join(expected_file);
363        if !file_path.exists() {
364            return Err(E::msg(format!("Expected file not found: {expected_file}")));
365        }
366
367        // Check that it's not still an LFS pointer
368        if is_lfs_pointer_file(&file_path)? {
369            return Err(E::msg(format!(
370                "File {expected_file} is still an LFS pointer"
371            )));
372        }
373
374        if verbose {
375            let size = fs::metadata(&file_path)
376                .map_err(|e| E::msg(format!("Failed to get file metadata: {e}")))?
377                .len();
378            println!("  ✅ {expected_file} ({size} bytes)");
379        }
380    }
381
382    if verbose {
383        println!("✅ Download verification completed");
384    }
385
386    Ok(())
387}
388
389/// Check if a file is still an LFS pointer file
390fn is_lfs_pointer_file(file_path: &Path) -> Result<bool> {
391    match check_lfs_pointer_file(file_path, file_path.parent().unwrap_or(file_path)) {
392        Ok(_) => Ok(true),
393        Err(_) => Ok(false),
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use std::env;
401
402    #[test]
403    fn test_lfs_pointer_parsing() {
404        let content = "version https://git-lfs.github.com/spec/v1\noid sha256:abc123\nsize 12345\n";
405        let result = parse_lfs_pointer(content, PathBuf::from("test.bin"));
406
407        assert!(result.is_ok());
408        let pointer = result.unwrap();
409        assert_eq!(pointer.oid, "abc123");
410        assert_eq!(pointer.size, 12345);
411    }
412
413    #[test]
414    fn test_invalid_lfs_pointer() {
415        let content = "This is not an LFS pointer file";
416        let result = parse_lfs_pointer(content, PathBuf::from("test.txt"));
417        assert!(result.is_err());
418    }
419
420    #[test]
421    fn test_config_creation() {
422        let temp_dir = env::temp_dir();
423        let config = CleanDownloadConfig::for_hf_model("test/model", &temp_dir);
424
425        assert_eq!(config.model_id, "test/model");
426        assert!(config
427            .target_dir
428            .to_string_lossy()
429            .contains("clean-test--model"));
430        assert!(!config.verbose);
431        assert!(!config.keep_git_dir);
432    }
433}