1use std::{
34 path::{Component, PathBuf},
35 sync::atomic::{AtomicU64, AtomicU8, Ordering},
36};
37
38use git2::{build::RepoBuilder, CredentialType, FetchOptions, RemoteCallbacks, Repository};
39
40use crate::{
41 manifest::{Dep, Manifest},
42 PkgError,
43};
44
45static TMP_CTR: AtomicU64 = AtomicU64::new(0);
47
48pub trait Fetcher {
55 fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError>;
61}
62
63#[derive(Debug, Clone)]
67pub struct FetchedPkg {
68 pub cache_path: PathBuf,
70
71 pub sha: String,
73
74 pub manifest: Option<Manifest>,
76
77 pub resolved_tag: Option<String>,
82}
83
84pub struct GitFetcher {
88 cache_root: PathBuf,
90}
91
92impl GitFetcher {
93 pub fn new(cache_root: PathBuf) -> Self {
95 Self { cache_root }
96 }
97
98 fn cache_dir(&self, url: &str, sha: &str) -> Result<PathBuf, PkgError> {
107 let stripped = url
109 .trim_start_matches("https://")
110 .trim_start_matches("http://")
111 .trim_start_matches("ssh://")
112 .trim_start_matches("git@")
113 .replace(':', "/") .trim_end_matches(".git")
115 .to_owned();
116
117 if stripped.is_empty() {
118 return Err(PkgError::Validation {
119 message: format!("cannot derive cache path from URL: {url:?}"),
120 });
121 }
122
123 for component in stripped.split('/') {
125 if component == ".." || component == "." {
126 return Err(PkgError::Validation {
127 message: format!(
128 "URL {url:?} contains a path traversal component: {component:?}"
129 ),
130 });
131 }
132 }
133
134 if sha.is_empty() || !sha.chars().all(|c| c.is_ascii_hexdigit()) {
136 return Err(PkgError::Validation {
137 message: format!("invalid SHA: {sha:?}"),
138 });
139 }
140
141 let mut path = self.cache_root.join("git");
142 for segment in stripped.split('/') {
143 if segment.is_empty() {
144 continue;
145 }
146 let p = path.join(segment);
148 for c in p.components() {
149 if c == Component::ParentDir {
150 return Err(PkgError::Validation {
151 message: format!(
152 "URL {url:?} resolves to a path with parent-dir traversal"
153 ),
154 });
155 }
156 }
157 path = p;
158 }
159 path = path.join(sha);
160 Ok(path)
161 }
162
163 fn validate_url(url: &str) -> Result<(), PkgError> {
169 let stripped = url
170 .trim_start_matches("https://")
171 .trim_start_matches("http://")
172 .trim_start_matches("ssh://")
173 .trim_start_matches("git@")
174 .replace(':', "/")
175 .trim_end_matches(".git")
176 .to_owned();
177
178 if stripped.is_empty() {
179 return Err(PkgError::Validation {
180 message: format!("cannot derive cache path from URL: {url:?}"),
181 });
182 }
183
184 for component in stripped.split('/') {
185 if component == ".." || component == "." {
186 return Err(PkgError::Validation {
187 message: format!(
188 "URL {url:?} contains a path traversal component: {component:?}"
189 ),
190 });
191 }
192 }
193 Ok(())
194 }
195
196 fn temp_clone_path(git_base: &std::path::Path) -> PathBuf {
199 let n = TMP_CTR.fetch_add(1, Ordering::Relaxed);
200 let pid = std::process::id();
201 git_base.join(format!(".fetch-{pid}-{n}"))
202 }
203
204 fn resolve_ref(repo: &Repository, dep: &Dep) -> Result<(String, Option<String>), PkgError> {
217 if let Some(rev) = &dep.rev {
218 let oid = repo.revparse_single(rev)?.peel_to_commit()?.id();
219 return Ok((oid.to_string(), None));
220 }
221 if let Some(tag) = &dep.tag {
222 let resolved = Self::resolve_tag_pin(repo, tag)?;
223 let refname = format!("refs/tags/{resolved}");
224 let oid = repo.find_reference(&refname)?.peel_to_commit()?.id();
225 return Ok((oid.to_string(), Some(resolved)));
226 }
227 if let Some(branch) = &dep.branch {
228 let refname = format!("refs/remotes/origin/{branch}");
229 let oid = repo.find_reference(&refname)?.peel_to_commit()?.id();
230 return Ok((oid.to_string(), None));
231 }
232 let oid = repo.head()?.peel_to_commit()?.id();
234 Ok((oid.to_string(), None))
235 }
236
237 fn resolve_tag_pin(repo: &Repository, tag: &str) -> Result<String, PkgError> {
247 use crate::version::{classify_tag_pin, pick_latest_for_pin, TagPin};
248
249 let pin = classify_tag_pin(tag);
250 let prefix = match pin {
251 Some(TagPin::Prefix(p)) => p,
252 _ => return Ok(tag.to_string()),
254 };
255
256 let tag_names = repo.tag_names(None)?;
257 let local_tags: Vec<String> = tag_names
258 .iter()
259 .filter_map(|t| t.map(|s| s.to_string()))
260 .collect();
261 pick_latest_for_pin(&local_tags, &prefix).ok_or_else(|| PkgError::Validation {
262 message: format!("tag prefix '{tag}' has no matching SemVer release on remote"),
263 })
264 }
265
266 fn checkout_sha(repo: &Repository, sha: &str) -> Result<(), PkgError> {
268 let oid = git2::Oid::from_str(sha).map_err(|e| PkgError::Validation {
269 message: format!("invalid SHA {sha}: {e}"),
270 })?;
271 let obj = repo.find_object(oid, None)?;
272 repo.reset(&obj, git2::ResetType::Hard, None)?;
273 Ok(())
274 }
275
276 pub fn list_tags(&self, url: &str) -> Result<Vec<String>, PkgError> {
282 Self::validate_url(url)?;
283
284 let scratch = self.cache_root.join("git").join(".ls-remote");
286 std::fs::create_dir_all(&scratch)?;
287 let tmp = Self::temp_clone_path(&scratch);
288 std::fs::create_dir_all(&tmp)?;
289
290 let result = Self::list_tags_inner(&tmp, url);
291 let _ = std::fs::remove_dir_all(&tmp);
292 result
293 }
294
295 fn list_tags_inner(tmp: &std::path::Path, url: &str) -> Result<Vec<String>, PkgError> {
296 let repo = Repository::init(tmp)?;
297 let mut remote = repo.remote_anonymous(url)?;
298 remote.connect_auth(
299 git2::Direction::Fetch,
300 Some(Self::make_credentials_callbacks()),
301 None,
302 )?;
303 let refs = remote.list()?;
304
305 let mut tags: Vec<String> = Vec::new();
306 for head in refs.iter() {
307 if let Some(name) = head.name().strip_prefix("refs/tags/") {
308 let trimmed = name.trim_end_matches("^{}");
309 let s = trimmed.to_string();
310 if !tags.contains(&s) {
311 tags.push(s);
312 }
313 }
314 }
315
316 let _ = remote.disconnect();
317 Ok(tags)
318 }
319
320 fn make_credentials_callbacks() -> RemoteCallbacks<'static> {
323 let mut callbacks = RemoteCallbacks::new();
324
325 let tried = AtomicU8::new(0);
326 callbacks.credentials(move |_url, username, allowed| {
327 let tried_bits = tried.load(Ordering::Relaxed);
328
329 if allowed.contains(CredentialType::SSH_KEY) && (tried_bits & 0b001 == 0) {
330 tried.fetch_or(0b001, Ordering::Relaxed);
331 let user = username.unwrap_or("git");
332 return git2::Cred::ssh_key_from_agent(user);
333 }
334
335 if allowed.contains(CredentialType::USER_PASS_PLAINTEXT) && (tried_bits & 0b010 == 0) {
336 tried.fetch_or(0b010, Ordering::Relaxed);
337 if let Ok(cfg) = git2::Config::open_default() {
338 return git2::Cred::credential_helper(&cfg, _url, username);
339 }
340 }
341
342 if tried_bits & 0b100 == 0 {
343 tried.fetch_or(0b100, Ordering::Relaxed);
344 return git2::Cred::default();
345 }
346
347 Err(git2::Error::from_str("all credential types exhausted"))
348 });
349 callbacks
350 }
351
352 fn make_fetch_options() -> FetchOptions<'static> {
354 let mut fo = FetchOptions::new();
355 fo.remote_callbacks(Self::make_credentials_callbacks());
356 fo
357 }
358}
359
360impl Fetcher for GitFetcher {
361 fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError> {
362 let url = &dep.git;
363
364 Self::validate_url(url)?;
366
367 let git_base = self.cache_root.join("git");
369 std::fs::create_dir_all(&git_base)?;
370
371 let tmp_path = Self::temp_clone_path(&git_base);
374
375 let fo = Self::make_fetch_options();
377 let repo = match RepoBuilder::new().fetch_options(fo).clone(url, &tmp_path) {
378 Ok(r) => r,
379 Err(e) => {
380 let _ = std::fs::remove_dir_all(&tmp_path);
382 return Err(e.into());
383 }
384 };
385
386 let (sha, resolved_tag) = match Self::resolve_ref(&repo, dep) {
388 Ok(s) => s,
389 Err(e) => {
390 let _ = std::fs::remove_dir_all(&tmp_path);
391 return Err(e);
392 }
393 };
394
395 if let Err(e) = Self::checkout_sha(&repo, &sha) {
401 let _ = std::fs::remove_dir_all(&tmp_path);
402 return Err(e);
403 }
404
405 let cache_path = match self.cache_dir(url, &sha) {
407 Ok(p) => p,
408 Err(e) => {
409 let _ = std::fs::remove_dir_all(&tmp_path);
410 return Err(e);
411 }
412 };
413
414 if cache_path.exists() {
415 drop(repo);
417 let _ = std::fs::remove_dir_all(&tmp_path);
418 } else {
419 if let Some(parent) = cache_path.parent() {
421 std::fs::create_dir_all(parent)?;
422 }
423 drop(repo); std::fs::rename(&tmp_path, &cache_path)?;
425 }
426
427 let manifest_path = cache_path.join("mlua-pkg.toml");
429 let manifest = if manifest_path.exists() {
430 Some(Manifest::from_path(&manifest_path)?)
431 } else {
432 None
433 };
434
435 Ok(FetchedPkg {
436 cache_path,
437 sha,
438 manifest,
439 resolved_tag,
440 })
441 }
442}
443
444#[cfg(test)]
447mod tests {
448 use super::*;
449 use git2::{Repository, Signature};
450 use std::fs;
451 use tempfile::TempDir;
452
453 fn init_repo_with_commit(dir: &std::path::Path) -> String {
455 let repo = Repository::init(dir).unwrap();
456
457 let mut config = repo.config().unwrap();
459 config.set_str("user.name", "Test").unwrap();
460 config.set_str("user.email", "test@example.com").unwrap();
461 drop(config);
462
463 let file_path = dir.join("README.md");
465 fs::write(&file_path, "# test\n").unwrap();
466
467 let mut index = repo.index().unwrap();
468 index.add_path(std::path::Path::new("README.md")).unwrap();
469 index.write().unwrap();
470
471 let tree_id = index.write_tree().unwrap();
472 let tree = repo.find_tree(tree_id).unwrap();
473 let sig = Signature::now("Test", "test@example.com").unwrap();
474 let oid = repo
475 .commit(Some("HEAD"), &sig, &sig, "initial commit", &tree, &[])
476 .unwrap();
477 oid.to_string()
478 }
479
480 fn add_tag(repo: &Repository, tag_name: &str) -> String {
482 let head = repo.head().unwrap().peel_to_commit().unwrap();
483 let sig = Signature::now("Test", "test@example.com").unwrap();
484 repo.tag(tag_name, head.as_object(), &sig, tag_name, false)
485 .unwrap();
486 head.id().to_string()
487 }
488
489 #[test]
492 fn clone_local_repo_happy_path() {
493 let src = TempDir::new().unwrap();
494 let sha = init_repo_with_commit(src.path());
495
496 let cache_root = TempDir::new().unwrap();
497 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
498
499 let url = format!("file://{}", src.path().display());
500 let dep = Dep {
501 git: url,
502 tag: None,
503 rev: None,
504 branch: None,
505 entry: None,
506 target_dir: None,
507 };
508
509 let result = fetcher.fetch(&dep).unwrap();
510 assert_eq!(result.sha, sha, "SHA should match the initial commit");
511 assert!(result.cache_path.exists(), "cache_path must exist on disk");
512 assert!(
513 result.manifest.is_none(),
514 "no mlua-pkg.toml in bare test repo"
515 );
516 }
517
518 #[test]
521 fn resolve_tag_sha() {
522 let src = TempDir::new().unwrap();
523 init_repo_with_commit(src.path());
524 let repo = Repository::open(src.path()).unwrap();
525 let expected_sha = add_tag(&repo, "v0.1.0");
526 drop(repo);
527
528 let cache_root = TempDir::new().unwrap();
529 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
530
531 let url = format!("file://{}", src.path().display());
532 let dep = Dep {
533 git: url,
534 tag: Some("v0.1.0".to_string()),
535 rev: None,
536 branch: None,
537 entry: None,
538 target_dir: None,
539 };
540
541 let result = fetcher.fetch(&dep).unwrap();
542 assert_eq!(result.sha, expected_sha, "tag must resolve to expected SHA");
543 assert!(result.cache_path.exists());
544 }
545
546 #[test]
549 fn resolve_rev_sha() {
550 let src = TempDir::new().unwrap();
551 let sha = init_repo_with_commit(src.path());
552
553 let cache_root = TempDir::new().unwrap();
554 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
555
556 let url = format!("file://{}", src.path().display());
557 let dep = Dep {
558 git: url,
559 rev: Some(sha.clone()),
560 tag: None,
561 branch: None,
562 entry: None,
563 target_dir: None,
564 };
565
566 let result = fetcher.fetch(&dep).unwrap();
567 assert_eq!(result.sha, sha, "rev should resolve to the given SHA");
568 }
569
570 #[test]
573 fn nonexistent_repo_returns_error() {
574 let cache_root = TempDir::new().unwrap();
575 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
576
577 let dep = Dep {
578 git: "file:///nonexistent/path/that/does/not/exist".to_string(),
579 tag: None,
580 rev: None,
581 branch: None,
582 entry: None,
583 target_dir: None,
584 };
585
586 let err = fetcher.fetch(&dep).unwrap_err();
587 assert!(
588 matches!(err, PkgError::GitFetch { .. }),
589 "expected GitFetch error, got: {err}"
590 );
591 }
592
593 #[test]
596 fn second_fetch_uses_cache() {
597 let src = TempDir::new().unwrap();
598 let sha = init_repo_with_commit(src.path());
599
600 let cache_root = TempDir::new().unwrap();
601 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
602
603 let url = format!("file://{}", src.path().display());
604 let dep = Dep {
605 git: url,
606 rev: Some(sha.clone()),
607 tag: None,
608 branch: None,
609 entry: None,
610 target_dir: None,
611 };
612
613 let first = fetcher.fetch(&dep).unwrap();
614 let second = fetcher.fetch(&dep).unwrap();
615
616 assert_eq!(
617 first.cache_path, second.cache_path,
618 "cache paths must be identical"
619 );
620 assert_eq!(first.sha, second.sha);
621 }
622
623 #[test]
626 fn path_traversal_in_url_is_rejected() {
627 let cache_root = TempDir::new().unwrap();
628 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
629
630 let dep = Dep {
631 git: "https://github.com/../../../etc/passwd".to_string(),
632 tag: None,
633 rev: None,
634 branch: None,
635 entry: None,
636 target_dir: None,
637 };
638
639 let err = fetcher.fetch(&dep).unwrap_err();
640 assert!(
641 matches!(err, PkgError::Validation { .. }),
642 "expected Validation error for path traversal, got: {err}"
643 );
644 }
645
646 #[test]
649 fn manifest_parsed_when_present() {
650 let src = TempDir::new().unwrap();
651
652 let toml_path = src.path().join("mlua-pkg.toml");
654 fs::write(
655 &toml_path,
656 r#"[package]
657name = "test-lib"
658version = "0.1.0"
659"#,
660 )
661 .unwrap();
662
663 let repo = Repository::init(src.path()).unwrap();
664 let mut config = repo.config().unwrap();
665 config.set_str("user.name", "Test").unwrap();
666 config.set_str("user.email", "test@example.com").unwrap();
667 drop(config);
668
669 let mut index = repo.index().unwrap();
670 index
671 .add_path(std::path::Path::new("mlua-pkg.toml"))
672 .unwrap();
673 index.write().unwrap();
674 let tree_id = index.write_tree().unwrap();
675 let tree = repo.find_tree(tree_id).unwrap();
676 let sig = Signature::now("Test", "test@example.com").unwrap();
677 repo.commit(Some("HEAD"), &sig, &sig, "add manifest", &tree, &[])
678 .unwrap();
679
680 let cache_root = TempDir::new().unwrap();
681 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
682
683 let url = format!("file://{}", src.path().display());
684 let dep = Dep {
685 git: url,
686 tag: None,
687 rev: None,
688 branch: None,
689 entry: None,
690 target_dir: None,
691 };
692
693 let result = fetcher.fetch(&dep).unwrap();
694 let manifest = result.manifest.expect("manifest should be parsed");
695 assert_eq!(manifest.package.name, "test-lib");
696 assert_eq!(manifest.package.version, "0.1.0");
697 }
698
699 #[test]
707 fn fetched_worktree_matches_resolved_tag_not_head() {
708 let src = TempDir::new().unwrap();
709 let repo = Repository::init(src.path()).unwrap();
710 let mut config = repo.config().unwrap();
711 config.set_str("user.name", "Test").unwrap();
712 config.set_str("user.email", "test@example.com").unwrap();
713 drop(config);
714 let sig = Signature::now("Test", "test@example.com").unwrap();
715
716 fs::write(src.path().join("VERSION"), "0.1.0").unwrap();
718 let mut index = repo.index().unwrap();
719 index.add_path(std::path::Path::new("VERSION")).unwrap();
720 index.write().unwrap();
721 let tree_id = index.write_tree().unwrap();
722 let tree = repo.find_tree(tree_id).unwrap();
723 let c1 = repo
724 .commit(Some("HEAD"), &sig, &sig, "v0.1.0", &tree, &[])
725 .unwrap();
726 let c1_obj = repo.find_object(c1, None).unwrap();
727 repo.tag("v0.1.0", &c1_obj, &sig, "v0.1.0", false).unwrap();
728 let v010_sha = c1.to_string();
729
730 fs::write(src.path().join("VERSION"), "0.2.0").unwrap();
732 let mut index = repo.index().unwrap();
733 index.add_path(std::path::Path::new("VERSION")).unwrap();
734 index.write().unwrap();
735 let tree_id = index.write_tree().unwrap();
736 let tree = repo.find_tree(tree_id).unwrap();
737 let parent = repo.find_commit(c1).unwrap();
738 let c2 = repo
739 .commit(Some("HEAD"), &sig, &sig, "v0.2.0", &tree, &[&parent])
740 .unwrap();
741 assert_ne!(c1, c2, "HEAD must have advanced past the tag");
742
743 let cache_root = TempDir::new().unwrap();
745 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
746 let dep = Dep {
747 git: format!("file://{}", src.path().display()),
748 tag: Some("v0.1.0".to_string()),
749 rev: None,
750 branch: None,
751 entry: None,
752 target_dir: None,
753 };
754 let fetched = fetcher.fetch(&dep).unwrap();
755
756 assert_eq!(fetched.sha, v010_sha, "SHA must resolve to tag commit");
758
759 let version = fs::read_to_string(fetched.cache_path.join("VERSION")).unwrap();
761 assert_eq!(
762 version, "0.1.0",
763 "fetched worktree must contain tag v0.1.0 content, got HEAD content instead"
764 );
765 }
766
767 #[test]
770 fn cache_dir_rejects_invalid_sha() {
771 let cache_root = TempDir::new().unwrap();
772 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
773
774 let err = fetcher
775 .cache_dir("https://github.com/x/y", "../evil")
776 .unwrap_err();
777 assert!(
778 matches!(err, PkgError::Validation { .. }),
779 "expected Validation error for invalid SHA, got: {err}"
780 );
781 }
782}