1use crate::error::{CollabError, Result};
4use serde::{Deserialize, Serialize};
5use similar::{ChangeTag, TextDiff};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
9#[serde(rename_all = "lowercase")]
10pub enum MergeStrategy {
11 Ours,
13 Theirs,
15 Auto,
17 Manual,
19}
20
21#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct ConflictResolution {
24 pub has_conflicts: bool,
26 pub resolved: serde_json::Value,
28 pub conflicts: Vec<Conflict>,
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct Conflict {
35 pub path: String,
37 pub ours: serde_json::Value,
39 pub theirs: serde_json::Value,
41 pub base: Option<serde_json::Value>,
43}
44
45pub struct ConflictResolver {
47 default_strategy: MergeStrategy,
49}
50
51impl ConflictResolver {
52 #[must_use]
54 pub const fn new(default_strategy: MergeStrategy) -> Self {
55 Self { default_strategy }
56 }
57
58 pub fn resolve(
60 &self,
61 base: Option<&serde_json::Value>,
62 ours: &serde_json::Value,
63 theirs: &serde_json::Value,
64 strategy: Option<MergeStrategy>,
65 ) -> Result<ConflictResolution> {
66 let strategy = strategy.unwrap_or(self.default_strategy);
67
68 if ours == theirs {
70 return Ok(ConflictResolution {
71 has_conflicts: false,
72 resolved: ours.clone(),
73 conflicts: Vec::new(),
74 });
75 }
76
77 match strategy {
78 MergeStrategy::Ours => Ok(ConflictResolution {
79 has_conflicts: false,
80 resolved: ours.clone(),
81 conflicts: Vec::new(),
82 }),
83 MergeStrategy::Theirs => Ok(ConflictResolution {
84 has_conflicts: false,
85 resolved: theirs.clone(),
86 conflicts: Vec::new(),
87 }),
88 MergeStrategy::Auto => self.auto_merge(base, ours, theirs),
89 MergeStrategy::Manual => {
90 let conflicts = self.detect_conflicts("", base, ours, theirs);
92 Ok(ConflictResolution {
93 has_conflicts: !conflicts.is_empty(),
94 resolved: ours.clone(), conflicts,
96 })
97 }
98 }
99 }
100
101 fn auto_merge(
103 &self,
104 base: Option<&serde_json::Value>,
105 ours: &serde_json::Value,
106 theirs: &serde_json::Value,
107 ) -> Result<ConflictResolution> {
108 match (ours, theirs) {
110 (serde_json::Value::Object(ours_obj), serde_json::Value::Object(theirs_obj)) => {
111 let mut resolved = serde_json::Map::new();
112 let mut conflicts = Vec::new();
113
114 let base_obj = base.and_then(|b| b.as_object());
115
116 let all_keys: std::collections::HashSet<_> =
118 ours_obj.keys().chain(theirs_obj.keys()).collect();
119
120 for key in all_keys {
121 let ours_val = ours_obj.get(key);
122 let theirs_val = theirs_obj.get(key);
123 let base_val = base_obj.and_then(|b| b.get(key));
124
125 match (ours_val, theirs_val) {
126 (Some(o), Some(t)) if o == t => {
127 resolved.insert(key.clone(), o.clone());
129 }
130 (Some(o), Some(t)) => {
131 if let Some(base_val) = base_val {
133 if o == base_val {
134 resolved.insert(key.clone(), t.clone());
136 } else if t == base_val {
137 resolved.insert(key.clone(), o.clone());
139 } else {
140 conflicts.push(Conflict {
142 path: key.clone(),
143 ours: o.clone(),
144 theirs: t.clone(),
145 base: Some(base_val.clone()),
146 });
147 resolved.insert(key.clone(), o.clone()); }
149 } else {
150 conflicts.push(Conflict {
152 path: key.clone(),
153 ours: o.clone(),
154 theirs: t.clone(),
155 base: None,
156 });
157 resolved.insert(key.clone(), o.clone());
158 }
159 }
160 (Some(o), None) => {
161 resolved.insert(key.clone(), o.clone());
163 }
164 (None, Some(t)) => {
165 resolved.insert(key.clone(), t.clone());
167 }
168 (None, None) => unreachable!(),
169 }
170 }
171
172 Ok(ConflictResolution {
173 has_conflicts: !conflicts.is_empty(),
174 resolved: serde_json::Value::Object(resolved),
175 conflicts,
176 })
177 }
178 _ => {
179 Ok(ConflictResolution {
181 has_conflicts: true,
182 resolved: ours.clone(),
183 conflicts: vec![Conflict {
184 path: String::new(),
185 ours: ours.clone(),
186 theirs: theirs.clone(),
187 base: base.cloned(),
188 }],
189 })
190 }
191 }
192 }
193
194 fn detect_conflicts(
196 &self,
197 path: &str,
198 base: Option<&serde_json::Value>,
199 ours: &serde_json::Value,
200 theirs: &serde_json::Value,
201 ) -> Vec<Conflict> {
202 let mut conflicts = Vec::new();
203
204 if ours == theirs {
205 return conflicts;
206 }
207
208 match (ours, theirs) {
209 (serde_json::Value::Object(ours_obj), serde_json::Value::Object(theirs_obj)) => {
210 let base_obj = base.and_then(|b| b.as_object());
211 let all_keys: std::collections::HashSet<_> =
212 ours_obj.keys().chain(theirs_obj.keys()).collect();
213
214 for key in all_keys {
215 let new_path = if path.is_empty() {
216 key.clone()
217 } else {
218 format!("{path}.{key}")
219 };
220
221 let ours_val = ours_obj.get(key);
222 let theirs_val = theirs_obj.get(key);
223 let base_val = base_obj.and_then(|b| b.get(key));
224
225 if let (Some(o), Some(t)) = (ours_val, theirs_val) {
226 conflicts.extend(self.detect_conflicts(&new_path, base_val, o, t));
227 } else if ours_val != theirs_val {
228 conflicts.push(Conflict {
229 path: new_path,
230 ours: ours_val.cloned().unwrap_or(serde_json::Value::Null),
231 theirs: theirs_val.cloned().unwrap_or(serde_json::Value::Null),
232 base: base_val.cloned(),
233 });
234 }
235 }
236 }
237 _ => {
238 conflicts.push(Conflict {
239 path: path.to_string(),
240 ours: ours.clone(),
241 theirs: theirs.clone(),
242 base: base.cloned(),
243 });
244 }
245 }
246
247 conflicts
248 }
249
250 pub fn merge_text(&self, base: &str, ours: &str, theirs: &str) -> Result<String> {
257 if ours == theirs {
258 return Ok(ours.to_string());
259 }
260
261 let base_lines: Vec<&str> = base.lines().collect();
263 let ours_lines: Vec<&str> = ours.lines().collect();
264 let theirs_lines: Vec<&str> = theirs.lines().collect();
265
266 let diff_ours = TextDiff::from_lines(base, ours);
268 let diff_theirs = TextDiff::from_lines(base, theirs);
269
270 let ours_changes = Self::collect_line_changes(&diff_ours);
272 let theirs_changes = Self::collect_line_changes(&diff_theirs);
273
274 let mut result = String::new();
275 let mut has_conflict = false;
276
277 for (i, base_line) in base_lines.iter().enumerate() {
279 let ours_changed = ours_changes.get(&i);
280 let theirs_changed = theirs_changes.get(&i);
281
282 match (ours_changed, theirs_changed) {
283 (None, None) => {
284 result.push_str(base_line);
286 result.push('\n');
287 }
288 (Some(ours_replacement), None) => {
289 for line in ours_replacement {
291 result.push_str(line);
292 result.push('\n');
293 }
294 }
295 (None, Some(theirs_replacement)) => {
296 for line in theirs_replacement {
298 result.push_str(line);
299 result.push('\n');
300 }
301 }
302 (Some(ours_replacement), Some(theirs_replacement)) => {
303 if ours_replacement == theirs_replacement {
304 for line in ours_replacement {
306 result.push_str(line);
307 result.push('\n');
308 }
309 } else {
310 has_conflict = true;
312 }
313 }
314 }
315 }
316
317 for line in ours_lines.iter().skip(base_lines.len()) {
319 result.push_str(line);
320 result.push('\n');
321 }
322 for line in theirs_lines.iter().skip(base_lines.len()) {
324 result.push_str(line);
325 result.push('\n');
326 }
327
328 if has_conflict {
329 Err(CollabError::ConflictDetected("Text merge conflict".to_string()))
330 } else {
331 Ok(result)
332 }
333 }
334
335 fn collect_line_changes<'a>(
340 diff: &TextDiff<'a, 'a, 'a, str>,
341 ) -> std::collections::HashMap<usize, Vec<&'a str>> {
342 let mut changes: std::collections::HashMap<usize, Vec<&str>> =
343 std::collections::HashMap::new();
344 let mut base_idx: usize = 0;
345
346 for change in diff.iter_all_changes() {
347 match change.tag() {
348 ChangeTag::Equal => {
349 base_idx += 1;
350 }
351 ChangeTag::Delete => {
352 changes.entry(base_idx).or_default();
354 base_idx += 1;
355 }
356 ChangeTag::Insert => {
357 let idx = if base_idx > 0 { base_idx - 1 } else { 0 };
359 changes.entry(idx).or_default().push(change.value().trim_end_matches('\n'));
360 }
361 }
362 }
363
364 changes
365 }
366}
367
368impl Default for ConflictResolver {
369 fn default() -> Self {
370 Self::new(MergeStrategy::Auto)
371 }
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use serde_json::json;
378
379 #[test]
380 fn test_no_conflict() {
381 let resolver = ConflictResolver::default();
382 let value = json!({"key": "value"});
383
384 let result = resolver.resolve(None, &value, &value, None).unwrap();
385
386 assert!(!result.has_conflicts);
387 assert_eq!(result.resolved, value);
388 assert!(result.conflicts.is_empty());
389 }
390
391 #[test]
392 fn test_strategy_ours() {
393 let resolver = ConflictResolver::default();
394 let ours = json!({"key": "ours"});
395 let theirs = json!({"key": "theirs"});
396
397 let result = resolver.resolve(None, &ours, &theirs, Some(MergeStrategy::Ours)).unwrap();
398
399 assert!(!result.has_conflicts);
400 assert_eq!(result.resolved, ours);
401 }
402
403 #[test]
404 fn test_strategy_theirs() {
405 let resolver = ConflictResolver::default();
406 let ours = json!({"key": "ours"});
407 let theirs = json!({"key": "theirs"});
408
409 let result = resolver.resolve(None, &ours, &theirs, Some(MergeStrategy::Theirs)).unwrap();
410
411 assert!(!result.has_conflicts);
412 assert_eq!(result.resolved, theirs);
413 }
414
415 #[test]
416 fn test_auto_merge_no_base() {
417 let resolver = ConflictResolver::default();
418 let ours = json!({"key1": "value1"});
419 let theirs = json!({"key2": "value2"});
420
421 let result = resolver.resolve(None, &ours, &theirs, Some(MergeStrategy::Auto)).unwrap();
422
423 assert!(!result.has_conflicts);
425 assert_eq!(result.resolved["key1"], "value1");
426 assert_eq!(result.resolved["key2"], "value2");
427 }
428
429 #[test]
430 fn test_auto_merge_with_base() {
431 let resolver = ConflictResolver::default();
432 let base = json!({"key": "base"});
433 let ours = json!({"key": "ours"});
434 let theirs = json!({"key": "base"}); let result = resolver
437 .resolve(Some(&base), &ours, &theirs, Some(MergeStrategy::Auto))
438 .unwrap();
439
440 assert!(!result.has_conflicts);
441 assert_eq!(result.resolved["key"], "ours");
442 }
443
444 #[test]
445 fn test_conflict_detection() {
446 let resolver = ConflictResolver::default();
447 let base = json!({"key": "base"});
448 let ours = json!({"key": "ours"});
449 let theirs = json!({"key": "theirs"});
450
451 let result = resolver
452 .resolve(Some(&base), &ours, &theirs, Some(MergeStrategy::Auto))
453 .unwrap();
454
455 assert!(result.has_conflicts);
456 assert_eq!(result.conflicts.len(), 1);
457 assert_eq!(result.conflicts[0].path, "key");
458 }
459}