Skip to main content

omni_dev/git/
repository.rs

1//! Git repository operations.
2
3use std::io::BufReader;
4use std::path::PathBuf;
5
6use anyhow::{Context, Result};
7use git2::{Repository, Status};
8use ssh2_config::{ParseRule, SshConfig};
9use tracing::{debug, error, info};
10
11use crate::git::CommitInfo;
12
13/// Maximum credential callback attempts before giving up.
14const MAX_AUTH_ATTEMPTS: u32 = 3;
15
16/// Git repository wrapper.
17pub struct GitRepository {
18    repo: Repository,
19}
20
21/// Working directory status.
22#[derive(Debug)]
23pub struct WorkingDirectoryStatus {
24    /// Whether the working directory has no changes.
25    pub clean: bool,
26    /// List of files with uncommitted changes.
27    pub untracked_changes: Vec<FileStatus>,
28}
29
30/// File status information.
31#[derive(Debug)]
32pub struct FileStatus {
33    /// Git status flags (e.g., "AM", "??", "M ").
34    pub status: String,
35    /// Path to the file relative to repository root.
36    pub file: String,
37}
38
39impl GitRepository {
40    /// Opens a repository at the current directory.
41    pub fn open() -> Result<Self> {
42        let repo = Repository::open(".").context("Not in a git repository")?;
43
44        Ok(Self { repo })
45    }
46
47    /// Opens a repository at the specified path.
48    pub fn open_at<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
49        let repo = Repository::open(path).context("Failed to open git repository")?;
50
51        Ok(Self { repo })
52    }
53
54    /// Returns the working directory status.
55    pub fn get_working_directory_status(&self) -> Result<WorkingDirectoryStatus> {
56        let statuses = self
57            .repo
58            .statuses(None)
59            .context("Failed to get repository status")?;
60
61        let mut untracked_changes = Vec::new();
62
63        for entry in statuses.iter() {
64            if let Some(path) = entry.path() {
65                let status_flags = entry.status();
66
67                // Skip ignored files - they should not affect clean status
68                if status_flags.contains(Status::IGNORED) {
69                    continue;
70                }
71
72                let status_str = format_status_flags(status_flags);
73
74                untracked_changes.push(FileStatus {
75                    status: status_str,
76                    file: path.to_string(),
77                });
78            }
79        }
80
81        let clean = untracked_changes.is_empty();
82
83        Ok(WorkingDirectoryStatus {
84            clean,
85            untracked_changes,
86        })
87    }
88
89    /// Checks if the working directory is clean.
90    pub fn is_working_directory_clean(&self) -> Result<bool> {
91        let status = self.get_working_directory_status()?;
92        Ok(status.clean)
93    }
94
95    /// Returns the repository path.
96    pub fn path(&self) -> &std::path::Path {
97        self.repo.path()
98    }
99
100    /// Returns the workdir path.
101    pub fn workdir(&self) -> Option<&std::path::Path> {
102        self.repo.workdir()
103    }
104
105    /// Returns access to the underlying `git2::Repository`.
106    pub fn repository(&self) -> &Repository {
107        &self.repo
108    }
109
110    /// Returns the current branch name.
111    pub fn get_current_branch(&self) -> Result<String> {
112        let head = self.repo.head().context("Failed to get HEAD reference")?;
113
114        if let Some(name) = head.shorthand() {
115            if name != "HEAD" {
116                return Ok(name.to_string());
117            }
118        }
119
120        anyhow::bail!("Repository is in detached HEAD state")
121    }
122
123    /// Checks if a branch exists.
124    pub fn branch_exists(&self, branch_name: &str) -> Result<bool> {
125        // Check if it exists as a local branch
126        if self
127            .repo
128            .find_branch(branch_name, git2::BranchType::Local)
129            .is_ok()
130        {
131            return Ok(true);
132        }
133
134        // Check if it exists as a remote branch
135        if self
136            .repo
137            .find_branch(branch_name, git2::BranchType::Remote)
138            .is_ok()
139        {
140            return Ok(true);
141        }
142
143        // Check if we can resolve it as a reference
144        if self.repo.revparse_single(branch_name).is_ok() {
145            return Ok(true);
146        }
147
148        Ok(false)
149    }
150
151    /// Parses a commit range and returns the commits.
152    pub fn get_commits_in_range(&self, range: &str) -> Result<Vec<CommitInfo>> {
153        let mut commits = Vec::new();
154
155        if range == "HEAD" {
156            // Single HEAD commit
157            let head = self.repo.head().context("Failed to get HEAD")?;
158            let commit = head
159                .peel_to_commit()
160                .context("Failed to peel HEAD to commit")?;
161            commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
162        } else if range.contains("..") {
163            // Range format like HEAD~3..HEAD
164            let parts: Vec<&str> = range.split("..").collect();
165            if parts.len() != 2 {
166                anyhow::bail!("Invalid range format: {range}");
167            }
168
169            let start_spec = parts[0];
170            let end_spec = parts[1];
171
172            // Parse start and end commits
173            let start_obj = self
174                .repo
175                .revparse_single(start_spec)
176                .with_context(|| format!("Failed to parse start commit: {start_spec}"))?;
177            let end_obj = self
178                .repo
179                .revparse_single(end_spec)
180                .with_context(|| format!("Failed to parse end commit: {end_spec}"))?;
181
182            let start_commit = start_obj
183                .peel_to_commit()
184                .context("Failed to peel start object to commit")?;
185            let end_commit = end_obj
186                .peel_to_commit()
187                .context("Failed to peel end object to commit")?;
188
189            // Walk from end_commit back to start_commit (exclusive)
190            let mut walker = self.repo.revwalk().context("Failed to create revwalk")?;
191            walker
192                .push(end_commit.id())
193                .context("Failed to push end commit")?;
194            walker
195                .hide(start_commit.id())
196                .context("Failed to hide start commit")?;
197
198            for oid in walker {
199                let oid = oid.context("Failed to get commit OID from walker")?;
200                let commit = self
201                    .repo
202                    .find_commit(oid)
203                    .context("Failed to find commit")?;
204
205                // Skip merge commits
206                if commit.parent_count() > 1 {
207                    continue;
208                }
209
210                commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
211            }
212
213            // Reverse to get chronological order (oldest first)
214            commits.reverse();
215        } else {
216            // Single commit by hash or reference
217            let obj = self
218                .repo
219                .revparse_single(range)
220                .with_context(|| format!("Failed to parse commit: {range}"))?;
221            let commit = obj
222                .peel_to_commit()
223                .context("Failed to peel object to commit")?;
224            commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
225        }
226
227        Ok(commits)
228    }
229}
230
231/// Formats git status flags into a string representation.
232fn format_status_flags(flags: Status) -> String {
233    let mut status = String::new();
234
235    if flags.contains(Status::INDEX_NEW) {
236        status.push('A');
237    } else if flags.contains(Status::INDEX_MODIFIED) {
238        status.push('M');
239    } else if flags.contains(Status::INDEX_DELETED) {
240        status.push('D');
241    } else if flags.contains(Status::INDEX_RENAMED) {
242        status.push('R');
243    } else if flags.contains(Status::INDEX_TYPECHANGE) {
244        status.push('T');
245    } else {
246        status.push(' ');
247    }
248
249    if flags.contains(Status::WT_NEW) {
250        status.push('?');
251    } else if flags.contains(Status::WT_MODIFIED) {
252        status.push('M');
253    } else if flags.contains(Status::WT_DELETED) {
254        status.push('D');
255    } else if flags.contains(Status::WT_TYPECHANGE) {
256        status.push('T');
257    } else if flags.contains(Status::WT_RENAMED) {
258        status.push('R');
259    } else {
260        status.push(' ');
261    }
262
263    status
264}
265
266/// Extracts hostname from a git URL (e.g., "git@github.com:user/repo.git" -> "github.com").
267fn extract_hostname_from_git_url(url: &str) -> Option<String> {
268    if let Some(ssh_url) = url.strip_prefix("git@") {
269        // SSH URL format: git@hostname:path
270        ssh_url.split(':').next().map(str::to_string)
271    } else if let Some(https_url) = url.strip_prefix("https://") {
272        // HTTPS URL format: https://hostname/path
273        https_url.split('/').next().map(str::to_string)
274    } else if let Some(http_url) = url.strip_prefix("http://") {
275        // HTTP URL format: http://hostname/path
276        http_url.split('/').next().map(str::to_string)
277    } else {
278        None
279    }
280}
281
282/// Returns the SSH identity file for a given host from SSH config.
283fn get_ssh_identity_for_host(hostname: &str) -> Option<PathBuf> {
284    let home = std::env::var("HOME").ok()?;
285    let ssh_config_path = PathBuf::from(&home).join(".ssh/config");
286
287    if !ssh_config_path.exists() {
288        debug!("SSH config file not found at: {:?}", ssh_config_path);
289        return None;
290    }
291
292    // Open and parse the SSH config file
293    let file = std::fs::File::open(&ssh_config_path).ok()?;
294    let mut reader = BufReader::new(file);
295
296    let config = SshConfig::default()
297        .parse(&mut reader, ParseRule::ALLOW_UNKNOWN_FIELDS)
298        .ok()?;
299
300    // Query the config for the specific host
301    let params = config.query(hostname);
302
303    // Get the identity file from the config
304    if let Some(identity_files) = &params.identity_file {
305        if let Some(first_identity) = identity_files.first() {
306            // Expand ~ to home directory
307            let identity_str = first_identity.to_string_lossy();
308            let identity_path = identity_str.replace('~', &home);
309            let path = PathBuf::from(identity_path);
310
311            if path.exists() {
312                debug!("Found SSH key for host '{}': {:?}", hostname, path);
313                return Some(path);
314            }
315            debug!("SSH key specified in config but not found: {:?}", path);
316        }
317    }
318
319    None
320}
321
322/// Creates `RemoteCallbacks` with SSH credential resolution for the given hostname.
323///
324/// Tries credentials in order: SSH config identity → SSH agent → default key
325/// locations (`~/.ssh/id_ed25519`, `~/.ssh/id_rsa`). Bails after
326/// [`MAX_AUTH_ATTEMPTS`] to prevent infinite callback loops.
327fn make_auth_callbacks(hostname: String) -> git2::RemoteCallbacks<'static> {
328    let mut callbacks = git2::RemoteCallbacks::new();
329    let mut auth_attempts: u32 = 0;
330
331    callbacks.credentials(move |url, username_from_url, allowed_types| {
332        auth_attempts += 1;
333        debug!(
334            "Credential callback attempt {} - URL: {}, Username: {:?}, Allowed types: {:?}",
335            auth_attempts, url, username_from_url, allowed_types
336        );
337
338        if auth_attempts > MAX_AUTH_ATTEMPTS {
339            error!(
340                "Too many authentication attempts ({}), giving up",
341                auth_attempts
342            );
343            return Err(git2::Error::from_str(
344                "Authentication failed after multiple attempts",
345            ));
346        }
347
348        let username = username_from_url.unwrap_or("git");
349
350        if allowed_types.contains(git2::CredentialType::SSH_KEY) {
351            // Try SSH config identity first — avoids agent returning OK with no valid keys
352            if let Some(ssh_key_path) = get_ssh_identity_for_host(&hostname) {
353                let pub_key_path = ssh_key_path.with_extension("pub");
354                debug!("Trying SSH key from config: {:?}", ssh_key_path);
355
356                match git2::Cred::ssh_key(username, Some(&pub_key_path), &ssh_key_path, None) {
357                    Ok(cred) => {
358                        debug!(
359                            "Successfully loaded SSH key from config: {:?}",
360                            ssh_key_path
361                        );
362                        return Ok(cred);
363                    }
364                    Err(e) => {
365                        debug!("Failed to load SSH key from config: {}", e);
366                    }
367                }
368            }
369
370            // Only try SSH agent on first attempt
371            if auth_attempts == 1 {
372                match git2::Cred::ssh_key_from_agent(username) {
373                    Ok(cred) => {
374                        debug!("SSH agent credentials obtained (attempt {})", auth_attempts);
375                        return Ok(cred);
376                    }
377                    Err(e) => {
378                        debug!("SSH agent failed: {}, trying default keys", e);
379                    }
380                }
381            }
382
383            // Try default SSH key locations as fallback
384            let home = std::env::var("HOME").unwrap_or_else(|_| "~".to_string());
385            let ssh_keys = [
386                format!("{home}/.ssh/id_ed25519"),
387                format!("{home}/.ssh/id_rsa"),
388            ];
389
390            for key_path in &ssh_keys {
391                let key_path = PathBuf::from(key_path);
392                if key_path.exists() {
393                    let pub_key_path = key_path.with_extension("pub");
394                    debug!("Trying default SSH key: {:?}", key_path);
395
396                    match git2::Cred::ssh_key(username, Some(&pub_key_path), &key_path, None) {
397                        Ok(cred) => {
398                            debug!("Successfully loaded SSH key from {:?}", key_path);
399                            return Ok(cred);
400                        }
401                        Err(e) => debug!("Failed to load SSH key from {:?}: {}", key_path, e),
402                    }
403                }
404            }
405        }
406
407        debug!("Falling back to default credentials");
408        git2::Cred::default()
409    });
410
411    callbacks
412}
413
414/// Formats a user-friendly SSH authentication error message with troubleshooting steps.
415fn format_auth_error(operation: &str, error: &git2::Error) -> String {
416    if error.message().contains("authentication") || error.message().contains("SSH") {
417        format!(
418            "Failed to {operation}: {error}. \n\nTroubleshooting steps:\n\
419            1. Check if your SSH key is loaded: ssh-add -l\n\
420            2. Test GitHub SSH connection: ssh -T git@github.com\n\
421            3. Use GitHub CLI auth instead: gh auth setup-git",
422        )
423    } else {
424        format!("Failed to {operation}: {error}")
425    }
426}
427
428impl GitRepository {
429    /// Pushes the current branch to remote.
430    pub fn push_branch(&self, branch_name: &str, remote_name: &str) -> Result<()> {
431        info!(
432            "Pushing branch '{}' to remote '{}'",
433            branch_name, remote_name
434        );
435
436        // Get remote
437        debug!("Finding remote '{}'", remote_name);
438        let mut remote = self
439            .repo
440            .find_remote(remote_name)
441            .context("Failed to find remote")?;
442
443        let remote_url = remote.url().unwrap_or("<unknown>");
444        debug!("Remote URL: {}", remote_url);
445
446        // Set up refspec for push
447        let refspec = format!("refs/heads/{branch_name}:refs/heads/{branch_name}");
448        debug!("Using refspec: {}", refspec);
449
450        // Extract hostname from remote URL for SSH config lookup
451        let hostname =
452            extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
453        debug!(
454            "Extracted hostname '{}' from URL '{}'",
455            hostname, remote_url
456        );
457
458        // Push with authentication callbacks
459        let mut push_options = git2::PushOptions::new();
460        let callbacks = make_auth_callbacks(hostname);
461        push_options.remote_callbacks(callbacks);
462
463        // Perform the push
464        debug!("Attempting to push to remote...");
465        match remote.push(&[&refspec], Some(&mut push_options)) {
466            Ok(()) => {
467                info!(
468                    "Successfully pushed branch '{}' to remote '{}'",
469                    branch_name, remote_name
470                );
471
472                // Set upstream branch after successful push
473                debug!("Setting upstream branch for '{}'", branch_name);
474                match self.repo.find_branch(branch_name, git2::BranchType::Local) {
475                    Ok(mut branch) => {
476                        let remote_ref = format!("{remote_name}/{branch_name}");
477                        match branch.set_upstream(Some(&remote_ref)) {
478                            Ok(()) => {
479                                info!(
480                                    "Successfully set upstream to '{}'/{}",
481                                    remote_name, branch_name
482                                );
483                            }
484                            Err(e) => {
485                                // Log but don't fail - the push succeeded
486                                error!("Failed to set upstream branch: {}", e);
487                            }
488                        }
489                    }
490                    Err(e) => {
491                        // Log but don't fail - the push succeeded
492                        error!("Failed to find local branch to set upstream: {}", e);
493                    }
494                }
495
496                Ok(())
497            }
498            Err(e) => {
499                error!("Failed to push branch: {}", e);
500                Err(anyhow::anyhow!(format_auth_error(
501                    "push branch to remote",
502                    &e
503                )))
504            }
505        }
506    }
507
508    /// Checks if a branch exists on remote.
509    pub fn branch_exists_on_remote(&self, branch_name: &str, remote_name: &str) -> Result<bool> {
510        debug!(
511            "Checking if branch '{}' exists on remote '{}'",
512            branch_name, remote_name
513        );
514
515        let remote = self
516            .repo
517            .find_remote(remote_name)
518            .context("Failed to find remote")?;
519
520        let remote_url = remote.url().unwrap_or("<unknown>");
521        debug!("Remote URL: {}", remote_url);
522
523        // Extract hostname from remote URL for SSH config lookup
524        let hostname =
525            extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
526        debug!(
527            "Extracted hostname '{}' from URL '{}'",
528            hostname, remote_url
529        );
530
531        // Connect to remote to get refs
532        let mut remote = remote;
533        let callbacks = make_auth_callbacks(hostname);
534
535        debug!("Attempting to connect to remote...");
536        match remote.connect_auth(git2::Direction::Fetch, Some(callbacks), None) {
537            Ok(_) => debug!("Successfully connected to remote"),
538            Err(e) => {
539                error!("Failed to connect to remote: {}", e);
540                return Err(anyhow::anyhow!(format_auth_error("connect to remote", &e)));
541            }
542        }
543
544        // Check if the remote branch exists
545        debug!("Listing remote refs...");
546        let refs = remote.list()?;
547        let remote_branch_ref = format!("refs/heads/{branch_name}");
548        debug!("Looking for remote branch ref: {}", remote_branch_ref);
549
550        for remote_head in refs {
551            debug!("Found remote ref: {}", remote_head.name());
552            if remote_head.name() == remote_branch_ref {
553                info!(
554                    "Branch '{}' exists on remote '{}'",
555                    branch_name, remote_name
556                );
557                return Ok(true);
558            }
559        }
560
561        info!(
562            "Branch '{}' does not exist on remote '{}'",
563            branch_name, remote_name
564        );
565        Ok(false)
566    }
567}
568
569#[cfg(test)]
570mod tests {
571    use super::*;
572
573    // ── extract_hostname_from_git_url ──────────────────────────────
574
575    #[test]
576    fn hostname_from_ssh_url() {
577        let hostname = extract_hostname_from_git_url("git@github.com:user/repo.git");
578        assert_eq!(hostname, Some("github.com".to_string()));
579    }
580
581    #[test]
582    fn hostname_from_https_url() {
583        let hostname = extract_hostname_from_git_url("https://github.com/user/repo.git");
584        assert_eq!(hostname, Some("github.com".to_string()));
585    }
586
587    #[test]
588    fn hostname_from_http_url() {
589        let hostname = extract_hostname_from_git_url("http://gitlab.com/user/repo.git");
590        assert_eq!(hostname, Some("gitlab.com".to_string()));
591    }
592
593    #[test]
594    fn hostname_from_unknown_scheme() {
595        let hostname = extract_hostname_from_git_url("ftp://example.com/repo");
596        assert_eq!(hostname, None);
597    }
598
599    #[test]
600    fn hostname_from_ssh_custom_host() {
601        let hostname = extract_hostname_from_git_url("git@gitlab.example.com:org/project.git");
602        assert_eq!(hostname, Some("gitlab.example.com".to_string()));
603    }
604
605    // ── format_status_flags ────────────────────────────────────────
606
607    #[test]
608    fn status_flags_new_index() {
609        let status = format_status_flags(Status::INDEX_NEW);
610        assert_eq!(status, "A ");
611    }
612
613    #[test]
614    fn status_flags_modified_index() {
615        let status = format_status_flags(Status::INDEX_MODIFIED);
616        assert_eq!(status, "M ");
617    }
618
619    #[test]
620    fn status_flags_deleted_index() {
621        let status = format_status_flags(Status::INDEX_DELETED);
622        assert_eq!(status, "D ");
623    }
624
625    #[test]
626    fn status_flags_wt_new() {
627        let status = format_status_flags(Status::WT_NEW);
628        assert_eq!(status, " ?");
629    }
630
631    #[test]
632    fn status_flags_wt_modified() {
633        let status = format_status_flags(Status::WT_MODIFIED);
634        assert_eq!(status, " M");
635    }
636
637    #[test]
638    fn status_flags_combined() {
639        let status = format_status_flags(Status::INDEX_NEW | Status::WT_MODIFIED);
640        assert_eq!(status, "AM");
641    }
642
643    #[test]
644    fn status_flags_empty() {
645        let status = format_status_flags(Status::empty());
646        assert_eq!(status, "  ");
647    }
648
649    // ── format_auth_error ──────────────────────────────────────────
650
651    #[test]
652    fn auth_error_with_ssh_message() {
653        let error = git2::Error::from_str("SSH authentication failed");
654        let msg = format_auth_error("push", &error);
655        assert!(msg.contains("Troubleshooting steps"));
656        assert!(msg.contains("ssh-add -l"));
657    }
658
659    #[test]
660    fn auth_error_without_auth_message() {
661        let error = git2::Error::from_str("network timeout");
662        let msg = format_auth_error("fetch", &error);
663        assert!(msg.contains("Failed to fetch"));
664        assert!(!msg.contains("Troubleshooting"));
665    }
666
667    // ── GitRepository with temp repo ───────────────────────────────
668
669    #[test]
670    fn open_at_temp_repo() -> Result<()> {
671        let temp_dir = {
672            std::fs::create_dir_all("tmp")?;
673            tempfile::tempdir_in("tmp")?
674        };
675        git2::Repository::init(temp_dir.path())?;
676        let repo = GitRepository::open_at(temp_dir.path())?;
677        assert!(repo.path().exists());
678        Ok(())
679    }
680
681    #[test]
682    fn working_directory_clean_empty_repo() -> Result<()> {
683        let temp_dir = {
684            std::fs::create_dir_all("tmp")?;
685            tempfile::tempdir_in("tmp")?
686        };
687        git2::Repository::init(temp_dir.path())?;
688        let repo = GitRepository::open_at(temp_dir.path())?;
689        let status = repo.get_working_directory_status()?;
690        assert!(status.clean);
691        assert!(status.untracked_changes.is_empty());
692        Ok(())
693    }
694
695    #[test]
696    fn working_directory_dirty_with_file() -> Result<()> {
697        let temp_dir = {
698            std::fs::create_dir_all("tmp")?;
699            tempfile::tempdir_in("tmp")?
700        };
701        git2::Repository::init(temp_dir.path())?;
702        std::fs::write(temp_dir.path().join("new_file.txt"), "content")?;
703        let repo = GitRepository::open_at(temp_dir.path())?;
704        let status = repo.get_working_directory_status()?;
705        assert!(!status.clean);
706        assert!(!status.untracked_changes.is_empty());
707        Ok(())
708    }
709
710    #[test]
711    fn is_working_directory_clean_delegator() -> Result<()> {
712        let temp_dir = {
713            std::fs::create_dir_all("tmp")?;
714            tempfile::tempdir_in("tmp")?
715        };
716        git2::Repository::init(temp_dir.path())?;
717        let repo = GitRepository::open_at(temp_dir.path())?;
718        assert!(repo.is_working_directory_clean()?);
719        Ok(())
720    }
721}