fresh/model/
piece_tree_diff.rs

1use std::ops::Range;
2use std::sync::Arc;
3
4use crate::model::piece_tree::{LeafData, PieceTreeNode};
5
6/// Summary of differences between two piece tree roots.
7#[derive(Debug, Clone, PartialEq, Eq)]
8pub struct PieceTreeDiff {
9    /// Whether the two trees represent identical piece sequences.
10    pub equal: bool,
11    /// Changed byte ranges in the "after" tree (exclusive end). Empty when `equal` is true.
12    pub byte_ranges: Vec<Range<usize>>,
13    /// Changed line ranges in the "after" tree (exclusive end). `None` when line counts are unknown.
14    pub line_ranges: Option<Vec<Range<usize>>>,
15}
16
17/// Compute a diff between two piece tree roots.
18///
19/// Comparison happens at the byte-span level (not whole leaves) so split leaves
20/// still align. The result identifies the minimal contiguous range in the
21/// "after" tree that differs from "before".
22///
23/// `line_counter` should return the number of line feeds in a slice of a leaf.
24/// If it returns None for any consulted slice, the diff will have line_range=None.
25pub fn diff_piece_trees(
26    before: &Arc<PieceTreeNode>,
27    after: &Arc<PieceTreeNode>,
28    line_counter: &dyn Fn(&LeafData, usize, usize) -> Option<usize>,
29) -> PieceTreeDiff {
30    let mut before_leaves = Vec::new();
31    collect_leaves(before, &mut before_leaves);
32    let before_leaves = normalize_leaves(before_leaves);
33
34    let mut after_leaves = Vec::new();
35    collect_leaves(after, &mut after_leaves);
36    let after_leaves = normalize_leaves(after_leaves);
37
38    // Fast-path: identical leaf sequences.
39    if leaf_slices_equal(&before_leaves, &after_leaves) {
40        return PieceTreeDiff {
41            equal: true,
42            byte_ranges: Vec::new(),
43            line_ranges: Some(Vec::new()),
44        };
45    }
46
47    let before_spans = with_doc_offsets(&before_leaves);
48    let after_spans = with_doc_offsets(&after_leaves);
49
50    let _total_after = sum_bytes(&after_leaves);
51
52    // Longest common prefix at byte granularity.
53    let prefix = common_prefix_bytes(&before_spans, &after_spans);
54    // Longest common suffix without overlapping prefix.
55    let suffix = common_suffix_bytes(&before_spans, &after_spans, prefix);
56
57    let ranges = collect_diff_ranges(&before_spans, &after_spans, prefix, suffix);
58
59    // Map byte ranges to line ranges (best effort).
60    let line_ranges = line_ranges(&after_spans, &ranges, line_counter);
61
62    PieceTreeDiff {
63        equal: false,
64        byte_ranges: ranges,
65        line_ranges,
66    }
67}
68
69fn collect_leaves(node: &Arc<PieceTreeNode>, out: &mut Vec<LeafData>) {
70    match node.as_ref() {
71        PieceTreeNode::Internal { left, right, .. } => {
72            collect_leaves(left, out);
73            collect_leaves(right, out);
74        }
75        PieceTreeNode::Leaf {
76            location,
77            offset,
78            bytes,
79            line_feed_cnt,
80        } => out.push(LeafData::new(*location, *offset, *bytes, *line_feed_cnt)),
81    }
82}
83
84fn leaves_equal(a: &LeafData, b: &LeafData) -> bool {
85    a.location == b.location && a.offset == b.offset && a.bytes == b.bytes
86}
87
88fn leaf_slices_equal(a: &[LeafData], b: &[LeafData]) -> bool {
89    a.len() == b.len() && a.iter().zip(b.iter()).all(|(x, y)| leaves_equal(x, y))
90}
91
92fn normalize_leaves(mut leaves: Vec<LeafData>) -> Vec<LeafData> {
93    if leaves.is_empty() {
94        return leaves;
95    }
96
97    let mut normalized = Vec::with_capacity(leaves.len());
98    let mut current = leaves.remove(0);
99
100    for leaf in leaves.into_iter() {
101        let contiguous =
102            current.location == leaf.location && current.offset + current.bytes == leaf.offset;
103        if contiguous {
104            // Merge by extending bytes and line feeds if known
105            current.bytes += leaf.bytes;
106            current.line_feed_cnt = match (current.line_feed_cnt, leaf.line_feed_cnt) {
107                (Some(a), Some(b)) => Some(a + b),
108                _ => None,
109            };
110        } else {
111            normalized.push(current);
112            current = leaf;
113        }
114    }
115
116    normalized.push(current);
117    normalized
118}
119
120fn sum_bytes(leaves: &[LeafData]) -> usize {
121    leaves.iter().map(|leaf| leaf.bytes).sum()
122}
123
124#[derive(Clone)]
125struct Span {
126    leaf: LeafData,
127    doc_offset: usize,
128}
129
130fn with_doc_offsets(leaves: &[LeafData]) -> Vec<Span> {
131    let mut spans = Vec::with_capacity(leaves.len());
132    let mut offset = 0;
133    for leaf in leaves {
134        spans.push(Span {
135            leaf: *leaf,
136            doc_offset: offset,
137        });
138        offset += leaf.bytes;
139    }
140    spans
141}
142
143fn common_prefix_bytes(before: &[Span], after: &[Span]) -> usize {
144    let mut b_idx = 0;
145    let mut a_idx = 0;
146    let mut b_off = 0;
147    let mut a_off = 0;
148    let mut consumed = 0;
149
150    while b_idx < before.len() && a_idx < after.len() {
151        let b = &before[b_idx].leaf;
152        let a = &after[a_idx].leaf;
153
154        let b_pos = b.offset + b_off;
155        let a_pos = a.offset + a_off;
156
157        if b.location == a.location && b_pos == a_pos {
158            let b_rem = b.bytes - b_off;
159            let a_rem = a.bytes - a_off;
160            let take = b_rem.min(a_rem);
161
162            consumed += take;
163            b_off += take;
164            a_off += take;
165
166            if b_off == b.bytes {
167                b_idx += 1;
168                b_off = 0;
169            }
170            if a_off == a.bytes {
171                a_idx += 1;
172                a_off = 0;
173            }
174        } else {
175            break;
176        }
177    }
178
179    consumed
180}
181
182fn common_suffix_bytes(before: &[Span], after: &[Span], prefix_bytes: usize) -> usize {
183    let total_before = before.iter().map(|s| s.leaf.bytes).sum::<usize>();
184    let total_after = after.iter().map(|s| s.leaf.bytes).sum::<usize>();
185
186    let mut b_idx: isize = before.len() as isize - 1;
187    let mut a_idx: isize = after.len() as isize - 1;
188    let mut b_off = 0;
189    let mut a_off = 0;
190    let mut consumed = 0;
191
192    while b_idx >= 0
193        && a_idx >= 0
194        && (total_before - consumed) > prefix_bytes
195        && (total_after - consumed) > prefix_bytes
196    {
197        let b_leaf = &before[b_idx as usize].leaf;
198        let a_leaf = &after[a_idx as usize].leaf;
199
200        let b_pos = b_leaf.offset + b_leaf.bytes - b_off;
201        let a_pos = a_leaf.offset + a_leaf.bytes - a_off;
202
203        if b_leaf.location == a_leaf.location && b_pos == a_pos {
204            let b_rem = b_leaf.bytes - b_off;
205            let a_rem = a_leaf.bytes - a_off;
206            let take = b_rem.min(a_rem);
207
208            consumed += take;
209            b_off += take;
210            a_off += take;
211
212            if b_off == b_leaf.bytes {
213                b_idx -= 1;
214                b_off = 0;
215            }
216            if a_off == a_leaf.bytes {
217                a_idx -= 1;
218                a_off = 0;
219            }
220        } else {
221            break;
222        }
223    }
224
225    consumed.min(total_after.saturating_sub(prefix_bytes))
226}
227
228fn collect_diff_ranges(
229    before: &[Span],
230    after: &[Span],
231    prefix: usize,
232    suffix: usize,
233) -> Vec<Range<usize>> {
234    let mut ranges = Vec::new();
235    let mut b_idx = 0;
236    let mut a_idx = 0;
237    let mut b_off = 0;
238    let mut a_off = 0;
239    let mut matched_prefix = 0;
240
241    // Skip matching prefix
242    while matched_prefix < prefix && b_idx < before.len() && a_idx < after.len() {
243        let b = &before[b_idx].leaf;
244        let a = &after[a_idx].leaf;
245        let b_rem = b.bytes - b_off;
246        let a_rem = a.bytes - a_off;
247        let take = b_rem.min(a_rem).min(prefix - matched_prefix);
248        matched_prefix += take;
249        b_off += take;
250        a_off += take;
251        if b_off == b.bytes {
252            b_idx += 1;
253            b_off = 0;
254        }
255        if a_off == a.bytes {
256            a_idx += 1;
257            a_off = 0;
258        }
259    }
260
261    let total_after = after.iter().map(|s| s.leaf.bytes).sum::<usize>();
262    let compare_limit = total_after.saturating_sub(suffix);
263
264    let mut current_start: Option<usize> = None;
265    let mut current_end: usize = 0;
266
267    while a_idx < after.len() {
268        let a = &after[a_idx];
269        let pos = a.doc_offset + a_off;
270        if pos >= compare_limit {
271            break;
272        }
273
274        let matches = if b_idx < before.len() {
275            let b = &before[b_idx].leaf;
276            let b_pos = b.offset + b_off;
277            let a_pos = a.leaf.offset + a_off;
278            b.location == a.leaf.location && b_pos == a_pos
279        } else {
280            false
281        };
282
283        if matches {
284            if let Some(start) = current_start.take() {
285                ranges.push(start..current_end);
286            }
287
288            let b = &before[b_idx].leaf;
289            let b_rem = b.bytes - b_off;
290            let a_rem = a.leaf.bytes - a_off;
291            let take = b_rem.min(a_rem).min(compare_limit.saturating_sub(pos));
292
293            b_off += take;
294            a_off += take;
295
296            if b_off == b.bytes {
297                b_idx += 1;
298                b_off = 0;
299            }
300            if a_off == a.leaf.bytes {
301                a_idx += 1;
302                a_off = 0;
303            }
304        } else {
305            if current_start.is_none() {
306                current_start = Some(pos);
307                current_end = pos;
308            }
309            let take = (a.leaf.bytes - a_off).min(compare_limit.saturating_sub(pos));
310            current_end += take;
311            a_off += take;
312            if a_off == a.leaf.bytes {
313                a_idx += 1;
314                a_off = 0;
315            }
316        }
317    }
318
319    if let Some(start) = current_start {
320        ranges.push(start..current_end);
321    }
322
323    // Any trailing unmatched "after" spans up to suffix boundary
324    while a_idx < after.len() {
325        let start = after[a_idx].doc_offset + a_off;
326        if start >= compare_limit {
327            break;
328        }
329        let end = (after[a_idx].doc_offset + after[a_idx].leaf.bytes).min(compare_limit);
330        ranges.push(start..end);
331        a_idx += 1;
332        a_off = 0;
333    }
334
335    if ranges.is_empty() {
336        let total_after = after.iter().map(|s| s.leaf.bytes).sum::<usize>();
337        let compare_limit = total_after.saturating_sub(suffix);
338        ranges.push(prefix..compare_limit);
339    }
340
341    ranges
342}
343
344fn count_lines_in_range(
345    spans: &[Span],
346    start: usize,
347    len: usize,
348    line_counter: &dyn Fn(&LeafData, usize, usize) -> Option<usize>,
349) -> Option<usize> {
350    if len == 0 {
351        return Some(0);
352    }
353
354    let mut remaining = len;
355    let mut offset = start;
356    let mut line_feeds = 0usize;
357
358    for span in spans {
359        if remaining == 0 {
360            break;
361        }
362        let span_start = span.doc_offset;
363        let span_end = span_start + span.leaf.bytes;
364        if offset >= span_end {
365            continue;
366        }
367        let local_start = offset.saturating_sub(span_start);
368        let available = span.leaf.bytes - local_start;
369        let take = available.min(remaining);
370
371        let chunk_lines = line_counter(&span.leaf, local_start, take)?;
372        line_feeds += chunk_lines;
373
374        offset += take;
375        remaining -= take;
376    }
377
378    Some(line_feeds)
379}
380
381fn line_ranges(
382    after_spans: &[Span],
383    byte_ranges: &[Range<usize>],
384    line_counter: &dyn Fn(&LeafData, usize, usize) -> Option<usize>,
385) -> Option<Vec<Range<usize>>> {
386    let mut accum = Vec::with_capacity(byte_ranges.len());
387    for range in byte_ranges {
388        let lf_before = count_lines_in_range(after_spans, 0, range.start, line_counter)?;
389        let lf_in_range = count_lines_in_range(
390            after_spans,
391            range.start,
392            range.end.saturating_sub(range.start),
393            line_counter,
394        )?;
395        let start_line = lf_before;
396        let end_line = if range.start == range.end {
397            lf_before + 1
398        } else {
399            lf_before + lf_in_range + 1
400        };
401        accum.push(start_line..end_line);
402    }
403
404    Some(accum)
405}
406
407#[cfg(test)]
408#[allow(clippy::single_range_in_vec_init)]
409mod tests {
410    use super::*;
411    use crate::model::piece_tree::BufferLocation;
412
413    fn leaf(loc: BufferLocation, offset: usize, bytes: usize, lfs: Option<usize>) -> LeafData {
414        LeafData::new(loc, offset, bytes, lfs)
415    }
416
417    // Minimal balanced builder for tests.
418    fn build(leaves: &[LeafData]) -> Arc<PieceTreeNode> {
419        if leaves.is_empty() {
420            return Arc::new(PieceTreeNode::Leaf {
421                location: BufferLocation::Stored(0),
422                offset: 0,
423                bytes: 0,
424                line_feed_cnt: Some(0),
425            });
426        }
427        if leaves.len() == 1 {
428            let l = leaves[0];
429            return Arc::new(PieceTreeNode::Leaf {
430                location: l.location,
431                offset: l.offset,
432                bytes: l.bytes,
433                line_feed_cnt: l.line_feed_cnt,
434            });
435        }
436
437        let mid = leaves.len() / 2;
438        let left = build(&leaves[..mid]);
439        let right = build(&leaves[mid..]);
440
441        Arc::new(PieceTreeNode::Internal {
442            left_bytes: sum_bytes(&leaves[..mid]),
443            lf_left: leaves[..mid]
444                .iter()
445                .map(|l| l.line_feed_cnt)
446                .try_fold(0usize, |acc, v| v.map(|b| acc + b)),
447            left,
448            right,
449        })
450    }
451
452    fn count_line_feeds(leaf: &LeafData, start: usize, len: usize) -> Option<usize> {
453        if len == 0 {
454            return Some(0);
455        }
456        // If we know total LFs, assume uniform distribution only when full coverage.
457        if start == 0 && len == leaf.bytes {
458            return leaf.line_feed_cnt;
459        }
460        None
461    }
462
463    #[test]
464    fn detects_identical_trees() {
465        let leaves = vec![leaf(BufferLocation::Stored(0), 0, 10, Some(0))];
466        let before = build(&leaves);
467        let after = build(&leaves);
468
469        let diff = diff_piece_trees(&before, &after, &count_line_feeds);
470        assert!(diff.equal);
471        assert!(diff.byte_ranges.is_empty());
472        assert_eq!(diff.line_ranges, Some(Vec::new()));
473    }
474
475    #[test]
476    fn detects_single_line_change() {
477        let before = build(&[leaf(BufferLocation::Stored(0), 0, 5, Some(0))]);
478        let after = build(&[leaf(BufferLocation::Added(1), 0, 5, Some(0))]);
479
480        let diff = diff_piece_trees(&before, &after, &count_line_feeds);
481        assert!(!diff.equal);
482        assert_eq!(diff.byte_ranges, vec![0..5]);
483        assert_eq!(diff.line_ranges, Some(vec![0..1])); // same line, different content
484    }
485
486    #[test]
487    fn tracks_newlines_in_changed_span() {
488        let before = build(&[leaf(BufferLocation::Stored(0), 0, 6, Some(0))]);
489        let after = build(&[leaf(BufferLocation::Added(1), 0, 6, Some(1))]); // introduces a newline
490
491        let diff = diff_piece_trees(&before, &after, &count_line_feeds);
492        assert!(!diff.equal);
493        assert_eq!(diff.byte_ranges, vec![0..6]);
494        assert_eq!(diff.line_ranges, Some(vec![0..2])); // spans two lines after change
495    }
496
497    #[test]
498    fn handles_deletion_by_marking_anchor_line() {
499        let before = build(&[
500            leaf(BufferLocation::Stored(0), 0, 6, Some(1)), // two lines
501            leaf(BufferLocation::Stored(0), 6, 4, Some(0)), // trailing text
502        ]);
503        let after = build(&[leaf(BufferLocation::Stored(0), 0, 6, Some(1))]);
504
505        let diff = diff_piece_trees(&before, &after, &count_line_feeds);
506        assert!(!diff.equal);
507        assert_eq!(diff.byte_ranges, vec![6..6]); // no bytes remain at the change site
508        assert_eq!(diff.line_ranges, Some(vec![1..2])); // anchor after deleted span
509    }
510
511    #[test]
512    fn tolerates_split_leaves_with_same_content_prefix() {
513        let before = build(&[leaf(BufferLocation::Stored(0), 0, 100, Some(1))]);
514        let after = build(&[
515            leaf(BufferLocation::Stored(0), 0, 50, Some(0)),
516            leaf(BufferLocation::Added(1), 0, 10, Some(0)),
517            leaf(BufferLocation::Stored(0), 50, 50, Some(1)),
518        ]);
519
520        let diff = diff_piece_trees(&before, &after, &count_line_feeds);
521        assert!(!diff.equal);
522        // Only the inserted span should be marked.
523        assert_eq!(diff.byte_ranges, vec![50..60]);
524    }
525}