use std::{
path::{Component, PathBuf},
sync::atomic::{AtomicU64, AtomicU8, Ordering},
};
use git2::{build::RepoBuilder, CredentialType, FetchOptions, RemoteCallbacks, Repository};
use crate::{
manifest::{Dep, Manifest},
PkgError,
};
static TMP_CTR: AtomicU64 = AtomicU64::new(0);
pub trait Fetcher {
fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError>;
}
#[derive(Debug, Clone)]
pub struct FetchedPkg {
pub cache_path: PathBuf,
pub sha: String,
pub manifest: Option<Manifest>,
pub resolved_tag: Option<String>,
}
pub struct GitFetcher {
cache_root: PathBuf,
}
impl GitFetcher {
pub fn new(cache_root: PathBuf) -> Self {
Self { cache_root }
}
fn cache_dir(&self, url: &str, sha: &str) -> Result<PathBuf, PkgError> {
let stripped = url
.trim_start_matches("https://")
.trim_start_matches("http://")
.trim_start_matches("ssh://")
.trim_start_matches("git@")
.replace(':', "/") .trim_end_matches(".git")
.to_owned();
if stripped.is_empty() {
return Err(PkgError::Validation {
message: format!("cannot derive cache path from URL: {url:?}"),
});
}
for component in stripped.split('/') {
if component == ".." || component == "." {
return Err(PkgError::Validation {
message: format!(
"URL {url:?} contains a path traversal component: {component:?}"
),
});
}
}
if sha.is_empty() || !sha.chars().all(|c| c.is_ascii_hexdigit()) {
return Err(PkgError::Validation {
message: format!("invalid SHA: {sha:?}"),
});
}
let mut path = self.cache_root.join("git");
for segment in stripped.split('/') {
if segment.is_empty() {
continue;
}
let p = path.join(segment);
for c in p.components() {
if c == Component::ParentDir {
return Err(PkgError::Validation {
message: format!(
"URL {url:?} resolves to a path with parent-dir traversal"
),
});
}
}
path = p;
}
path = path.join(sha);
Ok(path)
}
fn validate_url(url: &str) -> Result<(), PkgError> {
let stripped = url
.trim_start_matches("https://")
.trim_start_matches("http://")
.trim_start_matches("ssh://")
.trim_start_matches("git@")
.replace(':', "/")
.trim_end_matches(".git")
.to_owned();
if stripped.is_empty() {
return Err(PkgError::Validation {
message: format!("cannot derive cache path from URL: {url:?}"),
});
}
for component in stripped.split('/') {
if component == ".." || component == "." {
return Err(PkgError::Validation {
message: format!(
"URL {url:?} contains a path traversal component: {component:?}"
),
});
}
}
Ok(())
}
fn temp_clone_path(git_base: &std::path::Path) -> PathBuf {
let n = TMP_CTR.fetch_add(1, Ordering::Relaxed);
let pid = std::process::id();
git_base.join(format!(".fetch-{pid}-{n}"))
}
fn resolve_ref(repo: &Repository, dep: &Dep) -> Result<(String, Option<String>), PkgError> {
if let Some(rev) = &dep.rev {
let oid = repo.revparse_single(rev)?.peel_to_commit()?.id();
return Ok((oid.to_string(), None));
}
if let Some(tag) = &dep.tag {
let resolved = Self::resolve_tag_pin(repo, tag)?;
let refname = format!("refs/tags/{resolved}");
let oid = repo.find_reference(&refname)?.peel_to_commit()?.id();
return Ok((oid.to_string(), Some(resolved)));
}
if let Some(branch) = &dep.branch {
let refname = format!("refs/remotes/origin/{branch}");
let oid = repo.find_reference(&refname)?.peel_to_commit()?.id();
return Ok((oid.to_string(), None));
}
let oid = repo.head()?.peel_to_commit()?.id();
Ok((oid.to_string(), None))
}
fn resolve_tag_pin(repo: &Repository, tag: &str) -> Result<String, PkgError> {
use crate::version::{classify_tag_pin, pick_latest_for_pin, TagPin};
let pin = classify_tag_pin(tag);
let prefix = match pin {
Some(TagPin::Prefix(p)) => p,
_ => return Ok(tag.to_string()),
};
let tag_names = repo.tag_names(None)?;
let local_tags: Vec<String> = tag_names
.iter()
.filter_map(|t| t.map(|s| s.to_string()))
.collect();
pick_latest_for_pin(&local_tags, &prefix).ok_or_else(|| PkgError::Validation {
message: format!("tag prefix '{tag}' has no matching SemVer release on remote"),
})
}
fn checkout_sha(repo: &Repository, sha: &str) -> Result<(), PkgError> {
let oid = git2::Oid::from_str(sha).map_err(|e| PkgError::Validation {
message: format!("invalid SHA {sha}: {e}"),
})?;
let obj = repo.find_object(oid, None)?;
repo.reset(&obj, git2::ResetType::Hard, None)?;
Ok(())
}
pub fn list_tags(&self, url: &str) -> Result<Vec<String>, PkgError> {
Self::validate_url(url)?;
let scratch = self.cache_root.join("git").join(".ls-remote");
std::fs::create_dir_all(&scratch)?;
let tmp = Self::temp_clone_path(&scratch);
std::fs::create_dir_all(&tmp)?;
let result = Self::list_tags_inner(&tmp, url);
let _ = std::fs::remove_dir_all(&tmp);
result
}
fn list_tags_inner(tmp: &std::path::Path, url: &str) -> Result<Vec<String>, PkgError> {
let repo = Repository::init(tmp)?;
let mut remote = repo.remote_anonymous(url)?;
remote.connect_auth(
git2::Direction::Fetch,
Some(Self::make_credentials_callbacks()),
None,
)?;
let refs = remote.list()?;
let mut tags: Vec<String> = Vec::new();
for head in refs.iter() {
if let Some(name) = head.name().strip_prefix("refs/tags/") {
let trimmed = name.trim_end_matches("^{}");
let s = trimmed.to_string();
if !tags.contains(&s) {
tags.push(s);
}
}
}
let _ = remote.disconnect();
Ok(tags)
}
fn make_credentials_callbacks() -> RemoteCallbacks<'static> {
let mut callbacks = RemoteCallbacks::new();
let tried = AtomicU8::new(0);
callbacks.credentials(move |_url, username, allowed| {
let tried_bits = tried.load(Ordering::Relaxed);
if allowed.contains(CredentialType::SSH_KEY) && (tried_bits & 0b001 == 0) {
tried.fetch_or(0b001, Ordering::Relaxed);
let user = username.unwrap_or("git");
return git2::Cred::ssh_key_from_agent(user);
}
if allowed.contains(CredentialType::USER_PASS_PLAINTEXT) && (tried_bits & 0b010 == 0) {
tried.fetch_or(0b010, Ordering::Relaxed);
if let Ok(cfg) = git2::Config::open_default() {
return git2::Cred::credential_helper(&cfg, _url, username);
}
}
if tried_bits & 0b100 == 0 {
tried.fetch_or(0b100, Ordering::Relaxed);
return git2::Cred::default();
}
Err(git2::Error::from_str("all credential types exhausted"))
});
callbacks
}
fn make_fetch_options() -> FetchOptions<'static> {
let mut fo = FetchOptions::new();
fo.remote_callbacks(Self::make_credentials_callbacks());
fo
}
}
impl Fetcher for GitFetcher {
fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError> {
let url = &dep.git;
Self::validate_url(url)?;
let git_base = self.cache_root.join("git");
std::fs::create_dir_all(&git_base)?;
let tmp_path = Self::temp_clone_path(&git_base);
let fo = Self::make_fetch_options();
let repo = match RepoBuilder::new().fetch_options(fo).clone(url, &tmp_path) {
Ok(r) => r,
Err(e) => {
let _ = std::fs::remove_dir_all(&tmp_path);
return Err(e.into());
}
};
let (sha, resolved_tag) = match Self::resolve_ref(&repo, dep) {
Ok(s) => s,
Err(e) => {
let _ = std::fs::remove_dir_all(&tmp_path);
return Err(e);
}
};
if let Err(e) = Self::checkout_sha(&repo, &sha) {
let _ = std::fs::remove_dir_all(&tmp_path);
return Err(e);
}
let cache_path = match self.cache_dir(url, &sha) {
Ok(p) => p,
Err(e) => {
let _ = std::fs::remove_dir_all(&tmp_path);
return Err(e);
}
};
if cache_path.exists() {
drop(repo);
let _ = std::fs::remove_dir_all(&tmp_path);
} else {
if let Some(parent) = cache_path.parent() {
std::fs::create_dir_all(parent)?;
}
drop(repo); std::fs::rename(&tmp_path, &cache_path)?;
}
let manifest_path = cache_path.join("mlua-pkg.toml");
let manifest = if manifest_path.exists() {
Some(Manifest::from_path(&manifest_path)?)
} else {
None
};
Ok(FetchedPkg {
cache_path,
sha,
manifest,
resolved_tag,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use git2::{Repository, Signature};
use std::fs;
use tempfile::TempDir;
fn init_repo_with_commit(dir: &std::path::Path) -> String {
let repo = Repository::init(dir).unwrap();
let mut config = repo.config().unwrap();
config.set_str("user.name", "Test").unwrap();
config.set_str("user.email", "test@example.com").unwrap();
drop(config);
let file_path = dir.join("README.md");
fs::write(&file_path, "# test\n").unwrap();
let mut index = repo.index().unwrap();
index.add_path(std::path::Path::new("README.md")).unwrap();
index.write().unwrap();
let tree_id = index.write_tree().unwrap();
let tree = repo.find_tree(tree_id).unwrap();
let sig = Signature::now("Test", "test@example.com").unwrap();
let oid = repo
.commit(Some("HEAD"), &sig, &sig, "initial commit", &tree, &[])
.unwrap();
oid.to_string()
}
fn add_tag(repo: &Repository, tag_name: &str) -> String {
let head = repo.head().unwrap().peel_to_commit().unwrap();
let sig = Signature::now("Test", "test@example.com").unwrap();
repo.tag(tag_name, head.as_object(), &sig, tag_name, false)
.unwrap();
head.id().to_string()
}
#[test]
fn clone_local_repo_happy_path() {
let src = TempDir::new().unwrap();
let sha = init_repo_with_commit(src.path());
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let url = format!("file://{}", src.path().display());
let dep = Dep {
git: url,
tag: None,
rev: None,
branch: None,
entry: None,
target_dir: None,
};
let result = fetcher.fetch(&dep).unwrap();
assert_eq!(result.sha, sha, "SHA should match the initial commit");
assert!(result.cache_path.exists(), "cache_path must exist on disk");
assert!(
result.manifest.is_none(),
"no mlua-pkg.toml in bare test repo"
);
}
#[test]
fn resolve_tag_sha() {
let src = TempDir::new().unwrap();
init_repo_with_commit(src.path());
let repo = Repository::open(src.path()).unwrap();
let expected_sha = add_tag(&repo, "v0.1.0");
drop(repo);
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let url = format!("file://{}", src.path().display());
let dep = Dep {
git: url,
tag: Some("v0.1.0".to_string()),
rev: None,
branch: None,
entry: None,
target_dir: None,
};
let result = fetcher.fetch(&dep).unwrap();
assert_eq!(result.sha, expected_sha, "tag must resolve to expected SHA");
assert!(result.cache_path.exists());
}
#[test]
fn resolve_rev_sha() {
let src = TempDir::new().unwrap();
let sha = init_repo_with_commit(src.path());
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let url = format!("file://{}", src.path().display());
let dep = Dep {
git: url,
rev: Some(sha.clone()),
tag: None,
branch: None,
entry: None,
target_dir: None,
};
let result = fetcher.fetch(&dep).unwrap();
assert_eq!(result.sha, sha, "rev should resolve to the given SHA");
}
#[test]
fn nonexistent_repo_returns_error() {
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let dep = Dep {
git: "file:///nonexistent/path/that/does/not/exist".to_string(),
tag: None,
rev: None,
branch: None,
entry: None,
target_dir: None,
};
let err = fetcher.fetch(&dep).unwrap_err();
assert!(
matches!(err, PkgError::GitFetch { .. }),
"expected GitFetch error, got: {err}"
);
}
#[test]
fn second_fetch_uses_cache() {
let src = TempDir::new().unwrap();
let sha = init_repo_with_commit(src.path());
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let url = format!("file://{}", src.path().display());
let dep = Dep {
git: url,
rev: Some(sha.clone()),
tag: None,
branch: None,
entry: None,
target_dir: None,
};
let first = fetcher.fetch(&dep).unwrap();
let second = fetcher.fetch(&dep).unwrap();
assert_eq!(
first.cache_path, second.cache_path,
"cache paths must be identical"
);
assert_eq!(first.sha, second.sha);
}
#[test]
fn path_traversal_in_url_is_rejected() {
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let dep = Dep {
git: "https://github.com/../../../etc/passwd".to_string(),
tag: None,
rev: None,
branch: None,
entry: None,
target_dir: None,
};
let err = fetcher.fetch(&dep).unwrap_err();
assert!(
matches!(err, PkgError::Validation { .. }),
"expected Validation error for path traversal, got: {err}"
);
}
#[test]
fn manifest_parsed_when_present() {
let src = TempDir::new().unwrap();
let toml_path = src.path().join("mlua-pkg.toml");
fs::write(
&toml_path,
r#"[package]
name = "test-lib"
version = "0.1.0"
"#,
)
.unwrap();
let repo = Repository::init(src.path()).unwrap();
let mut config = repo.config().unwrap();
config.set_str("user.name", "Test").unwrap();
config.set_str("user.email", "test@example.com").unwrap();
drop(config);
let mut index = repo.index().unwrap();
index
.add_path(std::path::Path::new("mlua-pkg.toml"))
.unwrap();
index.write().unwrap();
let tree_id = index.write_tree().unwrap();
let tree = repo.find_tree(tree_id).unwrap();
let sig = Signature::now("Test", "test@example.com").unwrap();
repo.commit(Some("HEAD"), &sig, &sig, "add manifest", &tree, &[])
.unwrap();
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let url = format!("file://{}", src.path().display());
let dep = Dep {
git: url,
tag: None,
rev: None,
branch: None,
entry: None,
target_dir: None,
};
let result = fetcher.fetch(&dep).unwrap();
let manifest = result.manifest.expect("manifest should be parsed");
assert_eq!(manifest.package.name, "test-lib");
assert_eq!(manifest.package.version, "0.1.0");
}
#[test]
fn fetched_worktree_matches_resolved_tag_not_head() {
let src = TempDir::new().unwrap();
let repo = Repository::init(src.path()).unwrap();
let mut config = repo.config().unwrap();
config.set_str("user.name", "Test").unwrap();
config.set_str("user.email", "test@example.com").unwrap();
drop(config);
let sig = Signature::now("Test", "test@example.com").unwrap();
fs::write(src.path().join("VERSION"), "0.1.0").unwrap();
let mut index = repo.index().unwrap();
index.add_path(std::path::Path::new("VERSION")).unwrap();
index.write().unwrap();
let tree_id = index.write_tree().unwrap();
let tree = repo.find_tree(tree_id).unwrap();
let c1 = repo
.commit(Some("HEAD"), &sig, &sig, "v0.1.0", &tree, &[])
.unwrap();
let c1_obj = repo.find_object(c1, None).unwrap();
repo.tag("v0.1.0", &c1_obj, &sig, "v0.1.0", false).unwrap();
let v010_sha = c1.to_string();
fs::write(src.path().join("VERSION"), "0.2.0").unwrap();
let mut index = repo.index().unwrap();
index.add_path(std::path::Path::new("VERSION")).unwrap();
index.write().unwrap();
let tree_id = index.write_tree().unwrap();
let tree = repo.find_tree(tree_id).unwrap();
let parent = repo.find_commit(c1).unwrap();
let c2 = repo
.commit(Some("HEAD"), &sig, &sig, "v0.2.0", &tree, &[&parent])
.unwrap();
assert_ne!(c1, c2, "HEAD must have advanced past the tag");
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let dep = Dep {
git: format!("file://{}", src.path().display()),
tag: Some("v0.1.0".to_string()),
rev: None,
branch: None,
entry: None,
target_dir: None,
};
let fetched = fetcher.fetch(&dep).unwrap();
assert_eq!(fetched.sha, v010_sha, "SHA must resolve to tag commit");
let version = fs::read_to_string(fetched.cache_path.join("VERSION")).unwrap();
assert_eq!(
version, "0.1.0",
"fetched worktree must contain tag v0.1.0 content, got HEAD content instead"
);
}
#[test]
fn cache_dir_rejects_invalid_sha() {
let cache_root = TempDir::new().unwrap();
let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
let err = fetcher
.cache_dir("https://github.com/x/y", "../evil")
.unwrap_err();
assert!(
matches!(err, PkgError::Validation { .. }),
"expected Validation error for invalid SHA, got: {err}"
);
}
}