mod cache;
mod github;
mod gitlab;
mod platform;
use anstyle::{AnsiColor, Color, Style};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::DeserializeOwned};
use worktrunk::git::Repository;
use worktrunk::shell_exec::Cmd;
use worktrunk::utils::epoch_now;
#[derive(Debug, Clone)]
pub struct CiBranchName {
pub full_name: String,
pub remote: Option<String>,
pub name: String,
}
impl CiBranchName {
pub fn from_branch_ref(branch: &str, is_remote: bool) -> Self {
if is_remote {
if let Some((remote, name)) = branch.split_once('/') {
return Self {
full_name: branch.to_string(),
remote: Some(remote.to_string()),
name: name.to_string(),
};
}
}
Self {
full_name: branch.to_string(),
remote: None,
name: branch.to_string(),
}
}
pub fn is_remote(&self) -> bool {
self.remote.is_some()
}
pub fn has_upstream(&self, repo: &Repository) -> bool {
self.is_remote() || repo.branch(&self.name).upstream().ok().flatten().is_some()
}
}
pub(crate) use cache::CachedCiStatus;
pub use platform::{CiPlatform, platform_for_repo};
const MAX_PRS_TO_FETCH: u8 = 20;
fn non_interactive_cmd(program: &str) -> Cmd {
Cmd::new(program)
.env_remove("CLICOLOR_FORCE")
.env_remove("GH_FORCE_TTY")
.env("NO_COLOR", "1")
.env("CLICOLOR", "0")
.env("GH_PROMPT_DISABLED", "1")
}
fn tool_available(tool: &str, args: &[&str]) -> bool {
Cmd::new(tool)
.args(args.iter().copied())
.run()
.map(|o| o.status.success())
.unwrap_or(false)
}
fn parse_json<T: DeserializeOwned>(stdout: &[u8], command: &str, branch: &str) -> Option<T> {
serde_json::from_slice(stdout)
.map_err(|e| log::warn!("Failed to parse {} JSON for {}: {}", command, branch, e))
.ok()
}
fn is_retriable_error(stderr: &str) -> bool {
let lower = stderr.to_ascii_lowercase();
[
"rate limit",
"api rate",
"403",
"429",
"timeout",
"connection",
"network",
]
.iter()
.any(|p| lower.contains(p))
}
#[derive(Debug, Clone, Copy)]
pub struct CiToolsStatus {
pub gh_installed: bool,
pub gh_authenticated: bool,
pub glab_installed: bool,
pub glab_authenticated: bool,
}
impl CiToolsStatus {
pub fn detect(gitlab_host: Option<&str>) -> Self {
let gh_installed = tool_available("gh", &["--version"]);
let gh_authenticated = gh_installed && tool_available("gh", &["auth", "status"]);
let glab_installed = tool_available("glab", &["--version"]);
let glab_authenticated = glab_installed
&& if let Some(host) = gitlab_host {
tool_available("glab", &["auth", "status", "--hostname", host])
} else {
tool_available("glab", &["auth", "status"])
};
Self {
gh_installed,
gh_authenticated,
glab_installed,
glab_authenticated,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, strum::IntoStaticStr)]
#[serde(rename_all = "kebab-case")]
#[strum(serialize_all = "kebab-case")]
pub enum CiStatus {
Passed,
Running,
Failed,
Conflicts,
NoCI,
Error,
}
#[derive(
Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, strum::IntoStaticStr, JsonSchema,
)]
#[strum(serialize_all = "kebab-case")]
pub enum CiSource {
#[serde(rename = "pr", alias = "pull-request")]
PullRequest,
#[serde(rename = "branch")]
Branch,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PrStatus {
pub ci_status: CiStatus,
pub source: CiSource,
pub is_stale: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub url: Option<String>,
}
impl CiStatus {
pub fn color(&self) -> AnsiColor {
match self {
Self::Passed => AnsiColor::Green,
Self::Running => AnsiColor::Blue,
Self::Failed => AnsiColor::Red,
Self::Conflicts | Self::Error => AnsiColor::Yellow,
Self::NoCI => AnsiColor::BrightBlack,
}
}
}
impl PrStatus {
pub fn style(&self) -> Style {
let style = Style::new().fg_color(Some(Color::Ansi(self.ci_status.color())));
if self.is_stale { style.dimmed() } else { style }
}
pub fn indicator(&self) -> &'static str {
if matches!(self.ci_status, CiStatus::Error) {
"⚠"
} else {
"●"
}
}
pub fn format_indicator(&self, include_link: bool) -> String {
let indicator = self.indicator();
if let (true, Some(url)) = (include_link, &self.url) {
let style = self.style().underline();
format!(
"{}{}{}{}{}",
style,
osc8::Hyperlink::new(url),
indicator,
osc8::Hyperlink::END,
style.render_reset()
)
} else {
let style = self.style();
format!("{style}{indicator}{style:#}")
}
}
fn error() -> Self {
Self {
ci_status: CiStatus::Error,
source: CiSource::Branch,
is_stale: false,
url: None,
}
}
pub fn detect(repo: &Repository, branch: &CiBranchName, local_head: &str) -> Option<Self> {
let has_upstream = branch.has_upstream(repo);
let repo_path = repo.current_worktree().root().ok()?;
let now_secs = epoch_now();
if let Some(cached) = CachedCiStatus::read(repo, &branch.full_name) {
if cached.is_valid(local_head, now_secs, &repo_path) {
log::debug!(
"Using cached CI status for {} (age={}s, ttl={}s, status={:?})",
branch.full_name,
now_secs - cached.checked_at,
CachedCiStatus::ttl_for_repo(&repo_path),
cached.status.as_ref().map(|s| &s.ci_status)
);
return cached.status;
}
log::debug!(
"Cache expired for {} (age={}s, ttl={}s, head_match={})",
branch.full_name,
now_secs - cached.checked_at,
CachedCiStatus::ttl_for_repo(&repo_path),
cached.head == local_head
);
}
let status = Self::detect_uncached(repo, branch, local_head, has_upstream);
let cached = CachedCiStatus {
status: status.clone(),
checked_at: now_secs,
head: local_head.to_string(),
branch: branch.full_name.clone(),
};
cached.write(repo, &branch.full_name);
status
}
fn detect_uncached(
repo: &Repository,
branch: &CiBranchName,
local_head: &str,
has_upstream: bool,
) -> Option<Self> {
let platform = platform_for_repo(repo, branch.remote.as_deref());
match platform {
Some(p) => p.detect_ci(repo, branch, local_head, has_upstream),
None => {
log::debug!(
"Could not detect CI platform from remote URL; \
set forge.platform in .config/wt.toml for CI status"
);
None
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_retriable_error() {
assert!(is_retriable_error("API rate limit exceeded"));
assert!(is_retriable_error("rate limit exceeded for requests"));
assert!(is_retriable_error("Error 403: forbidden"));
assert!(is_retriable_error("HTTP 429 Too Many Requests"));
assert!(is_retriable_error("connection timed out"));
assert!(is_retriable_error("network error"));
assert!(is_retriable_error("timeout waiting for response"));
assert!(is_retriable_error("RATE LIMIT"));
assert!(is_retriable_error("Connection Reset"));
assert!(!is_retriable_error("branch not found"));
assert!(!is_retriable_error("invalid credentials"));
assert!(!is_retriable_error("permission denied"));
assert!(!is_retriable_error(""));
}
#[test]
fn test_ci_status_color() {
use anstyle::AnsiColor;
assert_eq!(CiStatus::Passed.color(), AnsiColor::Green);
assert_eq!(CiStatus::Running.color(), AnsiColor::Blue);
assert_eq!(CiStatus::Failed.color(), AnsiColor::Red);
assert_eq!(CiStatus::Conflicts.color(), AnsiColor::Yellow);
assert_eq!(CiStatus::Error.color(), AnsiColor::Yellow);
assert_eq!(CiStatus::NoCI.color(), AnsiColor::BrightBlack);
}
#[test]
fn test_pr_status_indicator() {
let pr_passed = PrStatus {
ci_status: CiStatus::Passed,
source: CiSource::PullRequest,
is_stale: false,
url: None,
};
assert_eq!(pr_passed.indicator(), "●");
let branch_running = PrStatus {
ci_status: CiStatus::Running,
source: CiSource::Branch,
is_stale: false,
url: None,
};
assert_eq!(branch_running.indicator(), "●");
let error_status = PrStatus {
ci_status: CiStatus::Error,
source: CiSource::PullRequest,
is_stale: false,
url: None,
};
assert_eq!(error_status.indicator(), "⚠");
}
#[test]
fn test_format_indicator() {
use insta::assert_snapshot;
let with_url = PrStatus {
ci_status: CiStatus::Passed,
source: CiSource::PullRequest,
is_stale: false,
url: Some("https://github.com/owner/repo/pull/123".to_string()),
};
let no_url = PrStatus {
ci_status: CiStatus::Passed,
source: CiSource::PullRequest,
is_stale: false,
url: None,
};
assert_snapshot!(with_url.format_indicator(true), @r"[4m[32m]8;;https://github.com/owner/repo/pull/123\●]8;;\[0m");
assert_snapshot!(with_url.format_indicator(false), @"[32m●[0m");
assert_snapshot!(no_url.format_indicator(true), @"[32m●[0m");
}
#[test]
fn test_pr_status_error_constructor() {
let error = PrStatus::error();
assert_eq!(error.ci_status, CiStatus::Error);
assert_eq!(error.source, CiSource::Branch);
assert!(!error.is_stale);
assert!(error.url.is_none());
}
#[test]
fn test_pr_status_style() {
let stale = PrStatus {
ci_status: CiStatus::Running,
source: CiSource::Branch,
is_stale: true,
url: None,
};
let style = stale.style();
let _ = format!("{style}test{style:#}");
}
}