1use std::io::BufReader;
4use std::path::PathBuf;
5
6use anyhow::{Context, Result};
7use git2::{Repository, Status};
8use ssh2_config::{ParseRule, SshConfig};
9use tracing::{debug, error, info};
10
11use crate::git::CommitInfo;
12
13const MAX_AUTH_ATTEMPTS: u32 = 3;
15
16pub struct GitRepository {
18 repo: Repository,
19}
20
21#[derive(Debug)]
23pub struct WorkingDirectoryStatus {
24 pub clean: bool,
26 pub untracked_changes: Vec<FileStatus>,
28}
29
30#[derive(Debug)]
32pub struct FileStatus {
33 pub status: String,
35 pub file: String,
37}
38
39impl GitRepository {
40 pub fn open() -> Result<Self> {
42 let repo = Repository::open(".").context("Not in a git repository")?;
43
44 Ok(Self { repo })
45 }
46
47 pub fn open_at<P: AsRef<std::path::Path>>(path: P) -> Result<Self> {
49 let repo = Repository::open(path).context("Failed to open git repository")?;
50
51 Ok(Self { repo })
52 }
53
54 pub fn get_working_directory_status(&self) -> Result<WorkingDirectoryStatus> {
56 let statuses = self
57 .repo
58 .statuses(None)
59 .context("Failed to get repository status")?;
60
61 let mut untracked_changes = Vec::new();
62
63 for entry in statuses.iter() {
64 if let Some(path) = entry.path() {
65 let status_flags = entry.status();
66
67 if status_flags.contains(Status::IGNORED) {
69 continue;
70 }
71
72 let status_str = format_status_flags(status_flags);
73
74 untracked_changes.push(FileStatus {
75 status: status_str,
76 file: path.to_string(),
77 });
78 }
79 }
80
81 let clean = untracked_changes.is_empty();
82
83 Ok(WorkingDirectoryStatus {
84 clean,
85 untracked_changes,
86 })
87 }
88
89 pub fn is_working_directory_clean(&self) -> Result<bool> {
91 let status = self.get_working_directory_status()?;
92 Ok(status.clean)
93 }
94
95 pub fn path(&self) -> &std::path::Path {
97 self.repo.path()
98 }
99
100 pub fn workdir(&self) -> Option<&std::path::Path> {
102 self.repo.workdir()
103 }
104
105 pub fn repository(&self) -> &Repository {
107 &self.repo
108 }
109
110 pub fn get_current_branch(&self) -> Result<String> {
112 let head = self.repo.head().context("Failed to get HEAD reference")?;
113
114 if let Some(name) = head.shorthand() {
115 if name != "HEAD" {
116 return Ok(name.to_string());
117 }
118 }
119
120 anyhow::bail!("Repository is in detached HEAD state")
121 }
122
123 pub fn branch_exists(&self, branch_name: &str) -> Result<bool> {
125 if self
127 .repo
128 .find_branch(branch_name, git2::BranchType::Local)
129 .is_ok()
130 {
131 return Ok(true);
132 }
133
134 if self
136 .repo
137 .find_branch(branch_name, git2::BranchType::Remote)
138 .is_ok()
139 {
140 return Ok(true);
141 }
142
143 if self.repo.revparse_single(branch_name).is_ok() {
145 return Ok(true);
146 }
147
148 Ok(false)
149 }
150
151 pub fn get_commits_in_range(&self, range: &str) -> Result<Vec<CommitInfo>> {
153 let mut commits = Vec::new();
154
155 if range == "HEAD" {
156 let head = self.repo.head().context("Failed to get HEAD")?;
158 let commit = head
159 .peel_to_commit()
160 .context("Failed to peel HEAD to commit")?;
161 commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
162 } else if range.contains("..") {
163 let parts: Vec<&str> = range.split("..").collect();
165 if parts.len() != 2 {
166 anyhow::bail!("Invalid range format: {}", range);
167 }
168
169 let start_spec = parts[0];
170 let end_spec = parts[1];
171
172 let start_obj = self
174 .repo
175 .revparse_single(start_spec)
176 .with_context(|| format!("Failed to parse start commit: {}", start_spec))?;
177 let end_obj = self
178 .repo
179 .revparse_single(end_spec)
180 .with_context(|| format!("Failed to parse end commit: {}", end_spec))?;
181
182 let start_commit = start_obj
183 .peel_to_commit()
184 .context("Failed to peel start object to commit")?;
185 let end_commit = end_obj
186 .peel_to_commit()
187 .context("Failed to peel end object to commit")?;
188
189 let mut walker = self.repo.revwalk().context("Failed to create revwalk")?;
191 walker
192 .push(end_commit.id())
193 .context("Failed to push end commit")?;
194 walker
195 .hide(start_commit.id())
196 .context("Failed to hide start commit")?;
197
198 for oid in walker {
199 let oid = oid.context("Failed to get commit OID from walker")?;
200 let commit = self
201 .repo
202 .find_commit(oid)
203 .context("Failed to find commit")?;
204
205 if commit.parent_count() > 1 {
207 continue;
208 }
209
210 commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
211 }
212
213 commits.reverse();
215 } else {
216 let obj = self
218 .repo
219 .revparse_single(range)
220 .with_context(|| format!("Failed to parse commit: {}", range))?;
221 let commit = obj
222 .peel_to_commit()
223 .context("Failed to peel object to commit")?;
224 commits.push(CommitInfo::from_git_commit(&self.repo, &commit)?);
225 }
226
227 Ok(commits)
228 }
229}
230
231fn format_status_flags(flags: Status) -> String {
233 let mut status = String::new();
234
235 if flags.contains(Status::INDEX_NEW) {
236 status.push('A');
237 } else if flags.contains(Status::INDEX_MODIFIED) {
238 status.push('M');
239 } else if flags.contains(Status::INDEX_DELETED) {
240 status.push('D');
241 } else if flags.contains(Status::INDEX_RENAMED) {
242 status.push('R');
243 } else if flags.contains(Status::INDEX_TYPECHANGE) {
244 status.push('T');
245 } else {
246 status.push(' ');
247 }
248
249 if flags.contains(Status::WT_NEW) {
250 status.push('?');
251 } else if flags.contains(Status::WT_MODIFIED) {
252 status.push('M');
253 } else if flags.contains(Status::WT_DELETED) {
254 status.push('D');
255 } else if flags.contains(Status::WT_TYPECHANGE) {
256 status.push('T');
257 } else if flags.contains(Status::WT_RENAMED) {
258 status.push('R');
259 } else {
260 status.push(' ');
261 }
262
263 status
264}
265
266fn extract_hostname_from_git_url(url: &str) -> Option<String> {
268 if let Some(ssh_url) = url.strip_prefix("git@") {
269 ssh_url.split(':').next().map(|s| s.to_string())
271 } else if let Some(https_url) = url.strip_prefix("https://") {
272 https_url.split('/').next().map(|s| s.to_string())
274 } else if let Some(http_url) = url.strip_prefix("http://") {
275 http_url.split('/').next().map(|s| s.to_string())
277 } else {
278 None
279 }
280}
281
282fn get_ssh_identity_for_host(hostname: &str) -> Option<PathBuf> {
284 let home = std::env::var("HOME").ok()?;
285 let ssh_config_path = PathBuf::from(&home).join(".ssh/config");
286
287 if !ssh_config_path.exists() {
288 debug!("SSH config file not found at: {:?}", ssh_config_path);
289 return None;
290 }
291
292 let file = std::fs::File::open(&ssh_config_path).ok()?;
294 let mut reader = BufReader::new(file);
295
296 let config = SshConfig::default()
297 .parse(&mut reader, ParseRule::ALLOW_UNKNOWN_FIELDS)
298 .ok()?;
299
300 let params = config.query(hostname);
302
303 if let Some(identity_files) = ¶ms.identity_file {
305 if let Some(first_identity) = identity_files.first() {
306 let identity_str = first_identity.to_string_lossy();
308 let identity_path = identity_str.replace("~", &home);
309 let path = PathBuf::from(identity_path);
310
311 if path.exists() {
312 debug!("Found SSH key for host '{}': {:?}", hostname, path);
313 return Some(path);
314 } else {
315 debug!("SSH key specified in config but not found: {:?}", path);
316 }
317 }
318 }
319
320 None
321}
322
323fn make_auth_callbacks(hostname: String) -> git2::RemoteCallbacks<'static> {
329 let mut callbacks = git2::RemoteCallbacks::new();
330 let mut auth_attempts: u32 = 0;
331
332 callbacks.credentials(move |url, username_from_url, allowed_types| {
333 auth_attempts += 1;
334 debug!(
335 "Credential callback attempt {} - URL: {}, Username: {:?}, Allowed types: {:?}",
336 auth_attempts, url, username_from_url, allowed_types
337 );
338
339 if auth_attempts > MAX_AUTH_ATTEMPTS {
340 error!(
341 "Too many authentication attempts ({}), giving up",
342 auth_attempts
343 );
344 return Err(git2::Error::from_str(
345 "Authentication failed after multiple attempts",
346 ));
347 }
348
349 let username = username_from_url.unwrap_or("git");
350
351 if allowed_types.contains(git2::CredentialType::SSH_KEY) {
352 if let Some(ssh_key_path) = get_ssh_identity_for_host(&hostname) {
354 let pub_key_path = ssh_key_path.with_extension("pub");
355 debug!("Trying SSH key from config: {:?}", ssh_key_path);
356
357 match git2::Cred::ssh_key(username, Some(&pub_key_path), &ssh_key_path, None) {
358 Ok(cred) => {
359 debug!(
360 "Successfully loaded SSH key from config: {:?}",
361 ssh_key_path
362 );
363 return Ok(cred);
364 }
365 Err(e) => {
366 debug!("Failed to load SSH key from config: {}", e);
367 }
368 }
369 }
370
371 if auth_attempts == 1 {
373 match git2::Cred::ssh_key_from_agent(username) {
374 Ok(cred) => {
375 debug!("SSH agent credentials obtained (attempt {})", auth_attempts);
376 return Ok(cred);
377 }
378 Err(e) => {
379 debug!("SSH agent failed: {}, trying default keys", e);
380 }
381 }
382 }
383
384 let home = std::env::var("HOME").unwrap_or_else(|_| "~".to_string());
386 let ssh_keys = [
387 format!("{}/.ssh/id_ed25519", home),
388 format!("{}/.ssh/id_rsa", home),
389 ];
390
391 for key_path in &ssh_keys {
392 let key_path = PathBuf::from(key_path);
393 if key_path.exists() {
394 let pub_key_path = key_path.with_extension("pub");
395 debug!("Trying default SSH key: {:?}", key_path);
396
397 match git2::Cred::ssh_key(username, Some(&pub_key_path), &key_path, None) {
398 Ok(cred) => {
399 debug!("Successfully loaded SSH key from {:?}", key_path);
400 return Ok(cred);
401 }
402 Err(e) => debug!("Failed to load SSH key from {:?}: {}", key_path, e),
403 }
404 }
405 }
406 }
407
408 debug!("Falling back to default credentials");
409 git2::Cred::default()
410 });
411
412 callbacks
413}
414
415fn format_auth_error(operation: &str, error: &git2::Error) -> String {
417 if error.message().contains("authentication") || error.message().contains("SSH") {
418 format!(
419 "Failed to {operation}: {error}. \n\nTroubleshooting steps:\n\
420 1. Check if your SSH key is loaded: ssh-add -l\n\
421 2. Test GitHub SSH connection: ssh -T git@github.com\n\
422 3. Use GitHub CLI auth instead: gh auth setup-git",
423 )
424 } else {
425 format!("Failed to {operation}: {error}")
426 }
427}
428
429impl GitRepository {
430 pub fn push_branch(&self, branch_name: &str, remote_name: &str) -> Result<()> {
432 info!(
433 "Pushing branch '{}' to remote '{}'",
434 branch_name, remote_name
435 );
436
437 debug!("Finding remote '{}'", remote_name);
439 let mut remote = self
440 .repo
441 .find_remote(remote_name)
442 .context("Failed to find remote")?;
443
444 let remote_url = remote.url().unwrap_or("<unknown>");
445 debug!("Remote URL: {}", remote_url);
446
447 let refspec = format!("refs/heads/{}:refs/heads/{}", branch_name, branch_name);
449 debug!("Using refspec: {}", refspec);
450
451 let hostname =
453 extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
454 debug!(
455 "Extracted hostname '{}' from URL '{}'",
456 hostname, remote_url
457 );
458
459 let mut push_options = git2::PushOptions::new();
461 let callbacks = make_auth_callbacks(hostname);
462 push_options.remote_callbacks(callbacks);
463
464 debug!("Attempting to push to remote...");
466 match remote.push(&[&refspec], Some(&mut push_options)) {
467 Ok(_) => {
468 info!(
469 "Successfully pushed branch '{}' to remote '{}'",
470 branch_name, remote_name
471 );
472
473 debug!("Setting upstream branch for '{}'", branch_name);
475 match self.repo.find_branch(branch_name, git2::BranchType::Local) {
476 Ok(mut branch) => {
477 let remote_ref = format!("{}/{}", remote_name, branch_name);
478 match branch.set_upstream(Some(&remote_ref)) {
479 Ok(_) => {
480 info!(
481 "Successfully set upstream to '{}'/{}",
482 remote_name, branch_name
483 );
484 }
485 Err(e) => {
486 error!("Failed to set upstream branch: {}", e);
488 }
489 }
490 }
491 Err(e) => {
492 error!("Failed to find local branch to set upstream: {}", e);
494 }
495 }
496
497 Ok(())
498 }
499 Err(e) => {
500 error!("Failed to push branch: {}", e);
501 Err(anyhow::anyhow!(format_auth_error(
502 "push branch to remote",
503 &e
504 )))
505 }
506 }
507 }
508
509 pub fn branch_exists_on_remote(&self, branch_name: &str, remote_name: &str) -> Result<bool> {
511 debug!(
512 "Checking if branch '{}' exists on remote '{}'",
513 branch_name, remote_name
514 );
515
516 let remote = self
517 .repo
518 .find_remote(remote_name)
519 .context("Failed to find remote")?;
520
521 let remote_url = remote.url().unwrap_or("<unknown>");
522 debug!("Remote URL: {}", remote_url);
523
524 let hostname =
526 extract_hostname_from_git_url(remote_url).unwrap_or("github.com".to_string());
527 debug!(
528 "Extracted hostname '{}' from URL '{}'",
529 hostname, remote_url
530 );
531
532 let mut remote = remote;
534 let callbacks = make_auth_callbacks(hostname);
535
536 debug!("Attempting to connect to remote...");
537 match remote.connect_auth(git2::Direction::Fetch, Some(callbacks), None) {
538 Ok(_) => debug!("Successfully connected to remote"),
539 Err(e) => {
540 error!("Failed to connect to remote: {}", e);
541 return Err(anyhow::anyhow!(format_auth_error("connect to remote", &e)));
542 }
543 }
544
545 debug!("Listing remote refs...");
547 let refs = remote.list()?;
548 let remote_branch_ref = format!("refs/heads/{}", branch_name);
549 debug!("Looking for remote branch ref: {}", remote_branch_ref);
550
551 for remote_head in refs {
552 debug!("Found remote ref: {}", remote_head.name());
553 if remote_head.name() == remote_branch_ref {
554 info!(
555 "Branch '{}' exists on remote '{}'",
556 branch_name, remote_name
557 );
558 return Ok(true);
559 }
560 }
561
562 info!(
563 "Branch '{}' does not exist on remote '{}'",
564 branch_name, remote_name
565 );
566 Ok(false)
567 }
568}