Skip to main content

aster/checkpoint/
diff.rs

1//! Diff 引擎
2//!
3//! 计算和应用文件差异
4
5use serde::{Deserialize, Serialize};
6
7/// Diff 操作类型
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub enum DiffOp {
10    #[serde(rename = "add")]
11    Add,
12    #[serde(rename = "del")]
13    Del,
14    #[serde(rename = "eq")]
15    Eq,
16}
17
18/// Diff 条目
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct DiffEntry {
21    pub op: DiffOp,
22    pub line: String,
23    pub num: usize,
24}
25
26/// Diff 引擎
27pub struct DiffEngine;
28
29impl DiffEngine {
30    /// 创建新的 Diff 引擎
31    pub fn new() -> Self {
32        Self
33    }
34
35    /// 计算两个字符串之间的 diff
36    pub fn calculate_diff(&self, old_content: &str, new_content: &str) -> String {
37        let old_lines: Vec<&str> = old_content.lines().collect();
38        let new_lines: Vec<&str> = new_content.lines().collect();
39
40        let lcs = self.longest_common_subsequence(&old_lines, &new_lines);
41        let mut diff: Vec<DiffEntry> = Vec::new();
42
43        let mut old_idx = 0;
44        let mut new_idx = 0;
45        let mut lcs_idx = 0;
46
47        while old_idx < old_lines.len() || new_idx < new_lines.len() {
48            if lcs_idx < lcs.len() {
49                // 找到下一个公共行
50                while old_idx < old_lines.len() && old_lines[old_idx] != lcs[lcs_idx] {
51                    diff.push(DiffEntry {
52                        op: DiffOp::Del,
53                        line: old_lines[old_idx].to_string(),
54                        num: old_idx,
55                    });
56                    old_idx += 1;
57                }
58                while new_idx < new_lines.len() && new_lines[new_idx] != lcs[lcs_idx] {
59                    diff.push(DiffEntry {
60                        op: DiffOp::Add,
61                        line: new_lines[new_idx].to_string(),
62                        num: new_idx,
63                    });
64                    new_idx += 1;
65                }
66                if old_idx < old_lines.len() && new_idx < new_lines.len() {
67                    diff.push(DiffEntry {
68                        op: DiffOp::Eq,
69                        line: old_lines[old_idx].to_string(),
70                        num: old_idx,
71                    });
72                    old_idx += 1;
73                    new_idx += 1;
74                    lcs_idx += 1;
75                }
76            } else {
77                // 剩余行
78                while old_idx < old_lines.len() {
79                    diff.push(DiffEntry {
80                        op: DiffOp::Del,
81                        line: old_lines[old_idx].to_string(),
82                        num: old_idx,
83                    });
84                    old_idx += 1;
85                }
86                while new_idx < new_lines.len() {
87                    diff.push(DiffEntry {
88                        op: DiffOp::Add,
89                        line: new_lines[new_idx].to_string(),
90                        num: new_idx,
91                    });
92                    new_idx += 1;
93                }
94            }
95        }
96
97        serde_json::to_string(&diff).unwrap_or_default()
98    }
99
100    /// 应用 diff 到内容
101    pub fn apply_diff(&self, old_content: &str, diff_str: &str) -> String {
102        let diff: Vec<DiffEntry> = match serde_json::from_str(diff_str) {
103            Ok(d) => d,
104            Err(_) => return old_content.to_string(),
105        };
106
107        let mut result: Vec<String> = Vec::new();
108
109        for entry in diff {
110            match entry.op {
111                DiffOp::Add | DiffOp::Eq => {
112                    result.push(entry.line);
113                }
114                DiffOp::Del => {
115                    // 删除的行不添加到结果
116                }
117            }
118        }
119
120        result.join("\n")
121    }
122
123    /// 最长公共子序列算法
124    fn longest_common_subsequence<'a>(&self, arr1: &[&'a str], arr2: &[&'a str]) -> Vec<&'a str> {
125        let m = arr1.len();
126        let n = arr2.len();
127        let mut dp: Vec<Vec<usize>> = vec![vec![0; n + 1]; m + 1];
128
129        for i in 1..=m {
130            for j in 1..=n {
131                if arr1[i - 1] == arr2[j - 1] {
132                    dp[i][j] = dp[i - 1][j - 1] + 1;
133                } else {
134                    dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
135                }
136            }
137        }
138
139        // 回溯找到 LCS
140        let mut lcs: Vec<&'a str> = Vec::new();
141        let mut i = m;
142        let mut j = n;
143
144        while i > 0 && j > 0 {
145            if arr1[i - 1] == arr2[j - 1] {
146                lcs.push(arr1[i - 1]);
147                i -= 1;
148                j -= 1;
149            } else if dp[i - 1][j] > dp[i][j - 1] {
150                i -= 1;
151            } else {
152                j -= 1;
153            }
154        }
155
156        lcs.reverse();
157        lcs
158    }
159}
160
161impl Default for DiffEngine {
162    fn default() -> Self {
163        Self::new()
164    }
165}