Skip to main content

lean_ctx/tools/
ctx_edit.rs

1use std::path::Path;
2
3use crate::core::cache::SessionCache;
4use crate::core::tokens::count_tokens;
5
6pub struct EditParams {
7    pub path: String,
8    pub old_string: String,
9    pub new_string: String,
10    pub replace_all: bool,
11    pub create: bool,
12}
13
14struct ReplaceArgs<'a> {
15    content: &'a str,
16    old_str: &'a str,
17    new_str: &'a str,
18    occurrences: usize,
19    replace_all: bool,
20    old_tokens: usize,
21    new_tokens: usize,
22}
23
24pub fn handle(cache: &mut SessionCache, params: EditParams) -> String {
25    let file_path = &params.path;
26
27    if params.create {
28        return handle_create(cache, file_path, &params.new_string);
29    }
30
31    let cap = crate::core::limits::max_read_bytes();
32    if let Ok(meta) = std::fs::metadata(file_path) {
33        if meta.len() > cap as u64 {
34            return format!(
35                "ERROR: file too large ({} bytes, cap {} via LCTX_MAX_READ_BYTES): {file_path}",
36                meta.len(),
37                cap
38            );
39        }
40    }
41
42    let mut file = match std::fs::OpenOptions::new()
43        .read(true)
44        .write(true)
45        .open(file_path)
46    {
47        Ok(f) => f,
48        Err(e) => return format!("ERROR: cannot open {file_path}: {e}"),
49    };
50
51    let mut raw_bytes: Vec<u8> = Vec::new();
52    {
53        use std::io::Read;
54        let mut limited = (&mut file).take((cap as u64).saturating_add(1));
55        if let Err(e) = limited.read_to_end(&mut raw_bytes) {
56            return format!("ERROR: cannot read {file_path}: {e}");
57        }
58    }
59    if raw_bytes.len() > cap {
60        return format!(
61            "ERROR: file too large (cap {} via LCTX_MAX_READ_BYTES): {file_path}",
62            cap
63        );
64    }
65
66    let content = String::from_utf8_lossy(&raw_bytes).into_owned();
67
68    if params.old_string.is_empty() {
69        return "ERROR: old_string must not be empty (use create=true to create a new file)".into();
70    }
71
72    let uses_crlf = content.contains("\r\n");
73    let old_str = &params.old_string;
74    let new_str = &params.new_string;
75
76    let occurrences = content.matches(old_str).count();
77
78    if occurrences > 0 {
79        let args = ReplaceArgs {
80            content: &content,
81            old_str,
82            new_str,
83            occurrences,
84            replace_all: params.replace_all,
85            old_tokens: count_tokens(&params.old_string),
86            new_tokens: count_tokens(&params.new_string),
87        };
88        return do_replace(cache, &mut file, file_path, args);
89    }
90
91    // Direct match failed -- try CRLF/LF normalization
92    if uses_crlf && !old_str.contains('\r') {
93        let old_crlf = old_str.replace('\n', "\r\n");
94        let occ = content.matches(&old_crlf).count();
95        if occ > 0 {
96            let new_crlf = new_str.replace('\n', "\r\n");
97            let args = ReplaceArgs {
98                content: &content,
99                old_str: &old_crlf,
100                new_str: &new_crlf,
101                occurrences: occ,
102                replace_all: params.replace_all,
103                old_tokens: count_tokens(&params.old_string),
104                new_tokens: count_tokens(&params.new_string),
105            };
106            return do_replace(cache, &mut file, file_path, args);
107        }
108    } else if !uses_crlf && old_str.contains("\r\n") {
109        let old_lf = old_str.replace("\r\n", "\n");
110        let occ = content.matches(&old_lf).count();
111        if occ > 0 {
112            let new_lf = new_str.replace("\r\n", "\n");
113            let args = ReplaceArgs {
114                content: &content,
115                old_str: &old_lf,
116                new_str: &new_lf,
117                occurrences: occ,
118                replace_all: params.replace_all,
119                old_tokens: count_tokens(&params.old_string),
120                new_tokens: count_tokens(&params.new_string),
121            };
122            return do_replace(cache, &mut file, file_path, args);
123        }
124    }
125
126    // Still not found -- try trimmed trailing whitespace per line
127    let normalized_content = trim_trailing_per_line(&content);
128    let normalized_old = trim_trailing_per_line(old_str);
129    if !normalized_old.is_empty() && normalized_content.contains(&normalized_old) {
130        let line_sep = if uses_crlf { "\r\n" } else { "\n" };
131        let adapted_new = adapt_new_string_to_line_sep(new_str, line_sep);
132        let adapted_old = find_original_span(&content, &normalized_old);
133        if let Some(original_match) = adapted_old {
134            let occ = content.matches(&original_match).count();
135            let args = ReplaceArgs {
136                content: &content,
137                old_str: &original_match,
138                new_str: &adapted_new,
139                occurrences: occ,
140                replace_all: params.replace_all,
141                old_tokens: count_tokens(&params.old_string),
142                new_tokens: count_tokens(&params.new_string),
143            };
144            return do_replace(cache, &mut file, file_path, args);
145        }
146    }
147
148    let preview = if old_str.len() > 80 {
149        format!("{}...", &old_str[..77])
150    } else {
151        old_str.clone()
152    };
153    let hint = if uses_crlf {
154        " (file uses CRLF line endings)"
155    } else {
156        ""
157    };
158    format!(
159        "ERROR: old_string not found in {file_path}{hint}. \
160         Make sure it matches exactly (including whitespace/indentation).\n\
161         Searched for: {preview}"
162    )
163}
164
165fn do_replace(
166    cache: &mut SessionCache,
167    file: &mut std::fs::File,
168    file_path: &str,
169    args: ReplaceArgs<'_>,
170) -> String {
171    if args.occurrences > 1 && !args.replace_all {
172        return format!(
173            "ERROR: old_string found {} times in {file_path}. \
174             Use replace_all=true to replace all, or provide more context to make old_string unique."
175            , args.occurrences
176        );
177    }
178
179    let new_content = if args.replace_all {
180        args.content.replace(args.old_str, args.new_str)
181    } else {
182        args.content.replacen(args.old_str, args.new_str, 1)
183    };
184
185    use std::io::{Seek, SeekFrom, Write};
186    if let Err(e) = file.set_len(0) {
187        return format!("ERROR: cannot write {file_path}: {e}");
188    }
189    if let Err(e) = file.seek(SeekFrom::Start(0)) {
190        return format!("ERROR: cannot write {file_path}: {e}");
191    }
192    if let Err(e) = file.write_all(new_content.as_bytes()) {
193        return format!("ERROR: cannot write {file_path}: {e}");
194    }
195    let _ = file.flush();
196    let _ = file.sync_all();
197
198    cache.invalidate(file_path);
199
200    let old_lines = args.content.lines().count();
201    let new_lines = new_content.lines().count();
202    let line_delta = new_lines as i64 - old_lines as i64;
203    let delta_str = if line_delta > 0 {
204        format!("+{line_delta}")
205    } else {
206        format!("{line_delta}")
207    };
208
209    let old_tokens = args.old_tokens;
210    let new_tokens = args.new_tokens;
211
212    let replaced_str = if args.replace_all && args.occurrences > 1 {
213        format!("{} replacements", args.occurrences)
214    } else {
215        "1 replacement".into()
216    };
217
218    let short = Path::new(file_path)
219        .file_name()
220        .map(|f| f.to_string_lossy().to_string())
221        .unwrap_or_else(|| file_path.to_string());
222
223    format!("✓ {short}: {replaced_str}, {delta_str} lines ({old_tokens}→{new_tokens} tok)")
224}
225
226fn handle_create(cache: &mut SessionCache, file_path: &str, content: &str) -> String {
227    if let Some(parent) = Path::new(file_path).parent() {
228        if !parent.exists() {
229            if let Err(e) = std::fs::create_dir_all(parent) {
230                return format!("ERROR: cannot create directory {}: {e}", parent.display());
231            }
232        }
233    }
234
235    if let Err(e) = std::fs::write(file_path, content) {
236        return format!("ERROR: cannot write {file_path}: {e}");
237    }
238
239    cache.invalidate(file_path);
240
241    let lines = content.lines().count();
242    let tokens = count_tokens(content);
243    let short = Path::new(file_path)
244        .file_name()
245        .map(|f| f.to_string_lossy().to_string())
246        .unwrap_or_else(|| file_path.to_string());
247
248    format!("✓ created {short}: {lines} lines, {tokens} tok")
249}
250
251fn trim_trailing_per_line(s: &str) -> String {
252    s.lines()
253        .map(|l| l.trim_end())
254        .collect::<Vec<_>>()
255        .join("\n")
256}
257
258fn adapt_new_string_to_line_sep(s: &str, sep: &str) -> String {
259    let normalized = s.replace("\r\n", "\n");
260    if sep == "\r\n" {
261        normalized.replace('\n', "\r\n")
262    } else {
263        normalized
264    }
265}
266
267/// Find the original (un-trimmed) span in `content` that matches `normalized_needle`
268/// after trailing-whitespace trimming per line.
269fn find_original_span(content: &str, normalized_needle: &str) -> Option<String> {
270    let needle_lines: Vec<&str> = normalized_needle.lines().collect();
271    if needle_lines.is_empty() {
272        return None;
273    }
274
275    let content_lines: Vec<&str> = content.lines().collect();
276
277    'outer: for start in 0..content_lines.len() {
278        if start + needle_lines.len() > content_lines.len() {
279            break;
280        }
281        for (i, nl) in needle_lines.iter().enumerate() {
282            if content_lines[start + i].trim_end() != *nl {
283                continue 'outer;
284            }
285        }
286        let sep = if content.contains("\r\n") {
287            "\r\n"
288        } else {
289            "\n"
290        };
291        return Some(content_lines[start..start + needle_lines.len()].join(sep));
292    }
293    None
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use std::io::Write;
300    use tempfile::NamedTempFile;
301
302    fn make_temp(content: &str) -> NamedTempFile {
303        let mut f = NamedTempFile::new().unwrap();
304        f.write_all(content.as_bytes()).unwrap();
305        f
306    }
307
308    #[test]
309    fn replace_single_occurrence() {
310        let f = make_temp("fn hello() {\n    println!(\"hello\");\n}\n");
311        let mut cache = SessionCache::new();
312        let result = handle(
313            &mut cache,
314            EditParams {
315                path: f.path().to_str().unwrap().to_string(),
316                old_string: "hello".into(),
317                new_string: "world".into(),
318                replace_all: false,
319                create: false,
320            },
321        );
322        assert!(result.contains("ERROR"), "should fail: 'hello' appears 2x");
323    }
324
325    #[test]
326    fn replace_all() {
327        let f = make_temp("aaa bbb aaa\n");
328        let mut cache = SessionCache::new();
329        let result = handle(
330            &mut cache,
331            EditParams {
332                path: f.path().to_str().unwrap().to_string(),
333                old_string: "aaa".into(),
334                new_string: "ccc".into(),
335                replace_all: true,
336                create: false,
337            },
338        );
339        assert!(result.contains("2 replacements"));
340        let content = std::fs::read_to_string(f.path()).unwrap();
341        assert_eq!(content, "ccc bbb ccc\n");
342    }
343
344    #[test]
345    fn not_found_error() {
346        let f = make_temp("some content\n");
347        let mut cache = SessionCache::new();
348        let result = handle(
349            &mut cache,
350            EditParams {
351                path: f.path().to_str().unwrap().to_string(),
352                old_string: "nonexistent".into(),
353                new_string: "x".into(),
354                replace_all: false,
355                create: false,
356            },
357        );
358        assert!(result.contains("ERROR: old_string not found"));
359    }
360
361    #[test]
362    fn create_new_file() {
363        let dir = tempfile::tempdir().unwrap();
364        let path = dir.path().join("sub/new_file.txt");
365        let mut cache = SessionCache::new();
366        let result = handle(
367            &mut cache,
368            EditParams {
369                path: path.to_str().unwrap().to_string(),
370                old_string: String::new(),
371                new_string: "line1\nline2\nline3\n".into(),
372                replace_all: false,
373                create: true,
374            },
375        );
376        assert!(result.contains("created new_file.txt"));
377        assert!(result.contains("3 lines"));
378        assert!(path.exists());
379    }
380
381    #[test]
382    fn unique_match_succeeds() {
383        let f = make_temp("fn main() {\n    let x = 42;\n}\n");
384        let mut cache = SessionCache::new();
385        let result = handle(
386            &mut cache,
387            EditParams {
388                path: f.path().to_str().unwrap().to_string(),
389                old_string: "let x = 42".into(),
390                new_string: "let x = 99".into(),
391                replace_all: false,
392                create: false,
393            },
394        );
395        assert!(result.contains("✓"));
396        assert!(result.contains("1 replacement"));
397        let content = std::fs::read_to_string(f.path()).unwrap();
398        assert!(content.contains("let x = 99"));
399    }
400
401    #[test]
402    fn crlf_file_with_lf_search() {
403        let f = make_temp("line1\r\nline2\r\nline3\r\n");
404        let mut cache = SessionCache::new();
405        let result = handle(
406            &mut cache,
407            EditParams {
408                path: f.path().to_str().unwrap().to_string(),
409                old_string: "line1\nline2".into(),
410                new_string: "changed1\nchanged2".into(),
411                replace_all: false,
412                create: false,
413            },
414        );
415        assert!(result.contains("✓"), "CRLF fallback should work: {result}");
416        let content = std::fs::read_to_string(f.path()).unwrap();
417        assert!(
418            content.contains("changed1\r\nchanged2"),
419            "new_string should be adapted to CRLF: {content:?}"
420        );
421        assert!(
422            content.contains("\r\nline3\r\n"),
423            "rest of file should keep CRLF: {content:?}"
424        );
425    }
426
427    #[test]
428    fn lf_file_with_crlf_search() {
429        let f = make_temp("line1\nline2\nline3\n");
430        let mut cache = SessionCache::new();
431        let result = handle(
432            &mut cache,
433            EditParams {
434                path: f.path().to_str().unwrap().to_string(),
435                old_string: "line1\r\nline2".into(),
436                new_string: "a\r\nb".into(),
437                replace_all: false,
438                create: false,
439            },
440        );
441        assert!(result.contains("✓"), "LF fallback should work: {result}");
442        let content = std::fs::read_to_string(f.path()).unwrap();
443        assert!(
444            content.contains("a\nb"),
445            "new_string should be adapted to LF: {content:?}"
446        );
447    }
448
449    #[test]
450    fn trailing_whitespace_tolerance() {
451        let f = make_temp("  let x = 1;  \n  let y = 2;\n");
452        let mut cache = SessionCache::new();
453        let result = handle(
454            &mut cache,
455            EditParams {
456                path: f.path().to_str().unwrap().to_string(),
457                old_string: "  let x = 1;\n  let y = 2;".into(),
458                new_string: "  let x = 10;\n  let y = 20;".into(),
459                replace_all: false,
460                create: false,
461            },
462        );
463        assert!(
464            result.contains("✓"),
465            "trailing whitespace tolerance should work: {result}"
466        );
467        let content = std::fs::read_to_string(f.path()).unwrap();
468        assert!(content.contains("let x = 10;"));
469        assert!(content.contains("let y = 20;"));
470    }
471
472    #[test]
473    fn crlf_with_trailing_whitespace() {
474        let f = make_temp("  const a = 1;  \r\n  const b = 2;\r\n");
475        let mut cache = SessionCache::new();
476        let result = handle(
477            &mut cache,
478            EditParams {
479                path: f.path().to_str().unwrap().to_string(),
480                old_string: "  const a = 1;\n  const b = 2;".into(),
481                new_string: "  const a = 10;\n  const b = 20;".into(),
482                replace_all: false,
483                create: false,
484            },
485        );
486        assert!(
487            result.contains("✓"),
488            "CRLF + trailing whitespace should work: {result}"
489        );
490        let content = std::fs::read_to_string(f.path()).unwrap();
491        assert!(content.contains("const a = 10;"));
492        assert!(content.contains("const b = 20;"));
493    }
494}