1use std::io::BufReader;
4use std::path::PathBuf;
5
6use anyhow::{Context, Result};
7use git2::{Repository, Status};
8use ssh2_config::{ParseRule, SshConfig};
9use tracing::{debug, error, info};
10
11use crate::git::CommitInfo;
12
13const MAX_AUTH_ATTEMPTS: u32 = 3;
15
16pub struct GitRepository {
18 repo: Repository,
19}
20
21#[derive(Debug)]
23pub struct WorkingDirectoryStatus {
24 pub clean: bool,
26 pub untracked_changes: Vec<FileStatus>,
28}
29
30#[derive(Debug)]
32pub struct FileStatus {
33 pub status: String,
35 pub file: String,
37}
38
39impl GitRepository {
40 pub fn open() -> Result<Self> {
42 let repo = Repository::open(".").context("Not in a git repository")?;
43
44 Ok(Self { repo })
45 }
46
47 pub fn open_at<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
49 let repo = Repository::open(path).context("Failed to open git repository")?;
50
51 Ok(Self { repo })
52 }
53
54 pub fn get_working_directory_status(&self) -> Result<WorkingDirectoryStatus> {
56 let statuses = self
57 .repo
58 .statuses(None)
59 .context("Failed to get repository status")?;
60
61 let mut untracked_changes = Vec::new();
62
63 for entry in statuses.iter() {
64 if let Some(path) = entry.path() {
65 let status_flags = entry.status();
66
67 if status_flags.contains(Status::IGNORED) {
69 continue;
70 }
71
72 let status_str = format_status_flags(status_flags);
73
74 untracked_changes.push(FileStatus {
75 status: status_str,
76 file: path.to_string(),
77 });
78 }
79 }
80
81 let clean = untracked_changes.is_empty();
82
83 Ok(WorkingDirectoryStatus {
84 clean,
85 untracked_changes,
86 })
87 }
88
89 pub fn is_working_directory_clean(&self) -> Result<bool> {
91 let status = self.get_working_directory_status()?;
92 Ok(status.clean)
93 }
94
95 pub fn path(&self) -> &std::path::Path {
97 self.repo.path()
98 }
99
100 pub fn workdir(&self) -> Option<&std::path::Path> {
102 self.repo.workdir()
103 }
104
105 pub fn repository(&self) -> &Repository {
107 &self.repo
108 }
109
110 pub fn get_current_branch(&self) -> Result<String> {
112 let head = self.repo.head().context("Failed to get HEAD reference")?;
113
114 if let Some(name) = head.shorthand() {
115 if name != "HEAD" {
116 return Ok(name.to_string());
117 }
118 }
119
120 anyhow::bail!("Repository is in detached HEAD state")
121 }
122
123 pub fn branch_exists(&self, branch_name: &str) -> Result<bool> {
125 if self
127 .repo
128 .find_branch(branch_name, git2::BranchType::Local)
129 .is_ok()
130 {
131 return Ok(true);
132 }
133
134 if self
136 .repo
137 .find_branch(branch_name, git2::BranchType::Remote)
138 .is_ok()
139 {
140 return Ok(true);
141 }
142
143 if self.repo.revparse_single(branch_name).is_ok() {
145 return Ok(true);
146 }
147
148 Ok(false)
149 }
150
151 pub fn get_commits_in_range(&self, range: &str) -> Result<Vec<CommitInfo>> {
153 let mut commits = Vec::new();
154
155 if range == "HEAD" {
156 let head = self.repo.head().context("Failed to get HEAD")?;
158 let commit = head
159 .peel_to_commit()
160 .context("Failed to peel HEAD to commit")?;
161 commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
162 } else if range.contains("..") {
163 let parts: Vec<&str> = range.split("..").collect();
165 if parts.len() != 2 {
166 anyhow::bail!("Invalid range format: {range}");
167 }
168
169 let start_spec = parts[0];
170 let end_spec = parts[1];
171
172 let start_obj = self
174 .repo
175 .revparse_single(start_spec)
176 .with_context(|| format!("Failed to parse start commit: {start_spec}"))?;
177 let end_obj = self
178 .repo
179 .revparse_single(end_spec)
180 .with_context(|| format!("Failed to parse end commit: {end_spec}"))?;
181
182 let start_commit = start_obj
183 .peel_to_commit()
184 .context("Failed to peel start object to commit")?;
185 let end_commit = end_obj
186 .peel_to_commit()
187 .context("Failed to peel end object to commit")?;
188
189 let mut walker = self.repo.revwalk().context("Failed to create revwalk")?;
191 walker
192 .push(end_commit.id())
193 .context("Failed to push end commit")?;
194 walker
195 .hide(start_commit.id())
196 .context("Failed to hide start commit")?;
197
198 for oid in walker {
199 let oid = oid.context("Failed to get commit OID from walker")?;
200 let commit = self
201 .repo
202 .find_commit(oid)
203 .context("Failed to find commit")?;
204
205 if commit.parent_count() > 1 {
207 continue;
208 }
209
210 commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
211 }
212
213 commits.reverse();
215 } else {
216 let obj = self
218 .repo
219 .revparse_single(range)
220 .with_context(|| format!("Failed to parse commit: {range}"))?;
221 let commit = obj
222 .peel_to_commit()
223 .context("Failed to peel object to commit")?;
224 commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
225 }
226
227 Ok(commits)
228 }
229}
230
231fn format_status_flags(flags: Status) -> String {
233 let mut status = String::new();
234
235 if flags.contains(Status::INDEX_NEW) {
236 status.push('A');
237 } else if flags.contains(Status::INDEX_MODIFIED) {
238 status.push('M');
239 } else if flags.contains(Status::INDEX_DELETED) {
240 status.push('D');
241 } else if flags.contains(Status::INDEX_RENAMED) {
242 status.push('R');
243 } else if flags.contains(Status::INDEX_TYPECHANGE) {
244 status.push('T');
245 } else {
246 status.push(' ');
247 }
248
249 if flags.contains(Status::WT_NEW) {
250 status.push('?');
251 } else if flags.contains(Status::WT_MODIFIED) {
252 status.push('M');
253 } else if flags.contains(Status::WT_DELETED) {
254 status.push('D');
255 } else if flags.contains(Status::WT_TYPECHANGE) {
256 status.push('T');
257 } else if flags.contains(Status::WT_RENAMED) {
258 status.push('R');
259 } else {
260 status.push(' ');
261 }
262
263 status
264}
265
266fn extract_hostname_from_git_url(url: &str) -> Option<String> {
268 if let Some(ssh_url) = url.strip_prefix("git@") {
269 ssh_url.split(':').next().map(str::to_string)
271 } else if let Some(https_url) = url.strip_prefix("https://") {
272 https_url.split('/').next().map(str::to_string)
274 } else if let Some(http_url) = url.strip_prefix("http://") {
275 http_url.split('/').next().map(str::to_string)
277 } else {
278 None
279 }
280}
281
282fn get_ssh_identity_for_host(hostname: &str) -> Option<PathBuf> {
284 let home = std::env::var("HOME").ok()?;
285 let ssh_config_path = PathBuf::from(&home).join(".ssh/config");
286
287 if !ssh_config_path.exists() {
288 debug!("SSH config file not found at: {:?}", ssh_config_path);
289 return None;
290 }
291
292 let file = std::fs::File::open(&ssh_config_path).ok()?;
294 let mut reader = BufReader::new(file);
295
296 let config = SshConfig::default()
297 .parse(&mut reader, ParseRule::ALLOW_UNKNOWN_FIELDS)
298 .ok()?;
299
300 let params = config.query(hostname);
302
303 if let Some(identity_files) = ¶ms.identity_file {
305 if let Some(first_identity) = identity_files.first() {
306 let identity_str = first_identity.to_string_lossy();
308 let identity_path = identity_str.replace('~', &home);
309 let path = PathBuf::from(identity_path);
310
311 if path.exists() {
312 debug!("Found SSH key for host '{}': {:?}", hostname, path);
313 return Some(path);
314 }
315 debug!("SSH key specified in config but not found: {:?}", path);
316 }
317 }
318
319 None
320}
321
322fn make_auth_callbacks(hostname: String) -> git2::RemoteCallbacks<'static> {
328 let mut callbacks = git2::RemoteCallbacks::new();
329 let mut auth_attempts: u32 = 0;
330
331 callbacks.credentials(move |url, username_from_url, allowed_types| {
332 auth_attempts += 1;
333 debug!(
334 "Credential callback attempt {} - URL: {}, Username: {:?}, Allowed types: {:?}",
335 auth_attempts, url, username_from_url, allowed_types
336 );
337
338 if auth_attempts > MAX_AUTH_ATTEMPTS {
339 error!(
340 "Too many authentication attempts ({}), giving up",
341 auth_attempts
342 );
343 return Err(git2::Error::from_str(
344 "Authentication failed after multiple attempts",
345 ));
346 }
347
348 let username = username_from_url.unwrap_or("git");
349
350 if allowed_types.contains(git2::CredentialType::SSH_KEY) {
351 if let Some(ssh_key_path) = get_ssh_identity_for_host(&hostname) {
353 let pub_key_path = ssh_key_path.with_extension("pub");
354 debug!("Trying SSH key from config: {:?}", ssh_key_path);
355
356 match git2::Cred::ssh_key(username, Some(&pub_key_path), &ssh_key_path, None) {
357 Ok(cred) => {
358 debug!(
359 "Successfully loaded SSH key from config: {:?}",
360 ssh_key_path
361 );
362 return Ok(cred);
363 }
364 Err(e) => {
365 debug!("Failed to load SSH key from config: {}", e);
366 }
367 }
368 }
369
370 if auth_attempts == 1 {
372 match git2::Cred::ssh_key_from_agent(username) {
373 Ok(cred) => {
374 debug!("SSH agent credentials obtained (attempt {})", auth_attempts);
375 return Ok(cred);
376 }
377 Err(e) => {
378 debug!("SSH agent failed: {}, trying default keys", e);
379 }
380 }
381 }
382
383 let home = std::env::var("HOME").unwrap_or_else(|_| "~".to_string());
385 let ssh_keys = [
386 format!("{home}/.ssh/id_ed25519"),
387 format!("{home}/.ssh/id_rsa"),
388 ];
389
390 for key_path in &ssh_keys {
391 let key_path = PathBuf::from(key_path);
392 if key_path.exists() {
393 let pub_key_path = key_path.with_extension("pub");
394 debug!("Trying default SSH key: {:?}", key_path);
395
396 match git2::Cred::ssh_key(username, Some(&pub_key_path), &key_path, None) {
397 Ok(cred) => {
398 debug!("Successfully loaded SSH key from {:?}", key_path);
399 return Ok(cred);
400 }
401 Err(e) => debug!("Failed to load SSH key from {:?}: {}", key_path, e),
402 }
403 }
404 }
405 }
406
407 debug!("Falling back to default credentials");
408 git2::Cred::default()
409 });
410
411 callbacks
412}
413
414fn format_auth_error(operation: &str, error: &git2::Error) -> String {
416 if error.message().contains("authentication") || error.message().contains("SSH") {
417 format!(
418 "Failed to {operation}: {error}. \n\nTroubleshooting steps:\n\
419 1. Check if your SSH key is loaded: ssh-add -l\n\
420 2. Test GitHub SSH connection: ssh -T git@github.com\n\
421 3. Use GitHub CLI auth instead: gh auth setup-git",
422 )
423 } else {
424 format!("Failed to {operation}: {error}")
425 }
426}
427
428impl GitRepository {
429 pub fn push_branch(&self, branch_name: &str, remote_name: &str) -> Result<()> {
431 info!(
432 "Pushing branch '{}' to remote '{}'",
433 branch_name, remote_name
434 );
435
436 debug!("Finding remote '{}'", remote_name);
438 let mut remote = self
439 .repo
440 .find_remote(remote_name)
441 .context("Failed to find remote")?;
442
443 let remote_url = remote.url().unwrap_or("<unknown>");
444 debug!("Remote URL: {}", remote_url);
445
446 let refspec = format!("refs/heads/{branch_name}:refs/heads/{branch_name}");
448 debug!("Using refspec: {}", refspec);
449
450 let hostname =
452 extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
453 debug!(
454 "Extracted hostname '{}' from URL '{}'",
455 hostname, remote_url
456 );
457
458 let mut push_options = git2::PushOptions::new();
460 let callbacks = make_auth_callbacks(hostname);
461 push_options.remote_callbacks(callbacks);
462
463 debug!("Attempting to push to remote...");
465 match remote.push(&[&refspec], Some(&mut push_options)) {
466 Ok(()) => {
467 info!(
468 "Successfully pushed branch '{}' to remote '{}'",
469 branch_name, remote_name
470 );
471
472 debug!("Setting upstream branch for '{}'", branch_name);
474 match self.repo.find_branch(branch_name, git2::BranchType::Local) {
475 Ok(mut branch) => {
476 let remote_ref = format!("{remote_name}/{branch_name}");
477 match branch.set_upstream(Some(&remote_ref)) {
478 Ok(()) => {
479 info!(
480 "Successfully set upstream to '{}'/{}",
481 remote_name, branch_name
482 );
483 }
484 Err(e) => {
485 error!("Failed to set upstream branch: {}", e);
487 }
488 }
489 }
490 Err(e) => {
491 error!("Failed to find local branch to set upstream: {}", e);
493 }
494 }
495
496 Ok(())
497 }
498 Err(e) => {
499 error!("Failed to push branch: {}", e);
500 Err(anyhow::anyhow!(format_auth_error(
501 "push branch to remote",
502 &e
503 )))
504 }
505 }
506 }
507
508 pub fn branch_exists_on_remote(&self, branch_name: &str, remote_name: &str) -> Result<bool> {
510 debug!(
511 "Checking if branch '{}' exists on remote '{}'",
512 branch_name, remote_name
513 );
514
515 let remote = self
516 .repo
517 .find_remote(remote_name)
518 .context("Failed to find remote")?;
519
520 let remote_url = remote.url().unwrap_or("<unknown>");
521 debug!("Remote URL: {}", remote_url);
522
523 let hostname =
525 extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
526 debug!(
527 "Extracted hostname '{}' from URL '{}'",
528 hostname, remote_url
529 );
530
531 let mut remote = remote;
533 let callbacks = make_auth_callbacks(hostname);
534
535 debug!("Attempting to connect to remote...");
536 match remote.connect_auth(git2::Direction::Fetch, Some(callbacks), None) {
537 Ok(_) => debug!("Successfully connected to remote"),
538 Err(e) => {
539 error!("Failed to connect to remote: {}", e);
540 return Err(anyhow::anyhow!(format_auth_error("connect to remote", &e)));
541 }
542 }
543
544 debug!("Listing remote refs...");
546 let refs = remote.list()?;
547 let remote_branch_ref = format!("refs/heads/{branch_name}");
548 debug!("Looking for remote branch ref: {}", remote_branch_ref);
549
550 for remote_head in refs {
551 debug!("Found remote ref: {}", remote_head.name());
552 if remote_head.name() == remote_branch_ref {
553 info!(
554 "Branch '{}' exists on remote '{}'",
555 branch_name, remote_name
556 );
557 return Ok(true);
558 }
559 }
560
561 info!(
562 "Branch '{}' does not exist on remote '{}'",
563 branch_name, remote_name
564 );
565 Ok(false)
566 }
567}
568
569#[cfg(test)]
570mod tests {
571 use super::*;
572
573 #[test]
576 fn hostname_from_ssh_url() {
577 let hostname = extract_hostname_from_git_url("git@github.com:user/repo.git");
578 assert_eq!(hostname, Some("github.com".to_string()));
579 }
580
581 #[test]
582 fn hostname_from_https_url() {
583 let hostname = extract_hostname_from_git_url("https://github.com/user/repo.git");
584 assert_eq!(hostname, Some("github.com".to_string()));
585 }
586
587 #[test]
588 fn hostname_from_http_url() {
589 let hostname = extract_hostname_from_git_url("http://gitlab.com/user/repo.git");
590 assert_eq!(hostname, Some("gitlab.com".to_string()));
591 }
592
593 #[test]
594 fn hostname_from_unknown_scheme() {
595 let hostname = extract_hostname_from_git_url("ftp://example.com/repo");
596 assert_eq!(hostname, None);
597 }
598
599 #[test]
600 fn hostname_from_ssh_custom_host() {
601 let hostname = extract_hostname_from_git_url("git@gitlab.example.com:org/project.git");
602 assert_eq!(hostname, Some("gitlab.example.com".to_string()));
603 }
604
605 #[test]
608 fn status_flags_new_index() {
609 let status = format_status_flags(Status::INDEX_NEW);
610 assert_eq!(status, "A ");
611 }
612
613 #[test]
614 fn status_flags_modified_index() {
615 let status = format_status_flags(Status::INDEX_MODIFIED);
616 assert_eq!(status, "M ");
617 }
618
619 #[test]
620 fn status_flags_deleted_index() {
621 let status = format_status_flags(Status::INDEX_DELETED);
622 assert_eq!(status, "D ");
623 }
624
625 #[test]
626 fn status_flags_wt_new() {
627 let status = format_status_flags(Status::WT_NEW);
628 assert_eq!(status, " ?");
629 }
630
631 #[test]
632 fn status_flags_wt_modified() {
633 let status = format_status_flags(Status::WT_MODIFIED);
634 assert_eq!(status, " M");
635 }
636
637 #[test]
638 fn status_flags_combined() {
639 let status = format_status_flags(Status::INDEX_NEW | Status::WT_MODIFIED);
640 assert_eq!(status, "AM");
641 }
642
643 #[test]
644 fn status_flags_empty() {
645 let status = format_status_flags(Status::empty());
646 assert_eq!(status, " ");
647 }
648
649 #[test]
652 fn auth_error_with_ssh_message() {
653 let error = git2::Error::from_str("SSH authentication failed");
654 let msg = format_auth_error("push", &error);
655 assert!(msg.contains("Troubleshooting steps"));
656 assert!(msg.contains("ssh-add -l"));
657 }
658
659 #[test]
660 fn auth_error_without_auth_message() {
661 let error = git2::Error::from_str("network timeout");
662 let msg = format_auth_error("fetch", &error);
663 assert!(msg.contains("Failed to fetch"));
664 assert!(!msg.contains("Troubleshooting"));
665 }
666
667 #[test]
670 fn open_at_temp_repo() -> Result<()> {
671 let temp_dir = {
672 std::fs::create_dir_all("tmp")?;
673 tempfile::tempdir_in("tmp")?
674 };
675 git2::Repository::init(temp_dir.path())?;
676 let repo = GitRepository::open_at(temp_dir.path())?;
677 assert!(repo.path().exists());
678 Ok(())
679 }
680
681 #[test]
682 fn working_directory_clean_empty_repo() -> Result<()> {
683 let temp_dir = {
684 std::fs::create_dir_all("tmp")?;
685 tempfile::tempdir_in("tmp")?
686 };
687 git2::Repository::init(temp_dir.path())?;
688 let repo = GitRepository::open_at(temp_dir.path())?;
689 let status = repo.get_working_directory_status()?;
690 assert!(status.clean);
691 assert!(status.untracked_changes.is_empty());
692 Ok(())
693 }
694
695 #[test]
696 fn working_directory_dirty_with_file() -> Result<()> {
697 let temp_dir = {
698 std::fs::create_dir_all("tmp")?;
699 tempfile::tempdir_in("tmp")?
700 };
701 git2::Repository::init(temp_dir.path())?;
702 std::fs::write(temp_dir.path().join("new_file.txt"), "content")?;
703 let repo = GitRepository::open_at(temp_dir.path())?;
704 let status = repo.get_working_directory_status()?;
705 assert!(!status.clean);
706 assert!(!status.untracked_changes.is_empty());
707 Ok(())
708 }
709
710 #[test]
711 fn is_working_directory_clean_delegator() -> Result<()> {
712 let temp_dir = {
713 std::fs::create_dir_all("tmp")?;
714 tempfile::tempdir_in("tmp")?
715 };
716 git2::Repository::init(temp_dir.path())?;
717 let repo = GitRepository::open_at(temp_dir.path())?;
718 assert!(repo.is_working_directory_clean()?);
719 Ok(())
720 }
721}