Skip to main content

lean_ctx/tools/
ctx_edit.rs

1use std::path::Path;
2
3use crate::core::cache::SessionCache;
4use crate::core::tokens::count_tokens;
5
6pub struct EditParams {
7    pub path: String,
8    pub old_string: String,
9    pub new_string: String,
10    pub replace_all: bool,
11    pub create: bool,
12}
13
14pub fn handle(cache: &mut SessionCache, params: EditParams) -> String {
15    let file_path = &params.path;
16
17    if params.create {
18        return handle_create(cache, file_path, &params.new_string);
19    }
20
21    let raw_bytes = match std::fs::read(file_path) {
22        Ok(b) => b,
23        Err(e) => return format!("ERROR: cannot read {file_path}: {e}"),
24    };
25
26    let content = String::from_utf8_lossy(&raw_bytes).into_owned();
27
28    if params.old_string.is_empty() {
29        return "ERROR: old_string must not be empty (use create=true to create a new file)".into();
30    }
31
32    let uses_crlf = content.contains("\r\n");
33    let old_str = &params.old_string;
34    let new_str = &params.new_string;
35
36    let occurrences = content.matches(old_str).count();
37
38    if occurrences > 0 {
39        return do_replace(
40            cache,
41            file_path,
42            &content,
43            old_str,
44            new_str,
45            occurrences,
46            &params,
47        );
48    }
49
50    // Direct match failed -- try CRLF/LF normalization
51    if uses_crlf && !old_str.contains('\r') {
52        let old_crlf = old_str.replace('\n', "\r\n");
53        let occ = content.matches(&old_crlf).count();
54        if occ > 0 {
55            let new_crlf = new_str.replace('\n', "\r\n");
56            return do_replace(
57                cache, file_path, &content, &old_crlf, &new_crlf, occ, &params,
58            );
59        }
60    } else if !uses_crlf && old_str.contains("\r\n") {
61        let old_lf = old_str.replace("\r\n", "\n");
62        let occ = content.matches(&old_lf).count();
63        if occ > 0 {
64            let new_lf = new_str.replace("\r\n", "\n");
65            return do_replace(cache, file_path, &content, &old_lf, &new_lf, occ, &params);
66        }
67    }
68
69    // Still not found -- try trimmed trailing whitespace per line
70    let normalized_content = trim_trailing_per_line(&content);
71    let normalized_old = trim_trailing_per_line(old_str);
72    if !normalized_old.is_empty() && normalized_content.contains(&normalized_old) {
73        let line_sep = if uses_crlf { "\r\n" } else { "\n" };
74        let adapted_new = adapt_new_string_to_line_sep(new_str, line_sep);
75        let adapted_old = find_original_span(&content, &normalized_old);
76        if let Some(original_match) = adapted_old {
77            let occ = content.matches(&original_match).count();
78            return do_replace(
79                cache,
80                file_path,
81                &content,
82                &original_match,
83                &adapted_new,
84                occ,
85                &params,
86            );
87        }
88    }
89
90    let preview = if old_str.len() > 80 {
91        format!("{}...", &old_str[..77])
92    } else {
93        old_str.clone()
94    };
95    let hint = if uses_crlf {
96        " (file uses CRLF line endings)"
97    } else {
98        ""
99    };
100    format!(
101        "ERROR: old_string not found in {file_path}{hint}. \
102         Make sure it matches exactly (including whitespace/indentation).\n\
103         Searched for: {preview}"
104    )
105}
106
107fn do_replace(
108    cache: &mut SessionCache,
109    file_path: &str,
110    content: &str,
111    old_str: &str,
112    new_str: &str,
113    occurrences: usize,
114    params: &EditParams,
115) -> String {
116    if occurrences > 1 && !params.replace_all {
117        return format!(
118            "ERROR: old_string found {occurrences} times in {file_path}. \
119             Use replace_all=true to replace all, or provide more context to make old_string unique."
120        );
121    }
122
123    let new_content = if params.replace_all {
124        content.replace(old_str, new_str)
125    } else {
126        content.replacen(old_str, new_str, 1)
127    };
128
129    if let Err(e) = std::fs::write(file_path, &new_content) {
130        return format!("ERROR: cannot write {file_path}: {e}");
131    }
132
133    cache.invalidate(file_path);
134
135    let old_lines = content.lines().count();
136    let new_lines = new_content.lines().count();
137    let line_delta = new_lines as i64 - old_lines as i64;
138    let delta_str = if line_delta > 0 {
139        format!("+{line_delta}")
140    } else {
141        format!("{line_delta}")
142    };
143
144    let old_tokens = count_tokens(&params.old_string);
145    let new_tokens = count_tokens(&params.new_string);
146
147    let replaced_str = if params.replace_all && occurrences > 1 {
148        format!("{occurrences} replacements")
149    } else {
150        "1 replacement".into()
151    };
152
153    let short = Path::new(file_path)
154        .file_name()
155        .map(|f| f.to_string_lossy().to_string())
156        .unwrap_or_else(|| file_path.to_string());
157
158    format!("✓ {short}: {replaced_str}, {delta_str} lines ({old_tokens}→{new_tokens} tok)")
159}
160
161fn handle_create(cache: &mut SessionCache, file_path: &str, content: &str) -> String {
162    if let Some(parent) = Path::new(file_path).parent() {
163        if !parent.exists() {
164            if let Err(e) = std::fs::create_dir_all(parent) {
165                return format!("ERROR: cannot create directory {}: {e}", parent.display());
166            }
167        }
168    }
169
170    if let Err(e) = std::fs::write(file_path, content) {
171        return format!("ERROR: cannot write {file_path}: {e}");
172    }
173
174    cache.invalidate(file_path);
175
176    let lines = content.lines().count();
177    let tokens = count_tokens(content);
178    let short = Path::new(file_path)
179        .file_name()
180        .map(|f| f.to_string_lossy().to_string())
181        .unwrap_or_else(|| file_path.to_string());
182
183    format!("✓ created {short}: {lines} lines, {tokens} tok")
184}
185
186fn trim_trailing_per_line(s: &str) -> String {
187    s.lines()
188        .map(|l| l.trim_end())
189        .collect::<Vec<_>>()
190        .join("\n")
191}
192
193fn adapt_new_string_to_line_sep(s: &str, sep: &str) -> String {
194    let normalized = s.replace("\r\n", "\n");
195    if sep == "\r\n" {
196        normalized.replace('\n', "\r\n")
197    } else {
198        normalized
199    }
200}
201
202/// Find the original (un-trimmed) span in `content` that matches `normalized_needle`
203/// after trailing-whitespace trimming per line.
204fn find_original_span(content: &str, normalized_needle: &str) -> Option<String> {
205    let needle_lines: Vec<&str> = normalized_needle.lines().collect();
206    if needle_lines.is_empty() {
207        return None;
208    }
209
210    let content_lines: Vec<&str> = content.lines().collect();
211
212    'outer: for start in 0..content_lines.len() {
213        if start + needle_lines.len() > content_lines.len() {
214            break;
215        }
216        for (i, nl) in needle_lines.iter().enumerate() {
217            if content_lines[start + i].trim_end() != *nl {
218                continue 'outer;
219            }
220        }
221        let sep = if content.contains("\r\n") {
222            "\r\n"
223        } else {
224            "\n"
225        };
226        return Some(content_lines[start..start + needle_lines.len()].join(sep));
227    }
228    None
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use std::io::Write;
235    use tempfile::NamedTempFile;
236
237    fn make_temp(content: &str) -> NamedTempFile {
238        let mut f = NamedTempFile::new().unwrap();
239        f.write_all(content.as_bytes()).unwrap();
240        f
241    }
242
243    #[test]
244    fn replace_single_occurrence() {
245        let f = make_temp("fn hello() {\n    println!(\"hello\");\n}\n");
246        let mut cache = SessionCache::new();
247        let result = handle(
248            &mut cache,
249            EditParams {
250                path: f.path().to_str().unwrap().to_string(),
251                old_string: "hello".into(),
252                new_string: "world".into(),
253                replace_all: false,
254                create: false,
255            },
256        );
257        assert!(result.contains("ERROR"), "should fail: 'hello' appears 2x");
258    }
259
260    #[test]
261    fn replace_all() {
262        let f = make_temp("aaa bbb aaa\n");
263        let mut cache = SessionCache::new();
264        let result = handle(
265            &mut cache,
266            EditParams {
267                path: f.path().to_str().unwrap().to_string(),
268                old_string: "aaa".into(),
269                new_string: "ccc".into(),
270                replace_all: true,
271                create: false,
272            },
273        );
274        assert!(result.contains("2 replacements"));
275        let content = std::fs::read_to_string(f.path()).unwrap();
276        assert_eq!(content, "ccc bbb ccc\n");
277    }
278
279    #[test]
280    fn not_found_error() {
281        let f = make_temp("some content\n");
282        let mut cache = SessionCache::new();
283        let result = handle(
284            &mut cache,
285            EditParams {
286                path: f.path().to_str().unwrap().to_string(),
287                old_string: "nonexistent".into(),
288                new_string: "x".into(),
289                replace_all: false,
290                create: false,
291            },
292        );
293        assert!(result.contains("ERROR: old_string not found"));
294    }
295
296    #[test]
297    fn create_new_file() {
298        let dir = tempfile::tempdir().unwrap();
299        let path = dir.path().join("sub/new_file.txt");
300        let mut cache = SessionCache::new();
301        let result = handle(
302            &mut cache,
303            EditParams {
304                path: path.to_str().unwrap().to_string(),
305                old_string: String::new(),
306                new_string: "line1\nline2\nline3\n".into(),
307                replace_all: false,
308                create: true,
309            },
310        );
311        assert!(result.contains("created new_file.txt"));
312        assert!(result.contains("3 lines"));
313        assert!(path.exists());
314    }
315
316    #[test]
317    fn unique_match_succeeds() {
318        let f = make_temp("fn main() {\n    let x = 42;\n}\n");
319        let mut cache = SessionCache::new();
320        let result = handle(
321            &mut cache,
322            EditParams {
323                path: f.path().to_str().unwrap().to_string(),
324                old_string: "let x = 42".into(),
325                new_string: "let x = 99".into(),
326                replace_all: false,
327                create: false,
328            },
329        );
330        assert!(result.contains("✓"));
331        assert!(result.contains("1 replacement"));
332        let content = std::fs::read_to_string(f.path()).unwrap();
333        assert!(content.contains("let x = 99"));
334    }
335
336    #[test]
337    fn crlf_file_with_lf_search() {
338        let f = make_temp("line1\r\nline2\r\nline3\r\n");
339        let mut cache = SessionCache::new();
340        let result = handle(
341            &mut cache,
342            EditParams {
343                path: f.path().to_str().unwrap().to_string(),
344                old_string: "line1\nline2".into(),
345                new_string: "changed1\nchanged2".into(),
346                replace_all: false,
347                create: false,
348            },
349        );
350        assert!(result.contains("✓"), "CRLF fallback should work: {result}");
351        let content = std::fs::read_to_string(f.path()).unwrap();
352        assert!(
353            content.contains("changed1\r\nchanged2"),
354            "new_string should be adapted to CRLF: {content:?}"
355        );
356        assert!(
357            content.contains("\r\nline3\r\n"),
358            "rest of file should keep CRLF: {content:?}"
359        );
360    }
361
362    #[test]
363    fn lf_file_with_crlf_search() {
364        let f = make_temp("line1\nline2\nline3\n");
365        let mut cache = SessionCache::new();
366        let result = handle(
367            &mut cache,
368            EditParams {
369                path: f.path().to_str().unwrap().to_string(),
370                old_string: "line1\r\nline2".into(),
371                new_string: "a\r\nb".into(),
372                replace_all: false,
373                create: false,
374            },
375        );
376        assert!(result.contains("✓"), "LF fallback should work: {result}");
377        let content = std::fs::read_to_string(f.path()).unwrap();
378        assert!(
379            content.contains("a\nb"),
380            "new_string should be adapted to LF: {content:?}"
381        );
382    }
383
384    #[test]
385    fn trailing_whitespace_tolerance() {
386        let f = make_temp("  let x = 1;  \n  let y = 2;\n");
387        let mut cache = SessionCache::new();
388        let result = handle(
389            &mut cache,
390            EditParams {
391                path: f.path().to_str().unwrap().to_string(),
392                old_string: "  let x = 1;\n  let y = 2;".into(),
393                new_string: "  let x = 10;\n  let y = 20;".into(),
394                replace_all: false,
395                create: false,
396            },
397        );
398        assert!(
399            result.contains("✓"),
400            "trailing whitespace tolerance should work: {result}"
401        );
402        let content = std::fs::read_to_string(f.path()).unwrap();
403        assert!(content.contains("let x = 10;"));
404        assert!(content.contains("let y = 20;"));
405    }
406
407    #[test]
408    fn crlf_with_trailing_whitespace() {
409        let f = make_temp("  const a = 1;  \r\n  const b = 2;\r\n");
410        let mut cache = SessionCache::new();
411        let result = handle(
412            &mut cache,
413            EditParams {
414                path: f.path().to_str().unwrap().to_string(),
415                old_string: "  const a = 1;\n  const b = 2;".into(),
416                new_string: "  const a = 10;\n  const b = 20;".into(),
417                replace_all: false,
418                create: false,
419            },
420        );
421        assert!(
422            result.contains("✓"),
423            "CRLF + trailing whitespace should work: {result}"
424        );
425        let content = std::fs::read_to_string(f.path()).unwrap();
426        assert!(content.contains("const a = 10;"));
427        assert!(content.contains("const b = 20;"));
428    }
429}