1use std::{
34 path::{Component, PathBuf},
35 sync::atomic::{AtomicU64, AtomicU8, Ordering},
36};
37
38use git2::{build::RepoBuilder, CredentialType, FetchOptions, RemoteCallbacks, Repository};
39
40use crate::{
41 manifest::{Dep, Manifest},
42 PkgError,
43};
44
45static TMP_CTR: AtomicU64 = AtomicU64::new(0);
47
48pub trait Fetcher {
55 fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError>;
61}
62
63#[derive(Debug, Clone)]
67pub struct FetchedPkg {
68 pub cache_path: PathBuf,
70
71 pub sha: String,
73
74 pub manifest: Option<Manifest>,
76}
77
78pub struct GitFetcher {
82 cache_root: PathBuf,
84}
85
86impl GitFetcher {
87 pub fn new(cache_root: PathBuf) -> Self {
89 Self { cache_root }
90 }
91
92 fn cache_dir(&self, url: &str, sha: &str) -> Result<PathBuf, PkgError> {
101 let stripped = url
103 .trim_start_matches("https://")
104 .trim_start_matches("http://")
105 .trim_start_matches("ssh://")
106 .trim_start_matches("git@")
107 .replace(':', "/") .trim_end_matches(".git")
109 .to_owned();
110
111 if stripped.is_empty() {
112 return Err(PkgError::Validation {
113 message: format!("cannot derive cache path from URL: {url:?}"),
114 });
115 }
116
117 for component in stripped.split('/') {
119 if component == ".." || component == "." {
120 return Err(PkgError::Validation {
121 message: format!(
122 "URL {url:?} contains a path traversal component: {component:?}"
123 ),
124 });
125 }
126 }
127
128 if sha.is_empty() || !sha.chars().all(|c| c.is_ascii_hexdigit()) {
130 return Err(PkgError::Validation {
131 message: format!("invalid SHA: {sha:?}"),
132 });
133 }
134
135 let mut path = self.cache_root.join("git");
136 for segment in stripped.split('/') {
137 if segment.is_empty() {
138 continue;
139 }
140 let p = path.join(segment);
142 for c in p.components() {
143 if c == Component::ParentDir {
144 return Err(PkgError::Validation {
145 message: format!(
146 "URL {url:?} resolves to a path with parent-dir traversal"
147 ),
148 });
149 }
150 }
151 path = p;
152 }
153 path = path.join(sha);
154 Ok(path)
155 }
156
157 fn validate_url(url: &str) -> Result<(), PkgError> {
163 let stripped = url
164 .trim_start_matches("https://")
165 .trim_start_matches("http://")
166 .trim_start_matches("ssh://")
167 .trim_start_matches("git@")
168 .replace(':', "/")
169 .trim_end_matches(".git")
170 .to_owned();
171
172 if stripped.is_empty() {
173 return Err(PkgError::Validation {
174 message: format!("cannot derive cache path from URL: {url:?}"),
175 });
176 }
177
178 for component in stripped.split('/') {
179 if component == ".." || component == "." {
180 return Err(PkgError::Validation {
181 message: format!(
182 "URL {url:?} contains a path traversal component: {component:?}"
183 ),
184 });
185 }
186 }
187 Ok(())
188 }
189
190 fn temp_clone_path(git_base: &std::path::Path) -> PathBuf {
193 let n = TMP_CTR.fetch_add(1, Ordering::Relaxed);
194 let pid = std::process::id();
195 git_base.join(format!(".fetch-{pid}-{n}"))
196 }
197
198 fn resolve_sha(repo: &Repository, dep: &Dep) -> Result<String, PkgError> {
208 let oid = if let Some(rev) = &dep.rev {
209 repo.revparse_single(rev)?.peel_to_commit()?.id()
210 } else if let Some(tag) = &dep.tag {
211 let refname = format!("refs/tags/{tag}");
212 repo.find_reference(&refname)?.peel_to_commit()?.id()
213 } else if let Some(branch) = &dep.branch {
214 let refname = format!("refs/remotes/origin/{branch}");
215 repo.find_reference(&refname)?.peel_to_commit()?.id()
216 } else {
217 repo.head()?.peel_to_commit()?.id()
219 };
220 Ok(oid.to_string())
221 }
222
223 fn checkout_sha(repo: &Repository, sha: &str) -> Result<(), PkgError> {
225 let oid = git2::Oid::from_str(sha).map_err(|e| PkgError::Validation {
226 message: format!("invalid SHA {sha}: {e}"),
227 })?;
228 let obj = repo.find_object(oid, None)?;
229 repo.reset(&obj, git2::ResetType::Hard, None)?;
230 Ok(())
231 }
232
233 fn make_fetch_options() -> FetchOptions<'static> {
235 let mut callbacks = RemoteCallbacks::new();
236
237 let tried = AtomicU8::new(0);
240
241 callbacks.credentials(move |_url, username, allowed| {
242 let tried_bits = tried.load(Ordering::Relaxed);
243
244 if allowed.contains(CredentialType::SSH_KEY) && (tried_bits & 0b001 == 0) {
246 tried.fetch_or(0b001, Ordering::Relaxed);
247 let user = username.unwrap_or("git");
248 return git2::Cred::ssh_key_from_agent(user);
249 }
250
251 if allowed.contains(CredentialType::USER_PASS_PLAINTEXT) && (tried_bits & 0b010 == 0) {
253 tried.fetch_or(0b010, Ordering::Relaxed);
254 if let Ok(cfg) = git2::Config::open_default() {
255 return git2::Cred::credential_helper(&cfg, _url, username);
256 }
257 }
259
260 if tried_bits & 0b100 == 0 {
262 tried.fetch_or(0b100, Ordering::Relaxed);
263 return git2::Cred::default();
264 }
265
266 Err(git2::Error::from_str("all credential types exhausted"))
267 });
268
269 let mut fo = FetchOptions::new();
270 fo.remote_callbacks(callbacks);
271 fo
272 }
273}
274
275impl Fetcher for GitFetcher {
276 fn fetch(&self, dep: &Dep) -> Result<FetchedPkg, PkgError> {
277 let url = &dep.git;
278
279 Self::validate_url(url)?;
281
282 let git_base = self.cache_root.join("git");
284 std::fs::create_dir_all(&git_base)?;
285
286 let tmp_path = Self::temp_clone_path(&git_base);
289
290 let fo = Self::make_fetch_options();
292 let repo = match RepoBuilder::new().fetch_options(fo).clone(url, &tmp_path) {
293 Ok(r) => r,
294 Err(e) => {
295 let _ = std::fs::remove_dir_all(&tmp_path);
297 return Err(e.into());
298 }
299 };
300
301 let sha = match Self::resolve_sha(&repo, dep) {
303 Ok(s) => s,
304 Err(e) => {
305 let _ = std::fs::remove_dir_all(&tmp_path);
306 return Err(e);
307 }
308 };
309
310 if let Err(e) = Self::checkout_sha(&repo, &sha) {
316 let _ = std::fs::remove_dir_all(&tmp_path);
317 return Err(e);
318 }
319
320 let cache_path = match self.cache_dir(url, &sha) {
322 Ok(p) => p,
323 Err(e) => {
324 let _ = std::fs::remove_dir_all(&tmp_path);
325 return Err(e);
326 }
327 };
328
329 if cache_path.exists() {
330 drop(repo);
332 let _ = std::fs::remove_dir_all(&tmp_path);
333 } else {
334 if let Some(parent) = cache_path.parent() {
336 std::fs::create_dir_all(parent)?;
337 }
338 drop(repo); std::fs::rename(&tmp_path, &cache_path)?;
340 }
341
342 let manifest_path = cache_path.join("mlua-pkg.toml");
344 let manifest = if manifest_path.exists() {
345 Some(Manifest::from_path(&manifest_path)?)
346 } else {
347 None
348 };
349
350 Ok(FetchedPkg {
351 cache_path,
352 sha,
353 manifest,
354 })
355 }
356}
357
358#[cfg(test)]
361mod tests {
362 use super::*;
363 use git2::{Repository, Signature};
364 use std::fs;
365 use tempfile::TempDir;
366
367 fn init_repo_with_commit(dir: &std::path::Path) -> String {
369 let repo = Repository::init(dir).unwrap();
370
371 let mut config = repo.config().unwrap();
373 config.set_str("user.name", "Test").unwrap();
374 config.set_str("user.email", "test@example.com").unwrap();
375 drop(config);
376
377 let file_path = dir.join("README.md");
379 fs::write(&file_path, "# test\n").unwrap();
380
381 let mut index = repo.index().unwrap();
382 index.add_path(std::path::Path::new("README.md")).unwrap();
383 index.write().unwrap();
384
385 let tree_id = index.write_tree().unwrap();
386 let tree = repo.find_tree(tree_id).unwrap();
387 let sig = Signature::now("Test", "test@example.com").unwrap();
388 let oid = repo
389 .commit(Some("HEAD"), &sig, &sig, "initial commit", &tree, &[])
390 .unwrap();
391 oid.to_string()
392 }
393
394 fn add_tag(repo: &Repository, tag_name: &str) -> String {
396 let head = repo.head().unwrap().peel_to_commit().unwrap();
397 let sig = Signature::now("Test", "test@example.com").unwrap();
398 repo.tag(tag_name, head.as_object(), &sig, tag_name, false)
399 .unwrap();
400 head.id().to_string()
401 }
402
403 #[test]
406 fn clone_local_repo_happy_path() {
407 let src = TempDir::new().unwrap();
408 let sha = init_repo_with_commit(src.path());
409
410 let cache_root = TempDir::new().unwrap();
411 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
412
413 let url = format!("file://{}", src.path().display());
414 let dep = Dep {
415 git: url,
416 tag: None,
417 rev: None,
418 branch: None,
419 entry: None,
420 };
421
422 let result = fetcher.fetch(&dep).unwrap();
423 assert_eq!(result.sha, sha, "SHA should match the initial commit");
424 assert!(result.cache_path.exists(), "cache_path must exist on disk");
425 assert!(
426 result.manifest.is_none(),
427 "no mlua-pkg.toml in bare test repo"
428 );
429 }
430
431 #[test]
434 fn resolve_tag_sha() {
435 let src = TempDir::new().unwrap();
436 init_repo_with_commit(src.path());
437 let repo = Repository::open(src.path()).unwrap();
438 let expected_sha = add_tag(&repo, "v0.1.0");
439 drop(repo);
440
441 let cache_root = TempDir::new().unwrap();
442 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
443
444 let url = format!("file://{}", src.path().display());
445 let dep = Dep {
446 git: url,
447 tag: Some("v0.1.0".to_string()),
448 rev: None,
449 branch: None,
450 entry: None,
451 };
452
453 let result = fetcher.fetch(&dep).unwrap();
454 assert_eq!(result.sha, expected_sha, "tag must resolve to expected SHA");
455 assert!(result.cache_path.exists());
456 }
457
458 #[test]
461 fn resolve_rev_sha() {
462 let src = TempDir::new().unwrap();
463 let sha = init_repo_with_commit(src.path());
464
465 let cache_root = TempDir::new().unwrap();
466 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
467
468 let url = format!("file://{}", src.path().display());
469 let dep = Dep {
470 git: url,
471 rev: Some(sha.clone()),
472 tag: None,
473 branch: None,
474 entry: None,
475 };
476
477 let result = fetcher.fetch(&dep).unwrap();
478 assert_eq!(result.sha, sha, "rev should resolve to the given SHA");
479 }
480
481 #[test]
484 fn nonexistent_repo_returns_error() {
485 let cache_root = TempDir::new().unwrap();
486 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
487
488 let dep = Dep {
489 git: "file:///nonexistent/path/that/does/not/exist".to_string(),
490 tag: None,
491 rev: None,
492 branch: None,
493 entry: None,
494 };
495
496 let err = fetcher.fetch(&dep).unwrap_err();
497 assert!(
498 matches!(err, PkgError::GitFetch { .. }),
499 "expected GitFetch error, got: {err}"
500 );
501 }
502
503 #[test]
506 fn second_fetch_uses_cache() {
507 let src = TempDir::new().unwrap();
508 let sha = init_repo_with_commit(src.path());
509
510 let cache_root = TempDir::new().unwrap();
511 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
512
513 let url = format!("file://{}", src.path().display());
514 let dep = Dep {
515 git: url,
516 rev: Some(sha.clone()),
517 tag: None,
518 branch: None,
519 entry: None,
520 };
521
522 let first = fetcher.fetch(&dep).unwrap();
523 let second = fetcher.fetch(&dep).unwrap();
524
525 assert_eq!(
526 first.cache_path, second.cache_path,
527 "cache paths must be identical"
528 );
529 assert_eq!(first.sha, second.sha);
530 }
531
532 #[test]
535 fn path_traversal_in_url_is_rejected() {
536 let cache_root = TempDir::new().unwrap();
537 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
538
539 let dep = Dep {
540 git: "https://github.com/../../../etc/passwd".to_string(),
541 tag: None,
542 rev: None,
543 branch: None,
544 entry: None,
545 };
546
547 let err = fetcher.fetch(&dep).unwrap_err();
548 assert!(
549 matches!(err, PkgError::Validation { .. }),
550 "expected Validation error for path traversal, got: {err}"
551 );
552 }
553
554 #[test]
557 fn manifest_parsed_when_present() {
558 let src = TempDir::new().unwrap();
559
560 let toml_path = src.path().join("mlua-pkg.toml");
562 fs::write(
563 &toml_path,
564 r#"[package]
565name = "test-lib"
566version = "0.1.0"
567"#,
568 )
569 .unwrap();
570
571 let repo = Repository::init(src.path()).unwrap();
572 let mut config = repo.config().unwrap();
573 config.set_str("user.name", "Test").unwrap();
574 config.set_str("user.email", "test@example.com").unwrap();
575 drop(config);
576
577 let mut index = repo.index().unwrap();
578 index
579 .add_path(std::path::Path::new("mlua-pkg.toml"))
580 .unwrap();
581 index.write().unwrap();
582 let tree_id = index.write_tree().unwrap();
583 let tree = repo.find_tree(tree_id).unwrap();
584 let sig = Signature::now("Test", "test@example.com").unwrap();
585 repo.commit(Some("HEAD"), &sig, &sig, "add manifest", &tree, &[])
586 .unwrap();
587
588 let cache_root = TempDir::new().unwrap();
589 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
590
591 let url = format!("file://{}", src.path().display());
592 let dep = Dep {
593 git: url,
594 tag: None,
595 rev: None,
596 branch: None,
597 entry: None,
598 };
599
600 let result = fetcher.fetch(&dep).unwrap();
601 let manifest = result.manifest.expect("manifest should be parsed");
602 assert_eq!(manifest.package.name, "test-lib");
603 assert_eq!(manifest.package.version, "0.1.0");
604 }
605
606 #[test]
614 fn fetched_worktree_matches_resolved_tag_not_head() {
615 let src = TempDir::new().unwrap();
616 let repo = Repository::init(src.path()).unwrap();
617 let mut config = repo.config().unwrap();
618 config.set_str("user.name", "Test").unwrap();
619 config.set_str("user.email", "test@example.com").unwrap();
620 drop(config);
621 let sig = Signature::now("Test", "test@example.com").unwrap();
622
623 fs::write(src.path().join("VERSION"), "0.1.0").unwrap();
625 let mut index = repo.index().unwrap();
626 index.add_path(std::path::Path::new("VERSION")).unwrap();
627 index.write().unwrap();
628 let tree_id = index.write_tree().unwrap();
629 let tree = repo.find_tree(tree_id).unwrap();
630 let c1 = repo
631 .commit(Some("HEAD"), &sig, &sig, "v0.1.0", &tree, &[])
632 .unwrap();
633 let c1_obj = repo.find_object(c1, None).unwrap();
634 repo.tag("v0.1.0", &c1_obj, &sig, "v0.1.0", false).unwrap();
635 let v010_sha = c1.to_string();
636
637 fs::write(src.path().join("VERSION"), "0.2.0").unwrap();
639 let mut index = repo.index().unwrap();
640 index.add_path(std::path::Path::new("VERSION")).unwrap();
641 index.write().unwrap();
642 let tree_id = index.write_tree().unwrap();
643 let tree = repo.find_tree(tree_id).unwrap();
644 let parent = repo.find_commit(c1).unwrap();
645 let c2 = repo
646 .commit(Some("HEAD"), &sig, &sig, "v0.2.0", &tree, &[&parent])
647 .unwrap();
648 assert_ne!(c1, c2, "HEAD must have advanced past the tag");
649
650 let cache_root = TempDir::new().unwrap();
652 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
653 let dep = Dep {
654 git: format!("file://{}", src.path().display()),
655 tag: Some("v0.1.0".to_string()),
656 rev: None,
657 branch: None,
658 entry: None,
659 };
660 let fetched = fetcher.fetch(&dep).unwrap();
661
662 assert_eq!(fetched.sha, v010_sha, "SHA must resolve to tag commit");
664
665 let version = fs::read_to_string(fetched.cache_path.join("VERSION")).unwrap();
667 assert_eq!(
668 version, "0.1.0",
669 "fetched worktree must contain tag v0.1.0 content, got HEAD content instead"
670 );
671 }
672
673 #[test]
676 fn cache_dir_rejects_invalid_sha() {
677 let cache_root = TempDir::new().unwrap();
678 let fetcher = GitFetcher::new(cache_root.path().to_path_buf());
679
680 let err = fetcher
681 .cache_dir("https://github.com/x/y", "../evil")
682 .unwrap_err();
683 assert!(
684 matches!(err, PkgError::Validation { .. }),
685 "expected Validation error for invalid SHA, got: {err}"
686 );
687 }
688}