use crate::cli::args::PrCommands;
use crate::cli::UI;
use crate::core::{Config, ScanEngine, Severity};
use crate::platform::{self, types::CreatePR};
use anyhow::{bail, Result};
use std::path::PathBuf;
pub async fn execute(action: PrCommands, ui: &UI) -> Result<()> {
match action {
PrCommands::Create {
title,
body,
head,
base,
draft,
no_scan,
fail_on,
label,
} => {
create(
CreatePrParams {
title,
body,
head,
base,
draft,
no_scan,
fail_on,
labels: label,
},
ui,
)
.await
}
PrCommands::List { state } => list(&state, ui).await,
PrCommands::View { number } => view(number, ui).await,
}
}
struct CreatePrParams {
title: Option<String>,
body: Option<String>,
head: Option<String>,
base: Option<String>,
draft: bool,
no_scan: bool,
fail_on: String,
labels: Vec<String>,
}
async fn create(params: CreatePrParams, ui: &UI) -> Result<()> {
let path = PathBuf::from(".");
let remote = platform::detect_remote(&path)?;
let token = platform::resolve_token(&remote.host).ok_or_else(|| {
anyhow::anyhow!(
"Not authenticated for {}. Run: securegit auth login --provider {}",
remote.host,
remote.host.to_string().to_lowercase()
)
})?;
let client = platform::create_client(&remote, token);
let repo = crate::ops::open_repo(&path)?;
let head_branch = params.head.unwrap_or_else(|| {
repo.head()
.ok()
.and_then(|h| h.shorthand().map(|s| s.to_string()))
.unwrap_or_else(|| "HEAD".to_string())
});
let base_branch = params.base.unwrap_or_else(|| "main".to_string());
ui.header("SecureGit Pull Request");
ui.blank();
ui.field("Head", &head_branch);
ui.field("Base", &base_branch);
ui.field(
"Provider",
format!("{} ({}/{})", remote.host, remote.owner, remote.repo),
);
ui.blank();
let mut security_labels: Vec<String> = Vec::new();
if !params.no_scan {
let threshold = Severity::parse_str(¶ms.fail_on).unwrap_or(Severity::High);
let spinner = ui.spinner("Running security scan...");
let config = Config::default();
let engine = ScanEngine::new(config);
let report = engine.scan_directory(&path).await?;
ui.finish_progress(&spinner, "Scan complete");
let critical = report.count_by_severity(Severity::Critical);
let high = report.count_by_severity(Severity::High);
let medium = report.count_by_severity(Severity::Medium);
let low = report.count_by_severity(Severity::Low);
let info_count = report.count_by_severity(Severity::Info);
ui.blank();
ui.section("Scan Results");
ui.severity_row(critical, high, medium, low, info_count);
if report.has_findings_at_or_above(threshold) {
let count = report
.findings
.iter()
.filter(|f| f.severity >= threshold)
.count();
ui.blank();
for finding in &report.findings {
if finding.severity >= threshold {
ui.finding(finding);
}
}
ui.blank();
ui.warning(format!(
"Security scan found {} finding(s) at or above {} severity",
count, threshold
));
let proceed = dialoguer::Confirm::new()
.with_prompt("Create PR anyway with 'security-review-needed' label?")
.default(false)
.interact()?;
if !proceed {
bail!("PR creation aborted due to security findings");
}
security_labels.push("security-review-needed".to_string());
} else {
ui.blank();
ui.success(format!("Security gate passed (threshold: {})", threshold));
}
}
let pr_title = match params.title {
Some(t) => t,
None => dialoguer::Input::<String>::new()
.with_prompt("PR title")
.interact_text()?,
};
let pr_body = params.body.unwrap_or_default();
let spinner = ui.spinner("Ensuring branch is pushed...");
let push_result = ensure_branch_pushed(&path, &head_branch);
if let Err(e) = push_result {
ui.finish_progress(&spinner, "");
ui.warning(format!("Could not auto-push branch: {}", e));
ui.info("Make sure your branch is pushed before creating a PR");
} else {
ui.finish_progress(&spinner, "Branch up to date");
}
let spinner = ui.spinner("Creating pull request...");
let pr = client
.create_pull_request(&CreatePR {
title: pr_title,
body: pr_body,
head: head_branch,
base: base_branch,
draft: params.draft,
})
.await?;
ui.finish_progress(&spinner, "Pull request created");
let mut all_labels = params.labels;
all_labels.extend(security_labels);
if !all_labels.is_empty() {
let _ = client.add_labels(pr.number, &all_labels).await;
}
ui.blank();
ui.success("Pull request created");
ui.blank();
ui.field("Title", &pr.title);
ui.field("PR", format!("#{}", pr.number));
ui.field("URL", &pr.html_url);
if pr.draft {
ui.field("Status", "Draft");
}
ui.blank();
if ui.json {
ui.json_out(&serde_json::json!({
"number": pr.number,
"title": pr.title,
"url": pr.html_url,
"state": pr.state,
"draft": pr.draft,
}));
}
Ok(())
}
async fn list(state: &str, ui: &UI) -> Result<()> {
let path = PathBuf::from(".");
let remote = platform::detect_remote(&path)?;
let token = platform::resolve_token(&remote.host)
.ok_or_else(|| anyhow::anyhow!("Not authenticated. Run: securegit auth login"))?;
let client = platform::create_client(&remote, token);
let pr_label = if remote.host == platform::PlatformHost::GitLab {
"Merge Requests"
} else {
"Pull Requests"
};
ui.header(&format!("SecureGit {}", pr_label));
ui.blank();
ui.field("Repository", format!("{}/{}", remote.owner, remote.repo));
ui.field("State", state);
ui.blank();
let spinner = ui.spinner(&format!("Fetching {}...", pr_label.to_lowercase()));
let prs = client.list_pull_requests(state).await?;
ui.finish_progress(&spinner, "");
if prs.is_empty() {
ui.info(format!("No {} found", pr_label.to_lowercase()));
} else {
for pr in &prs {
let draft_marker = if pr.draft { " [draft]" } else { "" };
ui.list_item(format!(
"#{:<5} {}{} ({} -> {})",
pr.number, pr.title, draft_marker, pr.head_ref, pr.base_ref
));
}
}
if ui.json {
ui.json_out(&serde_json::to_value(&prs).unwrap_or_default());
}
ui.blank();
Ok(())
}
async fn view(number: u64, ui: &UI) -> Result<()> {
let path = PathBuf::from(".");
let remote = platform::detect_remote(&path)?;
let token = platform::resolve_token(&remote.host)
.ok_or_else(|| anyhow::anyhow!("Not authenticated. Run: securegit auth login"))?;
let client = platform::create_client(&remote, token);
let spinner = ui.spinner("Fetching pull request...");
let pr = client.get_pull_request(number).await?;
ui.finish_progress(&spinner, "");
ui.header(&format!("PR #{}: {}", pr.number, pr.title));
ui.blank();
ui.field("State", &pr.state);
ui.field("Author", &pr.user);
ui.field("Head", &pr.head_ref);
ui.field("Base", &pr.base_ref);
ui.field("Draft", pr.draft);
ui.field("URL", &pr.html_url);
ui.blank();
if ui.json {
ui.json_out(&serde_json::to_value(&pr).unwrap_or_default());
}
Ok(())
}
fn ensure_branch_pushed(path: &PathBuf, branch: &str) -> Result<()> {
let repo = crate::ops::open_repo(path)?;
let remote_ref = format!("refs/remotes/origin/{}", branch);
if repo.find_reference(&remote_ref).is_ok() {
return Ok(());
}
let mut remote = repo.find_remote("origin").or_else(|_| {
let remotes = repo.remotes()?;
let name = remotes
.get(0)
.ok_or_else(|| git2::Error::from_str("no remotes"))?;
repo.find_remote(name)
})?;
let callbacks = crate::auth::build_git2_callbacks(None, None, None);
let mut push_opts = git2::PushOptions::new();
push_opts.remote_callbacks(callbacks);
let refspec = format!("refs/heads/{}:refs/heads/{}", branch, branch);
remote.push(&[&refspec], Some(&mut push_opts))?;
Ok(())
}