cersei_tools/tool_primitives/
diff.rs1use similar::{ChangeTag as SimilarTag, TextDiff};
7use std::fmt;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ChangeTag {
12 Added,
13 Removed,
14 Unchanged,
15}
16
17#[derive(Debug, Clone)]
19pub struct DiffLine {
20 pub tag: ChangeTag,
21 pub line_number_old: Option<usize>,
22 pub line_number_new: Option<usize>,
23 pub content: String,
24}
25
26#[derive(Debug)]
28pub struct PatchError {
29 pub message: String,
30}
31
32impl fmt::Display for PatchError {
33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 write!(f, "patch error: {}", self.message)
35 }
36}
37
38impl std::error::Error for PatchError {}
39
40pub fn unified_diff(old: &str, new: &str, context_lines: usize) -> String {
48 let diff = TextDiff::from_lines(old, new);
49 diff.unified_diff()
50 .context_radius(context_lines)
51 .header("old", "new")
52 .to_string()
53}
54
55pub fn line_diff(old: &str, new: &str) -> Vec<DiffLine> {
59 let diff = TextDiff::from_lines(old, new);
60 let mut result = Vec::new();
61 let mut old_line: usize = 1;
62 let mut new_line: usize = 1;
63
64 for change in diff.iter_all_changes() {
65 let tag = match change.tag() {
66 SimilarTag::Equal => ChangeTag::Unchanged,
67 SimilarTag::Insert => ChangeTag::Added,
68 SimilarTag::Delete => ChangeTag::Removed,
69 };
70
71 let (ln_old, ln_new) = match tag {
72 ChangeTag::Unchanged => {
73 let r = (Some(old_line), Some(new_line));
74 old_line += 1;
75 new_line += 1;
76 r
77 }
78 ChangeTag::Removed => {
79 let r = (Some(old_line), None);
80 old_line += 1;
81 r
82 }
83 ChangeTag::Added => {
84 let r = (None, Some(new_line));
85 new_line += 1;
86 r
87 }
88 };
89
90 result.push(DiffLine {
91 tag,
92 line_number_old: ln_old,
93 line_number_new: ln_new,
94 content: change.to_string_lossy().to_string(),
95 });
96 }
97
98 result
99}
100
101pub fn apply_patch(original: &str, patch: &str) -> Result<String, PatchError> {
107 let original_lines: Vec<&str> = original.lines().collect();
108 let mut result_lines: Vec<String> = Vec::new();
109 let mut orig_idx: usize = 0;
110
111 let patch_lines: Vec<&str> = patch.lines().collect();
112 let mut patch_idx: usize = 0;
113
114 while patch_idx < patch_lines.len() {
116 let line = patch_lines[patch_idx];
117 if line.starts_with("@@") {
118 break;
119 }
120 patch_idx += 1;
121 }
122
123 while patch_idx < patch_lines.len() {
124 let line = patch_lines[patch_idx];
125
126 if line.starts_with("@@") {
127 let parts: Vec<&str> = line.split_whitespace().collect();
129 if parts.len() < 3 {
130 return Err(PatchError {
131 message: format!("malformed hunk header: {}", line),
132 });
133 }
134
135 let old_part = parts[1].trim_start_matches('-');
136 let old_start: usize = old_part
137 .split(',')
138 .next()
139 .and_then(|s| s.parse().ok())
140 .unwrap_or(1);
141
142 while orig_idx + 1 < old_start && orig_idx < original_lines.len() {
144 result_lines.push(original_lines[orig_idx].to_string());
145 orig_idx += 1;
146 }
147
148 patch_idx += 1;
149 continue;
150 }
151
152 if line.starts_with('-') {
153 orig_idx += 1;
155 } else if line.starts_with('+') {
156 result_lines.push(line[1..].to_string());
158 } else if line.starts_with(' ') || line.is_empty() {
159 if orig_idx < original_lines.len() {
161 result_lines.push(original_lines[orig_idx].to_string());
162 orig_idx += 1;
163 }
164 }
165
166 patch_idx += 1;
167 }
168
169 while orig_idx < original_lines.len() {
171 result_lines.push(original_lines[orig_idx].to_string());
172 orig_idx += 1;
173 }
174
175 Ok(result_lines.join("\n"))
176}
177
178#[cfg(test)]
181mod tests {
182 use super::*;
183
184 #[test]
185 fn test_unified_diff_basic() {
186 let old = "hello\nworld\n";
187 let new = "hello\nearth\n";
188 let diff = unified_diff(old, new, 3);
189 assert!(diff.contains("-world"));
190 assert!(diff.contains("+earth"));
191 assert!(diff.contains("@@"));
192 }
193
194 #[test]
195 fn test_unified_diff_identical() {
196 let text = "same\ncontent\n";
197 let diff = unified_diff(text, text, 3);
198 assert!(diff.is_empty() || !diff.contains("@@"));
199 }
200
201 #[test]
202 fn test_line_diff_basic() {
203 let old = "a\nb\nc\n";
204 let new = "a\nB\nc\n";
205 let lines = line_diff(old, new);
206
207 let removed: Vec<_> = lines
208 .iter()
209 .filter(|l| l.tag == ChangeTag::Removed)
210 .collect();
211 let added: Vec<_> = lines.iter().filter(|l| l.tag == ChangeTag::Added).collect();
212
213 assert_eq!(removed.len(), 1);
214 assert_eq!(added.len(), 1);
215 assert!(removed[0].content.contains('b'));
216 assert!(added[0].content.contains('B'));
217 }
218
219 #[test]
220 fn test_line_diff_empty() {
221 let lines = line_diff("", "");
222 assert!(lines.is_empty());
223 }
224
225 #[test]
226 fn test_apply_patch_basic() {
227 let old = "hello\nworld\nfoo\n";
228 let new = "hello\nearth\nfoo\n";
229 let patch = unified_diff(old, new, 3);
230 let result = apply_patch(old, &patch).unwrap();
231 assert!(result.contains("earth"));
232 assert!(!result.contains("world"));
233 }
234
235 #[test]
236 fn test_line_numbers() {
237 let old = "a\nb\nc\n";
238 let new = "a\nc\n";
239 let lines = line_diff(old, new);
240
241 let removed = lines.iter().find(|l| l.tag == ChangeTag::Removed).unwrap();
242 assert_eq!(removed.line_number_old, Some(2));
243 assert_eq!(removed.line_number_new, None);
244 }
245}