1use std::time::Duration;
2use std::{fmt, process::Command};
3
4use anyhow::{Context, Result, anyhow, bail};
5
6use crate::git;
7use crate::settings;
8
9pub(super) const CHECK_GRACE_POLLS: u32 = 6;
14
15pub(super) fn check_poll_interval() -> Duration {
17 Duration::from_secs(5)
18}
19
20pub(super) fn checks_timed_out(review: &ReviewRequest, timeout: Duration) -> anyhow::Error {
24 anyhow!(
25 "{}'s checks have not settled within {}; rerun `git stk merge` once they pass, \
26 or raise stk.checkTimeout",
27 review.id,
28 humanize(timeout),
29 )
30}
31
32fn humanize(duration: Duration) -> String {
34 let seconds = duration.as_secs();
35 if seconds >= 60 && seconds.is_multiple_of(60) {
36 format!("{}m", seconds / 60)
37 } else {
38 format!("{seconds}s")
39 }
40}
41
42mod demo;
43mod github;
44mod gitlab;
45mod json;
46
47use demo::DemoProvider;
48use github::GitHubProvider;
49use gitlab::GitLabProvider;
50
51#[derive(Debug, Clone, Copy, Eq, PartialEq)]
52pub enum ProviderKind {
53 GitHub,
54 GitLab,
55 Demo,
58}
59
60impl ProviderKind {
61 fn parse(value: &str) -> Option<Self> {
62 match value.to_ascii_lowercase().as_str() {
63 "github" | "gh" => Some(Self::GitHub),
64 "gitlab" | "glab" => Some(Self::GitLab),
65 "demo" => Some(Self::Demo),
66 _ => None,
67 }
68 }
69}
70
71impl fmt::Display for ProviderKind {
72 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
73 match self {
74 Self::GitHub => write!(formatter, "github"),
75 Self::GitLab => write!(formatter, "gitlab"),
76 Self::Demo => write!(formatter, "demo"),
77 }
78 }
79}
80
81#[derive(Debug, Eq, PartialEq)]
82pub struct DetectedProvider {
83 pub kind: ProviderKind,
84 pub source: ProviderSource,
85}
86
87#[derive(Debug, Eq, PartialEq)]
88pub enum ProviderSource {
89 Config,
90 Remote { remote: String, url: String },
91}
92
93impl fmt::Display for ProviderSource {
94 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
95 match self {
96 Self::Config => write!(formatter, "config"),
97 Self::Remote { remote, url } => {
98 write!(formatter, "remote {remote} ({})", redact_url(url))
99 }
100 }
101 }
102}
103
104#[derive(Debug, Eq, PartialEq)]
105pub enum ReviewState {
106 Open,
107 Merged,
108 Closed,
109 Unknown(String),
110}
111
112#[derive(Debug, Clone, Copy, Eq, PartialEq)]
117pub enum MergeBlocker {
118 ChecksPending,
120 Conflicts,
122 None,
124}
125
126#[derive(Debug, Eq, PartialEq)]
127pub struct ReviewRequest {
128 pub id: String,
129 pub branch: String,
130 pub base: String,
131 pub state: ReviewState,
132 pub url: String,
133 pub title: String,
134 pub draft: bool,
135}
136
137pub enum WaitOutcome {
139 Passed,
141 Failed,
143 Landed,
146}
147
148pub trait ReviewProvider {
149 fn review_for_branch(&self, branch: &str) -> Result<Option<ReviewRequest>>;
150
151 fn review_for_branch_including_closed(&self, branch: &str) -> Result<Option<ReviewRequest>>;
156
157 fn create_review(&self, branch: &str, base: &str, draft: bool) -> Result<String>;
159
160 fn update_review_base(&self, review: &ReviewRequest, base: &str) -> Result<String>;
161
162 fn review_body(&self, review: &ReviewRequest) -> Result<String>;
163
164 fn update_review_body(&self, review: &ReviewRequest, body: &str) -> Result<String>;
165
166 fn merge_review(&self, review: &ReviewRequest, strategy: &str, auto: bool) -> Result<String>;
170
171 fn merge_blocker(&self, review: &ReviewRequest) -> Result<MergeBlocker>;
175
176 fn wait_for_checks(&self, review: &ReviewRequest) -> Result<WaitOutcome>;
180
181 fn open_reviews(&self) -> Result<Vec<ReviewRequest>>;
184
185 fn mark_ready(&self, review: &ReviewRequest) -> Result<String>;
187
188 fn close_review(&self, review: &ReviewRequest, delete_branch: bool) -> Result<String>;
191
192 fn open_review(&self, review: &ReviewRequest) -> Result<String>;
194}
195
196pub fn detect_review_provider() -> Result<(DetectedProvider, Box<dyn ReviewProvider>)> {
200 let provider = detect_provider()?;
201 let client = review_provider(provider.kind);
202 Ok((provider, client))
203}
204
205pub fn owned_review_for_branch(
209 provider: &dyn ReviewProvider,
210 branch: &str,
211) -> Result<Option<ReviewRequest>> {
212 Ok(provider
213 .review_for_branch(branch)?
214 .filter(|review| review.branch == branch))
215}
216
217pub(super) fn review_merged_out_of_band(
221 provider: &dyn ReviewProvider,
222 review: &ReviewRequest,
223) -> Result<bool> {
224 Ok(matches!(
225 provider.review_for_branch(&review.branch)?,
226 Some(current) if current.state == ReviewState::Merged
227 ))
228}
229
230pub fn detect_provider() -> Result<DetectedProvider> {
231 if let Some(value) = git::config_get(settings::PROVIDER_KEY)? {
232 let Some(kind) = ProviderKind::parse(&value) else {
233 bail!("unsupported stk.provider value {value:?}; expected github, gitlab, or demo");
234 };
235
236 return Ok(DetectedProvider {
237 kind,
238 source: ProviderSource::Config,
239 });
240 }
241
242 let remote = settings::remote()?;
243 let Some(url) = git::remote_url(&remote)? else {
244 bail!("could not detect provider: remote {remote:?} does not exist");
245 };
246
247 let gitlab_host = settings::gitlab_host()?;
248 let Some(kind) = detect_provider_from_url(&url, gitlab_host.as_deref()) else {
249 bail!(
250 "could not detect provider from remote {remote} ({})",
251 redact_url(&url)
252 );
253 };
254
255 Ok(DetectedProvider {
256 kind,
257 source: ProviderSource::Remote { remote, url },
258 })
259}
260
261fn detect_provider_from_url(url: &str, gitlab_host: Option<&str>) -> Option<ProviderKind> {
264 let normalized = url.to_ascii_lowercase();
265 let host = host_of(&normalized);
266 let is = |domain: &str| host == domain || host.ends_with(&format!(".{domain}"));
269
270 let gitlab_self_hosted = || {
273 gitlab_host.is_some_and(|configured| {
274 let configured = configured.to_ascii_lowercase();
275 is(host_of(&configured))
276 })
277 };
278
279 if is("github.com") {
280 Some(ProviderKind::GitHub)
281 } else if is("gitlab.com") || gitlab_self_hosted() {
282 Some(ProviderKind::GitLab)
283 } else {
284 None
285 }
286}
287
288fn host_of(url: &str) -> &str {
293 let after_scheme = url.split_once("://").map_or(url, |(_, rest)| rest);
294 let authority = after_scheme.split('/').next().unwrap_or(after_scheme);
298 let host_port = authority
299 .rsplit_once('@')
300 .map_or(authority, |(_, rest)| rest);
301 if let Some(after_bracket) = host_port.strip_prefix('[') {
303 return after_bracket
304 .split_once(']')
305 .map_or(host_port, |(addr, _)| addr);
306 }
307 host_port.split(':').next().unwrap_or(host_port)
309}
310
311fn redact_url(url: &str) -> String {
315 let Some((scheme, rest)) = url.split_once("://") else {
316 return url.to_owned();
317 };
318 let (authority, path) = match rest.split_once('/') {
319 Some((authority, path)) => (authority, Some(path)),
320 None => (rest, None),
321 };
322 let Some((_, host)) = authority.rsplit_once('@') else {
325 return url.to_owned();
326 };
327 match path {
328 Some(path) => format!("{scheme}://{host}/{path}"),
329 None => format!("{scheme}://{host}"),
330 }
331}
332
333pub(crate) fn review_provider(kind: ProviderKind) -> Box<dyn ReviewProvider> {
334 match kind {
335 ProviderKind::GitHub => Box::new(GitHubProvider),
336 ProviderKind::GitLab => Box::new(GitLabProvider),
337 ProviderKind::Demo => Box::new(DemoProvider),
338 }
339}
340
341fn provider_cli(program: &str) -> Option<(&'static str, &'static str, &'static str)> {
344 match program {
345 "gh" => Some(("GitHub CLI", "https://cli.github.com", "gh auth login")),
346 "glab" => Some((
347 "GitLab CLI",
348 "https://gitlab.com/gitlab-org/cli",
349 "glab auth login",
350 )),
351 _ => None,
352 }
353}
354
355fn looks_unauthenticated(stderr: &str) -> bool {
358 let stderr = stderr.to_ascii_lowercase();
359 [
360 "auth login",
361 "not logged",
362 "401",
363 "unauthorized",
364 "authentication required",
365 ]
366 .iter()
367 .any(|needle| stderr.contains(needle))
368}
369
370fn command_output(program: &str, args: &[&str]) -> Result<String> {
371 let output = match Command::new(program).args(args).output() {
372 Ok(output) => output,
373 Err(error) if error.kind() == std::io::ErrorKind::NotFound => {
376 if let Some((name, url, auth)) = provider_cli(program) {
377 bail!("{program} ({name}) is not installed - get it from {url}, then run `{auth}`");
378 }
379 return Err(error).with_context(|| format!("failed to run {program}"));
380 }
381 Err(error) => return Err(error).with_context(|| format!("failed to run {program}")),
382 };
383
384 if output.status.success() {
385 return Ok(String::from_utf8_lossy(&output.stdout).trim().to_owned());
386 }
387
388 let stderr = String::from_utf8_lossy(&output.stderr).trim().to_owned();
389 if let Some((_, _, auth)) = provider_cli(program)
392 && looks_unauthenticated(&stderr)
393 {
394 bail!("{program} failed: {stderr}\n(if you are not signed in, run `{auth}`)");
395 }
396 if stderr.is_empty() {
397 Err(anyhow!("{program} exited with status {}", output.status))
398 } else {
399 Err(anyhow!("{program} failed: {stderr}"))
400 }
401}
402
403const MERGE_ATTEMPTS: u32 = 3;
407const MERGE_RETRY_BACKOFF: Duration = Duration::from_millis(1500);
408
409fn is_transient_merge_error(error: &anyhow::Error) -> bool {
416 let text = error.to_string().to_lowercase();
417 [
418 "base branch was modified",
419 "head branch was modified",
420 "try the merge again",
421 "method not allowed",
422 "bad gateway",
425 "service unavailable",
426 "gateway time",
427 "internal server error",
428 ]
429 .iter()
430 .any(|signature| text.contains(signature))
431}
432
433fn merge_with_retry(attempt: impl FnMut() -> Result<String>) -> Result<String> {
436 retry_transient_merge(MERGE_ATTEMPTS, MERGE_RETRY_BACKOFF, attempt)
437}
438
439fn retry_transient_merge(
440 attempts: u32,
441 backoff: Duration,
442 mut attempt: impl FnMut() -> Result<String>,
443) -> Result<String> {
444 for remaining in (0..attempts).rev() {
445 match attempt() {
446 Ok(output) => return Ok(output),
447 Err(error) if remaining > 0 && is_transient_merge_error(&error) => {
448 std::thread::sleep(backoff);
449 }
450 Err(error) => return Err(error),
451 }
452 }
453 Err(anyhow!("merge retried with no attempts left"))
455}
456
457impl fmt::Display for ReviewState {
458 fn fmt(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
459 match self {
460 Self::Open => write!(formatter, "open"),
461 Self::Merged => write!(formatter, "merged"),
462 Self::Closed => write!(formatter, "closed"),
463 Self::Unknown(state) => write!(formatter, "{state}"),
464 }
465 }
466}
467
468impl ReviewRequest {
469 pub(crate) fn id_value(&self) -> &str {
470 self.id
471 .strip_prefix('#')
472 .or_else(|| self.id.strip_prefix('!'))
473 .unwrap_or(&self.id)
474 }
475
476 pub fn label(&self) -> String {
478 label(&self.title, &self.id)
479 }
480}
481
482pub(crate) fn label(title: &str, id: &str) -> String {
484 if title.is_empty() {
485 id.to_owned()
486 } else {
487 format!("{title} ({id})")
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn provider_cli_maps_only_the_provider_clis() {
497 assert!(provider_cli("gh").is_some());
498 assert!(provider_cli("glab").is_some());
499 assert!(provider_cli("git").is_none());
500 }
501
502 #[test]
503 fn looks_unauthenticated_matches_signin_failures_only() {
504 assert!(looks_unauthenticated(
505 "error: not logged into any GitHub hosts"
506 ));
507 assert!(looks_unauthenticated(
508 "To get started, please run: gh auth login"
509 ));
510 assert!(looks_unauthenticated("GET ...: 401 Unauthorized"));
511 assert!(!looks_unauthenticated("pull request not found"));
513 assert!(!looks_unauthenticated("merge conflict in src/lib.rs"));
514 }
515
516 #[test]
517 fn transient_error_is_retried_then_succeeds() {
518 let mut calls = 0;
519 let result = retry_transient_merge(3, Duration::ZERO, || {
520 calls += 1;
521 if calls < 2 {
522 Err(anyhow!(
523 "gh failed: GraphQL: Base branch was modified. Review and try the merge again."
524 ))
525 } else {
526 Ok("merged".to_owned())
527 }
528 });
529 assert_eq!(result.unwrap(), "merged");
530 assert_eq!(calls, 2, "should retry once then succeed");
531 }
532
533 #[test]
534 fn a_gitlab_405_while_the_merge_status_recomputes_is_retried() {
535 let mut calls = 0;
536 let result = retry_transient_merge(3, Duration::ZERO, || {
537 calls += 1;
538 if calls < 2 {
539 Err(anyhow!("glab failed: ... /merge: 405 Method Not Allowed"))
540 } else {
541 Ok("merged".to_owned())
542 }
543 });
544 assert_eq!(result.unwrap(), "merged");
545 assert_eq!(calls, 2, "GitLab's transient 405 should be retried");
546 }
547
548 #[test]
549 fn a_transient_5xx_from_the_api_is_retried() {
550 let mut calls = 0;
551 let result = retry_transient_merge(3, Duration::ZERO, || {
552 calls += 1;
553 if calls < 2 {
554 Err(anyhow!(
555 "gh failed: non-200 OK status code: 502 Bad Gateway"
556 ))
557 } else {
558 Ok("merged".to_owned())
559 }
560 });
561 assert_eq!(result.unwrap(), "merged");
562 assert_eq!(calls, 2, "a 502 is a server hiccup, not a merge verdict");
563 }
564
565 #[test]
566 fn a_persistent_transient_error_gives_up_after_the_attempt_budget() {
567 let mut calls = 0;
568 let result = retry_transient_merge(3, Duration::ZERO, || {
569 calls += 1;
570 Err(anyhow!("gh failed: Base branch was modified"))
571 });
572 assert!(result.is_err());
573 assert_eq!(calls, 3, "should try exactly the budgeted number of times");
574 }
575
576 #[test]
577 fn a_real_failure_is_not_retried() {
578 let mut calls = 0;
579 let result = retry_transient_merge(3, Duration::ZERO, || {
580 calls += 1;
581 Err(anyhow!(
582 "gh failed: Pull request is not mergeable: conflicts"
583 ))
584 });
585 assert!(result.is_err());
586 assert_eq!(calls, 1, "a non-transient error must surface immediately");
587 }
588
589 #[test]
590 fn host_of_extracts_the_host_across_url_shapes() {
591 assert_eq!(host_of("https://github.com/owner/repo.git"), "github.com");
592 assert_eq!(host_of("git@github.com:owner/repo.git"), "github.com");
593 assert_eq!(
594 host_of("ssh://git@gitlab.example.com:22/g/r"),
595 "gitlab.example.com"
596 );
597 assert_eq!(host_of("https://user@github.com/owner/repo"), "github.com");
598 assert_eq!(host_of("https://github.com:8443/owner/repo"), "github.com");
599 assert_eq!(
600 host_of("https://[2001:db8::1]:443/owner/repo"),
601 "2001:db8::1"
602 );
603 assert_eq!(host_of("gitlab.example.com"), "gitlab.example.com");
604 assert_eq!(host_of("https://user@name@github.com/r"), "github.com");
606 }
607
608 #[test]
609 fn redact_url_strips_embedded_credentials() {
610 assert_eq!(
612 redact_url("https://x-access-token:ghp_SECRET@github.com/owner/repo.git"),
613 "https://github.com/owner/repo.git"
614 );
615 assert_eq!(
616 redact_url("https://glpat-SECRET@gitlab.com/owner/repo"),
617 "https://gitlab.com/owner/repo"
618 );
619 assert_eq!(redact_url("ssh://git@host:22/g/r"), "ssh://host:22/g/r");
621 }
622
623 #[test]
624 fn redact_url_leaves_credential_free_urls_unchanged() {
625 assert_eq!(
626 redact_url("https://github.com/owner/repo.git"),
627 "https://github.com/owner/repo.git"
628 );
629 assert_eq!(
631 redact_url("git@github.com:owner/repo.git"),
632 "git@github.com:owner/repo.git"
633 );
634 }
635
636 #[test]
637 fn self_hosted_gitlab_accepts_a_bare_host_or_a_full_url() {
638 let remote = "git@gitlab.example.com:team/repo.git";
639 for configured in ["gitlab.example.com", "https://gitlab.example.com"] {
640 assert_eq!(
641 detect_provider_from_url(remote, Some(configured)),
642 Some(ProviderKind::GitLab),
643 "configured {configured:?} should detect the self-hosted host"
644 );
645 }
646 assert_eq!(
648 detect_provider_from_url("git@notgitlab.com:o/r", Some("gitlab.example.com")),
649 None
650 );
651 }
652}