use std::path::Path;
use normalize_languages::parsers::parse_with_grammar;
use normalize_languages::support_for_path;
use crate::{PlannedEdit, RefactoringPlan};
#[derive(Debug, Clone, Copy)]
pub struct ByteRange {
pub start: usize,
pub end: usize,
}
pub struct IntroduceVariableOutcome {
pub plan: RefactoringPlan,
pub name: String,
pub inserted_line: usize,
pub replaced_start: usize,
pub replaced_end: usize,
}
pub fn plan_introduce_variable(
file: &Path,
content: &str,
range: ByteRange,
name: &str,
) -> Result<IntroduceVariableOutcome, String> {
if range.start > range.end || range.end > content.len() {
return Err(format!(
"Invalid range {}..{} for file of length {}",
range.start,
range.end,
content.len()
));
}
let support = support_for_path(file)
.ok_or_else(|| format!("No language support for {}", file.display()))?;
let grammar = support.grammar_name();
let tree = parse_with_grammar(grammar, content).ok_or_else(|| {
format!(
"Grammar '{}' not available — install grammars with `normalize grammars install`",
grammar
)
})?;
let root = tree.root_node();
let expr_node = root
.descendant_for_byte_range(range.start, range.end)
.ok_or_else(|| {
format!(
"No AST node found at byte range {}..{}",
range.start, range.end
)
})?;
let node_start = expr_node.start_byte();
let node_end = expr_node.end_byte();
let expr_node = find_best_expression_node(expr_node, range);
let actual_start = expr_node.start_byte();
let actual_end = expr_node.end_byte();
let selected_text = content[actual_start..actual_end].trim();
if selected_text.is_empty() {
return Err("Selected range is empty or whitespace only".to_string());
}
let kind = expr_node.kind();
if is_statement_kind(kind) {
return Err(format!(
"Selected node '{}' is a statement, not an expression. Select the expression inside it.",
kind
));
}
let _ = (node_start, node_end);
let stmt_node = find_parent_statement(&expr_node)
.ok_or_else(|| "Could not find a parent statement for the expression".to_string())?;
let stmt_start = stmt_node.start_byte();
let indent = leading_whitespace(content, stmt_start);
let expr_text = content[actual_start..actual_end].to_string();
let binding = make_binding(grammar, name, &expr_text, &indent);
let insert_pos = line_start(content, stmt_start);
let new_expr_start = actual_start + binding.len();
let new_expr_end = actual_end + binding.len();
let mut new_content = content.to_string();
new_content.insert_str(insert_pos, &binding);
new_content.replace_range(new_expr_start..new_expr_end, name);
let inserted_line = content[..insert_pos].chars().filter(|&c| c == '\n').count() + 1;
let plan = RefactoringPlan {
operation: "introduce_variable".to_string(),
edits: vec![PlannedEdit {
file: file.to_path_buf(),
original: content.to_string(),
new_content,
description: format!("introduce variable '{}'", name),
}],
warnings: vec![],
};
Ok(IntroduceVariableOutcome {
plan,
name: name.to_string(),
inserted_line,
replaced_start: actual_start,
replaced_end: actual_end,
})
}
fn find_best_expression_node<'a>(
mut node: tree_sitter::Node<'a>,
range: ByteRange,
) -> tree_sitter::Node<'a> {
if node.start_byte() == range.start && node.end_byte() == range.end {
return node;
}
loop {
let Some(parent) = node.parent() else { break };
if parent.start_byte() == range.start && parent.end_byte() == range.end {
node = parent;
continue;
}
if parent.start_byte() <= range.start && parent.end_byte() >= range.end {
break;
}
break;
}
node
}
fn is_statement_kind(kind: &str) -> bool {
matches!(
kind,
"let_declaration"
| "expression_statement"
| "assignment"
| "augmented_assignment"
| "assert_statement"
| "return_statement"
| "pass_statement"
| "break_statement"
| "continue_statement"
| "delete_statement"
| "import_statement"
| "import_from_statement"
| "raise_statement"
| "global_statement"
| "nonlocal_statement"
| "lexical_declaration"
| "variable_declaration"
| "throw_statement"
| "if_statement"
| "while_statement"
| "for_statement"
| "for_in_statement"
| "switch_statement"
| "try_statement"
| "block"
| "source_file"
| "program"
| "module"
)
}
fn find_parent_statement<'a>(node: &tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
let mut current = *node;
loop {
let Some(parent) = current.parent() else {
return Some(current);
};
let parent_kind = parent.kind();
if is_block_kind(parent_kind) {
return Some(current);
}
current = parent;
}
}
fn is_block_kind(kind: &str) -> bool {
matches!(
kind,
"block"
| "module"
| "body"
| "program"
| "statement_block"
| "source_file"
| "class_body"
| "enum_body"
)
}
fn line_start(content: &str, pos: usize) -> usize {
content[..pos].rfind('\n').map(|i| i + 1).unwrap_or(0)
}
fn leading_whitespace(content: &str, pos: usize) -> String {
let ls = line_start(content, pos);
let line = &content[ls..];
let ws_end = line
.find(|c: char| !c.is_whitespace())
.unwrap_or(line.len());
line[..ws_end].to_string()
}
fn make_binding(grammar: &str, name: &str, expr: &str, indent: &str) -> String {
match grammar {
"python" => {
format!("{}{} = {}\n", indent, name, expr)
}
"javascript" | "typescript" | "tsx" => {
format!("{}const {} = {};\n", indent, name, expr)
}
_ => {
format!("{}let {} = {};\n", indent, name, expr)
}
}
}
pub fn parse_line_col_range(content: &str, range_str: &str) -> Result<ByteRange, String> {
let (start_part, end_part) = range_str.split_once('-').ok_or_else(|| {
format!(
"Invalid range '{}': expected format start_line:start_col-end_line:end_col",
range_str
)
})?;
let (sl, sc) = parse_line_col(start_part, range_str)?;
let (el, ec) = parse_line_col(end_part, range_str)?;
let start_byte = line_col_to_byte(content, sl, sc).ok_or_else(|| {
format!(
"Start {}:{} is out of bounds for file of {} chars",
sl,
sc,
content.len()
)
})?;
let end_byte = line_col_to_byte(content, el, ec).ok_or_else(|| {
format!(
"End {}:{} is out of bounds for file of {} chars",
el,
ec,
content.len()
)
})?;
if start_byte > end_byte {
return Err(format!(
"Start byte {} > end byte {} — range is backwards",
start_byte, end_byte
));
}
Ok(ByteRange {
start: start_byte,
end: end_byte,
})
}
fn parse_line_col(s: &str, full: &str) -> Result<(usize, usize), String> {
let (line_s, col_s) = s.split_once(':').ok_or_else(|| {
format!(
"Invalid position '{}' in range '{}': expected line:col",
s, full
)
})?;
let line: usize = line_s
.parse()
.map_err(|_| format!("Invalid line number '{}' in range '{}'", line_s, full))?;
let col: usize = col_s
.parse()
.map_err(|_| format!("Invalid column number '{}' in range '{}'", col_s, full))?;
if line == 0 || col == 0 {
return Err(format!(
"Line and column numbers are 1-based; got {}:{} in range '{}'",
line, col, full
));
}
Ok((line, col))
}
fn line_col_to_byte(content: &str, line: usize, col: usize) -> Option<usize> {
let mut current_line = 1usize;
let mut current_col = 1usize;
for (byte_pos, ch) in content.char_indices() {
if current_line == line && current_col == col {
return Some(byte_pos);
}
if ch == '\n' {
current_line += 1;
current_col = 1;
} else {
current_col += 1;
}
}
if current_line == line && current_col == col {
return Some(content.len());
}
None
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
fn rust_file() -> PathBuf {
PathBuf::from("test.rs")
}
fn py_file() -> PathBuf {
PathBuf::from("test.py")
}
fn ts_file() -> PathBuf {
PathBuf::from("test.ts")
}
fn js_file() -> PathBuf {
PathBuf::from("test.js")
}
fn byte_range_of(content: &str, needle: &str) -> ByteRange {
let start = content
.find(needle)
.unwrap_or_else(|| panic!("needle {:?} not found in content: {:?}", needle, content));
ByteRange {
start,
end: start + needle.len(),
}
}
#[test]
fn test_rust_introduce_variable() {
let content = "fn main() {\n let result = some_function(x + y * 2);\n}\n";
let range = byte_range_of(content, "x + y * 2");
let outcome = plan_introduce_variable(&rust_file(), content, range, "sum").unwrap();
assert_eq!(outcome.name, "sum");
let new_content = &outcome.plan.edits[0].new_content;
assert!(
new_content.contains("let sum = x + y * 2;"),
"expected let binding, got:\n{}",
new_content
);
assert!(
new_content.contains("some_function(sum)"),
"expected expression replaced, got:\n{}",
new_content
);
}
#[test]
fn test_python_introduce_variable() {
let content = "def main():\n result = some_function(x + y * 2)\n print(result)\n";
let range = byte_range_of(content, "x + y * 2");
let outcome = plan_introduce_variable(&py_file(), content, range, "total").unwrap();
let new_content = &outcome.plan.edits[0].new_content;
assert!(
new_content.contains("total = x + y * 2"),
"expected python binding, got:\n{}",
new_content
);
assert!(
new_content.contains("some_function(total)"),
"expected expression replaced, got:\n{}",
new_content
);
}
#[test]
fn test_typescript_introduce_variable() {
let content = "function main() {\n const result = someFunction(x + y * 2);\n console.log(result);\n}\n";
let range = byte_range_of(content, "x + y * 2");
let outcome = plan_introduce_variable(&ts_file(), content, range, "sum").unwrap();
let new_content = &outcome.plan.edits[0].new_content;
assert!(
new_content.contains("const sum = x + y * 2;"),
"expected const binding, got:\n{}",
new_content
);
assert!(
new_content.contains("someFunction(sum)"),
"expected expression replaced, got:\n{}",
new_content
);
}
#[test]
fn test_javascript_introduce_variable() {
let content = "function main() {\n const result = someFunction(x + y * 2);\n console.log(result);\n}\n";
let range = byte_range_of(content, "x + y * 2");
let outcome = plan_introduce_variable(&js_file(), content, range, "sum").unwrap();
let new_content = &outcome.plan.edits[0].new_content;
assert!(
new_content.contains("const sum = x + y * 2;"),
"expected const binding, got:\n{}",
new_content
);
}
#[test]
fn test_indentation_preserved() {
let content = "fn main() {\n if true {\n let x = foo(a + b);\n }\n}\n";
let range = byte_range_of(content, "a + b");
let outcome = plan_introduce_variable(&rust_file(), content, range, "sum").unwrap();
let new_content = &outcome.plan.edits[0].new_content;
assert!(
new_content.contains(" let sum = a + b;"),
"expected indented binding, got:\n{}",
new_content
);
}
#[test]
fn test_parse_line_col_range() {
let content = "fn main() {\n let x = 1;\n}\n";
let range = parse_line_col_range(content, "2:5-2:8").unwrap();
assert_eq!(&content[range.start..range.end], "let");
}
#[test]
fn test_error_on_statement_selection() {
let content = "fn main() {\n let x = 1 + 2;\n}\n";
let range = byte_range_of(content, "let x = 1 + 2;");
let result = plan_introduce_variable(&rust_file(), content, range, "y");
assert!(result.is_err(), "should error on statement selection");
}
}