Skip to main content

krait/commands/
edit.rs

1use std::path::{Path, PathBuf};
2
3use anyhow::{bail, Context};
4use serde_json::{json, Value};
5
6use crate::commands::find::{find_symbol, SymbolMatch};
7use crate::index::watcher::DirtyFiles;
8use crate::lsp::client::LspClient;
9use crate::lsp::files::FileTracker;
10use crate::lsp::symbols::{resolve_symbol_range, SymbolLocation};
11
12// ── Shared helpers ────────────────────────────────────────────────────────────
13
14/// Locate a symbol and return its absolute file path + resolved range.
15///
16/// Mirrors the candidate-iteration logic from `handle_read_symbol`.
17async fn locate_symbol(
18    name: &str,
19    client: &mut LspClient,
20    file_tracker: &mut FileTracker,
21    project_root: &Path,
22) -> anyhow::Result<(PathBuf, SymbolLocation)> {
23    let search_name = name.split('.').next().unwrap_or(name);
24
25    let candidates: Vec<SymbolMatch> = find_symbol(search_name, client, project_root).await?;
26
27    if candidates.is_empty() {
28        bail!("symbol '{name}' not found");
29    }
30
31    let mut last_err: Option<anyhow::Error> = None;
32    for sym in &candidates {
33        let abs = project_root.join(&sym.path);
34        let hint_line = sym.line.checked_sub(1);
35        match resolve_symbol_range(search_name, &abs, hint_line, client, file_tracker).await {
36            Ok(loc) => {
37                let location = if name.contains('.') {
38                    resolve_symbol_range(name, &abs, hint_line, client, file_tracker).await?
39                } else {
40                    loc
41                };
42                return Ok((abs, location));
43            }
44            Err(e) => last_err = Some(e),
45        }
46    }
47
48    Err(last_err.unwrap_or_else(|| anyhow::anyhow!("symbol '{name}' not found")))
49}
50
51/// Atomically write `contents` to `path`.
52///
53/// Writes to a sibling `.tmp` file first, then renames so the write is atomic.
54fn atomic_write(path: &Path, contents: &str) -> anyhow::Result<()> {
55    let tmp = path.with_extension("tmp");
56    std::fs::write(&tmp, contents)
57        .with_context(|| format!("failed to write temp file: {}", tmp.display()))?;
58    std::fs::rename(&tmp, path).with_context(|| {
59        let _ = std::fs::remove_file(&tmp);
60        format!("failed to rename temp file to: {}", path.display())
61    })?;
62    Ok(())
63}
64
65/// Mark a file dirty in the watcher so the index is refreshed on next query.
66fn mark_dirty(abs_path: &Path, project_root: &Path, dirty_files: &DirtyFiles) {
67    if let Ok(rel) = abs_path.strip_prefix(project_root) {
68        dirty_files.mark_dirty(rel.to_string_lossy().into_owned());
69    }
70}
71
72/// Ensure a trailing newline in content if missing.
73fn ensure_trailing_newline(s: &str) -> String {
74    if s.ends_with('\n') {
75        s.to_string()
76    } else {
77        format!("{s}\n")
78    }
79}
80
81// ── edit replace ─────────────────────────────────────────────────────────────
82
83/// Replace a symbol's body with `code`.
84///
85/// # Errors
86/// Returns an error if the symbol can't be found or the file can't be written.
87pub async fn handle_edit_replace(
88    name: &str,
89    code: &str,
90    client: &mut LspClient,
91    file_tracker: &mut FileTracker,
92    project_root: &Path,
93    dirty_files: &DirtyFiles,
94) -> anyhow::Result<Value> {
95    let (abs_path, location) = locate_symbol(name, client, file_tracker, project_root).await?;
96
97    let content = std::fs::read_to_string(&abs_path)
98        .with_context(|| format!("failed to read: {}", abs_path.display()))?;
99
100    let mut lines: Vec<&str> = content.lines().collect();
101
102    let start = location.start_line as usize;
103    let end = (location.end_line as usize + 1).min(lines.len());
104
105    if start >= lines.len() {
106        bail!("symbol range out of bounds in {}", abs_path.display());
107    }
108
109    let original_count = end - start;
110    let new_lines: Vec<&str> = code.lines().collect();
111    let new_count = new_lines.len();
112
113    // Replace lines [start..end] with new_lines
114    lines.splice(start..end, new_lines.iter().copied());
115
116    let new_content = ensure_trailing_newline(&lines.join("\n"));
117    atomic_write(&abs_path, &new_content)?;
118    mark_dirty(&abs_path, project_root, dirty_files);
119
120    let rel_path = abs_path
121        .strip_prefix(project_root)
122        .unwrap_or(&abs_path)
123        .to_string_lossy()
124        .to_string();
125
126    Ok(json!({
127        "path": rel_path,
128        "symbol": name,
129        "from": start + 1,
130        "to": end,
131        "lines_before": original_count,
132        "lines_after": new_count,
133    }))
134}
135
136// ── edit insert-after ─────────────────────────────────────────────────────────
137
138/// Insert `code` immediately after a symbol's end line.
139///
140/// Adds a blank line separator if the line after the symbol is not already blank.
141///
142/// # Errors
143/// Returns an error if the symbol can't be found or the file can't be written.
144pub async fn handle_edit_insert_after(
145    name: &str,
146    code: &str,
147    client: &mut LspClient,
148    file_tracker: &mut FileTracker,
149    project_root: &Path,
150    dirty_files: &DirtyFiles,
151) -> anyhow::Result<Value> {
152    let (abs_path, location) = locate_symbol(name, client, file_tracker, project_root).await?;
153
154    let content = std::fs::read_to_string(&abs_path)
155        .with_context(|| format!("failed to read: {}", abs_path.display()))?;
156
157    let mut lines: Vec<&str> = content.lines().collect();
158    let insert_at = (location.end_line as usize + 1).min(lines.len());
159
160    // Add blank separator if next line is not already blank
161    let needs_blank = lines.get(insert_at).is_some_and(|l| !l.trim().is_empty());
162
163    let new_lines: Vec<&str> = code.lines().collect();
164    let insert_count = new_lines.len();
165
166    if needs_blank {
167        lines.splice(
168            insert_at..insert_at,
169            std::iter::once("").chain(new_lines.iter().copied()),
170        );
171    } else {
172        lines.splice(insert_at..insert_at, new_lines.iter().copied());
173    }
174
175    let new_content = ensure_trailing_newline(&lines.join("\n"));
176    atomic_write(&abs_path, &new_content)?;
177    mark_dirty(&abs_path, project_root, dirty_files);
178
179    let rel_path = abs_path
180        .strip_prefix(project_root)
181        .unwrap_or(&abs_path)
182        .to_string_lossy()
183        .to_string();
184
185    Ok(json!({
186        "path": rel_path,
187        "symbol": name,
188        "operation": "after",
189        "inserted_at": insert_at + 1,
190        "lines_added": insert_count,
191    }))
192}
193
194// ── edit insert-before ────────────────────────────────────────────────────────
195
196/// Insert `code` before a symbol, skipping any leading attributes/decorators/doc comments.
197///
198/// Scans upward from the symbol's start line to find `#[...]`, `@decorator`,
199/// or `///`/`//!` doc comment lines, and inserts before those.
200///
201/// # Errors
202/// Returns an error if the symbol can't be found or the file can't be written.
203pub async fn handle_edit_insert_before(
204    name: &str,
205    code: &str,
206    client: &mut LspClient,
207    file_tracker: &mut FileTracker,
208    project_root: &Path,
209    dirty_files: &DirtyFiles,
210) -> anyhow::Result<Value> {
211    let (abs_path, location) = locate_symbol(name, client, file_tracker, project_root).await?;
212
213    let content = std::fs::read_to_string(&abs_path)
214        .with_context(|| format!("failed to read: {}", abs_path.display()))?;
215
216    let mut lines: Vec<&str> = content.lines().collect();
217
218    // Walk upward from symbol start to skip over attributes/decorators/doc comments
219    let symbol_start = location.start_line as usize;
220    let insert_at = find_insert_before_line(&lines, symbol_start);
221
222    let new_lines: Vec<&str> = code.lines().collect();
223    let insert_count = new_lines.len();
224
225    // Insert code + blank separator before the target line
226    let with_sep: Vec<&str> = new_lines
227        .iter()
228        .copied()
229        .chain(std::iter::once(""))
230        .collect();
231    lines.splice(insert_at..insert_at, with_sep.iter().copied());
232
233    let new_content = ensure_trailing_newline(&lines.join("\n"));
234    atomic_write(&abs_path, &new_content)?;
235    mark_dirty(&abs_path, project_root, dirty_files);
236
237    let rel_path = abs_path
238        .strip_prefix(project_root)
239        .unwrap_or(&abs_path)
240        .to_string_lossy()
241        .to_string();
242
243    Ok(json!({
244        "path": rel_path,
245        "symbol": name,
246        "operation": "before",
247        "inserted_at": insert_at + 1,
248        "lines_added": insert_count,
249    }))
250}
251
252/// Find the line index to insert before, walking upward past attributes/doc comments.
253fn find_insert_before_line(lines: &[&str], symbol_start: usize) -> usize {
254    if symbol_start == 0 {
255        return 0;
256    }
257
258    let mut cursor = symbol_start;
259
260    // Walk upward while lines look like attributes, decorators, or doc comments
261    loop {
262        if cursor == 0 {
263            break;
264        }
265        let prev = cursor - 1;
266        let trimmed = lines[prev].trim();
267
268        let is_attr_or_doc = trimmed.starts_with("#[")
269            || trimmed.starts_with('@')
270            || trimmed.starts_with("///")
271            || trimmed.starts_with("//!")
272            || trimmed.starts_with("/**")
273            || trimmed.starts_with("* ")
274            || trimmed == "*/"
275            || trimmed.starts_with("/*");
276
277        if is_attr_or_doc {
278            cursor = prev;
279        } else {
280            break;
281        }
282    }
283
284    cursor
285}
286
287// ── Output formatting ─────────────────────────────────────────────────────────
288
289/// Format an edit replace response for compact output.
290#[must_use]
291pub fn format_replace(data: &Value) -> String {
292    let path = data["path"].as_str().unwrap_or("?");
293    let symbol = data["symbol"].as_str().unwrap_or("?");
294    let from = data["from"].as_u64().unwrap_or(0);
295    let to = data["to"].as_u64().unwrap_or(0);
296    let before = data["lines_before"].as_u64().unwrap_or(0);
297    let after = data["lines_after"].as_u64().unwrap_or(0);
298    format!("replaced {path}:{from}-{to} {symbol} ({before} lines → {after} lines)")
299}
300
301/// Format an insert response for compact output.
302#[must_use]
303pub fn format_insert(data: &Value, kind: &str) -> String {
304    let path = data["path"].as_str().unwrap_or("?");
305    let symbol = data["symbol"].as_str().unwrap_or("?");
306    let at = data["inserted_at"].as_u64().unwrap_or(0);
307    let count = data["lines_added"].as_u64().unwrap_or(0);
308    format!("inserted {kind} {path}:{at} {symbol} ({count} lines added at line {at})")
309}
310
311// ── Tests ─────────────────────────────────────────────────────────────────────
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316    use tempfile::TempDir;
317
318    fn write_tmp(dir: &TempDir, name: &str, content: &str) -> PathBuf {
319        let path = dir.path().join(name);
320        std::fs::write(&path, content).unwrap();
321        path
322    }
323
324    #[test]
325    fn atomic_write_creates_file() {
326        let dir = tempfile::tempdir().unwrap();
327        let path = dir.path().join("out.rs");
328        atomic_write(&path, "fn hello() {}").unwrap();
329        let content = std::fs::read_to_string(&path).unwrap();
330        assert_eq!(content, "fn hello() {}");
331    }
332
333    #[test]
334    fn atomic_write_no_tmp_left_on_success() {
335        let dir = tempfile::tempdir().unwrap();
336        let path = dir.path().join("out.rs");
337        atomic_write(&path, "fn hello() {}").unwrap();
338        assert!(!path.with_extension("tmp").exists());
339    }
340
341    #[test]
342    fn find_insert_before_skips_attributes() {
343        let lines = vec![
344            "fn unrelated() {}",   // 0
345            "",                    // 1
346            "#[derive(Debug)]",    // 2
347            "#[allow(dead_code)]", // 3
348            "struct Foo {",        // 4
349            "}",                   // 5
350        ];
351        // Symbol starts at line 4; should insert before line 2
352        assert_eq!(find_insert_before_line(&lines, 4), 2);
353    }
354
355    #[test]
356    fn find_insert_before_skips_doc_comments() {
357        let lines = vec![
358            "fn other() {}",  // 0
359            "",               // 1
360            "/// My doc",     // 2
361            "fn target() {}", // 3
362        ];
363        assert_eq!(find_insert_before_line(&lines, 3), 2);
364    }
365
366    #[test]
367    fn find_insert_before_no_attrs_returns_symbol_start() {
368        let lines = vec![
369            "fn a() {}", // 0
370            "",          // 1
371            "fn b() {}", // 2
372        ];
373        assert_eq!(find_insert_before_line(&lines, 2), 2);
374    }
375
376    #[test]
377    fn find_insert_before_at_start_of_file() {
378        let lines = vec!["fn only() {}"];
379        assert_eq!(find_insert_before_line(&lines, 0), 0);
380    }
381
382    #[test]
383    fn ensure_trailing_newline_adds_newline() {
384        assert_eq!(ensure_trailing_newline("hello"), "hello\n");
385    }
386
387    #[test]
388    fn ensure_trailing_newline_no_double_newline() {
389        assert_eq!(ensure_trailing_newline("hello\n"), "hello\n");
390    }
391
392    #[test]
393    fn format_replace_output() {
394        let data = json!({
395            "path": "src/lib.rs",
396            "symbol": "greet",
397            "from": 5,
398            "to": 15,
399            "lines_before": 11,
400            "lines_after": 8,
401        });
402        let out = format_replace(&data);
403        assert!(out.contains("replaced"));
404        assert!(out.contains("src/lib.rs:5-15"));
405        assert!(out.contains("greet"));
406        assert!(out.contains("11 lines → 8 lines"));
407    }
408
409    #[test]
410    fn format_insert_after_output() {
411        let data = json!({
412            "path": "src/lib.rs",
413            "symbol": "greet",
414            "inserted_at": 16,
415            "lines_added": 5,
416        });
417        let out = format_insert(&data, "after");
418        assert!(out.contains("inserted after"));
419        assert!(out.contains("src/lib.rs:16"));
420        assert!(out.contains("5 lines added"));
421    }
422}