1use crate::{Error, Result};
10use camino::{Utf8Component, Utf8Path, Utf8PathBuf};
11use serde::{Deserialize, Serialize};
12use std::collections::BTreeSet;
13
14#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
20pub struct SpanReplacement {
21 pub file: Utf8PathBuf,
23 pub start_line: u32,
25 pub end_line: u32,
27 pub replacement: String,
29}
30
31#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
33pub struct AppliedPatch {
34 pub file: Utf8PathBuf,
36 pub start_line: u32,
38 pub end_line: u32,
40 pub replaced_text: String,
42}
43
44pub fn apply_single_span(root: &Utf8Path, edit: &SpanReplacement) -> Result<AppliedPatch> {
50 reject_top_level_command(edit)?;
51 let target = resolve_within(root, &edit.file)?;
52
53 if edit.start_line == 0 || edit.end_line == 0 {
54 return Err(Error::ZeroLineSpan {
55 file: edit.file.clone(),
56 });
57 }
58 if edit.start_line > edit.end_line {
59 return Err(Error::InvertedSpan {
60 file: edit.file.clone(),
61 start_line: edit.start_line,
62 end_line: edit.end_line,
63 });
64 }
65
66 let source = std::fs::read_to_string(&target)?;
67 let ranges = line_content_ranges(&source);
68 if (edit.end_line as usize) > ranges.len() {
69 return Err(Error::SpanOutOfBounds {
70 file: edit.file.clone(),
71 start_line: edit.start_line,
72 end_line: edit.end_line,
73 line_count: ranges.len(),
74 });
75 }
76
77 let start = ranges[(edit.start_line - 1) as usize].0;
78 let end = ranges[(edit.end_line - 1) as usize].1;
79 let replaced_text = source.get(start..end).unwrap_or("").to_owned();
80
81 let mut rewritten = String::with_capacity(source.len() + edit.replacement.len());
82 rewritten.push_str(source.get(..start).unwrap_or(""));
83 rewritten.push_str(&edit.replacement);
84 rewritten.push_str(source.get(end..).unwrap_or(""));
85 std::fs::write(&target, rewritten)?;
86
87 Ok(AppliedPatch {
88 file: edit.file.clone(),
89 start_line: edit.start_line,
90 end_line: edit.end_line,
91 replaced_text,
92 })
93}
94
95pub fn apply_edits(
102 root: &Utf8Path,
103 edits: &[SpanReplacement],
104 allow_multi_file: bool,
105) -> Result<Vec<AppliedPatch>> {
106 let distinct_files: BTreeSet<&Utf8PathBuf> = edits.iter().map(|edit| &edit.file).collect();
107 if distinct_files.len() > 1 && !allow_multi_file {
108 return Err(Error::MultiFileEditNotAllowed {
109 files: distinct_files.len(),
110 });
111 }
112
113 let mut ordered: Vec<&SpanReplacement> = edits.iter().collect();
114 ordered.sort_by(|a, b| a.file.cmp(&b.file).then(b.start_line.cmp(&a.start_line)));
115
116 let mut applied = Vec::with_capacity(ordered.len());
117 for edit in ordered {
118 applied.push(apply_single_span(root, edit)?);
119 }
120 Ok(applied)
121}
122
123const TOP_LEVEL_COMMANDS: [&str; 5] = ["import", "set_option", "macro", "elab", "open"];
125
126fn reject_top_level_command(edit: &SpanReplacement) -> Result<()> {
136 for line in edit.replacement.lines() {
137 if line.starts_with([' ', '\t']) {
140 continue;
141 }
142 let trimmed = line.trim_end();
143 let opens_command = trimmed.starts_with('#')
144 || trimmed
145 .split_whitespace()
146 .next()
147 .is_some_and(|token| TOP_LEVEL_COMMANDS.contains(&token));
148 if opens_command {
149 return Err(Error::DisallowedReplacement {
150 file: edit.file.clone(),
151 detail: format!("line `{trimmed}` opens a top-level command"),
152 });
153 }
154 }
155 Ok(())
156}
157
158fn resolve_within(root: &Utf8Path, rel: &Utf8Path) -> Result<Utf8PathBuf> {
160 if rel.is_absolute() {
161 return Err(Error::OutsideWorkspace {
162 path: rel.to_path_buf(),
163 });
164 }
165 for component in rel.components() {
166 match component {
167 Utf8Component::Normal(_) | Utf8Component::CurDir => {}
168 _ => {
169 return Err(Error::OutsideWorkspace {
170 path: rel.to_path_buf(),
171 });
172 }
173 }
174 }
175 Ok(root.join(rel))
176}
177
178fn line_content_ranges(source: &str) -> Vec<(usize, usize)> {
182 let bytes = source.as_bytes();
183 let mut ranges = Vec::new();
184 let mut start = 0usize;
185 let mut i = 0usize;
186 while i < bytes.len() {
187 if bytes[i] == b'\n' {
188 let mut end = i;
189 if end > start && bytes[end - 1] == b'\r' {
190 end -= 1;
191 }
192 ranges.push((start, end));
193 start = i + 1;
194 }
195 i += 1;
196 }
197 if start < bytes.len() {
198 let mut end = bytes.len();
199 if end > start && bytes[end - 1] == b'\r' {
200 end -= 1;
201 }
202 ranges.push((start, end));
203 }
204 ranges
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use tempfile::TempDir;
211
212 fn workspace_with(name: &str, contents: &str) -> Result<(TempDir, Utf8PathBuf)> {
213 let dir = TempDir::new()?;
214 let root = Utf8PathBuf::from_path_buf(dir.path().to_path_buf())
215 .map_err(|path| Error::NonUtf8Path { path })?;
216 std::fs::write(root.join(name), contents)?;
217 Ok((dir, root))
218 }
219
220 fn span(file: &str, start_line: u32, end_line: u32, replacement: &str) -> SpanReplacement {
221 SpanReplacement {
222 file: Utf8PathBuf::from(file),
223 start_line,
224 end_line,
225 replacement: replacement.to_owned(),
226 }
227 }
228
229 #[test]
230 fn replaces_single_line_and_preserves_newlines() -> Result<()> {
231 let (_dir, root) = workspace_with("A.lean", "line one\n sorry\nline three\n")?;
232 let applied = apply_single_span(&root, &span("A.lean", 2, 2, " exact rfl"))?;
233 assert_eq!(applied.replaced_text, " sorry");
234 let after = std::fs::read_to_string(root.join("A.lean"))?;
235 assert_eq!(after, "line one\n exact rfl\nline three\n");
236 Ok(())
237 }
238
239 #[test]
240 fn replaces_multi_line_range() -> Result<()> {
241 let (_dir, root) = workspace_with("A.lean", "a\nb\nc\nd\n")?;
242 apply_single_span(&root, &span("A.lean", 2, 3, "X\nY\nZ"))?;
243 let after = std::fs::read_to_string(root.join("A.lean"))?;
244 assert_eq!(after, "a\nX\nY\nZ\nd\n");
245 Ok(())
246 }
247
248 #[test]
249 fn replaces_final_line_without_trailing_newline() -> Result<()> {
250 let (_dir, root) = workspace_with("A.lean", "a\nb")?;
251 apply_single_span(&root, &span("A.lean", 2, 2, "bb"))?;
252 let after = std::fs::read_to_string(root.join("A.lean"))?;
253 assert_eq!(after, "a\nbb");
254 Ok(())
255 }
256
257 #[test]
258 fn rejects_span_past_end_of_file() -> Result<()> {
259 let (_dir, root) = workspace_with("A.lean", "a\nb\n")?;
260 let result = apply_single_span(&root, &span("A.lean", 3, 3, "c"));
261 assert!(matches!(
262 result,
263 Err(Error::SpanOutOfBounds { line_count: 2, .. })
264 ));
265 Ok(())
266 }
267
268 #[test]
269 fn rejects_zero_and_inverted_spans() -> Result<()> {
270 let (_dir, root) = workspace_with("A.lean", "a\nb\n")?;
271 assert!(matches!(
272 apply_single_span(&root, &span("A.lean", 0, 1, "x")),
273 Err(Error::ZeroLineSpan { .. })
274 ));
275 assert!(matches!(
276 apply_single_span(&root, &span("A.lean", 2, 1, "x")),
277 Err(Error::InvertedSpan { .. })
278 ));
279 Ok(())
280 }
281
282 #[test]
283 fn rejects_paths_escaping_workspace() -> Result<()> {
284 let (_dir, root) = workspace_with("A.lean", "a\n")?;
285 assert!(matches!(
286 apply_single_span(&root, &span("../escape.lean", 1, 1, "x")),
287 Err(Error::OutsideWorkspace { .. })
288 ));
289 assert!(matches!(
290 apply_single_span(&root, &span("/etc/passwd", 1, 1, "x")),
291 Err(Error::OutsideWorkspace { .. })
292 ));
293 Ok(())
294 }
295
296 #[test]
297 fn multi_file_is_refused_by_default_and_allowed_behind_flag() -> Result<()> {
298 let (_dir, root) = workspace_with("A.lean", "a\n")?;
299 std::fs::write(root.join("B.lean"), "b\n")?;
300 let edits = vec![span("A.lean", 1, 1, "aa"), span("B.lean", 1, 1, "bb")];
301 assert!(matches!(
302 apply_edits(&root, &edits, false),
303 Err(Error::MultiFileEditNotAllowed { files: 2 })
304 ));
305
306 let applied = apply_edits(&root, &edits, true)?;
307 assert_eq!(applied.len(), 2);
308 assert_eq!(std::fs::read_to_string(root.join("A.lean"))?, "aa\n");
309 assert_eq!(std::fs::read_to_string(root.join("B.lean"))?, "bb\n");
310 Ok(())
311 }
312
313 #[test]
314 fn rejects_replacement_that_injects_a_top_level_command() -> Result<()> {
315 let (_dir, root) = workspace_with("A.lean", "theorem t : True := by\n sorry\n")?;
316 let injection = " exact trivial\n#eval IO.println \"pwn\"";
318 assert!(matches!(
319 apply_single_span(&root, &span("A.lean", 2, 2, injection)),
320 Err(Error::DisallowedReplacement { .. })
321 ));
322 for command in ["import Foo", "set_option x true", "open Foo"] {
324 assert!(
325 matches!(
326 apply_single_span(&root, &span("A.lean", 2, 2, command)),
327 Err(Error::DisallowedReplacement { .. })
328 ),
329 "expected refusal for `{command}`"
330 );
331 }
332 let applied = apply_single_span(&root, &span("A.lean", 2, 2, " exact trivial"))?;
334 assert_eq!(applied.replaced_text, " sorry");
335 Ok(())
336 }
337
338 #[test]
339 fn allows_column_zero_declaration_replacement() -> Result<()> {
340 let (_dir, root) = workspace_with("A.lean", "theorem t : 2 = 3 := rfl\n")?;
341 let applied = apply_single_span(&root, &span("A.lean", 1, 1, "theorem t : 2 = 2 := rfl"))?;
344 assert_eq!(applied.start_line, 1);
345 assert_eq!(
346 std::fs::read_to_string(root.join("A.lean"))?,
347 "theorem t : 2 = 2 := rfl\n"
348 );
349 Ok(())
350 }
351
352 #[test]
353 fn same_file_edits_apply_top_down_without_offset_drift() -> Result<()> {
354 let (_dir, root) = workspace_with("A.lean", "one\ntwo\nthree\n")?;
355 let edits = vec![span("A.lean", 1, 1, "ONE"), span("A.lean", 3, 3, "THREE")];
356 apply_edits(&root, &edits, false)?;
357 assert_eq!(
358 std::fs::read_to_string(root.join("A.lean"))?,
359 "ONE\ntwo\nTHREE\n"
360 );
361 Ok(())
362 }
363}