Skip to main content

mlua_pkg/
fetcher.rs

1//! Git-based package fetcher.
2//!
3//! The [`Fetcher`] trait abstracts over different fetch backends (git, luarocks,
4//! http).  [`GitFetcher`] implements the git backend using libgit2 via the
5//! [`git2`] crate.  No subprocess `git` invocations are used.
6//!
7//! # Cache layout
8//!
9//! ```text
10//! <cache_root>/git/<host>/<path…>/<sha>/
11//! ```
12//!
13//! For example, `https://github.com/ynishi/lshape` at SHA `abc123` becomes:
14//!
15//! ```text
16//! <cache_root>/git/github.com/ynishi/lshape/abc123/
17//! ```
18//!
19//! If the directory already exists the clone is skipped.
20//!
21//! # Authentication
22//!
23//! `GitFetcher` uses a `RemoteCallbacks`-based credential cascade:
24//!
25//! 1. SSH agent (`Cred::ssh_key_from_agent`) — tried only when the remote
26//!    advertises `SSH_KEY` in `allowed_types`.
27//! 2. Credential helper (`Cred::credential_helper`) — tried when `USER_PASS_PLAINTEXT`
28//!    is advertised.
29//! 3. `Cred::default()` — last resort.
30//!
31//! The callback tracks attempted credential types to avoid infinite retry loops.
32
33use std::{
34    path::{Component, PathBuf},
35    sync::atomic::{AtomicU64, AtomicU8, Ordering},
36};
37
38use git2::{build::RepoBuilder, CredentialType, FetchOptions, RemoteCallbacks, Repository};
39
40use crate::{
41    manifest::{Dep, Manifest},
42    PkgError,
43};
44
45/// Monotonic counter for unique temp-clone directory names within one process.
46static TMP_CTR: AtomicU64 = AtomicU64::new(0);
47
48// ── Fetcher trait ─────────────────────────────────────────────────────────────
49
50/// Abstraction over package fetch backends.
51///
52/// Implementations are *not* required to be [`Send`] or [`Sync`] — the MVP
53/// is single-threaded.
54pub trait Fetcher {
55    /// Fetch the package described by `dep` and return a [`FetchedPkg`].
56    ///
57    /// # Errors
58    ///
59    /// Returns [`PkgError`] on any git, I/O, or validation failure.
60    fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError>;
61}
62
63// ── FetchedPkg ────────────────────────────────────────────────────────────────
64
65/// Result of a successful [`Fetcher::fetch`] call.
66#[derive(Debug, Clone)]
67pub struct FetchedPkg {
68    /// Absolute path to the cloned repository on disk.
69    pub cache_path: PathBuf,
70
71    /// The resolved commit SHA (40-character hex string).
72    pub sha: String,
73
74    /// The parsed `mlua-pkg.toml` found at `cache_path`, if present.
75    pub manifest: Option<Manifest>,
76}
77
78// ── GitFetcher ────────────────────────────────────────────────────────────────
79
80/// [`Fetcher`] implementation backed by libgit2.
81pub struct GitFetcher {
82    /// Root directory under which all git caches are stored.
83    cache_root: PathBuf,
84}
85
86impl GitFetcher {
87    /// Create a new `GitFetcher` that stores clones under `cache_root`.
88    pub fn new(cache_root: PathBuf) -> Self {
89        Self { cache_root }
90    }
91
92    /// Compute the cache directory for the given URL and SHA.
93    ///
94    /// Layout: `<cache_root>/git/<host>/<path…>/<sha>/`
95    ///
96    /// # Errors
97    ///
98    /// Returns [`PkgError::Validation`] if the URL cannot be parsed or if any
99    /// URL-derived path component contains `..` (path traversal defence).
100    fn cache_dir(&self, url: &str, sha: &str) -> Result<PathBuf, PkgError> {
101        // Strip protocol prefix and trailing `.git`.
102        let stripped = url
103            .trim_start_matches("https://")
104            .trim_start_matches("http://")
105            .trim_start_matches("ssh://")
106            .trim_start_matches("git@")
107            .replace(':', "/") // git@github.com:user/repo → github.com/user/repo
108            .trim_end_matches(".git")
109            .to_owned();
110
111        if stripped.is_empty() {
112            return Err(PkgError::Validation {
113                message: format!("cannot derive cache path from URL: {url:?}"),
114            });
115        }
116
117        // Defend against path traversal in every component.
118        for component in stripped.split('/') {
119            if component == ".." || component == "." {
120                return Err(PkgError::Validation {
121                    message: format!(
122                        "URL {url:?} contains a path traversal component: {component:?}"
123                    ),
124                });
125            }
126        }
127
128        // Validate SHA is safe (hex chars only).
129        if sha.is_empty() || !sha.chars().all(|c| c.is_ascii_hexdigit()) {
130            return Err(PkgError::Validation {
131                message: format!("invalid SHA: {sha:?}"),
132            });
133        }
134
135        let mut path = self.cache_root.join("git");
136        for segment in stripped.split('/') {
137            if segment.is_empty() {
138                continue;
139            }
140            // Extra check: ensure no path component resolves to `..` via PathBuf.
141            let p = path.join(segment);
142            for c in p.components() {
143                if c == Component::ParentDir {
144                    return Err(PkgError::Validation {
145                        message: format!(
146                            "URL {url:?} resolves to a path with parent-dir traversal"
147                        ),
148                    });
149                }
150            }
151            path = p;
152        }
153        path = path.join(sha);
154        Ok(path)
155    }
156
157    /// Validate the URL structure before any network operation.
158    ///
159    /// Rejects URLs that contain `..` or `.` path components (path traversal
160    /// defence).  Called at the start of [`Fetcher::fetch`] so that invalid
161    /// URLs are rejected without making a network connection.
162    fn validate_url(url: &str) -> Result<(), PkgError> {
163        let stripped = url
164            .trim_start_matches("https://")
165            .trim_start_matches("http://")
166            .trim_start_matches("ssh://")
167            .trim_start_matches("git@")
168            .replace(':', "/")
169            .trim_end_matches(".git")
170            .to_owned();
171
172        if stripped.is_empty() {
173            return Err(PkgError::Validation {
174                message: format!("cannot derive cache path from URL: {url:?}"),
175            });
176        }
177
178        for component in stripped.split('/') {
179            if component == ".." || component == "." {
180                return Err(PkgError::Validation {
181                    message: format!(
182                        "URL {url:?} contains a path traversal component: {component:?}"
183                    ),
184                });
185            }
186        }
187        Ok(())
188    }
189
190    /// Return a unique temp-clone path inside `git_base` (same filesystem, so
191    /// `std::fs::rename` works atomically).
192    fn temp_clone_path(git_base: &std::path::Path) -> PathBuf {
193        let n = TMP_CTR.fetch_add(1, Ordering::Relaxed);
194        let pid = std::process::id();
195        git_base.join(format!(".fetch-{pid}-{n}"))
196    }
197
198    /// Resolve the git ref from `dep` to a full 40-char SHA.
199    ///
200    /// Resolution order:
201    /// 1. `rev` — treated as a revspec; the commit it peels to is returned.
202    /// 2. `tag` — resolved via `refs/tags/<tag>` (peeled to commit).
203    /// 3. `branch` — resolved via `refs/remotes/origin/<branch>` (peeled to commit).
204    /// 4. No ref — `HEAD` (latest commit on default branch).
205    ///
206    /// The repository must already have been fetched (all refs present).
207    fn resolve_sha(repo: &Repository, dep: &Dep) -> Result<String, PkgError> {
208        let oid = if let Some(rev) = &dep.rev {
209            repo.revparse_single(rev)?.peel_to_commit()?.id()
210        } else if let Some(tag) = &dep.tag {
211            let refname = format!("refs/tags/{tag}");
212            repo.find_reference(&refname)?.peel_to_commit()?.id()
213        } else if let Some(branch) = &dep.branch {
214            let refname = format!("refs/remotes/origin/{branch}");
215            repo.find_reference(&refname)?.peel_to_commit()?.id()
216        } else {
217            // Default: HEAD
218            repo.head()?.peel_to_commit()?.id()
219        };
220        Ok(oid.to_string())
221    }
222
223    /// Build `FetchOptions` with the credential cascade callback.
224    fn make_fetch_options() -> FetchOptions<'static> {
225        let mut callbacks = RemoteCallbacks::new();
226
227        // Track which credential types we've already tried to avoid infinite loops.
228        // Bits: 0 = SSH_KEY, 1 = USER_PASS_PLAINTEXT, 2 = DEFAULT
229        let tried = AtomicU8::new(0);
230
231        callbacks.credentials(move |_url, username, allowed| {
232            let tried_bits = tried.load(Ordering::Relaxed);
233
234            // 1. SSH agent
235            if allowed.contains(CredentialType::SSH_KEY) && (tried_bits & 0b001 == 0) {
236                tried.fetch_or(0b001, Ordering::Relaxed);
237                let user = username.unwrap_or("git");
238                return git2::Cred::ssh_key_from_agent(user);
239            }
240
241            // 2. Credential helper (HTTPS)
242            if allowed.contains(CredentialType::USER_PASS_PLAINTEXT) && (tried_bits & 0b010 == 0) {
243                tried.fetch_or(0b010, Ordering::Relaxed);
244                if let Ok(cfg) = git2::Config::open_default() {
245                    return git2::Cred::credential_helper(&cfg, _url, username);
246                }
247                // If we cannot open git config fall through to default cred.
248            }
249
250            // 3. Default
251            if tried_bits & 0b100 == 0 {
252                tried.fetch_or(0b100, Ordering::Relaxed);
253                return git2::Cred::default();
254            }
255
256            Err(git2::Error::from_str("all credential types exhausted"))
257        });
258
259        let mut fo = FetchOptions::new();
260        fo.remote_callbacks(callbacks);
261        fo
262    }
263}
264
265impl Fetcher for GitFetcher {
266    fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError> {
267        let url = &dep.git;
268
269        // Reject obviously malformed / traversal URLs before any I/O.
270        Self::validate_url(url)?;
271
272        // All git caches live under `<cache_root>/git/`.
273        let git_base = self.cache_root.join("git");
274        std::fs::create_dir_all(&git_base)?;
275
276        // Clone into a temp directory that lives on the same filesystem as the
277        // final cache location so that `std::fs::rename` is atomic.
278        let tmp_path = Self::temp_clone_path(&git_base);
279
280        // ── Clone ────────────────────────────────────────────────────────────
281        let fo = Self::make_fetch_options();
282        let repo = match RepoBuilder::new().fetch_options(fo).clone(url, &tmp_path) {
283            Ok(r) => r,
284            Err(e) => {
285                // Best-effort cleanup on error.
286                let _ = std::fs::remove_dir_all(&tmp_path);
287                return Err(e.into());
288            }
289        };
290
291        // ── Resolve SHA ──────────────────────────────────────────────────────
292        let sha = match Self::resolve_sha(&repo, dep) {
293            Ok(s) => s,
294            Err(e) => {
295                let _ = std::fs::remove_dir_all(&tmp_path);
296                return Err(e);
297            }
298        };
299
300        // ── Compute final cache path ─────────────────────────────────────────
301        let cache_path = match self.cache_dir(url, &sha) {
302            Ok(p) => p,
303            Err(e) => {
304                let _ = std::fs::remove_dir_all(&tmp_path);
305                return Err(e);
306            }
307        };
308
309        if cache_path.exists() {
310            // Already cached — discard the temp clone.
311            drop(repo);
312            let _ = std::fs::remove_dir_all(&tmp_path);
313        } else {
314            // Ensure the parent directory exists, then rename temp → final.
315            if let Some(parent) = cache_path.parent() {
316                std::fs::create_dir_all(parent)?;
317            }
318            drop(repo); // Release file handles before rename.
319            std::fs::rename(&tmp_path, &cache_path)?;
320        }
321
322        // ── Parse manifest if present ────────────────────────────────────────
323        let manifest_path = cache_path.join("mlua-pkg.toml");
324        let manifest = if manifest_path.exists() {
325            Some(Manifest::from_path(&manifest_path)?)
326        } else {
327            None
328        };
329
330        Ok(FetchedPkg {
331            cache_path,
332            sha,
333            manifest,
334        })
335    }
336}
337
338// ── Unit tests ────────────────────────────────────────────────────────────────
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343    use git2::{Repository, Signature};
344    use std::fs;
345    use tempfile::TempDir;
346
347    /// Create a minimal git repo in `dir` with one commit and return the SHA.
348    fn init_repo_with_commit(dir: &std::path::Path) -> String {
349        let repo = Repository::init(dir).unwrap();
350
351        // Configure identity for the test repo.
352        let mut config = repo.config().unwrap();
353        config.set_str("user.name", "Test").unwrap();
354        config.set_str("user.email", "test@example.com").unwrap();
355        drop(config);
356
357        // Create an initial file and commit.
358        let file_path = dir.join("README.md");
359        fs::write(&file_path, "# test\n").unwrap();
360
361        let mut index = repo.index().unwrap();
362        index.add_path(std::path::Path::new("README.md")).unwrap();
363        index.write().unwrap();
364
365        let tree_id = index.write_tree().unwrap();
366        let tree = repo.find_tree(tree_id).unwrap();
367        let sig = Signature::now("Test", "test@example.com").unwrap();
368        let oid = repo
369            .commit(Some("HEAD"), &sig, &sig, "initial commit", &tree, &[])
370            .unwrap();
371        oid.to_string()
372    }
373
374    /// Add an annotated tag to the HEAD commit of `repo`.
375    fn add_tag(repo: &Repository, tag_name: &str) -> String {
376        let head = repo.head().unwrap().peel_to_commit().unwrap();
377        let sig = Signature::now("Test", "test@example.com").unwrap();
378        repo.tag(tag_name, head.as_object(), &sig, tag_name, false)
379            .unwrap();
380        head.id().to_string()
381    }
382
383    // ── 1. clone a local file:// repo (happy path) ───────────────────────────
384
385    #[test]
386    fn clone_local_repo_happy_path() {
387        let src = TempDir::new().unwrap();
388        let sha = init_repo_with_commit(src.path());
389
390        let cache_root = TempDir::new().unwrap();
391        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
392
393        let url = format!("file://{}", src.path().display());
394        let dep = Dep {
395            git: url,
396            tag: None,
397            rev: None,
398            branch: None,
399            entry: None,
400        };
401
402        let result = fetcher.fetch(&dep).unwrap();
403        assert_eq!(result.sha, sha, "SHA should match the initial commit");
404        assert!(result.cache_path.exists(), "cache_path must exist on disk");
405        assert!(
406            result.manifest.is_none(),
407            "no mlua-pkg.toml in bare test repo"
408        );
409    }
410
411    // ── 2. resolve tag → SHA ──────────────────────────────────────────────────
412
413    #[test]
414    fn resolve_tag_sha() {
415        let src = TempDir::new().unwrap();
416        init_repo_with_commit(src.path());
417        let repo = Repository::open(src.path()).unwrap();
418        let expected_sha = add_tag(&repo, "v0.1.0");
419        drop(repo);
420
421        let cache_root = TempDir::new().unwrap();
422        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
423
424        let url = format!("file://{}", src.path().display());
425        let dep = Dep {
426            git: url,
427            tag: Some("v0.1.0".to_string()),
428            rev: None,
429            branch: None,
430            entry: None,
431        };
432
433        let result = fetcher.fetch(&dep).unwrap();
434        assert_eq!(result.sha, expected_sha, "tag must resolve to expected SHA");
435        assert!(result.cache_path.exists());
436    }
437
438    // ── 3. resolve rev → SHA ──────────────────────────────────────────────────
439
440    #[test]
441    fn resolve_rev_sha() {
442        let src = TempDir::new().unwrap();
443        let sha = init_repo_with_commit(src.path());
444
445        let cache_root = TempDir::new().unwrap();
446        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
447
448        let url = format!("file://{}", src.path().display());
449        let dep = Dep {
450            git: url,
451            rev: Some(sha.clone()),
452            tag: None,
453            branch: None,
454            entry: None,
455        };
456
457        let result = fetcher.fetch(&dep).unwrap();
458        assert_eq!(result.sha, sha, "rev should resolve to the given SHA");
459    }
460
461    // ── 4. nonexistent repo returns GitFetch error ────────────────────────────
462
463    #[test]
464    fn nonexistent_repo_returns_error() {
465        let cache_root = TempDir::new().unwrap();
466        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
467
468        let dep = Dep {
469            git: "file:///nonexistent/path/that/does/not/exist".to_string(),
470            tag: None,
471            rev: None,
472            branch: None,
473            entry: None,
474        };
475
476        let err = fetcher.fetch(&dep).unwrap_err();
477        assert!(
478            matches!(err, PkgError::GitFetch { .. }),
479            "expected GitFetch error, got: {err}"
480        );
481    }
482
483    // ── 5. second fetch of same repo uses cache (skip re-clone) ──────────────
484
485    #[test]
486    fn second_fetch_uses_cache() {
487        let src = TempDir::new().unwrap();
488        let sha = init_repo_with_commit(src.path());
489
490        let cache_root = TempDir::new().unwrap();
491        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
492
493        let url = format!("file://{}", src.path().display());
494        let dep = Dep {
495            git: url,
496            rev: Some(sha.clone()),
497            tag: None,
498            branch: None,
499            entry: None,
500        };
501
502        let first = fetcher.fetch(&dep).unwrap();
503        let second = fetcher.fetch(&dep).unwrap();
504
505        assert_eq!(
506            first.cache_path, second.cache_path,
507            "cache paths must be identical"
508        );
509        assert_eq!(first.sha, second.sha);
510    }
511
512    // ── 6. path traversal in URL is rejected ──────────────────────────────────
513
514    #[test]
515    fn path_traversal_in_url_is_rejected() {
516        let cache_root = TempDir::new().unwrap();
517        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
518
519        let dep = Dep {
520            git: "https://github.com/../../../etc/passwd".to_string(),
521            tag: None,
522            rev: None,
523            branch: None,
524            entry: None,
525        };
526
527        let err = fetcher.fetch(&dep).unwrap_err();
528        assert!(
529            matches!(err, PkgError::Validation { .. }),
530            "expected Validation error for path traversal, got: {err}"
531        );
532    }
533
534    // ── 7. manifest is parsed when mlua-pkg.toml is present ──────────────────
535
536    #[test]
537    fn manifest_parsed_when_present() {
538        let src = TempDir::new().unwrap();
539
540        // Write mlua-pkg.toml before the initial commit.
541        let toml_path = src.path().join("mlua-pkg.toml");
542        fs::write(
543            &toml_path,
544            r#"[package]
545name = "test-lib"
546version = "0.1.0"
547"#,
548        )
549        .unwrap();
550
551        let repo = Repository::init(src.path()).unwrap();
552        let mut config = repo.config().unwrap();
553        config.set_str("user.name", "Test").unwrap();
554        config.set_str("user.email", "test@example.com").unwrap();
555        drop(config);
556
557        let mut index = repo.index().unwrap();
558        index
559            .add_path(std::path::Path::new("mlua-pkg.toml"))
560            .unwrap();
561        index.write().unwrap();
562        let tree_id = index.write_tree().unwrap();
563        let tree = repo.find_tree(tree_id).unwrap();
564        let sig = Signature::now("Test", "test@example.com").unwrap();
565        repo.commit(Some("HEAD"), &sig, &sig, "add manifest", &tree, &[])
566            .unwrap();
567
568        let cache_root = TempDir::new().unwrap();
569        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
570
571        let url = format!("file://{}", src.path().display());
572        let dep = Dep {
573            git: url,
574            tag: None,
575            rev: None,
576            branch: None,
577            entry: None,
578        };
579
580        let result = fetcher.fetch(&dep).unwrap();
581        let manifest = result.manifest.expect("manifest should be parsed");
582        assert_eq!(manifest.package.name, "test-lib");
583        assert_eq!(manifest.package.version, "0.1.0");
584    }
585
586    // ── 8. cache_dir rejects SHA with non-hex chars ───────────────────────────
587
588    #[test]
589    fn cache_dir_rejects_invalid_sha() {
590        let cache_root = TempDir::new().unwrap();
591        let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
592
593        let err = fetcher
594            .cache_dir("https://github.com/x/y", "../evil")
595            .unwrap_err();
596        assert!(
597            matches!(err, PkgError::Validation { .. }),
598            "expected Validation error for invalid SHA, got: {err}"
599        );
600    }
601}