use std::path::Path;
use crate::commands::checkout::remote_configured;
use crate::cx::Cx;
use crate::error::{Error, Result};
use crate::git::cli::GitCli;
use crate::git::discover::Repo;
use crate::git::{ahead_behind, branch_ref, is_ancestor, ops, resolve_hex, upstream_of};
use crate::worktree_service::enumerate_worktrees;
#[derive(Debug, Clone, PartialEq, Eq)]
pub(crate) struct StaleBase {
pub(crate) behind: u32,
pub(crate) tracking_ref: String,
pub(crate) upstream_display: String,
pub(crate) can_fast_forward: bool,
}
pub(crate) fn check_base_behind(
cx: &mut Cx,
git: &dyn GitCli,
repo: &Repo,
dir: &Path,
base: &str,
) -> Result<Option<StaleBase>> {
if resolve_hex(repo.gix(), &branch_ref(base)).is_none() {
return Ok(None);
}
let Some(up) = upstream_of(repo.gix(), base) else {
return Ok(None);
};
if up.is_gone {
return Ok(None);
}
let remote = up.display.split('/').next().unwrap_or("origin").to_string();
if remote_configured(repo.gix(), &remote)
&& let Err(e) = ops::fetch(git, dir, &remote)
{
let _ = cx
.err
.line(&format!("warning: failed to fetch {remote}: {e}"));
}
let (ahead, behind) = ahead_behind(git, dir, &up.tracking_ref, &branch_ref(base))?;
if behind == 0 {
return Ok(None);
}
Ok(Some(StaleBase {
behind,
tracking_ref: up.tracking_ref,
upstream_display: up.display,
can_fast_forward: ahead == 0,
}))
}
pub(crate) fn fast_forward_base(
cx: &mut Cx,
git: &dyn GitCli,
repo: &Repo,
root: &Path,
base: &str,
stale: &StaleBase,
) -> Result<()> {
if !is_ancestor(repo.gix(), &branch_ref(base), &stale.tracking_ref) {
return Err(Error::operation(format!(
"base {base:?} has diverged from {}; cannot fast-forward",
stale.upstream_display
)));
}
let checked_out = enumerate_worktrees(repo, git)?
.into_iter()
.find(|w| w.branch.as_deref() == Some(base))
.map(|w| w.path);
if let Some(path) = checked_out {
ops::merge_ff_only(git, &path, &stale.tracking_ref)?;
} else {
ops::set_branch_ref(git, root, base, &stale.tracking_ref)?;
}
let _ = cx.err.line(&format!(
"updated {base} to {} (fast-forward)",
stale.upstream_display
));
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::git::RealGit;
use crate::git::discover::Repo;
use crate::testutil::{TestRepo, test_cx};
fn check(repo: &TestRepo, base: &str) -> Option<StaleBase> {
let mut t = test_cx(&[], repo.root().to_str().unwrap());
let r = Repo::discover(repo.root()).unwrap();
super::check_base_behind(&mut t.cx, &RealGit, &r, repo.root(), base).unwrap()
}
fn behind_with_upstream(repo: &TestRepo, base: &str) {
let c1 = repo.git(&["rev-parse", "HEAD"]).trim().to_string();
repo.write("a.txt", "1\n");
repo.commit_all("c2");
let c2 = repo.git(&["rev-parse", "HEAD"]).trim().to_string();
repo.git(&["update-ref", &format!("refs/remotes/origin/{base}"), &c2]);
repo.git(&["reset", "-q", "--hard", &c1]);
repo.git(&["config", &format!("branch.{base}.remote"), "origin"]);
repo.git(&[
"config",
&format!("branch.{base}.merge"),
&format!("refs/heads/{base}"),
]);
}
#[test]
fn detects_base_behind_upstream() {
let repo = TestRepo::init();
behind_with_upstream(&repo, "main");
let stale = check(&repo, "main").expect("base is behind");
assert_eq!(stale.behind, 1);
assert_eq!(stale.upstream_display, "origin/main");
assert!(stale.can_fast_forward);
}
#[test]
fn no_upstream_is_not_stale() {
let repo = TestRepo::init();
repo.git(&["branch", "topic"]); assert!(check(&repo, "topic").is_none());
}
#[test]
fn non_local_base_is_not_stale() {
let repo = TestRepo::init();
assert!(check(&repo, "origin/main").is_none());
assert!(check(&repo, "HEAD").is_none());
}
#[test]
fn up_to_date_base_is_not_stale() {
let repo = TestRepo::init();
repo.git(&["update-ref", "refs/remotes/origin/main", "refs/heads/main"]);
repo.git(&["config", "branch.main.remote", "origin"]);
repo.git(&["config", "branch.main.merge", "refs/heads/main"]);
assert!(check(&repo, "main").is_none());
}
#[test]
fn fetches_then_detects_behind() {
let bare = TestRepo::init_bare();
let repo = TestRepo::init();
repo.git(&["remote", "add", "origin", bare.root().to_str().unwrap()]);
repo.git(&["push", "-q", "-u", "origin", "main"]);
let c1 = repo.git(&["rev-parse", "HEAD"]).trim().to_string();
repo.write("a.txt", "1\n");
repo.commit_all("c2");
repo.git(&["push", "-q", "origin", "main"]);
repo.git(&["reset", "-q", "--hard", &c1]);
let stale = check(&repo, "main").expect("base is behind after fetch");
assert_eq!(stale.behind, 1);
}
#[test]
fn fast_forward_checked_out_base_updates_working_tree() {
let repo = TestRepo::init();
behind_with_upstream(&repo, "main");
let c2 = repo
.git(&["rev-parse", "refs/remotes/origin/main"])
.trim()
.to_string();
let mut t = test_cx(&[], repo.root().to_str().unwrap());
let r = Repo::discover(repo.root()).unwrap();
let stale = StaleBase {
behind: 1,
tracking_ref: "refs/remotes/origin/main".into(),
upstream_display: "origin/main".into(),
can_fast_forward: true,
};
super::fast_forward_base(&mut t.cx, &RealGit, &r, repo.root(), "main", &stale).unwrap();
assert_eq!(repo.git(&["rev-parse", "refs/heads/main"]).trim(), c2);
}
#[test]
fn fast_forward_non_checked_out_base_moves_ref() {
let repo = TestRepo::init();
repo.git(&["branch", "topic"]); repo.write("a.txt", "1\n");
repo.commit_all("c2");
let c2 = repo.git(&["rev-parse", "HEAD"]).trim().to_string();
repo.git(&["update-ref", "refs/remotes/origin/topic", &c2]);
let mut t = test_cx(&[], repo.root().to_str().unwrap());
let r = Repo::discover(repo.root()).unwrap();
let stale = StaleBase {
behind: 1,
tracking_ref: "refs/remotes/origin/topic".into(),
upstream_display: "origin/topic".into(),
can_fast_forward: true,
};
super::fast_forward_base(&mut t.cx, &RealGit, &r, repo.root(), "topic", &stale).unwrap();
assert_eq!(repo.git(&["rev-parse", "refs/heads/topic"]).trim(), c2);
}
#[test]
fn fast_forward_refuses_diverged_base() {
let repo = TestRepo::init();
let c1 = repo.git(&["rev-parse", "HEAD"]).trim().to_string();
repo.write("a.txt", "1\n");
repo.commit_all("c2"); repo.git(&["update-ref", "refs/remotes/origin/main", &c1]); let mut t = test_cx(&[], repo.root().to_str().unwrap());
let r = Repo::discover(repo.root()).unwrap();
let stale = StaleBase {
behind: 0,
tracking_ref: "refs/remotes/origin/main".into(),
upstream_display: "origin/main".into(),
can_fast_forward: false,
};
let err = super::fast_forward_base(&mut t.cx, &RealGit, &r, repo.root(), "main", &stale)
.unwrap_err();
assert!(err.to_string().contains("fast-forward"));
}
}