use crate::operations::{FileInfo, VersionInfo};
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum ConflictType {
Divergent {
common_ancestor: u64,
local_version: u64,
remote_version: u64,
},
ContentMismatch {
version: u64,
local_hash: String,
remote_hash: String,
},
VersionGap {
expected: u64,
found: u64,
},
None,
}
impl std::fmt::Display for ConflictType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Divergent {
common_ancestor,
local_version,
remote_version,
} => {
write!(f, "Divergent histories from version {common_ancestor}: local={local_version}, remote={remote_version}")
}
Self::ContentMismatch {
version,
local_hash,
remote_hash,
} => {
write!(f, "Content mismatch at version {version}: local={local_hash}, remote={remote_hash}")
}
Self::VersionGap { expected, found } => {
write!(f, "Version gap: expected {expected}, found {found}")
}
Self::None => write!(f, "No conflict"),
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ConflictReport {
pub conflict_type: ConflictType,
pub local_version_count: u64,
pub remote_version_count: u64,
pub suggested_strategy: MergeStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum MergeStrategy {
KeepLocal,
KeepRemote,
KeepNewest,
Manual,
Append,
}
impl std::fmt::Display for MergeStrategy {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::KeepLocal => write!(f, "Keep local"),
Self::KeepRemote => write!(f, "Keep remote"),
Self::KeepNewest => write!(f, "Keep newest"),
Self::Manual => write!(f, "Manual merge required"),
Self::Append => write!(f, "Append remote versions"),
}
}
}
#[must_use]
pub fn detect_conflict(local: &FileInfo, remote: &FileInfo) -> ConflictReport {
if local.file_id != remote.file_id {
return file_id_mismatch_report(local, remote);
}
if local.version_count == remote.version_count {
return same_length_report(local, remote);
}
differing_length_report(local, remote)
}
fn file_id_mismatch_report(local: &FileInfo, remote: &FileInfo) -> ConflictReport {
ConflictReport {
conflict_type: ConflictType::ContentMismatch {
version: 0,
local_hash: format!("{:016x}", local.file_id),
remote_hash: format!("{:016x}", remote.file_id),
},
local_version_count: local.version_count,
remote_version_count: remote.version_count,
suggested_strategy: MergeStrategy::Manual,
}
}
fn same_length_report(local: &FileInfo, remote: &FileInfo) -> ConflictReport {
if versions_match(local, remote) {
return ConflictReport {
conflict_type: ConflictType::None,
local_version_count: local.version_count,
remote_version_count: remote.version_count,
suggested_strategy: MergeStrategy::KeepLocal,
};
}
ConflictReport {
conflict_type: ConflictType::ContentMismatch {
version: local.current_version,
local_hash: format_version_hash(local),
remote_hash: format_version_hash(remote),
},
local_version_count: local.version_count,
remote_version_count: remote.version_count,
suggested_strategy: MergeStrategy::Manual,
}
}
fn differing_length_report(local: &FileInfo, remote: &FileInfo) -> ConflictReport {
let (shorter, longer) = if local.version_count < remote.version_count {
(local, remote)
} else {
(remote, local)
};
if is_linear_extension(shorter, longer) {
let strategy = if local.version_count < remote.version_count {
MergeStrategy::KeepRemote
} else {
MergeStrategy::KeepLocal
};
return ConflictReport {
conflict_type: ConflictType::None,
local_version_count: local.version_count,
remote_version_count: remote.version_count,
suggested_strategy: strategy,
};
}
ConflictReport {
conflict_type: ConflictType::Divergent {
common_ancestor: find_common_ancestor(local, remote),
local_version: local.current_version,
remote_version: remote.current_version,
},
local_version_count: local.version_count,
remote_version_count: remote.version_count,
suggested_strategy: MergeStrategy::Manual,
}
}
fn versions_match(local: &FileInfo, remote: &FileInfo) -> bool {
match (local.versions.last(), remote.versions.last()) {
(Some(l), Some(r)) => l.rules_hash == r.rules_hash,
(None, None) => true,
_ => false,
}
}
fn format_version_hash(info: &FileInfo) -> String {
info.versions
.last()
.map_or_else(|| "empty".to_string(), |v| hex::encode(&v.rules_hash[..8]))
}
fn is_linear_extension(shorter: &FileInfo, longer: &FileInfo) -> bool {
if shorter.versions.len() > longer.versions.len() {
return false;
}
for (i, short_ver) in shorter.versions.iter().enumerate() {
let Some(long_ver) = longer.versions.get(i) else {
return false;
};
if short_ver.rules_hash != long_ver.rules_hash {
return false;
}
}
true
}
fn find_common_ancestor(local: &FileInfo, remote: &FileInfo) -> u64 {
let min_len = std::cmp::min(local.versions.len(), remote.versions.len());
let mut last_matching: Option<&VersionInfo> = None;
for i in 0..min_len {
let (Some(l), Some(r)) = (local.versions.get(i), remote.versions.get(i)) else {
break;
};
if l.rules_hash != r.rules_hash {
return last_matching.map_or(0, |v| v.version_number);
}
last_matching = Some(l);
}
last_matching.map_or(0, |v| v.version_number)
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ConflictMarker {
pub marker_type: MarkerType,
pub start: usize,
pub end: usize,
pub source: String,
}
#[derive(Debug, Clone, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
pub enum MarkerType {
ConflictStart,
Separator,
ConflictEnd,
}
#[must_use]
pub fn create_conflict_markers(
local_content: &[u8],
remote_content: &[u8],
local_label: &str,
remote_label: &str,
) -> Vec<u8> {
let mut result = Vec::new();
result.extend_from_slice(b"<<<<<<< ");
result.extend_from_slice(local_label.as_bytes());
result.push(b'\n');
result.extend_from_slice(local_content);
if !local_content.ends_with(b"\n") {
result.push(b'\n');
}
result.extend_from_slice(b"=======\n");
result.extend_from_slice(remote_content);
if !remote_content.ends_with(b"\n") {
result.push(b'\n');
}
result.extend_from_slice(b">>>>>>> ");
result.extend_from_slice(remote_label.as_bytes());
result.push(b'\n');
result
}
#[must_use]
pub fn has_conflict_markers(content: &[u8]) -> bool {
let content_str = String::from_utf8_lossy(content);
content_str.contains("<<<<<<<") && content_str.contains(">>>>>>>")
}
#[must_use]
pub fn parse_conflict_markers(content: &[u8]) -> Option<(Vec<u8>, Vec<u8>)> {
let content_str = String::from_utf8_lossy(content);
let start_idx = content_str.find("<<<<<<< ")?;
let separator_idx = content_str.find("=======")?;
let end_idx = content_str.find(">>>>>>> ")?;
if start_idx >= separator_idx || separator_idx >= end_idx {
return None;
}
let after_start = content_str
.get(start_idx..)?
.find('\n')?
.checked_add(start_idx)?
.checked_add(1)?;
let local = content_str.get(after_start..separator_idx)?;
let after_sep = separator_idx.checked_add(8)?; let remote = content_str.get(after_sep..end_idx)?;
Some((
local.trim_end().as_bytes().to_vec(),
remote.trim_end().as_bytes().to_vec(),
))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_conflict_markers() {
let local = b"local version";
let remote = b"remote version";
let merged = create_conflict_markers(local, remote, "LOCAL", "REMOTE");
assert!(has_conflict_markers(&merged));
let (parsed_local, parsed_remote) =
parse_conflict_markers(&merged).unwrap_or_else(|| std::process::abort());
assert_eq!(parsed_local, local.to_vec());
assert_eq!(parsed_remote, remote.to_vec());
}
#[test]
fn test_no_conflict_markers() {
let content = b"normal content without markers";
assert!(!has_conflict_markers(content));
}
#[test]
fn test_merge_strategy_display() {
assert_eq!(MergeStrategy::KeepLocal.to_string(), "Keep local");
assert_eq!(MergeStrategy::Manual.to_string(), "Manual merge required");
}
#[test]
fn test_conflict_type_display() {
let conflict = ConflictType::Divergent {
common_ancestor: 5,
local_version: 7,
remote_version: 8,
};
let display = conflict.to_string();
assert!(display.contains("Divergent"));
assert!(display.contains('5'));
}
}