xgit 0.2.6

A enhanced AI-powered Git tool
use crate::{github::client::GitHubClient, tui::branch_display::PullRequestState};
use anyhow::{Context, Error};
use serde::Deserialize;
use std::env;
use std::path::{Path, PathBuf};
use std::process::Command;

#[derive(Debug, Clone)]
pub struct PullRequestRecord {
    pub number: u64,
    pub state: PullRequestState,
    pub url: String,
    pub base_ref: String,
    pub head_ref: String,
    pub head_sha: String,
    pub merged: bool,
}

impl PullRequestRecord {
    pub fn is_closed_or_merged(&self) -> bool {
        matches!(self.state, PullRequestState::Closed) || self.merged
    }
}

enum Backend {
    GhCli,
    Api(GitHubClient),
}

pub struct GitHubPrService {
    backend: Backend,
    repo_slug: String,
    repo_path: PathBuf,
}

impl GitHubPrService {
    pub fn new(repo_path: &Path, owner: String, repo: String) -> Result<Self, Error> {
        let backend = match env::var("XGIT_GITHUB_BACKEND").ok().as_deref() {
            Some("api") => Backend::Api(GitHubClient::new(owner.clone(), repo.clone())?),
            Some("gh") => Backend::GhCli,
            _ => Backend::GhCli,
        };

        Ok(Self {
            backend,
            repo_slug: format!("{owner}/{repo}"),
            repo_path: repo_path.to_path_buf(),
        })
    }

    pub fn ensure_ready(&self) -> Result<(), Error> {
        match self.backend {
            Backend::GhCli => {
                let version = Command::new("gh")
                    .arg("--version")
                    .current_dir(&self.repo_path)
                    .output()
                    .context("Failed to execute gh --version. Please install GitHub CLI (`gh`)")?;
                if !version.status.success() {
                    return Err(anyhow::anyhow!(
                        "GitHub CLI (`gh`) is required for xgit GitHub operations"
                    ));
                }

                Ok(())
            }
            Backend::Api(_) => Ok(()),
        }
    }

    pub async fn get_default_branch(&self) -> Result<String, Error> {
        match &self.backend {
            Backend::GhCli => {
                let output = gh_output(
                    &self.repo_path,
                    &[
                        "api",
                        &format!("repos/{}", self.repo_slug),
                        "--jq",
                        ".default_branch",
                    ],
                )?;
                Ok(output.trim().to_string())
            }
            Backend::Api(client) => client.get_default_branch().await,
        }
    }

    pub async fn get_pr(&self, pr_number: u64) -> Result<PullRequestRecord, Error> {
        match &self.backend {
            Backend::GhCli => gh_pr_view(&self.repo_path, &self.repo_slug, pr_number),
            Backend::Api(client) => {
                let pr = client.get_pr_by_number(pr_number).await?;
                Ok(PullRequestRecord {
                    number: pr.number,
                    state: pr.state,
                    url: pr.url,
                    base_ref: pr.base_ref,
                    head_ref: pr.head_ref,
                    head_sha: pr.head_sha,
                    merged: pr.merged,
                })
            }
        }
    }

    pub async fn create_pr(
        &self,
        title: &str,
        body: Option<&str>,
        head: &str,
        base: &str,
        draft: bool,
    ) -> Result<PullRequestRecord, Error> {
        match &self.backend {
            Backend::GhCli => gh_pr_create(
                &self.repo_path,
                &self.repo_slug,
                title,
                body,
                head,
                base,
                draft,
            ),
            Backend::Api(client) => {
                let pr = client.create_pr(title, body, head, base, draft).await?;
                Ok(PullRequestRecord {
                    number: pr.number,
                    state: pr.state,
                    url: pr.url,
                    base_ref: pr.base_ref,
                    head_ref: pr.head_ref,
                    head_sha: pr.head_sha,
                    merged: pr.merged,
                })
            }
        }
    }

    pub async fn update_pr(
        &self,
        pr_number: u64,
        base: Option<&str>,
        title: Option<&str>,
        body: Option<&str>,
    ) -> Result<PullRequestRecord, Error> {
        match &self.backend {
            Backend::GhCli => {
                gh_pr_edit(
                    &self.repo_path,
                    &self.repo_slug,
                    pr_number,
                    base,
                    title,
                    body,
                )?;
                gh_pr_view(&self.repo_path, &self.repo_slug, pr_number)
            }
            Backend::Api(client) => {
                let pr = client.update_pr(pr_number, base, title, body).await?;
                Ok(PullRequestRecord {
                    number: pr.number,
                    state: pr.state,
                    url: pr.url,
                    base_ref: pr.base_ref,
                    head_ref: pr.head_ref,
                    head_sha: pr.head_sha,
                    merged: pr.merged,
                })
            }
        }
    }

    pub async fn find_pr_by_head(&self, head_branch: &str) -> Result<PullRequestRecord, Error> {
        match &self.backend {
            Backend::GhCli => gh_pr_find_by_head(&self.repo_path, &self.repo_slug, head_branch),
            Backend::Api(client) => {
                let found = client
                    .find_pr_by_head_branch(head_branch)
                    .await?
                    .ok_or_else(|| {
                        anyhow::anyhow!("No PR found for head branch '{head_branch}'")
                    })?;
                self.get_pr(found.number).await
            }
        }
    }
}

#[derive(Debug, Deserialize)]
struct GhPrViewResponse {
    number: u64,
    state: String,
    url: String,
    #[serde(rename = "baseRefName")]
    base_ref_name: String,
    #[serde(rename = "headRefName")]
    head_ref_name: String,
    #[serde(rename = "headRefOid")]
    head_ref_oid: String,
    #[serde(rename = "mergedAt")]
    merged_at: Option<String>,
}

fn gh_pr_view(
    repo_path: &Path,
    repo_slug: &str,
    pr_number: u64,
) -> Result<PullRequestRecord, Error> {
    let output = gh_output(
        repo_path,
        &[
            "pr",
            "view",
            &pr_number.to_string(),
            "--repo",
            repo_slug,
            "--json",
            "number,state,url,baseRefName,headRefName,headRefOid,mergedAt",
        ],
    )?;
    let parsed: GhPrViewResponse =
        serde_json::from_str(&output).context("Failed to parse `gh pr view` JSON output")?;
    Ok(PullRequestRecord {
        number: parsed.number,
        state: gh_state_to_pull_request_state(&parsed.state),
        url: parsed.url,
        base_ref: parsed.base_ref_name,
        head_ref: parsed.head_ref_name,
        head_sha: parsed.head_ref_oid,
        merged: parsed.merged_at.is_some(),
    })
}

fn gh_pr_create(
    repo_path: &Path,
    repo_slug: &str,
    title: &str,
    body: Option<&str>,
    head: &str,
    base: &str,
    draft: bool,
) -> Result<PullRequestRecord, Error> {
    let mut args = vec![
        "pr".to_string(),
        "create".to_string(),
        "--repo".to_string(),
        repo_slug.to_string(),
        "--title".to_string(),
        title.to_string(),
        "--head".to_string(),
        head.to_string(),
        "--base".to_string(),
        base.to_string(),
    ];

    if let Some(body) = body {
        args.push("--body".to_string());
        args.push(body.to_string());
    } else {
        args.push("--body".to_string());
        args.push("".to_string());
    }

    if draft {
        args.push("--draft".to_string());
    }

    let arg_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
    gh_output(repo_path, &arg_refs).context("`gh pr create` failed")?;
    gh_pr_find_by_head(repo_path, repo_slug, head)
        .context("PR was created but could not be resolved by head branch")
}

fn gh_pr_edit(
    repo_path: &Path,
    repo_slug: &str,
    pr_number: u64,
    base: Option<&str>,
    title: Option<&str>,
    body: Option<&str>,
) -> Result<(), Error> {
    let mut args = vec![
        "pr".to_string(),
        "edit".to_string(),
        pr_number.to_string(),
        "--repo".to_string(),
        repo_slug.to_string(),
    ];

    if let Some(base) = base {
        args.push("--base".to_string());
        args.push(base.to_string());
    }
    if let Some(title) = title {
        args.push("--title".to_string());
        args.push(title.to_string());
    }
    if let Some(body) = body {
        args.push("--body".to_string());
        args.push(body.to_string());
    }

    let arg_refs: Vec<&str> = args.iter().map(|s| s.as_str()).collect();
    gh_output(repo_path, &arg_refs).context("`gh pr edit` failed")?;
    Ok(())
}

fn gh_pr_find_by_head(
    repo_path: &Path,
    repo_slug: &str,
    head_branch: &str,
) -> Result<PullRequestRecord, Error> {
    let output = gh_output(
        repo_path,
        &[
            "pr",
            "list",
            "--repo",
            repo_slug,
            "--head",
            head_branch,
            "--state",
            "all",
            "--limit",
            "1",
            "--json",
            "number,state,url,baseRefName,headRefName,headRefOid,mergedAt",
        ],
    )?;

    let parsed: Vec<GhPrViewResponse> =
        serde_json::from_str(&output).context("Failed to parse `gh pr list` JSON output")?;
    let first = parsed
        .into_iter()
        .next()
        .ok_or_else(|| anyhow::anyhow!("No PR found for head branch '{}'", head_branch))?;

    Ok(PullRequestRecord {
        number: first.number,
        state: gh_state_to_pull_request_state(&first.state),
        url: first.url,
        base_ref: first.base_ref_name,
        head_ref: first.head_ref_name,
        head_sha: first.head_ref_oid,
        merged: first.merged_at.is_some(),
    })
}

fn gh_state_to_pull_request_state(state: &str) -> PullRequestState {
    if state.eq_ignore_ascii_case("closed") || state.eq_ignore_ascii_case("merged") {
        PullRequestState::Closed
    } else {
        PullRequestState::Open
    }
}

fn gh_output(repo_path: &Path, args: &[&str]) -> Result<String, Error> {
    let output = Command::new("gh")
        .args(args)
        .current_dir(repo_path)
        .output()
        .context("Failed to execute gh command")?;

    if !output.status.success() {
        let stderr = String::from_utf8_lossy(&output.stderr);
        return Err(anyhow::anyhow!(
            "gh {:?} failed (code {:?}): {}",
            args,
            output.status.code(),
            stderr.trim()
        ));
    }

    String::from_utf8(output.stdout).context("Invalid UTF-8 gh output")
}