use std::time::Duration;
use std::{fmt, process::Command};
use anyhow::{Context, Result, anyhow, bail};
use crate::git;
use crate::settings;
pub(super) const CHECK_GRACE_POLLS: u32 = 6;
pub(super) fn check_poll_interval() -> Duration {
Duration::from_secs(5)
}
mod demo;
mod github;
mod gitlab;
mod json;
use demo::DemoProvider;
use github::GitHubProvider;
use gitlab::GitLabProvider;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum ProviderKind {
GitHub,
GitLab,
Demo,
}
impl ProviderKind {
fn parse(value: &str) -> Option<Self> {
match value.to_ascii_lowercase().as_str() {
"github" | "gh" => Some(Self::GitHub),
"gitlab" | "glab" => Some(Self::GitLab),
"demo" => Some(Self::Demo),
_ => None,
}
}
}
impl fmt::Display for ProviderKind {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::GitHub => write!(formatter, "github"),
Self::GitLab => write!(formatter, "gitlab"),
Self::Demo => write!(formatter, "demo"),
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub struct DetectedProvider {
pub kind: ProviderKind,
pub source: ProviderSource,
}
#[derive(Debug, Eq, PartialEq)]
pub enum ProviderSource {
Config,
Remote { remote: String, url: String },
}
impl fmt::Display for ProviderSource {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Config => write!(formatter, "config"),
Self::Remote { remote, url } => write!(formatter, "remote {remote} ({url})"),
}
}
}
#[derive(Debug, Eq, PartialEq)]
pub enum ReviewState {
Open,
Merged,
Closed,
Unknown(String),
}
#[derive(Debug, Eq, PartialEq)]
pub struct ReviewRequest {
pub id: String,
pub branch: String,
pub base: String,
pub state: ReviewState,
pub url: String,
pub title: String,
pub draft: bool,
}
pub trait ReviewProvider {
fn review_for_branch(&self, branch: &str) -> Result<Option<ReviewRequest>>;
fn review_for_branch_including_closed(&self, branch: &str) -> Result<Option<ReviewRequest>>;
fn create_review(&self, branch: &str, base: &str, draft: bool) -> Result<String>;
fn update_review_base(&self, review: &ReviewRequest, base: &str) -> Result<String>;
fn review_body(&self, review: &ReviewRequest) -> Result<String>;
fn update_review_body(&self, review: &ReviewRequest, body: &str) -> Result<String>;
fn merge_review(&self, review: &ReviewRequest, strategy: &str, auto: bool) -> Result<String>;
fn wait_for_checks(&self, review: &ReviewRequest) -> Result<bool>;
fn open_reviews(&self) -> Result<Vec<ReviewRequest>>;
fn mark_ready(&self, review: &ReviewRequest) -> Result<String>;
fn close_review(&self, review: &ReviewRequest, delete_branch: bool) -> Result<String>;
fn open_review(&self, review: &ReviewRequest) -> Result<String>;
}
pub fn detect_provider() -> Result<DetectedProvider> {
if let Some(value) = git::config_get(settings::PROVIDER_KEY)? {
let Some(kind) = ProviderKind::parse(&value) else {
bail!("unsupported stk.provider value {value:?}; expected github, gitlab, or demo");
};
return Ok(DetectedProvider {
kind,
source: ProviderSource::Config,
});
}
let remote = settings::remote()?;
let Some(url) = git::remote_url(&remote)? else {
bail!("could not detect provider: remote {remote:?} does not exist");
};
let gitlab_host = settings::gitlab_host()?;
let Some(kind) = detect_provider_from_url(&url, gitlab_host.as_deref()) else {
bail!("could not detect provider from remote {remote} ({url})");
};
Ok(DetectedProvider {
kind,
source: ProviderSource::Remote { remote, url },
})
}
fn detect_provider_from_url(url: &str, gitlab_host: Option<&str>) -> Option<ProviderKind> {
let normalized = url.to_ascii_lowercase();
let host = host_of(&normalized);
let is = |domain: &str| host == domain || host.ends_with(&format!(".{domain}"));
let gitlab_self_hosted = || {
gitlab_host.is_some_and(|configured| {
let configured = configured.to_ascii_lowercase();
is(host_of(&configured))
})
};
if is("github.com") {
Some(ProviderKind::GitHub)
} else if is("gitlab.com") || gitlab_self_hosted() {
Some(ProviderKind::GitLab)
} else {
None
}
}
fn host_of(url: &str) -> &str {
let after_scheme = url.split_once("://").map_or(url, |(_, rest)| rest);
let authority = after_scheme.split('/').next().unwrap_or(after_scheme);
let host_port = authority
.rsplit_once('@')
.map_or(authority, |(_, rest)| rest);
if let Some(after_bracket) = host_port.strip_prefix('[') {
return after_bracket
.split_once(']')
.map_or(host_port, |(addr, _)| addr);
}
host_port.split(':').next().unwrap_or(host_port)
}
pub(crate) fn review_provider(kind: ProviderKind) -> Box<dyn ReviewProvider> {
match kind {
ProviderKind::GitHub => Box::new(GitHubProvider),
ProviderKind::GitLab => Box::new(GitLabProvider),
ProviderKind::Demo => Box::new(DemoProvider),
}
}
fn command_output(program: &str, args: &[&str]) -> Result<String> {
let output = Command::new(program)
.args(args)
.output()
.with_context(|| format!("failed to run {program}"))?;
if output.status.success() {
Ok(String::from_utf8_lossy(&output.stdout).trim().to_owned())
} else {
let stderr = String::from_utf8_lossy(&output.stderr).trim().to_owned();
if stderr.is_empty() {
Err(anyhow!("{program} exited with status {}", output.status))
} else {
Err(anyhow!("{program} failed: {stderr}"))
}
}
}
const MERGE_ATTEMPTS: u32 = 3;
const MERGE_RETRY_BACKOFF: Duration = Duration::from_millis(1500);
fn is_transient_merge_error(error: &anyhow::Error) -> bool {
let text = error.to_string().to_lowercase();
[
"base branch was modified",
"head branch was modified",
"try the merge again",
]
.iter()
.any(|signature| text.contains(signature))
}
fn merge_with_retry(attempt: impl FnMut() -> Result<String>) -> Result<String> {
retry_transient_merge(MERGE_ATTEMPTS, MERGE_RETRY_BACKOFF, attempt)
}
fn retry_transient_merge(
attempts: u32,
backoff: Duration,
mut attempt: impl FnMut() -> Result<String>,
) -> Result<String> {
for remaining in (0..attempts).rev() {
match attempt() {
Ok(output) => return Ok(output),
Err(error) if remaining > 0 && is_transient_merge_error(&error) => {
std::thread::sleep(backoff);
}
Err(error) => return Err(error),
}
}
Err(anyhow!("merge retried with no attempts left"))
}
impl fmt::Display for ReviewState {
fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Open => write!(formatter, "open"),
Self::Merged => write!(formatter, "merged"),
Self::Closed => write!(formatter, "closed"),
Self::Unknown(state) => write!(formatter, "{state}"),
}
}
}
impl ReviewRequest {
pub(crate) fn id_value(&self) -> &str {
self.id
.strip_prefix('#')
.or_else(|| self.id.strip_prefix('!'))
.unwrap_or(&self.id)
}
pub fn label(&self) -> String {
label(&self.title, &self.id)
}
}
pub(crate) fn label(title: &str, id: &str) -> String {
if title.is_empty() {
id.to_owned()
} else {
format!("{title} ({id})")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn transient_error_is_retried_then_succeeds() {
let mut calls = 0;
let result = retry_transient_merge(3, Duration::ZERO, || {
calls += 1;
if calls < 2 {
Err(anyhow!(
"gh failed: GraphQL: Base branch was modified. Review and try the merge again."
))
} else {
Ok("merged".to_owned())
}
});
assert_eq!(result.unwrap(), "merged");
assert_eq!(calls, 2, "should retry once then succeed");
}
#[test]
fn a_persistent_transient_error_gives_up_after_the_attempt_budget() {
let mut calls = 0;
let result = retry_transient_merge(3, Duration::ZERO, || {
calls += 1;
Err(anyhow!("gh failed: Base branch was modified"))
});
assert!(result.is_err());
assert_eq!(calls, 3, "should try exactly the budgeted number of times");
}
#[test]
fn a_real_failure_is_not_retried() {
let mut calls = 0;
let result = retry_transient_merge(3, Duration::ZERO, || {
calls += 1;
Err(anyhow!(
"gh failed: Pull request is not mergeable: conflicts"
))
});
assert!(result.is_err());
assert_eq!(calls, 1, "a non-transient error must surface immediately");
}
#[test]
fn host_of_extracts_the_host_across_url_shapes() {
assert_eq!(host_of("https://github.com/owner/repo.git"), "github.com");
assert_eq!(host_of("git@github.com:owner/repo.git"), "github.com");
assert_eq!(
host_of("ssh://git@gitlab.example.com:22/g/r"),
"gitlab.example.com"
);
assert_eq!(host_of("https://user@github.com/owner/repo"), "github.com");
assert_eq!(host_of("https://github.com:8443/owner/repo"), "github.com");
assert_eq!(
host_of("https://[2001:db8::1]:443/owner/repo"),
"2001:db8::1"
);
assert_eq!(host_of("gitlab.example.com"), "gitlab.example.com");
assert_eq!(host_of("https://user@name@github.com/r"), "github.com");
}
#[test]
fn self_hosted_gitlab_accepts_a_bare_host_or_a_full_url() {
let remote = "git@gitlab.example.com:team/repo.git";
for configured in ["gitlab.example.com", "https://gitlab.example.com"] {
assert_eq!(
detect_provider_from_url(remote, Some(configured)),
Some(ProviderKind::GitLab),
"configured {configured:?} should detect the self-hosted host"
);
}
assert_eq!(
detect_provider_from_url("git@notgitlab.com:o/r", Some("gitlab.example.com")),
None
);
}
}