Skip to main content

normalize_refactor/
introduce_variable.rs

1//! Introduce-variable recipe: extract an expression at a given range into a named variable binding.
2//!
3//! Algorithm:
4//! 1. Parse the file with tree-sitter.
5//! 2. Find the innermost complete expression node covering the given byte range.
6//! 3. Walk up to the parent statement of that expression.
7//! 4. Insert `let <name> = <expr>;` (or the language-appropriate binding) before the statement.
8//! 5. Replace the original expression text with `<name>`.
9//!
10//! Language-specific keyword mapping (everything else uses `let`):
11//! - Python: `<name> = <expr>` (no keyword)
12//! - JavaScript / TypeScript: `const <name> = <expr>;`
13//! - Rust, Go, Swift, Kotlin, Scala, Dart, etc.: `let <name> = <expr>;`
14
15use std::path::Path;
16
17use normalize_languages::parsers::parse_with_grammar;
18use normalize_languages::support_for_path;
19
20use crate::{PlannedEdit, RefactoringPlan};
21
22/// A byte range selected by the user.
23#[derive(Debug, Clone, Copy)]
24pub struct ByteRange {
25    pub start: usize,
26    pub end: usize,
27}
28
29/// Outcome of a successful introduce-variable plan.
30pub struct IntroduceVariableOutcome {
31    pub plan: RefactoringPlan,
32    /// The variable name that was introduced.
33    pub name: String,
34    /// 1-based line number where the `let` binding was inserted.
35    pub inserted_line: usize,
36    /// Byte range that was replaced with the variable name.
37    pub replaced_start: usize,
38    pub replaced_end: usize,
39}
40
41/// Build an introduce-variable plan without touching the filesystem.
42///
43/// `file` is the absolute path to the file.
44/// `content` is the current file content.
45/// `range` is the byte range of the expression to extract.
46/// `name` is the variable name to introduce.
47pub fn plan_introduce_variable(
48    file: &Path,
49    content: &str,
50    range: ByteRange,
51    name: &str,
52) -> Result<IntroduceVariableOutcome, String> {
53    // Validate range bounds.
54    if range.start > range.end || range.end > content.len() {
55        return Err(format!(
56            "Invalid range {}..{} for file of length {}",
57            range.start,
58            range.end,
59            content.len()
60        ));
61    }
62
63    // Determine grammar from path.
64    let support = support_for_path(file)
65        .ok_or_else(|| format!("No language support for {}", file.display()))?;
66    let grammar = support.grammar_name();
67
68    let tree = parse_with_grammar(grammar, content).ok_or_else(|| {
69        format!(
70            "Grammar '{}' not available — install grammars with `normalize grammars install`",
71            grammar
72        )
73    })?;
74
75    let root = tree.root_node();
76
77    // Find the innermost node that covers the selection range.
78    let expr_node = root
79        .descendant_for_byte_range(range.start, range.end)
80        .ok_or_else(|| {
81            format!(
82                "No AST node found at byte range {}..{}",
83                range.start, range.end
84            )
85        })?;
86
87    // Validate that the selected range corresponds to a reasonably complete expression.
88    // The node should start at or before the selection start and end at or after end.
89    let node_start = expr_node.start_byte();
90    let node_end = expr_node.end_byte();
91
92    // If the node's range doesn't match the selection closely (allowing for whitespace),
93    // search upward for a better match (a node whose range closely contains the selection).
94    let expr_node = find_best_expression_node(expr_node, range);
95
96    let actual_start = expr_node.start_byte();
97    let actual_end = expr_node.end_byte();
98
99    // The selected text must be a meaningful expression (not purely structural tokens).
100    let selected_text = content[actual_start..actual_end].trim();
101    if selected_text.is_empty() {
102        return Err("Selected range is empty or whitespace only".to_string());
103    }
104
105    // Verify the node kind looks like an expression (not a statement wrapper, keyword, etc.).
106    let kind = expr_node.kind();
107    if is_statement_kind(kind) {
108        return Err(format!(
109            "Selected node '{}' is a statement, not an expression. Select the expression inside it.",
110            kind
111        ));
112    }
113
114    // Unused — suppressed by the assignment above which overwrites it.
115    let _ = (node_start, node_end);
116
117    // Walk up to find the parent statement that contains this expression.
118    let stmt_node = find_parent_statement(&expr_node)
119        .ok_or_else(|| "Could not find a parent statement for the expression".to_string())?;
120
121    // Determine indentation of the statement.
122    let stmt_start = stmt_node.start_byte();
123    let indent = leading_whitespace(content, stmt_start);
124
125    // Generate the binding declaration.
126    let expr_text = content[actual_start..actual_end].to_string();
127    let binding = make_binding(grammar, name, &expr_text, &indent);
128
129    // Build new file content:
130    // 1. Insert the binding before the statement line.
131    // 2. Replace the expression with the variable name.
132    //
133    // We must do these two edits carefully because inserting text shifts byte offsets.
134    // Strategy: apply the replacement first (it's inside the statement), then insert
135    // before the statement. Because the insertion is before the replacement site, we
136    // need to do insertion first and adjust the replacement offset.
137
138    // Compute the byte position of the start of the statement's line.
139    let insert_pos = line_start(content, stmt_start);
140
141    // After inserting `binding` at `insert_pos`, the expression bytes shift by `binding.len()`.
142    let new_expr_start = actual_start + binding.len();
143    let new_expr_end = actual_end + binding.len();
144
145    let mut new_content = content.to_string();
146    // Insert binding before the statement line.
147    new_content.insert_str(insert_pos, &binding);
148    // Replace the expression with the variable name.
149    new_content.replace_range(new_expr_start..new_expr_end, name);
150
151    // Compute the 1-based line number of the inserted binding.
152    let inserted_line = content[..insert_pos].chars().filter(|&c| c == '\n').count() + 1;
153
154    let plan = RefactoringPlan {
155        operation: "introduce_variable".to_string(),
156        edits: vec![PlannedEdit {
157            file: file.to_path_buf(),
158            original: content.to_string(),
159            new_content,
160            description: format!("introduce variable '{}'", name),
161        }],
162        warnings: vec![],
163    };
164
165    Ok(IntroduceVariableOutcome {
166        plan,
167        name: name.to_string(),
168        inserted_line,
169        replaced_start: actual_start,
170        replaced_end: actual_end,
171    })
172}
173
174/// Walk up the tree to find the node whose range best matches the selection.
175///
176/// We prefer the most specific (innermost) node whose byte range exactly covers
177/// the trimmed selection. If the direct match is inside a larger expression that
178/// exactly matches, prefer the exact match.
179fn find_best_expression_node<'a>(
180    mut node: tree_sitter::Node<'a>,
181    range: ByteRange,
182) -> tree_sitter::Node<'a> {
183    // If the node already exactly covers the range, keep it.
184    if node.start_byte() == range.start && node.end_byte() == range.end {
185        return node;
186    }
187
188    // Walk up while the parent is a better (closer) match for the range.
189    loop {
190        let Some(parent) = node.parent() else { break };
191        // If the parent exactly covers the range, prefer it (it's the "expression" the
192        // user intends rather than an inner token).
193        if parent.start_byte() == range.start && parent.end_byte() == range.end {
194            node = parent;
195            continue;
196        }
197        // If the parent covers more than the range, stop — the current node is the
198        // innermost covering node.
199        if parent.start_byte() <= range.start && parent.end_byte() >= range.end {
200            break;
201        }
202        break;
203    }
204
205    node
206}
207
208/// Returns true if the node kind is a statement wrapper, not an expression.
209fn is_statement_kind(kind: &str) -> bool {
210    matches!(
211        kind,
212        // Rust
213        "let_declaration"
214            | "expression_statement"
215            // Python
216            | "assignment"
217            | "augmented_assignment"
218            | "assert_statement"
219            | "return_statement"
220            | "pass_statement"
221            | "break_statement"
222            | "continue_statement"
223            | "delete_statement"
224            | "import_statement"
225            | "import_from_statement"
226            | "raise_statement"
227            | "global_statement"
228            | "nonlocal_statement"
229            // JS/TS (not already covered above)
230            | "lexical_declaration"
231            | "variable_declaration"
232            | "throw_statement"
233            | "if_statement"
234            | "while_statement"
235            | "for_statement"
236            | "for_in_statement"
237            | "switch_statement"
238            | "try_statement"
239            // General
240            | "block"
241            | "source_file"
242            | "program"
243            | "module"
244    )
245}
246
247/// Return the statement-level parent of an expression node.
248///
249/// Walks up the tree until we find a node that is at the statement level
250/// (i.e., its parent is a block / function body / module).
251fn find_parent_statement<'a>(node: &tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
252    let mut current = *node;
253    loop {
254        let Some(parent) = current.parent() else {
255            // Reached root without finding a statement — return current (the
256            // expression itself, which lives at the top level).
257            return Some(current);
258        };
259        let parent_kind = parent.kind();
260        if is_block_kind(parent_kind) {
261            // current is a direct child of a block → it IS the statement.
262            return Some(current);
263        }
264        current = parent;
265    }
266}
267
268/// Returns true if the node kind is a block / body container.
269fn is_block_kind(kind: &str) -> bool {
270    matches!(
271        kind,
272        // Rust
273        "block"
274            // Python
275            | "module"
276            | "body"
277            // JS / TS
278            | "program"
279            | "statement_block"
280            // Generic
281            | "source_file"
282            | "class_body"
283            | "enum_body"
284    )
285}
286
287/// Return the byte position of the start of the line containing `pos`.
288fn line_start(content: &str, pos: usize) -> usize {
289    content[..pos].rfind('\n').map(|i| i + 1).unwrap_or(0)
290}
291
292/// Extract the leading whitespace from the line containing `pos`.
293fn leading_whitespace(content: &str, pos: usize) -> String {
294    let ls = line_start(content, pos);
295    let line = &content[ls..];
296    let ws_end = line
297        .find(|c: char| !c.is_whitespace())
298        .unwrap_or(line.len());
299    line[..ws_end].to_string()
300}
301
302/// Generate the variable binding declaration for the given grammar/language.
303fn make_binding(grammar: &str, name: &str, expr: &str, indent: &str) -> String {
304    match grammar {
305        "python" => {
306            // Python: `name = expr\n`
307            format!("{}{} = {}\n", indent, name, expr)
308        }
309        "javascript" | "typescript" | "tsx" => {
310            // JS/TS: `const name = expr;\n`
311            format!("{}const {} = {};\n", indent, name, expr)
312        }
313        _ => {
314            // Default (Rust, Go, Swift, Kotlin, etc.): `let name = expr;\n`
315            format!("{}let {} = {};\n", indent, name, expr)
316        }
317    }
318}
319
320// ── Range parsing helpers ─────────────────────────────────────────────────────
321
322/// Parse a `line:col-line:col` range string into a byte range.
323///
324/// Lines and columns are **1-based** (matching editor conventions).
325/// Returns `Err` with a descriptive message on any parse failure.
326pub fn parse_line_col_range(content: &str, range_str: &str) -> Result<ByteRange, String> {
327    // Expected format: `<start_line>:<start_col>-<end_line>:<end_col>`
328    let (start_part, end_part) = range_str.split_once('-').ok_or_else(|| {
329        format!(
330            "Invalid range '{}': expected format start_line:start_col-end_line:end_col",
331            range_str
332        )
333    })?;
334
335    let (sl, sc) = parse_line_col(start_part, range_str)?;
336    let (el, ec) = parse_line_col(end_part, range_str)?;
337
338    let start_byte = line_col_to_byte(content, sl, sc).ok_or_else(|| {
339        format!(
340            "Start {}:{} is out of bounds for file of {} chars",
341            sl,
342            sc,
343            content.len()
344        )
345    })?;
346    let end_byte = line_col_to_byte(content, el, ec).ok_or_else(|| {
347        format!(
348            "End {}:{} is out of bounds for file of {} chars",
349            el,
350            ec,
351            content.len()
352        )
353    })?;
354
355    if start_byte > end_byte {
356        return Err(format!(
357            "Start byte {} > end byte {} — range is backwards",
358            start_byte, end_byte
359        ));
360    }
361
362    Ok(ByteRange {
363        start: start_byte,
364        end: end_byte,
365    })
366}
367
368fn parse_line_col(s: &str, full: &str) -> Result<(usize, usize), String> {
369    let (line_s, col_s) = s.split_once(':').ok_or_else(|| {
370        format!(
371            "Invalid position '{}' in range '{}': expected line:col",
372            s, full
373        )
374    })?;
375    let line: usize = line_s
376        .parse()
377        .map_err(|_| format!("Invalid line number '{}' in range '{}'", line_s, full))?;
378    let col: usize = col_s
379        .parse()
380        .map_err(|_| format!("Invalid column number '{}' in range '{}'", col_s, full))?;
381    if line == 0 || col == 0 {
382        return Err(format!(
383            "Line and column numbers are 1-based; got {}:{} in range '{}'",
384            line, col, full
385        ));
386    }
387    Ok((line, col))
388}
389
390/// Convert a 1-based line:col pair to a byte offset in `content`.
391fn line_col_to_byte(content: &str, line: usize, col: usize) -> Option<usize> {
392    let mut current_line = 1usize;
393    let mut current_col = 1usize;
394    for (byte_pos, ch) in content.char_indices() {
395        if current_line == line && current_col == col {
396            return Some(byte_pos);
397        }
398        if ch == '\n' {
399            current_line += 1;
400            current_col = 1;
401        } else {
402            current_col += 1;
403        }
404    }
405    // Allow pointing at end of content (e.g. end of last line without newline).
406    if current_line == line && current_col == col {
407        return Some(content.len());
408    }
409    None
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use std::path::PathBuf;
416
417    fn rust_file() -> PathBuf {
418        PathBuf::from("test.rs")
419    }
420
421    fn py_file() -> PathBuf {
422        PathBuf::from("test.py")
423    }
424
425    fn ts_file() -> PathBuf {
426        PathBuf::from("test.ts")
427    }
428
429    fn js_file() -> PathBuf {
430        PathBuf::from("test.js")
431    }
432
433    // Helper: get byte range for a substring (first occurrence).
434    fn byte_range_of(content: &str, needle: &str) -> ByteRange {
435        let start = content
436            .find(needle)
437            .unwrap_or_else(|| panic!("needle {:?} not found in content: {:?}", needle, content));
438        ByteRange {
439            start,
440            end: start + needle.len(),
441        }
442    }
443
444    #[test]
445    fn test_rust_introduce_variable() {
446        let content = "fn main() {\n    let result = some_function(x + y * 2);\n}\n";
447        let range = byte_range_of(content, "x + y * 2");
448        let outcome = plan_introduce_variable(&rust_file(), content, range, "sum").unwrap();
449        assert_eq!(outcome.name, "sum");
450        let new_content = &outcome.plan.edits[0].new_content;
451        assert!(
452            new_content.contains("let sum = x + y * 2;"),
453            "expected let binding, got:\n{}",
454            new_content
455        );
456        assert!(
457            new_content.contains("some_function(sum)"),
458            "expected expression replaced, got:\n{}",
459            new_content
460        );
461    }
462
463    #[test]
464    fn test_python_introduce_variable() {
465        let content = "def main():\n    result = some_function(x + y * 2)\n    print(result)\n";
466        let range = byte_range_of(content, "x + y * 2");
467        let outcome = plan_introduce_variable(&py_file(), content, range, "total").unwrap();
468        let new_content = &outcome.plan.edits[0].new_content;
469        // Python uses `name = expr` (no let)
470        assert!(
471            new_content.contains("total = x + y * 2"),
472            "expected python binding, got:\n{}",
473            new_content
474        );
475        assert!(
476            new_content.contains("some_function(total)"),
477            "expected expression replaced, got:\n{}",
478            new_content
479        );
480    }
481
482    #[test]
483    fn test_typescript_introduce_variable() {
484        let content = "function main() {\n    const result = someFunction(x + y * 2);\n    console.log(result);\n}\n";
485        let range = byte_range_of(content, "x + y * 2");
486        let outcome = plan_introduce_variable(&ts_file(), content, range, "sum").unwrap();
487        let new_content = &outcome.plan.edits[0].new_content;
488        assert!(
489            new_content.contains("const sum = x + y * 2;"),
490            "expected const binding, got:\n{}",
491            new_content
492        );
493        assert!(
494            new_content.contains("someFunction(sum)"),
495            "expected expression replaced, got:\n{}",
496            new_content
497        );
498    }
499
500    #[test]
501    fn test_javascript_introduce_variable() {
502        let content = "function main() {\n    const result = someFunction(x + y * 2);\n    console.log(result);\n}\n";
503        let range = byte_range_of(content, "x + y * 2");
504        let outcome = plan_introduce_variable(&js_file(), content, range, "sum").unwrap();
505        let new_content = &outcome.plan.edits[0].new_content;
506        assert!(
507            new_content.contains("const sum = x + y * 2;"),
508            "expected const binding, got:\n{}",
509            new_content
510        );
511    }
512
513    #[test]
514    fn test_indentation_preserved() {
515        let content = "fn main() {\n    if true {\n        let x = foo(a + b);\n    }\n}\n";
516        let range = byte_range_of(content, "a + b");
517        let outcome = plan_introduce_variable(&rust_file(), content, range, "sum").unwrap();
518        let new_content = &outcome.plan.edits[0].new_content;
519        // Should preserve the 8-space indent of the statement.
520        assert!(
521            new_content.contains("        let sum = a + b;"),
522            "expected indented binding, got:\n{}",
523            new_content
524        );
525    }
526
527    #[test]
528    fn test_parse_line_col_range() {
529        let content = "fn main() {\n    let x = 1;\n}\n";
530        // "let" starts at line 2, col 5
531        let range = parse_line_col_range(content, "2:5-2:8").unwrap();
532        assert_eq!(&content[range.start..range.end], "let");
533    }
534
535    #[test]
536    fn test_error_on_statement_selection() {
537        let content = "fn main() {\n    let x = 1 + 2;\n}\n";
538        // Select the entire let_declaration
539        let range = byte_range_of(content, "let x = 1 + 2;");
540        let result = plan_introduce_variable(&rust_file(), content, range, "y");
541        assert!(result.is_err(), "should error on statement selection");
542    }
543}