#![expect(missing_docs)]
use itertools::EitherOrBoth;
use crate::backend::CommitId;
use crate::index::Index;
use crate::index::IndexResult;
use crate::iter_util::fallible_position;
use crate::merge::Diff;
use crate::merge::Merge;
use crate::merge::SameChange;
use crate::merge::trivial_merge;
use crate::op_store::RefTarget;
use crate::op_store::RemoteRef;
pub fn diff_named_ref_targets<'a, 'b, K: Ord>(
refs1: impl IntoIterator<Item = (K, &'a RefTarget)>,
refs2: impl IntoIterator<Item = (K, &'b RefTarget)>,
) -> impl Iterator<Item = (K, (&'a RefTarget, &'b RefTarget))> {
iter_named_pairs(
refs1,
refs2,
|| RefTarget::absent_ref(),
|| RefTarget::absent_ref(),
)
.filter(|(_, (target1, target2))| target1 != target2)
}
pub fn diff_named_remote_refs<'a, 'b, K: Ord>(
refs1: impl IntoIterator<Item = (K, &'a RemoteRef)>,
refs2: impl IntoIterator<Item = (K, &'b RemoteRef)>,
) -> impl Iterator<Item = (K, (&'a RemoteRef, &'b RemoteRef))> {
iter_named_pairs(
refs1,
refs2,
|| RemoteRef::absent_ref(),
|| RemoteRef::absent_ref(),
)
.filter(|(_, (ref1, ref2))| ref1 != ref2)
}
pub fn iter_named_local_remote_refs<'a, 'b, K: Ord>(
refs1: impl IntoIterator<Item = (K, &'a RefTarget)>,
refs2: impl IntoIterator<Item = (K, &'b RemoteRef)>,
) -> impl Iterator<Item = (K, (&'a RefTarget, &'b RemoteRef))> {
iter_named_pairs(
refs1,
refs2,
|| RefTarget::absent_ref(),
|| RemoteRef::absent_ref(),
)
}
pub fn diff_named_commit_ids<'a, 'b, K: Ord>(
ids1: impl IntoIterator<Item = (K, &'a CommitId)>,
ids2: impl IntoIterator<Item = (K, &'b CommitId)>,
) -> impl Iterator<Item = (K, (Option<&'a CommitId>, Option<&'b CommitId>))> {
iter_named_pairs(
ids1.into_iter().map(|(k, v)| (k, Some(v))),
ids2.into_iter().map(|(k, v)| (k, Some(v))),
|| None,
|| None,
)
.filter(|(_, (target1, target2))| target1 != target2)
}
fn iter_named_pairs<K: Ord, V1, V2>(
refs1: impl IntoIterator<Item = (K, V1)>,
refs2: impl IntoIterator<Item = (K, V2)>,
absent_ref1: impl Fn() -> V1,
absent_ref2: impl Fn() -> V2,
) -> impl Iterator<Item = (K, (V1, V2))> {
itertools::merge_join_by(refs1, refs2, |(name1, _), (name2, _)| name1.cmp(name2)).map(
move |entry| match entry {
EitherOrBoth::Both((name, target1), (_, target2)) => (name, (target1, target2)),
EitherOrBoth::Left((name, target1)) => (name, (target1, absent_ref2())),
EitherOrBoth::Right((name, target2)) => (name, (absent_ref1(), target2)),
},
)
}
pub fn merge_ref_targets(
index: &dyn Index,
left: &RefTarget,
base: &RefTarget,
right: &RefTarget,
) -> IndexResult<RefTarget> {
if let Some(&resolved) = trivial_merge(&[left, base, right], SameChange::Accept) {
return Ok(resolved.clone());
}
let mut merge = Merge::from_vec(vec![
left.as_merge().clone(),
base.as_merge().clone(),
right.as_merge().clone(),
])
.flatten()
.simplify();
if let Some(resolved) = merge.resolve_trivial(SameChange::Accept) {
Ok(RefTarget::resolved(resolved.clone()))
} else {
merge_ref_targets_non_trivial(index, &mut merge)?;
Ok(RefTarget::from_merge(merge))
}
}
pub fn merge_remote_refs(
index: &dyn Index,
left: &RemoteRef,
base: &RemoteRef,
right: &RemoteRef,
) -> IndexResult<RemoteRef> {
let target = merge_ref_targets(index, &left.target, &base.target, &right.target)?;
let state = *trivial_merge(&[left.state, base.state, right.state], SameChange::Accept)
.unwrap_or(&base.state);
Ok(RemoteRef { target, state })
}
fn merge_ref_targets_non_trivial(
index: &dyn Index,
conflict: &mut Merge<Option<CommitId>>,
) -> IndexResult<()> {
while let Some((remove_index, add_index)) = find_pair_to_remove(index, conflict)? {
conflict.swap_remove(remove_index, add_index);
}
Ok(())
}
fn find_pair_to_remove(
index: &dyn Index,
conflict: &Merge<Option<CommitId>>,
) -> IndexResult<Option<(usize, usize)>> {
for (add_index1, add1) in conflict.adds().enumerate() {
for (add_index2, add2) in conflict.adds().enumerate().skip(add_index1 + 1) {
let (add_index, add_id) = match (add1, add2) {
(Some(id1), Some(id2)) if id1 == id2 => (add_index1, id1),
(Some(id1), Some(id2)) if index.is_ancestor(id1, id2)? => (add_index1, id1),
(Some(id1), Some(id2)) if index.is_ancestor(id2, id1)? => (add_index2, id2),
_ => continue,
};
if let Some(remove_index) =
fallible_position(conflict.removes(), |remove| match remove {
Some(id) => index.is_ancestor(id, add_id),
None => Ok(true), })?
{
return Ok(Some((remove_index, add_index)));
}
}
}
Ok(None)
}
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
pub struct LocalAndRemoteRef<'a> {
pub local_target: &'a RefTarget,
pub remote_ref: &'a RemoteRef,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum RefPushAction {
Update(Diff<Option<CommitId>>),
AlreadyMatches,
LocalConflicted,
RemoteConflicted,
RemoteUntracked,
}
pub fn classify_ref_push_action(targets: LocalAndRemoteRef) -> RefPushAction {
let local_target = targets.local_target;
let remote_target = targets.remote_ref.tracked_target();
if local_target == remote_target {
RefPushAction::AlreadyMatches
} else if local_target.has_conflict() {
RefPushAction::LocalConflicted
} else if remote_target.has_conflict() {
RefPushAction::RemoteConflicted
} else if targets.remote_ref.is_present() && !targets.remote_ref.is_tracked() {
RefPushAction::RemoteUntracked
} else {
RefPushAction::Update(Diff::new(
remote_target.as_normal().cloned(),
local_target.as_normal().cloned(),
))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::op_store::RemoteRefState;
fn new_remote_ref(target: RefTarget) -> RemoteRef {
RemoteRef {
target,
state: RemoteRefState::New,
}
}
fn tracked_remote_ref(target: RefTarget) -> RemoteRef {
RemoteRef {
target,
state: RemoteRefState::Tracked,
}
}
#[test]
fn test_classify_ref_push_action_unchanged() {
let commit_id1 = CommitId::from_hex("11");
let targets = LocalAndRemoteRef {
local_target: &RefTarget::normal(commit_id1.clone()),
remote_ref: &tracked_remote_ref(RefTarget::normal(commit_id1)),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::AlreadyMatches
);
}
#[test]
fn test_classify_ref_push_action_added() {
let commit_id1 = CommitId::from_hex("11");
let targets = LocalAndRemoteRef {
local_target: &RefTarget::normal(commit_id1.clone()),
remote_ref: RemoteRef::absent_ref(),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::Update(Diff::new(None, Some(commit_id1)))
);
}
#[test]
fn test_classify_ref_push_action_removed() {
let commit_id1 = CommitId::from_hex("11");
let targets = LocalAndRemoteRef {
local_target: RefTarget::absent_ref(),
remote_ref: &tracked_remote_ref(RefTarget::normal(commit_id1.clone())),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::Update(Diff::new(Some(commit_id1), None))
);
}
#[test]
fn test_classify_ref_push_action_updated() {
let commit_id1 = CommitId::from_hex("11");
let commit_id2 = CommitId::from_hex("22");
let targets = LocalAndRemoteRef {
local_target: &RefTarget::normal(commit_id2.clone()),
remote_ref: &tracked_remote_ref(RefTarget::normal(commit_id1.clone())),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::Update(Diff::new(Some(commit_id1), Some(commit_id2)))
);
}
#[test]
fn test_classify_ref_push_action_removed_untracked() {
let commit_id1 = CommitId::from_hex("11");
let targets = LocalAndRemoteRef {
local_target: RefTarget::absent_ref(),
remote_ref: &new_remote_ref(RefTarget::normal(commit_id1.clone())),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::AlreadyMatches
);
}
#[test]
fn test_classify_ref_push_action_updated_untracked() {
let commit_id1 = CommitId::from_hex("11");
let commit_id2 = CommitId::from_hex("22");
let targets = LocalAndRemoteRef {
local_target: &RefTarget::normal(commit_id2.clone()),
remote_ref: &new_remote_ref(RefTarget::normal(commit_id1.clone())),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::RemoteUntracked
);
}
#[test]
fn test_classify_ref_push_action_local_conflicted() {
let commit_id1 = CommitId::from_hex("11");
let commit_id2 = CommitId::from_hex("22");
let targets = LocalAndRemoteRef {
local_target: &RefTarget::from_legacy_form([], [commit_id1.clone(), commit_id2]),
remote_ref: &tracked_remote_ref(RefTarget::normal(commit_id1)),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::LocalConflicted
);
}
#[test]
fn test_classify_ref_push_action_remote_conflicted() {
let commit_id1 = CommitId::from_hex("11");
let commit_id2 = CommitId::from_hex("22");
let targets = LocalAndRemoteRef {
local_target: &RefTarget::normal(commit_id1.clone()),
remote_ref: &tracked_remote_ref(RefTarget::from_legacy_form(
[],
[commit_id1, commit_id2],
)),
};
assert_eq!(
classify_ref_push_action(targets),
RefPushAction::RemoteConflicted
);
}
}