Skip to main content

claw_patch/
text_line.rs

1use claw_core::types::PatchOp;
2use similar::{DiffTag, TextDiff};
3
4use crate::codec::Codec;
5use crate::PatchError;
6
7/// Codec for UTF-8 text using line-based patch addresses.
8pub struct TextLineCodec;
9
10fn context_hash(lines: &[&str], center: usize) -> u64 {
11    use std::hash::{Hash, Hasher};
12    let mut hasher = std::collections::hash_map::DefaultHasher::new();
13    let start = center.saturating_sub(3);
14    let end = (center + 4).min(lines.len());
15    for line in &lines[start..end] {
16        line.hash(&mut hasher);
17    }
18    hasher.finish()
19}
20
21impl Codec for TextLineCodec {
22    fn id(&self) -> &str {
23        "text/line"
24    }
25
26    fn diff(&self, old: &[u8], new: &[u8]) -> Result<Vec<PatchOp>, PatchError> {
27        let old_str =
28            std::str::from_utf8(old).map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
29        let new_str =
30            std::str::from_utf8(new).map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
31
32        let diff = TextDiff::from_lines(old_str, new_str);
33        let old_lines: Vec<&str> = old_str.lines().collect();
34        let new_slices = diff.new_slices();
35        let old_slices = diff.old_slices();
36        let mut ops = Vec::new();
37        let mut old_line = 0usize;
38
39        for op in diff.ops() {
40            match op.tag() {
41                DiffTag::Equal => {
42                    old_line = op.old_range().end;
43                }
44                DiffTag::Delete => {
45                    let range = op.old_range();
46                    let deleted: String = old_slices[range.start..range.end].join("");
47                    ops.push(PatchOp {
48                        address: format!("L{}", range.start),
49                        op_type: "delete".to_string(),
50                        old_data: Some(deleted.as_bytes().to_vec()),
51                        new_data: None,
52                        context_hash: Some(context_hash(&old_lines, range.start)),
53                    });
54                    old_line = range.end;
55                }
56                DiffTag::Insert => {
57                    let new_range = op.new_range();
58                    let inserted: String = new_slices[new_range.start..new_range.end].join("");
59                    ops.push(PatchOp {
60                        address: format!("L{}", old_line),
61                        op_type: "insert".to_string(),
62                        old_data: None,
63                        new_data: Some(inserted.as_bytes().to_vec()),
64                        context_hash: if !old_lines.is_empty() {
65                            Some(context_hash(
66                                &old_lines,
67                                old_line.min(old_lines.len().saturating_sub(1)),
68                            ))
69                        } else {
70                            None
71                        },
72                    });
73                }
74                DiffTag::Replace => {
75                    let old_range = op.old_range();
76                    let new_range = op.new_range();
77                    let deleted: String = old_slices[old_range.start..old_range.end].join("");
78                    let inserted: String = new_slices[new_range.start..new_range.end].join("");
79                    ops.push(PatchOp {
80                        address: format!("L{}", old_range.start),
81                        op_type: "replace".to_string(),
82                        old_data: Some(deleted.as_bytes().to_vec()),
83                        new_data: Some(inserted.as_bytes().to_vec()),
84                        context_hash: Some(context_hash(&old_lines, old_range.start)),
85                    });
86                    old_line = old_range.end;
87                }
88            }
89        }
90
91        Ok(ops)
92    }
93
94    fn apply(&self, base: &[u8], ops: &[PatchOp]) -> Result<Vec<u8>, PatchError> {
95        let base_str =
96            std::str::from_utf8(base).map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
97        let mut lines: Vec<String> = base_str.lines().map(|l| l.to_string()).collect();
98        // Track whether original had trailing newline
99        let trailing_newline = base_str.ends_with('\n');
100
101        let mut offset: i64 = 0;
102
103        for op in ops {
104            let line_num = parse_line_address(&op.address)?;
105            let adjusted = (line_num as i64 + offset) as usize;
106
107            match op.op_type.as_str() {
108                "delete" => {
109                    let old_data = op.old_data.as_ref().ok_or_else(|| {
110                        PatchError::ApplyFailed("delete op missing old_data".into())
111                    })?;
112                    let old_str = std::str::from_utf8(old_data)
113                        .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
114                    let count = old_str.lines().count().max(1);
115                    if adjusted + count > lines.len() {
116                        return Err(PatchError::ApplyFailed(format!(
117                            "delete out of bounds: {} + {} > {}",
118                            adjusted,
119                            count,
120                            lines.len()
121                        )));
122                    }
123                    lines.drain(adjusted..adjusted + count);
124                    offset -= count as i64;
125                }
126                "insert" => {
127                    let new_data = op.new_data.as_ref().ok_or_else(|| {
128                        PatchError::ApplyFailed("insert op missing new_data".into())
129                    })?;
130                    let new_str = std::str::from_utf8(new_data)
131                        .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
132                    let new_lines: Vec<String> = new_str.lines().map(|l| l.to_string()).collect();
133                    let count = new_lines.len();
134                    let insert_at = adjusted.min(lines.len());
135                    for (i, line) in new_lines.into_iter().enumerate() {
136                        lines.insert(insert_at + i, line);
137                    }
138                    offset += count as i64;
139                }
140                "replace" => {
141                    let old_data = op.old_data.as_ref().ok_or_else(|| {
142                        PatchError::ApplyFailed("replace op missing old_data".into())
143                    })?;
144                    let old_str = std::str::from_utf8(old_data)
145                        .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
146                    let del_count = old_str.lines().count().max(1);
147                    if adjusted + del_count > lines.len() {
148                        return Err(PatchError::ApplyFailed(format!(
149                            "replace delete out of bounds: {} + {} > {}",
150                            adjusted,
151                            del_count,
152                            lines.len()
153                        )));
154                    }
155                    lines.drain(adjusted..adjusted + del_count);
156
157                    let new_data = op.new_data.as_ref().ok_or_else(|| {
158                        PatchError::ApplyFailed("replace op missing new_data".into())
159                    })?;
160                    let new_str = std::str::from_utf8(new_data)
161                        .map_err(|e| PatchError::ApplyFailed(e.to_string()))?;
162                    let new_lines: Vec<String> = new_str.lines().map(|l| l.to_string()).collect();
163                    let ins_count = new_lines.len();
164                    let insert_at = adjusted.min(lines.len());
165                    for (i, line) in new_lines.into_iter().enumerate() {
166                        lines.insert(insert_at + i, line);
167                    }
168                    offset += ins_count as i64 - del_count as i64;
169                }
170                other => {
171                    return Err(PatchError::ApplyFailed(format!("unknown op type: {other}")));
172                }
173            }
174        }
175
176        let mut result = lines.join("\n");
177        if (trailing_newline || base_str.is_empty()) && !result.is_empty() {
178            result.push('\n');
179        }
180        Ok(result.into_bytes())
181    }
182
183    fn invert(&self, ops: &[PatchOp]) -> Result<Vec<PatchOp>, PatchError> {
184        let mut inverted: Vec<PatchOp> = ops
185            .iter()
186            .map(|op| match op.op_type.as_str() {
187                "delete" => PatchOp {
188                    address: op.address.clone(),
189                    op_type: "insert".to_string(),
190                    old_data: None,
191                    new_data: op.old_data.clone(),
192                    context_hash: op.context_hash,
193                },
194                "insert" => PatchOp {
195                    address: op.address.clone(),
196                    op_type: "delete".to_string(),
197                    old_data: op.new_data.clone(),
198                    new_data: None,
199                    context_hash: op.context_hash,
200                },
201                "replace" => PatchOp {
202                    address: op.address.clone(),
203                    op_type: "replace".to_string(),
204                    old_data: op.new_data.clone(),
205                    new_data: op.old_data.clone(),
206                    context_hash: op.context_hash,
207                },
208                _ => op.clone(),
209            })
210            .collect();
211        inverted.reverse();
212        Ok(inverted)
213    }
214
215    fn commute(
216        &self,
217        left: &[PatchOp],
218        right: &[PatchOp],
219    ) -> Result<(Vec<PatchOp>, Vec<PatchOp>), PatchError> {
220        // Darcs-style commutation: non-overlapping line ranges commute with offset adjustment
221        let mut new_right = Vec::new();
222        let mut new_left = Vec::new();
223
224        for r_op in right {
225            let r_line = parse_line_address(&r_op.address)?;
226            let r_count = op_line_count(r_op);
227
228            let mut r_adjusted = r_line as i64;
229            let mut can_commute = true;
230
231            for l_op in left {
232                let l_line = parse_line_address(&l_op.address)?;
233                let l_count = op_line_count(l_op);
234
235                // Check for overlap
236                let (l_start, l_end) = match l_op.op_type.as_str() {
237                    "delete" | "replace" => (l_line as i64, l_line as i64 + l_count as i64),
238                    "insert" => (l_line as i64, l_line as i64),
239                    _ => (l_line as i64, l_line as i64),
240                };
241                let (r_start, r_end) = match r_op.op_type.as_str() {
242                    "delete" | "replace" => (r_adjusted, r_adjusted + r_count as i64),
243                    "insert" => (r_adjusted, r_adjusted),
244                    _ => (r_adjusted, r_adjusted),
245                };
246
247                // Check overlap
248                if r_start < l_end && r_end > l_start {
249                    can_commute = false;
250                    break;
251                }
252
253                // Adjust offset
254                if r_start >= l_end {
255                    match l_op.op_type.as_str() {
256                        "delete" => r_adjusted -= l_count as i64,
257                        "insert" => r_adjusted += l_count as i64,
258                        _ => {}
259                    }
260                }
261            }
262
263            if !can_commute {
264                return Err(PatchError::CommuteFailed);
265            }
266
267            new_right.push(PatchOp {
268                address: format!("L{}", r_adjusted),
269                ..r_op.clone()
270            });
271        }
272
273        // Adjust left ops considering right was applied first
274        for l_op in left {
275            let l_line = parse_line_address(&l_op.address)?;
276            let mut l_adjusted = l_line as i64;
277
278            for r_op in &new_right {
279                let r_line = parse_line_address(&r_op.address)?;
280                let r_count = op_line_count(r_op);
281
282                if (l_adjusted as usize) > r_line {
283                    match r_op.op_type.as_str() {
284                        "insert" => l_adjusted += r_count as i64,
285                        "delete" => l_adjusted -= r_count as i64,
286                        _ => {}
287                    }
288                }
289            }
290
291            new_left.push(PatchOp {
292                address: format!("L{}", l_adjusted.max(0)),
293                ..l_op.clone()
294            });
295        }
296
297        Ok((new_right, new_left))
298    }
299
300    fn merge3(&self, base: &[u8], left: &[u8], right: &[u8]) -> Result<Vec<u8>, PatchError> {
301        let base_str =
302            std::str::from_utf8(base).map_err(|e| PatchError::Merge3Failed(e.to_string()))?;
303        let left_str =
304            std::str::from_utf8(left).map_err(|e| PatchError::Merge3Failed(e.to_string()))?;
305        let right_str =
306            std::str::from_utf8(right).map_err(|e| PatchError::Merge3Failed(e.to_string()))?;
307
308        let base_lines: Vec<&str> = base_str.lines().collect();
309
310        let left_diff = TextDiff::from_lines(base_str, left_str);
311        let right_diff = TextDiff::from_lines(base_str, right_str);
312
313        // Build change maps: which base lines were modified
314        let mut left_changes: std::collections::HashMap<usize, Vec<&str>> =
315            std::collections::HashMap::new();
316        let mut right_changes: std::collections::HashMap<usize, Vec<&str>> =
317            std::collections::HashMap::new();
318
319        collect_changes(&left_diff, &mut left_changes);
320        collect_changes(&right_diff, &mut right_changes);
321
322        let mut result = Vec::new();
323        let mut i = 0;
324
325        while i < base_lines.len() {
326            let left_changed = left_changes.contains_key(&i);
327            let right_changed = right_changes.contains_key(&i);
328
329            match (left_changed, right_changed) {
330                (false, false) => {
331                    result.push(base_lines[i].to_string());
332                    i += 1;
333                }
334                (true, false) => {
335                    if let Some(replacement) = left_changes.get(&i) {
336                        result.extend(replacement.iter().map(|s| s.to_string()));
337                    }
338                    i += 1;
339                }
340                (false, true) => {
341                    if let Some(replacement) = right_changes.get(&i) {
342                        result.extend(replacement.iter().map(|s| s.to_string()));
343                    }
344                    i += 1;
345                }
346                (true, true) => {
347                    let left_rep = left_changes.get(&i);
348                    let right_rep = right_changes.get(&i);
349                    if left_rep == right_rep {
350                        if let Some(replacement) = left_rep {
351                            result.extend(replacement.iter().map(|s| s.to_string()));
352                        }
353                    } else {
354                        return Err(PatchError::Merge3Failed(format!(
355                            "conflict at line {i}: both sides changed differently"
356                        )));
357                    }
358                    i += 1;
359                }
360            }
361        }
362
363        // Handle appended lines
364        let max_base = base_lines.len();
365        if let Some(appended) = left_changes.get(&max_base) {
366            result.extend(appended.iter().map(|s| s.to_string()));
367        }
368        if let Some(appended) = right_changes.get(&max_base) {
369            result.extend(appended.iter().map(|s| s.to_string()));
370        }
371
372        let mut output = result.join("\n");
373        let left_trailing = left_str.ends_with('\n');
374        let right_trailing = right_str.ends_with('\n');
375        if (left_trailing || right_trailing) && !output.is_empty() {
376            output.push('\n');
377        }
378        Ok(output.into_bytes())
379    }
380}
381
382fn collect_changes<'a>(
383    diff: &TextDiff<'a, 'a, 'a, str>,
384    changes: &mut std::collections::HashMap<usize, Vec<&'a str>>,
385) {
386    for op in diff.ops() {
387        match op.tag() {
388            similar::DiffTag::Equal => {}
389            similar::DiffTag::Delete | similar::DiffTag::Replace | similar::DiffTag::Insert => {
390                let old_range = op.old_range();
391                let new_range = op.new_range();
392                let new_text: Vec<&str> = diff.new_slices()[new_range.start..new_range.end]
393                    .iter()
394                    .flat_map(|s| s.lines())
395                    .collect();
396                let key = old_range.start;
397                changes.insert(key, new_text);
398            }
399        }
400    }
401}
402
403fn parse_line_address(addr: &str) -> Result<usize, PatchError> {
404    addr.strip_prefix('L')
405        .and_then(|n| n.parse::<usize>().ok())
406        .ok_or_else(|| PatchError::AddressResolutionFailed(format!("invalid line address: {addr}")))
407}
408
409fn op_line_count(op: &PatchOp) -> usize {
410    match op.op_type.as_str() {
411        "delete" | "replace" => {
412            if let Some(data) = &op.old_data {
413                std::str::from_utf8(data)
414                    .map(|s| s.lines().count().max(1))
415                    .unwrap_or(1)
416            } else {
417                1
418            }
419        }
420        "insert" => {
421            if let Some(data) = &op.new_data {
422                std::str::from_utf8(data)
423                    .map(|s| s.lines().count().max(1))
424                    .unwrap_or(1)
425            } else {
426                1
427            }
428        }
429        _ => 0,
430    }
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    #[test]
438    fn diff_and_apply_roundtrip() {
439        let codec = TextLineCodec;
440        let old = b"line1\nline2\nline3\n";
441        let new = b"line1\nmodified\nline3\nextra\n";
442        let ops = codec.diff(old, new).unwrap();
443        let result = codec.apply(old, &ops).unwrap();
444        assert_eq!(result, new);
445    }
446
447    #[test]
448    fn invert_cancels_patch() {
449        let codec = TextLineCodec;
450        let old = b"a\nb\nc\n";
451        let new = b"a\nx\nc\n";
452        let ops = codec.diff(old, new).unwrap();
453        let applied = codec.apply(old, &ops).unwrap();
454        assert_eq!(applied, new);
455
456        let inv = codec.invert(&ops).unwrap();
457        let restored = codec.apply(new, &inv).unwrap();
458        assert_eq!(restored, old);
459    }
460
461    #[test]
462    fn merge3_no_conflict() {
463        let codec = TextLineCodec;
464        let base = b"line1\nline2\nline3\n";
465        let left = b"line1\nleft_change\nline3\n";
466        let right = b"line1\nline2\nright_change\n";
467        let merged = codec.merge3(base, left, right).unwrap();
468        let merged_str = std::str::from_utf8(&merged).unwrap();
469        assert!(merged_str.contains("left_change"));
470        assert!(merged_str.contains("right_change"));
471    }
472
473    #[test]
474    fn merge3_conflict() {
475        let codec = TextLineCodec;
476        let base = b"line1\nline2\nline3\n";
477        let left = b"line1\nleft_change\nline3\n";
478        let right = b"line1\nright_change\nline3\n";
479        assert!(codec.merge3(base, left, right).is_err());
480    }
481}