Skip to main content

omni_dev/git/
repository.rs

1//! Git repository operations.
2
3use std::io::BufReader;
4use std::path::PathBuf;
5
6use anyhow::{Context, Result};
7use git2::{Repository, Status};
8use ssh2_config::{ParseRule, SshConfig};
9use tracing::{debug, error, info};
10
11use crate::git::CommitInfo;
12
13/// Maximum credential callback attempts before giving up.
14const MAX_AUTH_ATTEMPTS: u32 = 3;
15
16/// Git repository wrapper.
17pub struct GitRepository {
18    repo: Repository,
19}
20
21/// Working directory status.
22#[derive(Debug)]
23pub struct WorkingDirectoryStatus {
24    /// Whether the working directory has no changes.
25    pub clean: bool,
26    /// List of files with uncommitted changes.
27    pub untracked_changes: Vec<FileStatus>,
28}
29
30/// File status information.
31#[derive(Debug)]
32pub struct FileStatus {
33    /// Git status flags (e.g., "AM", "??", "M ").
34    pub status: String,
35    /// Path to the file relative to repository root.
36    pub file: String,
37}
38
39impl GitRepository {
40    /// Opens a repository at the current directory.
41    pub fn open() -> Result<Self> {
42        let repo = Repository::open(".").context("Not in a git repository")?;
43
44        Ok(Self { repo })
45    }
46
47    /// Opens a repository at the specified path.
48    pub fn open_at<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
49        let repo = Repository::open(path).context("Failed to open git repository")?;
50
51        Ok(Self { repo })
52    }
53
54    /// Returns the working directory status.
55    pub fn get_working_directory_status(&self) -> Result<WorkingDirectoryStatus> {
56        let statuses = self
57            .repo
58            .statuses(None)
59            .context("Failed to get repository status")?;
60
61        let mut untracked_changes = Vec::new();
62
63        for entry in statuses.iter() {
64            if let Some(path) = entry.path() {
65                let status_flags = entry.status();
66
67                // Skip ignored files - they should not affect clean status
68                if status_flags.contains(Status::IGNORED) {
69                    continue;
70                }
71
72                let status_str = format_status_flags(status_flags);
73
74                untracked_changes.push(FileStatus {
75                    status: status_str,
76                    file: path.to_string(),
77                });
78            }
79        }
80
81        let clean = untracked_changes.is_empty();
82
83        Ok(WorkingDirectoryStatus {
84            clean,
85            untracked_changes,
86        })
87    }
88
89    /// Checks if the working directory is clean.
90    pub fn is_working_directory_clean(&self) -> Result<bool> {
91        let status = self.get_working_directory_status()?;
92        Ok(status.clean)
93    }
94
95    /// Returns the repository path.
96    pub fn path(&self) -> &std::path::Path {
97        self.repo.path()
98    }
99
100    /// Returns the workdir path.
101    pub fn workdir(&self) -> Option<&std::path::Path> {
102        self.repo.workdir()
103    }
104
105    /// Returns access to the underlying `git2::Repository`.
106    pub fn repository(&self) -> &Repository {
107        &self.repo
108    }
109
110    /// Returns the current branch name.
111    pub fn get_current_branch(&self) -> Result<String> {
112        let head = self.repo.head().context("Failed to get HEAD reference")?;
113
114        if let Some(name) = head.shorthand() {
115            if name != "HEAD" {
116                return Ok(name.to_string());
117            }
118        }
119
120        anyhow::bail!("Repository is in detached HEAD state")
121    }
122
123    /// Checks if a branch exists.
124    pub fn branch_exists(&self, branch_name: &str) -> Result<bool> {
125        // Check if it exists as a local branch
126        if self
127            .repo
128            .find_branch(branch_name, git2::BranchType::Local)
129            .is_ok()
130        {
131            return Ok(true);
132        }
133
134        // Check if it exists as a remote branch
135        if self
136            .repo
137            .find_branch(branch_name, git2::BranchType::Remote)
138            .is_ok()
139        {
140            return Ok(true);
141        }
142
143        // Check if we can resolve it as a reference
144        if self.repo.revparse_single(branch_name).is_ok() {
145            return Ok(true);
146        }
147
148        Ok(false)
149    }
150
151    /// Parses a commit range and returns the commits.
152    pub fn get_commits_in_range(&self, range: &str) -> Result<Vec<CommitInfo>> {
153        let mut commits = Vec::new();
154
155        if range == "HEAD" {
156            // Single HEAD commit
157            let head = self.repo.head().context("Failed to get HEAD")?;
158            let commit = head
159                .peel_to_commit()
160                .context("Failed to peel HEAD to commit")?;
161            commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
162        } else if range.contains("..") {
163            // Range format like HEAD~3..HEAD
164            let parts: Vec<&str> = range.split("..").collect();
165            if parts.len() != 2 {
166                anyhow::bail!("Invalid range format: {}", range);
167            }
168
169            let start_spec = parts[0];
170            let end_spec = parts[1];
171
172            // Parse start and end commits
173            let start_obj = self
174                .repo
175                .revparse_single(start_spec)
176                .with_context(|| format!("Failed to parse start commit: {}", start_spec))?;
177            let end_obj = self
178                .repo
179                .revparse_single(end_spec)
180                .with_context(|| format!("Failed to parse end commit: {}", end_spec))?;
181
182            let start_commit = start_obj
183                .peel_to_commit()
184                .context("Failed to peel start object to commit")?;
185            let end_commit = end_obj
186                .peel_to_commit()
187                .context("Failed to peel end object to commit")?;
188
189            // Walk from end_commit back to start_commit (exclusive)
190            let mut walker = self.repo.revwalk().context("Failed to create revwalk")?;
191            walker
192                .push(end_commit.id())
193                .context("Failed to push end commit")?;
194            walker
195                .hide(start_commit.id())
196                .context("Failed to hide start commit")?;
197
198            for oid in walker {
199                let oid = oid.context("Failed to get commit OID from walker")?;
200                let commit = self
201                    .repo
202                    .find_commit(oid)
203                    .context("Failed to find commit")?;
204
205                // Skip merge commits
206                if commit.parent_count() > 1 {
207                    continue;
208                }
209
210                commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
211            }
212
213            // Reverse to get chronological order (oldest first)
214            commits.reverse();
215        } else {
216            // Single commit by hash or reference
217            let obj = self
218                .repo
219                .revparse_single(range)
220                .with_context(|| format!("Failed to parse commit: {}", range))?;
221            let commit = obj
222                .peel_to_commit()
223                .context("Failed to peel object to commit")?;
224            commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
225        }
226
227        Ok(commits)
228    }
229}
230
231/// Formats git status flags into a string representation.
232fn format_status_flags(flags: Status) -> String {
233    let mut status = String::new();
234
235    if flags.contains(Status::INDEX_NEW) {
236        status.push('A');
237    } else if flags.contains(Status::INDEX_MODIFIED) {
238        status.push('M');
239    } else if flags.contains(Status::INDEX_DELETED) {
240        status.push('D');
241    } else if flags.contains(Status::INDEX_RENAMED) {
242        status.push('R');
243    } else if flags.contains(Status::INDEX_TYPECHANGE) {
244        status.push('T');
245    } else {
246        status.push(' ');
247    }
248
249    if flags.contains(Status::WT_NEW) {
250        status.push('?');
251    } else if flags.contains(Status::WT_MODIFIED) {
252        status.push('M');
253    } else if flags.contains(Status::WT_DELETED) {
254        status.push('D');
255    } else if flags.contains(Status::WT_TYPECHANGE) {
256        status.push('T');
257    } else if flags.contains(Status::WT_RENAMED) {
258        status.push('R');
259    } else {
260        status.push(' ');
261    }
262
263    status
264}
265
266/// Extracts hostname from a git URL (e.g., "git@github.com:user/repo.git" -> "github.com").
267fn extract_hostname_from_git_url(url: &str) -> Option<String> {
268    if let Some(ssh_url) = url.strip_prefix("git@") {
269        // SSH URL format: git@hostname:path
270        ssh_url.split(':').next().map(|s| s.to_string())
271    } else if let Some(https_url) = url.strip_prefix("https://") {
272        // HTTPS URL format: https://hostname/path
273        https_url.split('/').next().map(|s| s.to_string())
274    } else if let Some(http_url) = url.strip_prefix("http://") {
275        // HTTP URL format: http://hostname/path
276        http_url.split('/').next().map(|s| s.to_string())
277    } else {
278        None
279    }
280}
281
282/// Returns the SSH identity file for a given host from SSH config.
283fn get_ssh_identity_for_host(hostname: &str) -> Option<PathBuf> {
284    let home = std::env::var("HOME").ok()?;
285    let ssh_config_path = PathBuf::from(&home).join(".ssh/config");
286
287    if !ssh_config_path.exists() {
288        debug!("SSH config file not found at: {:?}", ssh_config_path);
289        return None;
290    }
291
292    // Open and parse the SSH config file
293    let file = std::fs::File::open(&ssh_config_path).ok()?;
294    let mut reader = BufReader::new(file);
295
296    let config = SshConfig::default()
297        .parse(&mut reader, ParseRule::ALLOW_UNKNOWN_FIELDS)
298        .ok()?;
299
300    // Query the config for the specific host
301    let params = config.query(hostname);
302
303    // Get the identity file from the config
304    if let Some(identity_files) = &params.identity_file {
305        if let Some(first_identity) = identity_files.first() {
306            // Expand ~ to home directory
307            let identity_str = first_identity.to_string_lossy();
308            let identity_path = identity_str.replace("~", &home);
309            let path = PathBuf::from(identity_path);
310
311            if path.exists() {
312                debug!("Found SSH key for host '{}': {:?}", hostname, path);
313                return Some(path);
314            } else {
315                debug!("SSH key specified in config but not found: {:?}", path);
316            }
317        }
318    }
319
320    None
321}
322
323/// Creates `RemoteCallbacks` with SSH credential resolution for the given hostname.
324///
325/// Tries credentials in order: SSH config identity → SSH agent → default key
326/// locations (`~/.ssh/id_ed25519`, `~/.ssh/id_rsa`). Bails after
327/// [`MAX_AUTH_ATTEMPTS`] to prevent infinite callback loops.
328fn make_auth_callbacks(hostname: String) -> git2::RemoteCallbacks<'static> {
329    let mut callbacks = git2::RemoteCallbacks::new();
330    let mut auth_attempts: u32 = 0;
331
332    callbacks.credentials(move |url, username_from_url, allowed_types| {
333        auth_attempts += 1;
334        debug!(
335            "Credential callback attempt {} - URL: {}, Username: {:?}, Allowed types: {:?}",
336            auth_attempts, url, username_from_url, allowed_types
337        );
338
339        if auth_attempts > MAX_AUTH_ATTEMPTS {
340            error!(
341                "Too many authentication attempts ({}), giving up",
342                auth_attempts
343            );
344            return Err(git2::Error::from_str(
345                "Authentication failed after multiple attempts",
346            ));
347        }
348
349        let username = username_from_url.unwrap_or("git");
350
351        if allowed_types.contains(git2::CredentialType::SSH_KEY) {
352            // Try SSH config identity first — avoids agent returning OK with no valid keys
353            if let Some(ssh_key_path) = get_ssh_identity_for_host(&hostname) {
354                let pub_key_path = ssh_key_path.with_extension("pub");
355                debug!("Trying SSH key from config: {:?}", ssh_key_path);
356
357                match git2::Cred::ssh_key(username, Some(&pub_key_path), &ssh_key_path, None) {
358                    Ok(cred) => {
359                        debug!(
360                            "Successfully loaded SSH key from config: {:?}",
361                            ssh_key_path
362                        );
363                        return Ok(cred);
364                    }
365                    Err(e) => {
366                        debug!("Failed to load SSH key from config: {}", e);
367                    }
368                }
369            }
370
371            // Only try SSH agent on first attempt
372            if auth_attempts == 1 {
373                match git2::Cred::ssh_key_from_agent(username) {
374                    Ok(cred) => {
375                        debug!("SSH agent credentials obtained (attempt {})", auth_attempts);
376                        return Ok(cred);
377                    }
378                    Err(e) => {
379                        debug!("SSH agent failed: {}, trying default keys", e);
380                    }
381                }
382            }
383
384            // Try default SSH key locations as fallback
385            let home = std::env::var("HOME").unwrap_or_else(|_| "~".to_string());
386            let ssh_keys = [
387                format!("{}/.ssh/id_ed25519", home),
388                format!("{}/.ssh/id_rsa", home),
389            ];
390
391            for key_path in &ssh_keys {
392                let key_path = PathBuf::from(key_path);
393                if key_path.exists() {
394                    let pub_key_path = key_path.with_extension("pub");
395                    debug!("Trying default SSH key: {:?}", key_path);
396
397                    match git2::Cred::ssh_key(username, Some(&pub_key_path), &key_path, None) {
398                        Ok(cred) => {
399                            debug!("Successfully loaded SSH key from {:?}", key_path);
400                            return Ok(cred);
401                        }
402                        Err(e) => debug!("Failed to load SSH key from {:?}: {}", key_path, e),
403                    }
404                }
405            }
406        }
407
408        debug!("Falling back to default credentials");
409        git2::Cred::default()
410    });
411
412    callbacks
413}
414
415/// Formats a user-friendly SSH authentication error message with troubleshooting steps.
416fn format_auth_error(operation: &str, error: &git2::Error) -> String {
417    if error.message().contains("authentication") || error.message().contains("SSH") {
418        format!(
419            "Failed to {operation}: {error}. \n\nTroubleshooting steps:\n\
420            1. Check if your SSH key is loaded: ssh-add -l\n\
421            2. Test GitHub SSH connection: ssh -T git@github.com\n\
422            3. Use GitHub CLI auth instead: gh auth setup-git",
423        )
424    } else {
425        format!("Failed to {operation}: {error}")
426    }
427}
428
429impl GitRepository {
430    /// Pushes the current branch to remote.
431    pub fn push_branch(&self, branch_name: &str, remote_name: &str) -> Result<()> {
432        info!(
433            "Pushing branch '{}' to remote '{}'",
434            branch_name, remote_name
435        );
436
437        // Get remote
438        debug!("Finding remote '{}'", remote_name);
439        let mut remote = self
440            .repo
441            .find_remote(remote_name)
442            .context("Failed to find remote")?;
443
444        let remote_url = remote.url().unwrap_or("<unknown>");
445        debug!("Remote URL: {}", remote_url);
446
447        // Set up refspec for push
448        let refspec = format!("refs/heads/{}:refs/heads/{}", branch_name, branch_name);
449        debug!("Using refspec: {}", refspec);
450
451        // Extract hostname from remote URL for SSH config lookup
452        let hostname =
453            extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
454        debug!(
455            "Extracted hostname '{}' from URL '{}'",
456            hostname, remote_url
457        );
458
459        // Push with authentication callbacks
460        let mut push_options = git2::PushOptions::new();
461        let callbacks = make_auth_callbacks(hostname);
462        push_options.remote_callbacks(callbacks);
463
464        // Perform the push
465        debug!("Attempting to push to remote...");
466        match remote.push(&[&refspec], Some(&mut push_options)) {
467            Ok(_) => {
468                info!(
469                    "Successfully pushed branch '{}' to remote '{}'",
470                    branch_name, remote_name
471                );
472
473                // Set upstream branch after successful push
474                debug!("Setting upstream branch for '{}'", branch_name);
475                match self.repo.find_branch(branch_name, git2::BranchType::Local) {
476                    Ok(mut branch) => {
477                        let remote_ref = format!("{}/{}", remote_name, branch_name);
478                        match branch.set_upstream(Some(&remote_ref)) {
479                            Ok(_) => {
480                                info!(
481                                    "Successfully set upstream to '{}'/{}",
482                                    remote_name, branch_name
483                                );
484                            }
485                            Err(e) => {
486                                // Log but don't fail - the push succeeded
487                                error!("Failed to set upstream branch: {}", e);
488                            }
489                        }
490                    }
491                    Err(e) => {
492                        // Log but don't fail - the push succeeded
493                        error!("Failed to find local branch to set upstream: {}", e);
494                    }
495                }
496
497                Ok(())
498            }
499            Err(e) => {
500                error!("Failed to push branch: {}", e);
501                Err(anyhow::anyhow!(format_auth_error(
502                    "push branch to remote",
503                    &e
504                )))
505            }
506        }
507    }
508
509    /// Checks if a branch exists on remote.
510    pub fn branch_exists_on_remote(&self, branch_name: &str, remote_name: &str) -> Result<bool> {
511        debug!(
512            "Checking if branch '{}' exists on remote '{}'",
513            branch_name, remote_name
514        );
515
516        let remote = self
517            .repo
518            .find_remote(remote_name)
519            .context("Failed to find remote")?;
520
521        let remote_url = remote.url().unwrap_or("<unknown>");
522        debug!("Remote URL: {}", remote_url);
523
524        // Extract hostname from remote URL for SSH config lookup
525        let hostname =
526            extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
527        debug!(
528            "Extracted hostname '{}' from URL '{}'",
529            hostname, remote_url
530        );
531
532        // Connect to remote to get refs
533        let mut remote = remote;
534        let callbacks = make_auth_callbacks(hostname);
535
536        debug!("Attempting to connect to remote...");
537        match remote.connect_auth(git2::Direction::Fetch, Some(callbacks), None) {
538            Ok(_) => debug!("Successfully connected to remote"),
539            Err(e) => {
540                error!("Failed to connect to remote: {}", e);
541                return Err(anyhow::anyhow!(format_auth_error("connect to remote", &e)));
542            }
543        }
544
545        // Check if the remote branch exists
546        debug!("Listing remote refs...");
547        let refs = remote.list()?;
548        let remote_branch_ref = format!("refs/heads/{}", branch_name);
549        debug!("Looking for remote branch ref: {}", remote_branch_ref);
550
551        for remote_head in refs {
552            debug!("Found remote ref: {}", remote_head.name());
553            if remote_head.name() == remote_branch_ref {
554                info!(
555                    "Branch '{}' exists on remote '{}'",
556                    branch_name, remote_name
557                );
558                return Ok(true);
559            }
560        }
561
562        info!(
563            "Branch '{}' does not exist on remote '{}'",
564            branch_name, remote_name
565        );
566        Ok(false)
567    }
568}