1use claw_core::types::PatchOp;
2use similar::{DiffTag, TextDiff};
3
4use crate::codec::Codec;
5use crate::PatchError;
6
7pub struct TextLineCodec;
9
10fn context_hash(lines: &[&str], center: usize) -> u64 {
11 use std::hash::{Hash, Hasher};
12 let mut hasher = std::collections::hash_map::DefaultHasher::new();
13 let start = center.saturating_sub(3);
14 let end = (center + 4).min(lines.len());
15 for line in &lines[start..end] {
16 line.hash(&mut hasher);
17 }
18 hasher.finish()
19}
20
21impl Codec for TextLineCodec {
22 fn id(&self) -> &str {
23 "text/line"
24 }
25
26 fn diff(&self, old: &[u8], new: &[u8]) -> Result<Vec<PatchOp>, PatchError> {
27 let old_str =
28 std::str::from_utf8(old).map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
29 let new_str =
30 std::str::from_utf8(new).map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
31
32 let diff = TextDiff::from_lines(old_str, new_str);
33 let old_lines: Vec<&str> = old_str.lines().collect();
34 let new_slices = diff.new_slices();
35 let old_slices = diff.old_slices();
36 let mut ops = Vec::new();
37 let mut old_line = 0usize;
38
39 for op in diff.ops() {
40 match op.tag() {
41 DiffTag::Equal => {
42 old_line = op.old_range().end;
43 }
44 DiffTag::Delete => {
45 let range = op.old_range();
46 let deleted: String = old_slices[range.start..range.end].join("");
47 ops.push(PatchOp {
48 address: format!("L{}", range.start),
49 op_type: "delete".to_string(),
50 old_data: Some(deleted.as_bytes().to_vec()),
51 new_data: None,
52 context_hash: Some(context_hash(&old_lines, range.start)),
53 });
54 old_line = range.end;
55 }
56 DiffTag::Insert => {
57 let new_range = op.new_range();
58 let inserted: String = new_slices[new_range.start..new_range.end].join("");
59 ops.push(PatchOp {
60 address: format!("L{}", old_line),
61 op_type: "insert".to_string(),
62 old_data: None,
63 new_data: Some(inserted.as_bytes().to_vec()),
64 context_hash: if !old_lines.is_empty() {
65 Some(context_hash(
66 &old_lines,
67 old_line.min(old_lines.len().saturating_sub(1)),
68 ))
69 } else {
70 None
71 },
72 });
73 }
74 DiffTag::Replace => {
75 let old_range = op.old_range();
76 let new_range = op.new_range();
77 let deleted: String = old_slices[old_range.start..old_range.end].join("");
78 let inserted: String = new_slices[new_range.start..new_range.end].join("");
79 ops.push(PatchOp {
80 address: format!("L{}", old_range.start),
81 op_type: "replace".to_string(),
82 old_data: Some(deleted.as_bytes().to_vec()),
83 new_data: Some(inserted.as_bytes().to_vec()),
84 context_hash: Some(context_hash(&old_lines, old_range.start)),
85 });
86 old_line = old_range.end;
87 }
88 }
89 }
90
91 Ok(ops)
92 }
93
94 fn apply(&self, base: &[u8], ops: &[PatchOp]) -> Result<Vec<u8>, PatchError> {
95 let base_str =
96 std::str::from_utf8(base).map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
97 let mut lines: Vec<String> = base_str.lines().map(|l| l.to_string()).collect();
98 let trailing_newline = base_str.ends_with('\n');
100
101 let mut offset: i64 = 0;
102
103 for op in ops {
104 let line_num = parse_line_address(&op.address)?;
105 let adjusted = (line_num as i64 + offset) as usize;
106
107 match op.op_type.as_str() {
108 "delete" => {
109 let old_data = op.old_data.as_ref().ok_or_else(|| {
110 PatchError::ApplyFailed("delete op missing old_data".into())
111 })?;
112 let old_str = std::str::from_utf8(old_data)
113 .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
114 let count = old_str.lines().count().max(1);
115 if adjusted + count > lines.len() {
116 return Err(PatchError::ApplyFailed(format!(
117 "delete out of bounds: {} + {} > {}",
118 adjusted,
119 count,
120 lines.len()
121 )));
122 }
123 lines.drain(adjusted..adjusted + count);
124 offset -= count as i64;
125 }
126 "insert" => {
127 let new_data = op.new_data.as_ref().ok_or_else(|| {
128 PatchError::ApplyFailed("insert op missing new_data".into())
129 })?;
130 let new_str = std::str::from_utf8(new_data)
131 .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
132 let new_lines: Vec<String> = new_str.lines().map(|l| l.to_string()).collect();
133 let count = new_lines.len();
134 let insert_at = adjusted.min(lines.len());
135 for (i, line) in new_lines.into_iter().enumerate() {
136 lines.insert(insert_at + i, line);
137 }
138 offset += count as i64;
139 }
140 "replace" => {
141 let old_data = op.old_data.as_ref().ok_or_else(|| {
142 PatchError::ApplyFailed("replace op missing old_data".into())
143 })?;
144 let old_str = std::str::from_utf8(old_data)
145 .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
146 let del_count = old_str.lines().count().max(1);
147 if adjusted + del_count > lines.len() {
148 return Err(PatchError::ApplyFailed(format!(
149 "replace delete out of bounds: {} + {} > {}",
150 adjusted,
151 del_count,
152 lines.len()
153 )));
154 }
155 lines.drain(adjusted..adjusted + del_count);
156
157 let new_data = op.new_data.as_ref().ok_or_else(|| {
158 PatchError::ApplyFailed("replace op missing new_data".into())
159 })?;
160 let new_str = std::str::from_utf8(new_data)
161 .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
162 let new_lines: Vec<String> = new_str.lines().map(|l| l.to_string()).collect();
163 let ins_count = new_lines.len();
164 let insert_at = adjusted.min(lines.len());
165 for (i, line) in new_lines.into_iter().enumerate() {
166 lines.insert(insert_at + i, line);
167 }
168 offset += ins_count as i64 - del_count as i64;
169 }
170 other => {
171 return Err(PatchError::ApplyFailed(format!("unknown op type: {other}")));
172 }
173 }
174 }
175
176 let mut result = lines.join("\n");
177 if (trailing_newline || base_str.is_empty()) && !result.is_empty() {
178 result.push('\n');
179 }
180 Ok(result.into_bytes())
181 }
182
183 fn invert(&self, ops: &[PatchOp]) -> Result<Vec<PatchOp>, PatchError> {
184 let mut inverted: Vec<PatchOp> = ops
185 .iter()
186 .map(|op| match op.op_type.as_str() {
187 "delete" => PatchOp {
188 address: op.address.clone(),
189 op_type: "insert".to_string(),
190 old_data: None,
191 new_data: op.old_data.clone(),
192 context_hash: op.context_hash,
193 },
194 "insert" => PatchOp {
195 address: op.address.clone(),
196 op_type: "delete".to_string(),
197 old_data: op.new_data.clone(),
198 new_data: None,
199 context_hash: op.context_hash,
200 },
201 "replace" => PatchOp {
202 address: op.address.clone(),
203 op_type: "replace".to_string(),
204 old_data: op.new_data.clone(),
205 new_data: op.old_data.clone(),
206 context_hash: op.context_hash,
207 },
208 _ => op.clone(),
209 })
210 .collect();
211 inverted.reverse();
212 Ok(inverted)
213 }
214
215 fn commute(
216 &self,
217 left: &[PatchOp],
218 right: &[PatchOp],
219 ) -> Result<(Vec<PatchOp>, Vec<PatchOp>), PatchError> {
220 let mut new_right = Vec::new();
222 let mut new_left = Vec::new();
223
224 for r_op in right {
225 let r_line = parse_line_address(&r_op.address)?;
226 let r_count = op_line_count(r_op);
227
228 let mut r_adjusted = r_line as i64;
229 let mut can_commute = true;
230
231 for l_op in left {
232 let l_line = parse_line_address(&l_op.address)?;
233 let l_count = op_line_count(l_op);
234
235 let (l_start, l_end) = match l_op.op_type.as_str() {
237 "delete" | "replace" => (l_line as i64, l_line as i64 + l_count as i64),
238 "insert" => (l_line as i64, l_line as i64),
239 _ => (l_line as i64, l_line as i64),
240 };
241 let (r_start, r_end) = match r_op.op_type.as_str() {
242 "delete" | "replace" => (r_adjusted, r_adjusted + r_count as i64),
243 "insert" => (r_adjusted, r_adjusted),
244 _ => (r_adjusted, r_adjusted),
245 };
246
247 if r_start < l_end && r_end > l_start {
249 can_commute = false;
250 break;
251 }
252
253 if r_start >= l_end {
255 match l_op.op_type.as_str() {
256 "delete" => r_adjusted -= l_count as i64,
257 "insert" => r_adjusted += l_count as i64,
258 _ => {}
259 }
260 }
261 }
262
263 if !can_commute {
264 return Err(PatchError::CommuteFailed);
265 }
266
267 new_right.push(PatchOp {
268 address: format!("L{}", r_adjusted),
269 ..r_op.clone()
270 });
271 }
272
273 for l_op in left {
275 let l_line = parse_line_address(&l_op.address)?;
276 let mut l_adjusted = l_line as i64;
277
278 for r_op in &new_right {
279 let r_line = parse_line_address(&r_op.address)?;
280 let r_count = op_line_count(r_op);
281
282 if (l_adjusted as usize) > r_line {
283 match r_op.op_type.as_str() {
284 "insert" => l_adjusted += r_count as i64,
285 "delete" => l_adjusted -= r_count as i64,
286 _ => {}
287 }
288 }
289 }
290
291 new_left.push(PatchOp {
292 address: format!("L{}", l_adjusted.max(0)),
293 ..l_op.clone()
294 });
295 }
296
297 Ok((new_right, new_left))
298 }
299
300 fn merge3(&self, base: &[u8], left: &[u8], right: &[u8]) -> Result<Vec<u8>, PatchError> {
301 let base_str =
302 std::str::from_utf8(base).map_err(|e| PatchError::Merge3Failed(e.to_string()))?;
303 let left_str =
304 std::str::from_utf8(left).map_err(|e| PatchError::Merge3Failed(e.to_string()))?;
305 let right_str =
306 std::str::from_utf8(right).map_err(|e| PatchError::Merge3Failed(e.to_string()))?;
307
308 let base_lines: Vec<&str> = base_str.lines().collect();
309
310 let left_diff = TextDiff::from_lines(base_str, left_str);
311 let right_diff = TextDiff::from_lines(base_str, right_str);
312
313 let mut left_changes: std::collections::HashMap<usize, Vec<&str>> =
315 std::collections::HashMap::new();
316 let mut right_changes: std::collections::HashMap<usize, Vec<&str>> =
317 std::collections::HashMap::new();
318
319 collect_changes(&left_diff, &mut left_changes);
320 collect_changes(&right_diff, &mut right_changes);
321
322 let mut result = Vec::new();
323 let mut i = 0;
324
325 while i < base_lines.len() {
326 let left_changed = left_changes.contains_key(&i);
327 let right_changed = right_changes.contains_key(&i);
328
329 match (left_changed, right_changed) {
330 (false, false) => {
331 result.push(base_lines[i].to_string());
332 i += 1;
333 }
334 (true, false) => {
335 if let Some(replacement) = left_changes.get(&i) {
336 result.extend(replacement.iter().map(|s| s.to_string()));
337 }
338 i += 1;
339 }
340 (false, true) => {
341 if let Some(replacement) = right_changes.get(&i) {
342 result.extend(replacement.iter().map(|s| s.to_string()));
343 }
344 i += 1;
345 }
346 (true, true) => {
347 let left_rep = left_changes.get(&i);
348 let right_rep = right_changes.get(&i);
349 if left_rep == right_rep {
350 if let Some(replacement) = left_rep {
351 result.extend(replacement.iter().map(|s| s.to_string()));
352 }
353 } else {
354 return Err(PatchError::Merge3Failed(format!(
355 "conflict at line {i}: both sides changed differently"
356 )));
357 }
358 i += 1;
359 }
360 }
361 }
362
363 let max_base = base_lines.len();
365 if let Some(appended) = left_changes.get(&max_base) {
366 result.extend(appended.iter().map(|s| s.to_string()));
367 }
368 if let Some(appended) = right_changes.get(&max_base) {
369 result.extend(appended.iter().map(|s| s.to_string()));
370 }
371
372 let mut output = result.join("\n");
373 let left_trailing = left_str.ends_with('\n');
374 let right_trailing = right_str.ends_with('\n');
375 if (left_trailing || right_trailing) && !output.is_empty() {
376 output.push('\n');
377 }
378 Ok(output.into_bytes())
379 }
380}
381
382fn collect_changes<'a>(
383 diff: &TextDiff<'a, 'a, 'a, str>,
384 changes: &mut std::collections::HashMap<usize, Vec<&'a str>>,
385) {
386 for op in diff.ops() {
387 match op.tag() {
388 similar::DiffTag::Equal => {}
389 similar::DiffTag::Delete | similar::DiffTag::Replace | similar::DiffTag::Insert => {
390 let old_range = op.old_range();
391 let new_range = op.new_range();
392 let new_text: Vec<&str> = diff.new_slices()[new_range.start..new_range.end]
393 .iter()
394 .flat_map(|s| s.lines())
395 .collect();
396 let key = old_range.start;
397 changes.insert(key, new_text);
398 }
399 }
400 }
401}
402
403fn parse_line_address(addr: &str) -> Result<usize, PatchError> {
404 addr.strip_prefix('L')
405 .and_then(|n| n.parse::<usize>().ok())
406 .ok_or_else(|| PatchError::AddressResolutionFailed(format!("invalid line address: {addr}")))
407}
408
409fn op_line_count(op: &PatchOp) -> usize {
410 match op.op_type.as_str() {
411 "delete" | "replace" => {
412 if let Some(data) = &op.old_data {
413 std::str::from_utf8(data)
414 .map(|s| s.lines().count().max(1))
415 .unwrap_or(1)
416 } else {
417 1
418 }
419 }
420 "insert" => {
421 if let Some(data) = &op.new_data {
422 std::str::from_utf8(data)
423 .map(|s| s.lines().count().max(1))
424 .unwrap_or(1)
425 } else {
426 1
427 }
428 }
429 _ => 0,
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn diff_and_apply_roundtrip() {
439 let codec = TextLineCodec;
440 let old = b"line1\nline2\nline3\n";
441 let new = b"line1\nmodified\nline3\nextra\n";
442 let ops = codec.diff(old, new).unwrap();
443 let result = codec.apply(old, &ops).unwrap();
444 assert_eq!(result, new);
445 }
446
447 #[test]
448 fn invert_cancels_patch() {
449 let codec = TextLineCodec;
450 let old = b"a\nb\nc\n";
451 let new = b"a\nx\nc\n";
452 let ops = codec.diff(old, new).unwrap();
453 let applied = codec.apply(old, &ops).unwrap();
454 assert_eq!(applied, new);
455
456 let inv = codec.invert(&ops).unwrap();
457 let restored = codec.apply(new, &inv).unwrap();
458 assert_eq!(restored, old);
459 }
460
461 #[test]
462 fn merge3_no_conflict() {
463 let codec = TextLineCodec;
464 let base = b"line1\nline2\nline3\n";
465 let left = b"line1\nleft_change\nline3\n";
466 let right = b"line1\nline2\nright_change\n";
467 let merged = codec.merge3(base, left, right).unwrap();
468 let merged_str = std::str::from_utf8(&merged).unwrap();
469 assert!(merged_str.contains("left_change"));
470 assert!(merged_str.contains("right_change"));
471 }
472
473 #[test]
474 fn merge3_conflict() {
475 let codec = TextLineCodec;
476 let base = b"line1\nline2\nline3\n";
477 let left = b"line1\nleft_change\nline3\n";
478 let right = b"line1\nright_change\nline3\n";
479 assert!(codec.merge3(base, left, right).is_err());
480 }
481}