1use std::path::Path;
16
17use normalize_languages::parsers::parse_with_grammar;
18use normalize_languages::support_for_path;
19
20use crate::{PlannedEdit, RefactoringPlan};
21
22#[derive(Debug, Clone, Copy)]
24pub struct ByteRange {
25 pub start: usize,
26 pub end: usize,
27}
28
29pub struct IntroduceVariableOutcome {
31 pub plan: RefactoringPlan,
32 pub name: String,
34 pub inserted_line: usize,
36 pub replaced_start: usize,
38 pub replaced_end: usize,
39}
40
41pub fn plan_introduce_variable(
48 file: &Path,
49 content: &str,
50 range: ByteRange,
51 name: &str,
52) -> Result<IntroduceVariableOutcome, String> {
53 if range.start > range.end || range.end > content.len() {
55 return Err(format!(
56 "Invalid range {}..{} for file of length {}",
57 range.start,
58 range.end,
59 content.len()
60 ));
61 }
62
63 let support = support_for_path(file)
65 .ok_or_else(|| format!("No language support for {}", file.display()))?;
66 let grammar = support.grammar_name();
67
68 let tree = parse_with_grammar(grammar, content).ok_or_else(|| {
69 format!(
70 "Grammar '{}' not available — install grammars with `normalize grammars install`",
71 grammar
72 )
73 })?;
74
75 let root = tree.root_node();
76
77 let expr_node = root
79 .descendant_for_byte_range(range.start, range.end)
80 .ok_or_else(|| {
81 format!(
82 "No AST node found at byte range {}..{}",
83 range.start, range.end
84 )
85 })?;
86
87 let node_start = expr_node.start_byte();
90 let node_end = expr_node.end_byte();
91
92 let expr_node = find_best_expression_node(expr_node, range);
95
96 let actual_start = expr_node.start_byte();
97 let actual_end = expr_node.end_byte();
98
99 let selected_text = content[actual_start..actual_end].trim();
101 if selected_text.is_empty() {
102 return Err("Selected range is empty or whitespace only".to_string());
103 }
104
105 let kind = expr_node.kind();
107 if is_statement_kind(kind) {
108 return Err(format!(
109 "Selected node '{}' is a statement, not an expression. Select the expression inside it.",
110 kind
111 ));
112 }
113
114 let _ = (node_start, node_end);
116
117 let stmt_node = find_parent_statement(&expr_node)
119 .ok_or_else(|| "Could not find a parent statement for the expression".to_string())?;
120
121 let stmt_start = stmt_node.start_byte();
123 let indent = leading_whitespace(content, stmt_start);
124
125 let expr_text = content[actual_start..actual_end].to_string();
127 let binding = make_binding(grammar, name, &expr_text, &indent);
128
129 let insert_pos = line_start(content, stmt_start);
140
141 let new_expr_start = actual_start + binding.len();
143 let new_expr_end = actual_end + binding.len();
144
145 let mut new_content = content.to_string();
146 new_content.insert_str(insert_pos, &binding);
148 new_content.replace_range(new_expr_start..new_expr_end, name);
150
151 let inserted_line = content[..insert_pos].chars().filter(|&c| c == '\n').count() + 1;
153
154 let plan = RefactoringPlan {
155 operation: "introduce_variable".to_string(),
156 edits: vec![PlannedEdit {
157 file: file.to_path_buf(),
158 original: content.to_string(),
159 new_content,
160 description: format!("introduce variable '{}'", name),
161 }],
162 warnings: vec![],
163 };
164
165 Ok(IntroduceVariableOutcome {
166 plan,
167 name: name.to_string(),
168 inserted_line,
169 replaced_start: actual_start,
170 replaced_end: actual_end,
171 })
172}
173
174fn find_best_expression_node<'a>(
180 mut node: tree_sitter::Node<'a>,
181 range: ByteRange,
182) -> tree_sitter::Node<'a> {
183 if node.start_byte() == range.start && node.end_byte() == range.end {
185 return node;
186 }
187
188 loop {
190 let Some(parent) = node.parent() else { break };
191 if parent.start_byte() == range.start && parent.end_byte() == range.end {
194 node = parent;
195 continue;
196 }
197 if parent.start_byte() <= range.start && parent.end_byte() >= range.end {
200 break;
201 }
202 break;
203 }
204
205 node
206}
207
208fn is_statement_kind(kind: &str) -> bool {
210 matches!(
211 kind,
212 "let_declaration"
214 | "expression_statement"
215 | "assignment"
217 | "augmented_assignment"
218 | "assert_statement"
219 | "return_statement"
220 | "pass_statement"
221 | "break_statement"
222 | "continue_statement"
223 | "delete_statement"
224 | "import_statement"
225 | "import_from_statement"
226 | "raise_statement"
227 | "global_statement"
228 | "nonlocal_statement"
229 | "lexical_declaration"
231 | "variable_declaration"
232 | "throw_statement"
233 | "if_statement"
234 | "while_statement"
235 | "for_statement"
236 | "for_in_statement"
237 | "switch_statement"
238 | "try_statement"
239 | "block"
241 | "source_file"
242 | "program"
243 | "module"
244 )
245}
246
247fn find_parent_statement<'a>(node: &tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
252 let mut current = *node;
253 loop {
254 let Some(parent) = current.parent() else {
255 return Some(current);
258 };
259 let parent_kind = parent.kind();
260 if is_block_kind(parent_kind) {
261 return Some(current);
263 }
264 current = parent;
265 }
266}
267
268fn is_block_kind(kind: &str) -> bool {
270 matches!(
271 kind,
272 "block"
274 | "module"
276 | "body"
277 | "program"
279 | "statement_block"
280 | "source_file"
282 | "class_body"
283 | "enum_body"
284 )
285}
286
287fn line_start(content: &str, pos: usize) -> usize {
289 content[..pos].rfind('\n').map(|i| i + 1).unwrap_or(0)
290}
291
292fn leading_whitespace(content: &str, pos: usize) -> String {
294 let ls = line_start(content, pos);
295 let line = &content[ls..];
296 let ws_end = line
297 .find(|c: char| !c.is_whitespace())
298 .unwrap_or(line.len());
299 line[..ws_end].to_string()
300}
301
302fn make_binding(grammar: &str, name: &str, expr: &str, indent: &str) -> String {
304 match grammar {
305 "python" => {
306 format!("{}{} = {}\n", indent, name, expr)
308 }
309 "javascript" | "typescript" | "tsx" => {
310 format!("{}const {} = {};\n", indent, name, expr)
312 }
313 _ => {
314 format!("{}let {} = {};\n", indent, name, expr)
316 }
317 }
318}
319
320pub fn parse_line_col_range(content: &str, range_str: &str) -> Result<ByteRange, String> {
327 let (start_part, end_part) = range_str.split_once('-').ok_or_else(|| {
329 format!(
330 "Invalid range '{}': expected format start_line:start_col-end_line:end_col",
331 range_str
332 )
333 })?;
334
335 let (sl, sc) = parse_line_col(start_part, range_str)?;
336 let (el, ec) = parse_line_col(end_part, range_str)?;
337
338 let start_byte = line_col_to_byte(content, sl, sc).ok_or_else(|| {
339 format!(
340 "Start {}:{} is out of bounds for file of {} chars",
341 sl,
342 sc,
343 content.len()
344 )
345 })?;
346 let end_byte = line_col_to_byte(content, el, ec).ok_or_else(|| {
347 format!(
348 "End {}:{} is out of bounds for file of {} chars",
349 el,
350 ec,
351 content.len()
352 )
353 })?;
354
355 if start_byte > end_byte {
356 return Err(format!(
357 "Start byte {} > end byte {} — range is backwards",
358 start_byte, end_byte
359 ));
360 }
361
362 Ok(ByteRange {
363 start: start_byte,
364 end: end_byte,
365 })
366}
367
368fn parse_line_col(s: &str, full: &str) -> Result<(usize, usize), String> {
369 let (line_s, col_s) = s.split_once(':').ok_or_else(|| {
370 format!(
371 "Invalid position '{}' in range '{}': expected line:col",
372 s, full
373 )
374 })?;
375 let line: usize = line_s
376 .parse()
377 .map_err(|_| format!("Invalid line number '{}' in range '{}'", line_s, full))?;
378 let col: usize = col_s
379 .parse()
380 .map_err(|_| format!("Invalid column number '{}' in range '{}'", col_s, full))?;
381 if line == 0 || col == 0 {
382 return Err(format!(
383 "Line and column numbers are 1-based; got {}:{} in range '{}'",
384 line, col, full
385 ));
386 }
387 Ok((line, col))
388}
389
390fn line_col_to_byte(content: &str, line: usize, col: usize) -> Option<usize> {
392 let mut current_line = 1usize;
393 let mut current_col = 1usize;
394 for (byte_pos, ch) in content.char_indices() {
395 if current_line == line && current_col == col {
396 return Some(byte_pos);
397 }
398 if ch == '\n' {
399 current_line += 1;
400 current_col = 1;
401 } else {
402 current_col += 1;
403 }
404 }
405 if current_line == line && current_col == col {
407 return Some(content.len());
408 }
409 None
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use std::path::PathBuf;
416
417 fn rust_file() -> PathBuf {
418 PathBuf::from("test.rs")
419 }
420
421 fn py_file() -> PathBuf {
422 PathBuf::from("test.py")
423 }
424
425 fn ts_file() -> PathBuf {
426 PathBuf::from("test.ts")
427 }
428
429 fn js_file() -> PathBuf {
430 PathBuf::from("test.js")
431 }
432
433 fn byte_range_of(content: &str, needle: &str) -> ByteRange {
435 let start = content
436 .find(needle)
437 .unwrap_or_else(|| panic!("needle {:?} not found in content: {:?}", needle, content));
438 ByteRange {
439 start,
440 end: start + needle.len(),
441 }
442 }
443
444 #[test]
445 fn test_rust_introduce_variable() {
446 let content = "fn main() {\n let result = some_function(x + y * 2);\n}\n";
447 let range = byte_range_of(content, "x + y * 2");
448 let outcome = plan_introduce_variable(&rust_file(), content, range, "sum").unwrap();
449 assert_eq!(outcome.name, "sum");
450 let new_content = &outcome.plan.edits[0].new_content;
451 assert!(
452 new_content.contains("let sum = x + y * 2;"),
453 "expected let binding, got:\n{}",
454 new_content
455 );
456 assert!(
457 new_content.contains("some_function(sum)"),
458 "expected expression replaced, got:\n{}",
459 new_content
460 );
461 }
462
463 #[test]
464 fn test_python_introduce_variable() {
465 let content = "def main():\n result = some_function(x + y * 2)\n print(result)\n";
466 let range = byte_range_of(content, "x + y * 2");
467 let outcome = plan_introduce_variable(&py_file(), content, range, "total").unwrap();
468 let new_content = &outcome.plan.edits[0].new_content;
469 assert!(
471 new_content.contains("total = x + y * 2"),
472 "expected python binding, got:\n{}",
473 new_content
474 );
475 assert!(
476 new_content.contains("some_function(total)"),
477 "expected expression replaced, got:\n{}",
478 new_content
479 );
480 }
481
482 #[test]
483 fn test_typescript_introduce_variable() {
484 let content = "function main() {\n const result = someFunction(x + y * 2);\n console.log(result);\n}\n";
485 let range = byte_range_of(content, "x + y * 2");
486 let outcome = plan_introduce_variable(&ts_file(), content, range, "sum").unwrap();
487 let new_content = &outcome.plan.edits[0].new_content;
488 assert!(
489 new_content.contains("const sum = x + y * 2;"),
490 "expected const binding, got:\n{}",
491 new_content
492 );
493 assert!(
494 new_content.contains("someFunction(sum)"),
495 "expected expression replaced, got:\n{}",
496 new_content
497 );
498 }
499
500 #[test]
501 fn test_javascript_introduce_variable() {
502 let content = "function main() {\n const result = someFunction(x + y * 2);\n console.log(result);\n}\n";
503 let range = byte_range_of(content, "x + y * 2");
504 let outcome = plan_introduce_variable(&js_file(), content, range, "sum").unwrap();
505 let new_content = &outcome.plan.edits[0].new_content;
506 assert!(
507 new_content.contains("const sum = x + y * 2;"),
508 "expected const binding, got:\n{}",
509 new_content
510 );
511 }
512
513 #[test]
514 fn test_indentation_preserved() {
515 let content = "fn main() {\n if true {\n let x = foo(a + b);\n }\n}\n";
516 let range = byte_range_of(content, "a + b");
517 let outcome = plan_introduce_variable(&rust_file(), content, range, "sum").unwrap();
518 let new_content = &outcome.plan.edits[0].new_content;
519 assert!(
521 new_content.contains(" let sum = a + b;"),
522 "expected indented binding, got:\n{}",
523 new_content
524 );
525 }
526
527 #[test]
528 fn test_parse_line_col_range() {
529 let content = "fn main() {\n let x = 1;\n}\n";
530 let range = parse_line_col_range(content, "2:5-2:8").unwrap();
532 assert_eq!(&content[range.start..range.end], "let");
533 }
534
535 #[test]
536 fn test_error_on_statement_selection() {
537 let content = "fn main() {\n let x = 1 + 2;\n}\n";
538 let range = byte_range_of(content, "let x = 1 + 2;");
540 let result = plan_introduce_variable(&rust_file(), content, range, "y");
541 assert!(result.is_err(), "should error on statement selection");
542 }
543}