Skip to main content

aft/
extract.rs

1//! Shared extraction utilities for `extract_function` (and future `inline_symbol`).
2//!
3//! Provides:
4//! - `detect_free_variables` — classify identifier references in a byte range
5//! - `detect_return_value` — infer what the extracted function should return
6//! - `generate_extracted_function` — produce function text for TS/JS or Python
7
8use std::collections::HashSet;
9
10use tree_sitter::{Node, Tree};
11
12use crate::indent::IndentStyle;
13use crate::parser::{grammar_for, node_text, LangId};
14
15// ---------------------------------------------------------------------------
16// Free variable detection
17// ---------------------------------------------------------------------------
18
19/// Classification result for free variables in a selected byte range.
20#[derive(Debug)]
21pub struct FreeVariableResult {
22    /// Identifiers declared in an enclosing function scope that the range
23    /// references — these become parameters of the extracted function.
24    pub parameters: Vec<String>,
25    /// Whether `this` (JS/TS) or `self` (Python) appears in the range.
26    pub has_this_or_self: bool,
27}
28
29/// Walk the AST for a byte range and classify every identifier reference.
30///
31/// Classification rules:
32/// 1. Declared-in-range (local variable) → skip
33/// 2. Declared in enclosing function scope → parameter
34/// 3. Module-level or import → skip
35/// 4. `this` / `self` keyword → flag (error for extract_function)
36/// 5. `property_identifier` / `field_identifier` on the right side of `.` → skip
37///    (these are member accesses, not free variables)
38pub fn detect_free_variables(
39    source: &str,
40    tree: &Tree,
41    start_byte: usize,
42    end_byte: usize,
43    lang: LangId,
44) -> FreeVariableResult {
45    let root = tree.root_node();
46
47    // 1. Collect all identifiers referenced in the range (excluding property access)
48    let mut references: Vec<String> = Vec::new();
49    collect_identifier_refs(&root, source, start_byte, end_byte, lang, &mut references);
50
51    // 2. Collect declarations within the range (these are locals, not free)
52    let mut local_decls: HashSet<String> = HashSet::new();
53    collect_declarations_in_range(&root, source, start_byte, end_byte, lang, &mut local_decls);
54
55    // 3. Find the enclosing function scope boundary
56    let enclosing_fn = find_enclosing_function(&root, start_byte, lang);
57
58    // 4. Collect declarations in the enclosing function but outside the range
59    let mut enclosing_decls: HashSet<String> = HashSet::new();
60    if let Some(fn_node) = enclosing_fn {
61        collect_declarations_in_range(
62            &fn_node,
63            source,
64            fn_node.start_byte(),
65            start_byte, // only before the range
66            lang,
67            &mut enclosing_decls,
68        );
69        // Also collect function parameters
70        collect_function_params(&fn_node, source, lang, &mut enclosing_decls);
71    }
72
73    // 5. Check for this/self
74    let has_this_or_self = check_this_or_self(&root, source, start_byte, end_byte, lang);
75
76    // 6. Classify: a reference is a parameter if it's not a local decl,
77    //    IS declared in the enclosing function scope, and is not module-level.
78    let mut seen = HashSet::new();
79    let mut parameters = Vec::new();
80    for name in &references {
81        if local_decls.contains(name) {
82            continue;
83        }
84        if !seen.insert(name.clone()) {
85            continue; // dedup
86        }
87        if enclosing_decls.contains(name) {
88            parameters.push(name.clone());
89        }
90        // If not in enclosing_decls, it's module-level or global — skip
91    }
92
93    FreeVariableResult {
94        parameters,
95        has_this_or_self,
96    }
97}
98
99/// Collect all `identifier` nodes in [start_byte, end_byte) that are genuine
100/// references (not property accesses on the right side of `.`).
101fn collect_identifier_refs(
102    node: &Node,
103    source: &str,
104    start_byte: usize,
105    end_byte: usize,
106    lang: LangId,
107    out: &mut Vec<String>,
108) {
109    // Skip nodes entirely outside the range
110    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
111        return;
112    }
113
114    let kind = node.kind();
115
116    // An `identifier` node in the range that is NOT a property/field access
117    if kind == "identifier" && node.start_byte() >= start_byte && node.end_byte() <= end_byte {
118        // Check parent: if parent is member_expression and this is the "property" field,
119        // it's a property access, not a free variable.
120        if !is_property_access(node, lang) {
121            let name = node_text(source, node).to_string();
122            // Skip language keywords that parse as identifiers
123            if !is_keyword(&name, lang) {
124                out.push(name);
125            }
126        }
127    }
128
129    // Recurse into children
130    let mut cursor = node.walk();
131    if cursor.goto_first_child() {
132        loop {
133            collect_identifier_refs(&cursor.node(), source, start_byte, end_byte, lang, out);
134            if !cursor.goto_next_sibling() {
135                break;
136            }
137        }
138    }
139}
140
141/// Check if an identifier node is a property access (right side of `.`).
142fn is_property_access(node: &Node, lang: LangId) -> bool {
143    // property_identifier and field_identifier are separate node kinds in TS/JS,
144    // so they won't even reach here. But for Python `attribute` access the child
145    // is still `identifier`.
146    if let Some(parent) = node.parent() {
147        let pk = parent.kind();
148        match lang {
149            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
150                // member_expression: object.property — the "property" child
151                if pk == "member_expression" {
152                    if let Some(prop) = parent.child_by_field_name("property") {
153                        return prop.id() == node.id();
154                    }
155                }
156            }
157            LangId::Python => {
158                // attribute: object.attr — the "attribute" child
159                if pk == "attribute" {
160                    if let Some(attr) = parent.child_by_field_name("attribute") {
161                        return attr.id() == node.id();
162                    }
163                }
164            }
165            _ => {}
166        }
167    }
168    false
169}
170
171/// Identifiers that are language keywords and should not be treated as free variables.
172fn is_keyword(name: &str, lang: LangId) -> bool {
173    match lang {
174        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => matches!(
175            name,
176            "undefined" | "null" | "true" | "false" | "NaN" | "Infinity" | "console" | "require"
177        ),
178        LangId::Python => matches!(
179            name,
180            "None"
181                | "True"
182                | "False"
183                | "print"
184                | "len"
185                | "range"
186                | "str"
187                | "int"
188                | "float"
189                | "list"
190                | "dict"
191                | "set"
192                | "tuple"
193                | "type"
194                | "super"
195                | "isinstance"
196                | "enumerate"
197                | "zip"
198                | "map"
199                | "filter"
200                | "sorted"
201                | "reversed"
202                | "any"
203                | "all"
204                | "min"
205                | "max"
206                | "sum"
207                | "abs"
208                | "open"
209                | "input"
210                | "format"
211                | "hasattr"
212                | "getattr"
213                | "setattr"
214                | "delattr"
215                | "repr"
216                | "iter"
217                | "next"
218                | "ValueError"
219                | "TypeError"
220                | "KeyError"
221                | "IndexError"
222                | "Exception"
223                | "RuntimeError"
224                | "StopIteration"
225                | "NotImplementedError"
226                | "AttributeError"
227                | "ImportError"
228                | "OSError"
229                | "IOError"
230                | "FileNotFoundError"
231        ),
232        _ => false,
233    }
234}
235
236/// Collect names declared (via variable declarations) within a byte range.
237fn collect_declarations_in_range(
238    node: &Node,
239    source: &str,
240    start_byte: usize,
241    end_byte: usize,
242    lang: LangId,
243    out: &mut HashSet<String>,
244) {
245    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
246        return;
247    }
248
249    let kind = node.kind();
250
251    match lang {
252        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
253            // variable_declarator has a "name" child that is the declared identifier
254            if kind == "variable_declarator" {
255                if let Some(name_node) = node.child_by_field_name("name") {
256                    if name_node.start_byte() >= start_byte && name_node.end_byte() <= end_byte {
257                        out.insert(node_text(source, &name_node).to_string());
258                    }
259                }
260            }
261        }
262        LangId::Python => {
263            // assignment: left side
264            if kind == "assignment" {
265                if let Some(left) = node.child_by_field_name("left") {
266                    if left.kind() == "identifier"
267                        && left.start_byte() >= start_byte
268                        && left.end_byte() <= end_byte
269                    {
270                        out.insert(node_text(source, &left).to_string());
271                    }
272                }
273            }
274        }
275        _ => {}
276    }
277
278    // Recurse
279    let mut cursor = node.walk();
280    if cursor.goto_first_child() {
281        loop {
282            collect_declarations_in_range(&cursor.node(), source, start_byte, end_byte, lang, out);
283            if !cursor.goto_next_sibling() {
284                break;
285            }
286        }
287    }
288}
289
290/// Collect parameter names from a function node.
291fn collect_function_params(fn_node: &Node, source: &str, lang: LangId, out: &mut HashSet<String>) {
292    match lang {
293        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
294            // function_declaration / arrow_function have "parameters" field
295            if let Some(params) = fn_node.child_by_field_name("parameters") {
296                collect_param_identifiers(&params, source, lang, out);
297            }
298            // For arrow functions inside lexical_declaration, drill down
299            let mut cursor = fn_node.walk();
300            if cursor.goto_first_child() {
301                loop {
302                    let child = cursor.node();
303                    if child.kind() == "variable_declarator" {
304                        if let Some(value) = child.child_by_field_name("value") {
305                            if value.kind() == "arrow_function" {
306                                if let Some(params) = value.child_by_field_name("parameters") {
307                                    collect_param_identifiers(&params, source, lang, out);
308                                }
309                            }
310                        }
311                    }
312                    if !cursor.goto_next_sibling() {
313                        break;
314                    }
315                }
316            }
317        }
318        LangId::Python => {
319            if let Some(params) = fn_node.child_by_field_name("parameters") {
320                collect_param_identifiers(&params, source, lang, out);
321            }
322        }
323        _ => {}
324    }
325}
326
327/// Walk a parameter list node and collect identifier names.
328fn collect_param_identifiers(
329    params_node: &Node,
330    source: &str,
331    lang: LangId,
332    out: &mut HashSet<String>,
333) {
334    let mut cursor = params_node.walk();
335    if cursor.goto_first_child() {
336        loop {
337            let child = cursor.node();
338            match lang {
339                LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
340                    // required_parameter, optional_parameter have pattern child,
341                    // or directly identifier
342                    if child.kind() == "required_parameter" || child.kind() == "optional_parameter"
343                    {
344                        if let Some(pattern) = child.child_by_field_name("pattern") {
345                            if pattern.kind() == "identifier" {
346                                out.insert(node_text(source, &pattern).to_string());
347                            }
348                        }
349                    } else if child.kind() == "identifier" {
350                        out.insert(node_text(source, &child).to_string());
351                    }
352                }
353                LangId::Python => {
354                    if child.kind() == "identifier" {
355                        let name = node_text(source, &child).to_string();
356                        // Skip `self` parameter
357                        if name != "self" {
358                            out.insert(name);
359                        }
360                    }
361                }
362                _ => {}
363            }
364            if !cursor.goto_next_sibling() {
365                break;
366            }
367        }
368    }
369}
370
371/// Find the innermost function node that encloses `byte_pos`.
372fn find_enclosing_function<'a>(root: &Node<'a>, byte_pos: usize, lang: LangId) -> Option<Node<'a>> {
373    find_deepest_function_ancestor(root, byte_pos, lang)
374}
375
376/// Find the deepest function-like ancestor that contains `byte_pos`.
377fn find_deepest_function_ancestor<'a>(
378    node: &Node<'a>,
379    byte_pos: usize,
380    lang: LangId,
381) -> Option<Node<'a>> {
382    let mut result: Option<Node<'a>> = None;
383    if is_function_like_boundary(node, byte_pos, lang)
384        && node.start_byte() <= byte_pos
385        && byte_pos < node.end_byte()
386    {
387        result = Some(*node);
388    }
389
390    let child_count = node.child_count();
391    for i in 0..child_count {
392        if let Some(child) = node.child(i as u32) {
393            if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
394                if let Some(deeper) = find_deepest_function_ancestor(&child, byte_pos, lang) {
395                    result = Some(deeper);
396                }
397            }
398        }
399    }
400
401    result
402}
403
404fn is_function_like_boundary(node: &Node, byte_pos: usize, lang: LangId) -> bool {
405    match lang {
406        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => match node.kind() {
407            "function_declaration"
408            | "method_definition"
409            | "arrow_function"
410            | "function_expression" => true,
411            "lexical_declaration" => lexical_declaration_has_function_initializer(node, byte_pos),
412            _ => false,
413        },
414        LangId::Python => node.kind() == "function_definition",
415        _ => false,
416    }
417}
418
419fn lexical_declaration_has_function_initializer(node: &Node, byte_pos: usize) -> bool {
420    let mut cursor = node.walk();
421    if cursor.goto_first_child() {
422        loop {
423            let child = cursor.node();
424            if child.kind() == "variable_declarator" {
425                if let Some(value) = child.child_by_field_name("value") {
426                    if matches!(value.kind(), "arrow_function" | "function_expression")
427                        && child.start_byte() <= byte_pos
428                        && byte_pos < child.end_byte()
429                    {
430                        return true;
431                    }
432                }
433            }
434            if !cursor.goto_next_sibling() {
435                break;
436            }
437        }
438    }
439
440    false
441}
442
443/// Check if `this` (JS/TS) or `self` (Python) appears in the byte range.
444fn check_this_or_self(
445    node: &Node,
446    source: &str,
447    start_byte: usize,
448    end_byte: usize,
449    lang: LangId,
450) -> bool {
451    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
452        return false;
453    }
454
455    if node.start_byte() >= start_byte && node.end_byte() <= end_byte {
456        let kind = node.kind();
457        match lang {
458            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
459                if kind == "this" {
460                    return true;
461                }
462            }
463            LangId::Python => {
464                if kind == "identifier" && node_text(source, node) == "self" {
465                    // Check it's not a parameter declaration (like `def foo(self):`)
466                    if let Some(parent) = node.parent() {
467                        if parent.kind() == "parameters" {
468                            return false;
469                        }
470                    }
471                    return true;
472                }
473            }
474            _ => {}
475        }
476    }
477
478    let mut cursor = node.walk();
479    if cursor.goto_first_child() {
480        loop {
481            if check_this_or_self(&cursor.node(), source, start_byte, end_byte, lang) {
482                return true;
483            }
484            if !cursor.goto_next_sibling() {
485                break;
486            }
487        }
488    }
489
490    false
491}
492
493// ---------------------------------------------------------------------------
494// Return value detection
495// ---------------------------------------------------------------------------
496
497/// What the extracted function should return.
498#[derive(Debug, Clone, PartialEq, Eq)]
499pub enum ReturnKind {
500    /// The range contains an explicit `return expr;` → use that expression
501    Expression(String),
502    /// A variable declared in-range is used after the range in the enclosing function
503    Variable(String),
504    /// Nothing needs to be returned (void)
505    Void,
506}
507
508const RETURN_VARIABLE_ASSIGNMENT_PREFIX: &str = "\0assignment:";
509
510#[derive(Debug, Clone, Copy, PartialEq, Eq)]
511enum JsDeclarationKind {
512    Const,
513    Let,
514    Var,
515    Assignment,
516}
517
518#[derive(Debug, Clone, PartialEq, Eq)]
519struct ReturnVariableBinding {
520    name: String,
521    js_kind: JsDeclarationKind,
522}
523
524impl ReturnVariableBinding {
525    fn encoded_for_return_kind(&self) -> String {
526        match self.js_kind {
527            JsDeclarationKind::Const => self.name.clone(),
528            JsDeclarationKind::Let => format!("let {}", self.name),
529            JsDeclarationKind::Var => format!("var {}", self.name),
530            JsDeclarationKind::Assignment => {
531                format!("{}{}", RETURN_VARIABLE_ASSIGNMENT_PREFIX, self.name)
532            }
533        }
534    }
535}
536
537fn parse_return_variable(var: &str) -> ReturnVariableBinding {
538    if let Some(name) = var.strip_prefix(RETURN_VARIABLE_ASSIGNMENT_PREFIX) {
539        return ReturnVariableBinding {
540            name: name.to_string(),
541            js_kind: JsDeclarationKind::Assignment,
542        };
543    }
544
545    for (prefix, js_kind) in [
546        ("let ", JsDeclarationKind::Let),
547        ("var ", JsDeclarationKind::Var),
548        ("const ", JsDeclarationKind::Const),
549    ] {
550        if let Some(name) = var.strip_prefix(prefix) {
551            return ReturnVariableBinding {
552                name: name.to_string(),
553                js_kind,
554            };
555        }
556    }
557
558    ReturnVariableBinding {
559        name: var.to_string(),
560        js_kind: JsDeclarationKind::Const,
561    }
562}
563
564/// Detect what the extracted code range should return.
565///
566/// 1. If there's an explicit `return` statement in the range, use its expression.
567/// 2. If a variable declared in-range is referenced after the range (but within
568///    the enclosing function), that variable becomes the return value.
569/// 3. Otherwise, void.
570pub fn detect_return_value(
571    source: &str,
572    tree: &Tree,
573    start_byte: usize,
574    end_byte: usize,
575    enclosing_fn_end_byte: Option<usize>,
576    lang: LangId,
577) -> ReturnKind {
578    let root = tree.root_node();
579    let effective_enclosing_fn_end_byte = find_enclosing_function(&root, start_byte, lang)
580        .map(|node| node.end_byte())
581        .or(enclosing_fn_end_byte);
582
583    // Check for explicit return statements in the range
584    if let Some(expr) = find_return_in_range(&root, source, start_byte, end_byte) {
585        return ReturnKind::Expression(expr);
586    }
587
588    let in_range_bindings =
589        collect_return_bindings_in_range(&root, source, start_byte, end_byte, lang);
590
591    // Check if any in-range declaration is used after the range in the enclosing fn
592    if let Some(fn_end) = effective_enclosing_fn_end_byte {
593        let post_range_end = fn_end.min(source.len());
594        if end_byte < post_range_end {
595            let mut post_refs: Vec<String> = Vec::new();
596            collect_identifier_refs(
597                &root,
598                source,
599                end_byte,
600                post_range_end,
601                lang,
602                &mut post_refs,
603            );
604
605            for binding in &in_range_bindings {
606                if post_refs.contains(&binding.name) {
607                    return ReturnKind::Variable(binding.encoded_for_return_kind());
608                }
609            }
610        }
611    }
612
613    ReturnKind::Void
614}
615
616/// Collect in-range names that can be returned to preserve post-range uses.
617fn collect_return_bindings_in_range(
618    node: &Node,
619    source: &str,
620    start_byte: usize,
621    end_byte: usize,
622    lang: LangId,
623) -> Vec<ReturnVariableBinding> {
624    let mut bindings = Vec::new();
625    collect_return_bindings_recursive(node, source, start_byte, end_byte, lang, &mut bindings);
626    bindings
627}
628
629fn collect_return_bindings_recursive(
630    node: &Node,
631    source: &str,
632    start_byte: usize,
633    end_byte: usize,
634    lang: LangId,
635    out: &mut Vec<ReturnVariableBinding>,
636) {
637    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
638        return;
639    }
640
641    match lang {
642        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
643            if node.kind() == "variable_declarator" {
644                if let Some(name_node) = node.child_by_field_name("name") {
645                    if name_node.start_byte() >= start_byte && name_node.end_byte() <= end_byte {
646                        let name = node_text(source, &name_node).to_string();
647                        out.push(ReturnVariableBinding {
648                            name,
649                            js_kind: js_declaration_kind_for_declarator(node),
650                        });
651                    }
652                }
653            } else if is_assignment_node(node) {
654                if let Some(left) = node.child_by_field_name("left") {
655                    if left.kind() == "identifier"
656                        && left.start_byte() >= start_byte
657                        && left.end_byte() <= end_byte
658                    {
659                        out.push(ReturnVariableBinding {
660                            name: node_text(source, &left).to_string(),
661                            js_kind: JsDeclarationKind::Assignment,
662                        });
663                    }
664                }
665            }
666        }
667        LangId::Python => {
668            if node.kind() == "assignment" {
669                if let Some(left) = node.child_by_field_name("left") {
670                    if left.kind() == "identifier"
671                        && left.start_byte() >= start_byte
672                        && left.end_byte() <= end_byte
673                    {
674                        out.push(ReturnVariableBinding {
675                            name: node_text(source, &left).to_string(),
676                            js_kind: JsDeclarationKind::Assignment,
677                        });
678                    }
679                }
680            }
681        }
682        _ => {}
683    }
684
685    let child_count = node.child_count();
686    for i in 0..child_count {
687        if let Some(child) = node.child(i as u32) {
688            collect_return_bindings_recursive(&child, source, start_byte, end_byte, lang, out);
689        }
690    }
691}
692
693fn js_declaration_kind_for_declarator(node: &Node) -> JsDeclarationKind {
694    let Some(parent) = node.parent() else {
695        return JsDeclarationKind::Const;
696    };
697
698    match parent.kind() {
699        "variable_declaration" => JsDeclarationKind::Var,
700        "lexical_declaration" => {
701            let mut cursor = parent.walk();
702            if cursor.goto_first_child() {
703                loop {
704                    let child = cursor.node();
705                    match child.kind() {
706                        "let" => return JsDeclarationKind::Let,
707                        "const" => return JsDeclarationKind::Const,
708                        _ => {}
709                    }
710                    if !cursor.goto_next_sibling() {
711                        break;
712                    }
713                }
714            }
715            JsDeclarationKind::Const
716        }
717        _ => JsDeclarationKind::Const,
718    }
719}
720
721fn is_assignment_node(node: &Node) -> bool {
722    matches!(
723        node.kind(),
724        "assignment_expression" | "augmented_assignment_expression" | "assignment"
725    )
726}
727
728/// Find an explicit `return` statement in the byte range and return its expression text.
729fn find_return_in_range(
730    node: &Node,
731    source: &str,
732    start_byte: usize,
733    end_byte: usize,
734) -> Option<String> {
735    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
736        return None;
737    }
738
739    if node.kind() == "return_statement"
740        && node.start_byte() >= start_byte
741        && node.end_byte() <= end_byte
742    {
743        // Get the expression after "return"
744        let text = node_text(source, node).trim().to_string();
745        let expr = text
746            .strip_prefix("return")
747            .unwrap_or("")
748            .trim()
749            .trim_end_matches(';')
750            .trim()
751            .to_string();
752        if !expr.is_empty() {
753            return Some(expr);
754        }
755    }
756
757    let mut cursor = node.walk();
758    if cursor.goto_first_child() {
759        loop {
760            if let Some(result) = find_return_in_range(&cursor.node(), source, start_byte, end_byte)
761            {
762                return Some(result);
763            }
764            if !cursor.goto_next_sibling() {
765                break;
766            }
767        }
768    }
769
770    None
771}
772
773// ---------------------------------------------------------------------------
774// Function generation
775// ---------------------------------------------------------------------------
776
777/// Generate the text for an extracted function.
778pub fn generate_extracted_function(
779    name: &str,
780    params: &[String],
781    return_kind: &ReturnKind,
782    body_text: &str,
783    base_indent: &str,
784    lang: LangId,
785    indent_style: IndentStyle,
786) -> String {
787    let indent_unit = indent_style.as_str();
788
789    match lang {
790        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => generate_ts_function(
791            name,
792            params,
793            return_kind,
794            body_text,
795            base_indent,
796            indent_unit,
797        ),
798        LangId::Python => generate_py_function(
799            name,
800            params,
801            return_kind,
802            body_text,
803            base_indent,
804            indent_unit,
805        ),
806        _ => {
807            // Shouldn't reach here due to language guard, but produce something reasonable
808            generate_ts_function(
809                name,
810                params,
811                return_kind,
812                body_text,
813                base_indent,
814                indent_unit,
815            )
816        }
817    }
818}
819
820fn generate_ts_function(
821    name: &str,
822    params: &[String],
823    return_kind: &ReturnKind,
824    body_text: &str,
825    base_indent: &str,
826    indent_unit: &str,
827) -> String {
828    let params_str = params.join(", ");
829    let mut lines = Vec::new();
830
831    lines.push(format!(
832        "{}function {}({}) {{",
833        base_indent, name, params_str
834    ));
835
836    // Re-indent body to be inside the function while preserving relative nesting.
837    let common_indent = common_leading_indent(body_text);
838    for line in body_text.lines() {
839        if line.trim().is_empty() {
840            lines.push(String::new());
841        } else {
842            let body_line = strip_leading_indent(line, &common_indent);
843            lines.push(format!("{}{}{}", base_indent, indent_unit, body_line));
844        }
845    }
846
847    // Add return statement if needed
848    match return_kind {
849        ReturnKind::Variable(var) => {
850            let binding = parse_return_variable(var);
851            lines.push(format!(
852                "{}{}return {};",
853                base_indent, indent_unit, binding.name
854            ));
855        }
856        ReturnKind::Expression(_) => {
857            // The return is already in the body text
858        }
859        ReturnKind::Void => {}
860    }
861
862    lines.push(format!("{}}}", base_indent));
863    lines.join("\n")
864}
865
866fn generate_py_function(
867    name: &str,
868    params: &[String],
869    return_kind: &ReturnKind,
870    body_text: &str,
871    base_indent: &str,
872    indent_unit: &str,
873) -> String {
874    let params_str = params.join(", ");
875    let mut lines = Vec::new();
876
877    lines.push(format!("{}def {}({}):", base_indent, name, params_str));
878
879    // Re-indent body while preserving relative nesting.
880    let common_indent = common_leading_indent(body_text);
881    for line in body_text.lines() {
882        if line.trim().is_empty() {
883            lines.push(String::new());
884        } else {
885            let body_line = strip_leading_indent(line, &common_indent);
886            lines.push(format!("{}{}{}", base_indent, indent_unit, body_line));
887        }
888    }
889
890    // Add return statement if needed
891    match return_kind {
892        ReturnKind::Variable(var) => {
893            let binding = parse_return_variable(var);
894            lines.push(format!(
895                "{}{}return {}",
896                base_indent, indent_unit, binding.name
897            ));
898        }
899        ReturnKind::Expression(_) => {
900            // Already in body
901        }
902        ReturnKind::Void => {}
903    }
904
905    lines.join("\n")
906}
907
908fn common_leading_indent(text: &str) -> String {
909    let mut lines = text.lines().filter(|line| !line.trim().is_empty());
910    let Some(first) = lines.next() else {
911        return String::new();
912    };
913
914    let mut common = leading_whitespace(first).to_string();
915    for line in lines {
916        let indent = leading_whitespace(line);
917        let common_len = common
918            .char_indices()
919            .zip(indent.char_indices())
920            .take_while(|((_, left), (_, right))| left == right)
921            .map(|((idx, ch), _)| idx + ch.len_utf8())
922            .last()
923            .unwrap_or(0);
924        common.truncate(common_len);
925        if common.is_empty() {
926            break;
927        }
928    }
929
930    common
931}
932
933fn leading_whitespace(line: &str) -> &str {
934    let trimmed = line.trim_start_matches(|ch: char| ch == ' ' || ch == '\t');
935    &line[..line.len() - trimmed.len()]
936}
937
938fn strip_leading_indent<'a>(line: &'a str, indent: &str) -> &'a str {
939    if indent.is_empty() {
940        line
941    } else {
942        line.strip_prefix(indent).unwrap_or(line)
943    }
944}
945
946/// Generate the call site text that replaces the extracted range.
947pub fn generate_call_site(
948    name: &str,
949    params: &[String],
950    return_kind: &ReturnKind,
951    indent: &str,
952    lang: LangId,
953) -> String {
954    let args_str = params.join(", ");
955
956    match return_kind {
957        ReturnKind::Variable(var) => match lang {
958            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
959                let binding = parse_return_variable(var);
960                match binding.js_kind {
961                    JsDeclarationKind::Const => {
962                        format!("{}const {} = {}({});", indent, binding.name, name, args_str)
963                    }
964                    JsDeclarationKind::Let => {
965                        format!("{}let {} = {}({});", indent, binding.name, name, args_str)
966                    }
967                    JsDeclarationKind::Var => {
968                        format!("{}var {} = {}({});", indent, binding.name, name, args_str)
969                    }
970                    JsDeclarationKind::Assignment => {
971                        format!("{}{} = {}({});", indent, binding.name, name, args_str)
972                    }
973                }
974            }
975            LangId::Python => {
976                let binding = parse_return_variable(var);
977                format!("{}{} = {}({})", indent, binding.name, name, args_str)
978            }
979            _ => format!("{}const {} = {}({});", indent, var, name, args_str),
980        },
981        ReturnKind::Expression(_expr) => match lang {
982            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
983                format!("{}return {}({});", indent, name, args_str)
984            }
985            LangId::Python => {
986                format!("{}return {}({})", indent, name, args_str)
987            }
988            _ => format!("{}return {}({});", indent, name, args_str),
989        },
990        ReturnKind::Void => match lang {
991            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
992                format!("{}{}({});", indent, name, args_str)
993            }
994            LangId::Python => {
995                format!("{}{}({})", indent, name, args_str)
996            }
997            _ => format!("{}{}({});", indent, name, args_str),
998        },
999    }
1000}
1001
1002// ---------------------------------------------------------------------------
1003// Inline symbol utilities
1004// ---------------------------------------------------------------------------
1005
1006/// A detected scope conflict when inlining a function body at a call site.
1007#[derive(Debug, Clone, PartialEq, Eq)]
1008pub struct ScopeConflict {
1009    /// The variable name that conflicts.
1010    pub name: String,
1011    /// Suggested alternative name to avoid the conflict.
1012    pub suggested: String,
1013}
1014
1015/// Detect scope conflicts between the call site scope and the function body
1016/// being inlined.
1017///
1018/// Collects all variable declarations at the call site's scope level
1019/// (surrounding function body), then checks for collisions with variables
1020/// declared in `body_text`.
1021pub fn detect_scope_conflicts(
1022    source: &str,
1023    tree: &Tree,
1024    insertion_byte: usize,
1025    _param_names: &[String],
1026    body_text: &str,
1027    lang: LangId,
1028) -> Vec<ScopeConflict> {
1029    let root = tree.root_node();
1030
1031    // 1. Find the enclosing function at the call site
1032    let enclosing_fn = find_enclosing_function(&root, insertion_byte, lang);
1033
1034    // 2. Collect all declarations in the call site's scope
1035    let mut scope_decls: HashSet<String> = HashSet::new();
1036    if let Some(fn_node) = enclosing_fn {
1037        collect_declarations_in_range(
1038            &fn_node,
1039            source,
1040            fn_node.start_byte(),
1041            fn_node.end_byte(),
1042            lang,
1043            &mut scope_decls,
1044        );
1045        collect_function_params(&fn_node, source, lang, &mut scope_decls);
1046    } else {
1047        // Module-level: collect all top-level declarations
1048        collect_declarations_in_range(
1049            &root,
1050            source,
1051            root.start_byte(),
1052            root.end_byte(),
1053            lang,
1054            &mut scope_decls,
1055        );
1056    }
1057
1058    // 3. Collect declarations in the body being inlined
1059    let mut body_decls: HashSet<String> = HashSet::new();
1060    let body_grammar = grammar_for(lang);
1061    let mut body_parser = tree_sitter::Parser::new();
1062    if body_parser.set_language(&body_grammar).is_ok() {
1063        if let Some(body_tree) = body_parser.parse(body_text.as_bytes(), None) {
1064            let body_root = body_tree.root_node();
1065            collect_declarations_in_range(
1066                &body_root,
1067                body_text,
1068                0,
1069                body_text.len(),
1070                lang,
1071                &mut body_decls,
1072            );
1073        }
1074    }
1075
1076    // 4. Find collisions
1077    let mut conflicts = Vec::new();
1078    for decl in &body_decls {
1079        if scope_decls.contains(decl) {
1080            conflicts.push(ScopeConflict {
1081                name: decl.clone(),
1082                suggested: format!("{}_inlined", decl),
1083            });
1084        }
1085    }
1086
1087    // Sort for deterministic output
1088    conflicts.sort_by(|a, b| a.name.cmp(&b.name));
1089    conflicts
1090}
1091
1092/// Validate that a function has at most one return statement (suitable for inlining).
1093///
1094/// - Arrow functions with expression bodies (no `return` keyword) → valid (single-return)
1095/// - Functions with 0 returns (void) → valid
1096/// - Functions with exactly 1 return → valid
1097/// - Functions with >1 return → invalid, returns the count
1098pub fn validate_single_return(
1099    source: &str,
1100    _tree: &Tree,
1101    fn_node: &Node,
1102    lang: LangId,
1103) -> Result<(), usize> {
1104    // Arrow functions with expression bodies are always single-return
1105    if lang != LangId::Python && fn_node.kind() == "arrow_function" {
1106        if let Some(body) = fn_node.child_by_field_name("body") {
1107            if body.kind() != "statement_block" {
1108                // Expression body — implicitly single-return
1109                return Ok(());
1110            }
1111        }
1112    }
1113
1114    let count = count_return_statements(fn_node, source);
1115    if count > 1 {
1116        Err(count)
1117    } else {
1118        Ok(())
1119    }
1120}
1121
1122/// Count `return_statement` nodes in a function body (non-recursive into nested functions).
1123fn count_return_statements(node: &Node, source: &str) -> usize {
1124    let _ = source;
1125    let mut count = 0;
1126
1127    // Don't count returns in nested function bodies
1128    let nested_fn_kinds = [
1129        "function_declaration",
1130        "function_definition",
1131        "arrow_function",
1132        "method_definition",
1133    ];
1134
1135    let kind = node.kind();
1136    if kind == "return_statement" {
1137        return 1;
1138    }
1139
1140    let child_count = node.child_count();
1141    for i in 0..child_count {
1142        if let Some(child) = node.child(i as u32) {
1143            // Skip nested function definitions
1144            if nested_fn_kinds.contains(&child.kind()) {
1145                continue;
1146            }
1147            count += count_return_statements(&child, source);
1148        }
1149    }
1150
1151    count
1152}
1153
1154/// Substitute parameter names with argument expressions in a function body.
1155///
1156/// Uses tree-sitter to find `identifier` nodes matching parameter names,
1157/// replacing from end to start to preserve byte offsets. Only replaces
1158/// whole-word matches (identifiers, not substrings).
1159pub fn substitute_params(
1160    body_text: &str,
1161    param_to_arg: &std::collections::HashMap<String, String>,
1162    lang: LangId,
1163) -> String {
1164    if param_to_arg.is_empty() {
1165        return body_text.to_string();
1166    }
1167
1168    let grammar = grammar_for(lang);
1169    let mut parser = tree_sitter::Parser::new();
1170    if parser.set_language(&grammar).is_err() {
1171        return body_text.to_string();
1172    }
1173
1174    let tree = match parser.parse(body_text.as_bytes(), None) {
1175        Some(t) => t,
1176        None => return body_text.to_string(),
1177    };
1178
1179    // Collect identifier references that are still bound to the inlined function's
1180    // parameters. Nested functions and local shadowing declarations are skipped.
1181    let mut replacements: Vec<(usize, usize, String)> = Vec::new();
1182    let shadowed = HashSet::new();
1183    collect_param_replacements(
1184        &tree.root_node(),
1185        body_text,
1186        param_to_arg,
1187        lang,
1188        &shadowed,
1189        true,
1190        &mut replacements,
1191    );
1192
1193    // Sort by start position descending so replacements don't shift offsets
1194    replacements.sort_by(|a, b| b.0.cmp(&a.0));
1195
1196    let mut result = body_text.to_string();
1197    for (start, end, replacement) in replacements {
1198        result = format!("{}{}{}", &result[..start], replacement, &result[end..]);
1199    }
1200
1201    result
1202}
1203
1204/// Collect identifier nodes that match parameter names for substitution.
1205fn collect_param_replacements(
1206    node: &Node,
1207    source: &str,
1208    param_to_arg: &std::collections::HashMap<String, String>,
1209    lang: LangId,
1210    shadowed: &HashSet<String>,
1211    is_root: bool,
1212    out: &mut Vec<(usize, usize, String)>,
1213) {
1214    if !is_root && is_function_scope_node(node, lang) {
1215        return;
1216    }
1217
1218    let mut current_shadowed = shadowed.clone();
1219    collect_shadowing_bindings_in_scope(node, source, param_to_arg, lang, &mut current_shadowed);
1220
1221    let kind = node.kind();
1222
1223    if kind == "identifier" {
1224        // Check it's not a property access
1225        if !is_property_access(node, lang) && !is_binding_identifier(node) {
1226            let name = node_text(source, node);
1227            if !current_shadowed.contains(name) {
1228                if let Some(replacement) = param_to_arg.get(name) {
1229                    out.push((node.start_byte(), node.end_byte(), replacement.clone()));
1230                }
1231            }
1232        }
1233    }
1234
1235    // Also handle Python-specific name node
1236    // Recurse into children
1237    let child_count = node.child_count();
1238    for i in 0..child_count {
1239        if let Some(child) = node.child(i as u32) {
1240            collect_param_replacements(
1241                &child,
1242                source,
1243                param_to_arg,
1244                lang,
1245                &current_shadowed,
1246                false,
1247                out,
1248            );
1249        }
1250    }
1251}
1252
1253fn collect_shadowing_bindings_in_scope(
1254    scope: &Node,
1255    source: &str,
1256    param_to_arg: &std::collections::HashMap<String, String>,
1257    lang: LangId,
1258    out: &mut HashSet<String>,
1259) {
1260    collect_shadowing_bindings_in_scope_recursive(
1261        scope,
1262        scope.id(),
1263        source,
1264        param_to_arg,
1265        lang,
1266        out,
1267    );
1268}
1269
1270fn collect_shadowing_bindings_in_scope_recursive(
1271    node: &Node,
1272    scope_id: usize,
1273    source: &str,
1274    param_to_arg: &std::collections::HashMap<String, String>,
1275    lang: LangId,
1276    out: &mut HashSet<String>,
1277) {
1278    if node.id() != scope_id {
1279        if is_function_scope_node(node, lang) || is_block_scope_node(node, lang) {
1280            return;
1281        }
1282    }
1283
1284    match node.kind() {
1285        "variable_declarator" => {
1286            if let Some(name) = node.child_by_field_name("name") {
1287                collect_shadowing_names_from_pattern(&name, source, param_to_arg, out);
1288            }
1289        }
1290        "catch_clause" => {
1291            if let Some(parameter) = node.child_by_field_name("parameter") {
1292                collect_shadowing_names_from_pattern(&parameter, source, param_to_arg, out);
1293            }
1294        }
1295        "for_in_statement" | "for_of_statement" => {
1296            if let Some(left) = node.child_by_field_name("left") {
1297                collect_shadowing_names_from_pattern(&left, source, param_to_arg, out);
1298            }
1299        }
1300        "assignment" if lang == LangId::Python => {
1301            if let Some(left) = node.child_by_field_name("left") {
1302                collect_shadowing_names_from_pattern(&left, source, param_to_arg, out);
1303            }
1304        }
1305        _ => {}
1306    }
1307
1308    let child_count = node.child_count();
1309    for i in 0..child_count {
1310        if let Some(child) = node.child(i as u32) {
1311            collect_shadowing_bindings_in_scope_recursive(
1312                &child,
1313                scope_id,
1314                source,
1315                param_to_arg,
1316                lang,
1317                out,
1318            );
1319        }
1320    }
1321}
1322
1323fn collect_shadowing_names_from_pattern(
1324    node: &Node,
1325    source: &str,
1326    param_to_arg: &std::collections::HashMap<String, String>,
1327    out: &mut HashSet<String>,
1328) {
1329    if node.kind() == "identifier" {
1330        let name = node_text(source, node);
1331        if param_to_arg.contains_key(name) {
1332            out.insert(name.to_string());
1333        }
1334        return;
1335    }
1336
1337    let child_count = node.child_count();
1338    for i in 0..child_count {
1339        if let Some(child) = node.child(i as u32) {
1340            collect_shadowing_names_from_pattern(&child, source, param_to_arg, out);
1341        }
1342    }
1343}
1344
1345fn is_function_scope_node(node: &Node, lang: LangId) -> bool {
1346    match lang {
1347        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => matches!(
1348            node.kind(),
1349            "function_declaration" | "method_definition" | "arrow_function" | "function_expression"
1350        ),
1351        LangId::Python => node.kind() == "function_definition" || node.kind() == "lambda",
1352        _ => false,
1353    }
1354}
1355
1356fn is_block_scope_node(node: &Node, lang: LangId) -> bool {
1357    match lang {
1358        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => node.kind() == "statement_block",
1359        LangId::Python => node.kind() == "block",
1360        _ => false,
1361    }
1362}
1363
1364fn is_binding_identifier(node: &Node) -> bool {
1365    let Some(parent) = node.parent() else {
1366        return false;
1367    };
1368
1369    if let Some(name) = parent.child_by_field_name("name") {
1370        if name.id() == node.id() || node_is_inside(&name, node) {
1371            return true;
1372        }
1373    }
1374    if let Some(pattern) = parent.child_by_field_name("pattern") {
1375        if pattern.id() == node.id() || node_is_inside(&pattern, node) {
1376            return true;
1377        }
1378    }
1379    if let Some(parameter) = parent.child_by_field_name("parameter") {
1380        if parameter.id() == node.id() || node_is_inside(&parameter, node) {
1381            return true;
1382        }
1383    }
1384    if let Some(left) = parent.child_by_field_name("left") {
1385        if matches!(
1386            parent.kind(),
1387            "for_in_statement" | "for_of_statement" | "assignment"
1388        ) && (left.id() == node.id() || node_is_inside(&left, node))
1389        {
1390            return true;
1391        }
1392    }
1393
1394    false
1395}
1396
1397fn node_is_inside(container: &Node, node: &Node) -> bool {
1398    container.start_byte() <= node.start_byte() && node.end_byte() <= container.end_byte()
1399}
1400
1401// ---------------------------------------------------------------------------
1402// Tests
1403// ---------------------------------------------------------------------------
1404
1405#[cfg(test)]
1406mod tests {
1407    use super::*;
1408    use crate::parser::grammar_for;
1409    use std::path::PathBuf;
1410    use tree_sitter::Parser;
1411
1412    fn fixture_path(name: &str) -> PathBuf {
1413        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
1414            .join("tests")
1415            .join("fixtures")
1416            .join("extract_function")
1417            .join(name)
1418    }
1419
1420    fn parse_source(source: &str, lang: LangId) -> Tree {
1421        let grammar = grammar_for(lang);
1422        let mut parser = Parser::new();
1423        parser.set_language(&grammar).unwrap();
1424        parser.parse(source.as_bytes(), None).unwrap()
1425    }
1426
1427    // --- Free variable detection: simple identifiers ---
1428
1429    #[test]
1430    fn free_vars_detects_enclosing_function_params() {
1431        // `items` and `prefix` are function params → should be detected as free variables
1432        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1433        let tree = parse_source(&source, LangId::TypeScript);
1434
1435        // Lines 5-8 (0-indexed): the body of processData that uses `items` and `prefix`
1436        // "  const filtered = items.filter(item => item.length > 0);"
1437        // "  const mapped = filtered.map(item => prefix + item);"
1438        // These lines reference `items` and `prefix` from the function params.
1439        let line5_start = crate::edit::line_col_to_byte(&source, 5, 0);
1440        let line6_end = crate::edit::line_col_to_byte(&source, 7, 0);
1441
1442        let result =
1443            detect_free_variables(&source, &tree, line5_start, line6_end, LangId::TypeScript);
1444        assert!(
1445            result.parameters.contains(&"items".to_string()),
1446            "should detect 'items' as parameter, got: {:?}",
1447            result.parameters
1448        );
1449        assert!(
1450            result.parameters.contains(&"prefix".to_string()),
1451            "should detect 'prefix' as parameter, got: {:?}",
1452            result.parameters
1453        );
1454        assert!(!result.has_this_or_self);
1455    }
1456
1457    // --- Property access filtering ---
1458
1459    #[test]
1460    fn free_vars_filters_property_identifiers() {
1461        // In `items.filter(...)`, `filter` should NOT be a free variable.
1462        // In `item.length`, `length` should NOT be a free variable.
1463        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1464        let tree = parse_source(&source, LangId::TypeScript);
1465
1466        let line5_start = crate::edit::line_col_to_byte(&source, 5, 0);
1467        let line6_end = crate::edit::line_col_to_byte(&source, 7, 0);
1468
1469        let result =
1470            detect_free_variables(&source, &tree, line5_start, line6_end, LangId::TypeScript);
1471        // "filter", "map", "length" should NOT appear
1472        assert!(
1473            !result.parameters.contains(&"filter".to_string()),
1474            "property 'filter' should not be a free variable"
1475        );
1476        assert!(
1477            !result.parameters.contains(&"length".to_string()),
1478            "property 'length' should not be a free variable"
1479        );
1480        assert!(
1481            !result.parameters.contains(&"map".to_string()),
1482            "property 'map' should not be a free variable"
1483        );
1484    }
1485
1486    // --- Module-level vs function-level classification ---
1487
1488    #[test]
1489    fn free_vars_skips_module_level_refs() {
1490        // `BASE_URL` is module-level → should NOT be a parameter
1491        // `console` is a global → should NOT be a parameter
1492        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1493        let tree = parse_source(&source, LangId::TypeScript);
1494
1495        // processData body: lines 5-9
1496        let start = crate::edit::line_col_to_byte(&source, 5, 0);
1497        let end = crate::edit::line_col_to_byte(&source, 10, 0);
1498
1499        let result = detect_free_variables(&source, &tree, start, end, LangId::TypeScript);
1500        assert!(
1501            !result.parameters.contains(&"BASE_URL".to_string()),
1502            "module-level 'BASE_URL' should not be a parameter, got: {:?}",
1503            result.parameters
1504        );
1505        assert!(
1506            !result.parameters.contains(&"console".to_string()),
1507            "'console' should not be a parameter, got: {:?}",
1508            result.parameters
1509        );
1510    }
1511
1512    #[test]
1513    fn free_vars_plain_lexical_declaration_uses_real_enclosing_function() {
1514        let source = "function f(a: number) {\n  const x = a + 1;\n  return x;\n}\n";
1515        let tree = parse_source(source, LangId::TypeScript);
1516        let start = crate::edit::line_col_to_byte(source, 1, 0);
1517        let end = crate::edit::line_col_to_byte(source, 2, 0);
1518
1519        let result = detect_free_variables(source, &tree, start, end, LangId::TypeScript);
1520        assert!(
1521            result.parameters.contains(&"a".to_string()),
1522            "plain const declaration should not stop enclosing-function lookup: {:?}",
1523            result.parameters
1524        );
1525    }
1526
1527    // --- this/self detection ---
1528
1529    #[test]
1530    fn free_vars_detects_this_in_ts() {
1531        let source = std::fs::read_to_string(fixture_path("sample_this.ts")).unwrap();
1532        let tree = parse_source(&source, LangId::TypeScript);
1533
1534        // getUser method body lines 4-6 contain `this.users.get(key)`
1535        let start = crate::edit::line_col_to_byte(&source, 4, 0);
1536        let end = crate::edit::line_col_to_byte(&source, 7, 0);
1537
1538        let result = detect_free_variables(&source, &tree, start, end, LangId::TypeScript);
1539        assert!(result.has_this_or_self, "should detect 'this' reference");
1540    }
1541
1542    #[test]
1543    fn free_vars_detects_self_in_python() {
1544        let source = r#"
1545class UserService:
1546    def get_user(self, id):
1547        key = id.lower()
1548        user = self.users.get(key)
1549        return user
1550"#;
1551        let tree = parse_source(source, LangId::Python);
1552
1553        // Lines 3-4 (0-indexed) contain `self.users.get(key)`
1554        let start = crate::edit::line_col_to_byte(source, 4, 0);
1555        let end = crate::edit::line_col_to_byte(source, 5, 0);
1556
1557        let result = detect_free_variables(source, &tree, start, end, LangId::Python);
1558        assert!(result.has_this_or_self, "should detect 'self' reference");
1559    }
1560
1561    // --- Return value detection ---
1562
1563    #[test]
1564    fn return_value_explicit_return() {
1565        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1566        let tree = parse_source(&source, LangId::TypeScript);
1567
1568        // simpleHelper: lines 13-16 — contains "return added;"
1569        let start = crate::edit::line_col_to_byte(&source, 14, 0);
1570        let end = crate::edit::line_col_to_byte(&source, 17, 0);
1571
1572        let result = detect_return_value(&source, &tree, start, end, None, LangId::TypeScript);
1573        assert_eq!(result, ReturnKind::Expression("added".to_string()));
1574    }
1575
1576    #[test]
1577    fn return_value_post_range_usage() {
1578        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1579        let tree = parse_source(&source, LangId::TypeScript);
1580
1581        // processData lines 5-7: declares `filtered`, `mapped`
1582        // Lines 7-9 use `result` (which comes after mapped), but line 7 declares `result`
1583        // Let's extract lines 5-6 only: `filtered` and `mapped` are declared
1584        // and `filtered` is used on line 6, `mapped` is used on line 7
1585        let start = crate::edit::line_col_to_byte(&source, 5, 0);
1586        let end = crate::edit::line_col_to_byte(&source, 6, 0);
1587
1588        // Enclosing function ends around line 10
1589        let fn_end = crate::edit::line_col_to_byte(&source, 10, 0);
1590
1591        let result =
1592            detect_return_value(&source, &tree, start, end, Some(fn_end), LangId::TypeScript);
1593        // `filtered` is declared in-range and used after the range
1594        assert_eq!(result, ReturnKind::Variable("filtered".to_string()));
1595    }
1596
1597    #[test]
1598    fn return_value_void() {
1599        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1600        let tree = parse_source(&source, LangId::TypeScript);
1601
1602        // voidWork lines 20-21: no return, `greeting` is only used within
1603        let start = crate::edit::line_col_to_byte(&source, 20, 0);
1604        let end = crate::edit::line_col_to_byte(&source, 22, 0);
1605
1606        let result = detect_return_value(
1607            &source,
1608            &tree,
1609            start,
1610            end,
1611            Some(crate::edit::line_col_to_byte(&source, 23, 0)),
1612            LangId::TypeScript,
1613        );
1614        assert_eq!(result, ReturnKind::Void);
1615    }
1616
1617    // --- Function generation ---
1618
1619    #[test]
1620    fn generate_ts_function_with_params() {
1621        let body = "const doubled = x * 2;\nconst added = doubled + 10;";
1622        let result = generate_extracted_function(
1623            "compute",
1624            &["x".to_string()],
1625            &ReturnKind::Variable("added".to_string()),
1626            body,
1627            "",
1628            LangId::TypeScript,
1629            IndentStyle::Spaces(2),
1630        );
1631        assert!(result.contains("function compute(x)"));
1632        assert!(result.contains("return added;"));
1633        assert!(result.contains("}"));
1634    }
1635
1636    #[test]
1637    fn generate_ts_function_preserves_relative_indentation() {
1638        let body = "  for (const item of items) {\n    if (item.active) {\n      console.log(item.name);\n    }\n  }";
1639        let result = generate_extracted_function(
1640            "processItems",
1641            &["items".to_string()],
1642            &ReturnKind::Void,
1643            body,
1644            "",
1645            LangId::TypeScript,
1646            IndentStyle::Spaces(2),
1647        );
1648        assert_eq!(
1649            result,
1650            "function processItems(items) {\n  for (const item of items) {\n    if (item.active) {\n      console.log(item.name);\n    }\n  }\n}"
1651        );
1652    }
1653
1654    #[test]
1655    fn generate_py_function_with_params() {
1656        let body = "doubled = x * 2\nadded = doubled + 10";
1657        let result = generate_extracted_function(
1658            "compute",
1659            &["x".to_string()],
1660            &ReturnKind::Variable("added".to_string()),
1661            body,
1662            "",
1663            LangId::Python,
1664            IndentStyle::Spaces(4),
1665        );
1666        assert!(result.contains("def compute(x):"));
1667        assert!(result.contains("return added"));
1668    }
1669
1670    #[test]
1671    fn generate_call_site_with_return_var() {
1672        let call = generate_call_site(
1673            "compute",
1674            &["x".to_string()],
1675            &ReturnKind::Variable("result".to_string()),
1676            "  ",
1677            LangId::TypeScript,
1678        );
1679        assert_eq!(call, "  const result = compute(x);");
1680    }
1681
1682    #[test]
1683    fn generate_call_site_preserves_let_return_var() {
1684        let call = generate_call_site(
1685            "compute",
1686            &[],
1687            &ReturnKind::Variable("let result".to_string()),
1688            "  ",
1689            LangId::TypeScript,
1690        );
1691        assert_eq!(call, "  let result = compute();");
1692    }
1693
1694    #[test]
1695    fn generate_call_site_void() {
1696        let call = generate_call_site(
1697            "doWork",
1698            &["a".to_string(), "b".to_string()],
1699            &ReturnKind::Void,
1700            "  ",
1701            LangId::TypeScript,
1702        );
1703        assert_eq!(call, "  doWork(a, b);");
1704    }
1705
1706    #[test]
1707    fn generate_call_site_return_expression() {
1708        let call = generate_call_site(
1709            "compute",
1710            &["x".to_string()],
1711            &ReturnKind::Expression("x * 2".to_string()),
1712            "  ",
1713            LangId::TypeScript,
1714        );
1715        assert_eq!(call, "  return compute(x);");
1716    }
1717
1718    // --- Python free variables ---
1719
1720    #[test]
1721    fn free_vars_python_function_params() {
1722        let source = std::fs::read_to_string(fixture_path("sample.py")).unwrap();
1723        let tree = parse_source(&source, LangId::Python);
1724
1725        // process_data body: lines 5-8 reference `items` and `prefix`
1726        let start = crate::edit::line_col_to_byte(&source, 5, 0);
1727        let end = crate::edit::line_col_to_byte(&source, 7, 0);
1728
1729        let result = detect_free_variables(&source, &tree, start, end, LangId::Python);
1730        assert!(
1731            result.parameters.contains(&"items".to_string()),
1732            "should detect 'items': {:?}",
1733            result.parameters
1734        );
1735        assert!(
1736            result.parameters.contains(&"prefix".to_string()),
1737            "should detect 'prefix': {:?}",
1738            result.parameters
1739        );
1740        assert!(!result.has_this_or_self);
1741    }
1742
1743    // --- validate_single_return ---
1744
1745    #[test]
1746    fn validate_single_return_single() {
1747        let source =
1748            "function add(a: number, b: number): number {\n  const sum = a + b;\n  return sum;\n}";
1749        let tree = parse_source(source, LangId::TypeScript);
1750        let root = tree.root_node();
1751        let fn_node = root.child(0).unwrap(); // function_declaration
1752        assert!(validate_single_return(source, &tree, &fn_node, LangId::TypeScript).is_ok());
1753    }
1754
1755    #[test]
1756    fn validate_single_return_void() {
1757        let source = "function greet(name: string): void {\n  console.log(name);\n}";
1758        let tree = parse_source(source, LangId::TypeScript);
1759        let root = tree.root_node();
1760        let fn_node = root.child(0).unwrap();
1761        assert!(validate_single_return(source, &tree, &fn_node, LangId::TypeScript).is_ok());
1762    }
1763
1764    #[test]
1765    fn validate_single_return_expression_body() {
1766        let source = "const double = (n: number): number => n * 2;";
1767        let tree = parse_source(source, LangId::TypeScript);
1768        let root = tree.root_node();
1769        // lexical_declaration > variable_declarator > arrow_function
1770        let lex_decl = root.child(0).unwrap();
1771        let var_decl = lex_decl.child(1).unwrap(); // variable_declarator
1772        let arrow = var_decl.child_by_field_name("value").unwrap();
1773        assert_eq!(arrow.kind(), "arrow_function");
1774        assert!(validate_single_return(source, &tree, &arrow, LangId::TypeScript).is_ok());
1775    }
1776
1777    #[test]
1778    fn validate_single_return_multiple() {
1779        let source = "function abs(x: number): number {\n  if (x > 0) {\n    return x;\n  }\n  return -x;\n}";
1780        let tree = parse_source(source, LangId::TypeScript);
1781        let root = tree.root_node();
1782        let fn_node = root.child(0).unwrap();
1783        let result = validate_single_return(source, &tree, &fn_node, LangId::TypeScript);
1784        assert!(result.is_err());
1785        assert_eq!(result.unwrap_err(), 2);
1786    }
1787
1788    // --- detect_scope_conflicts ---
1789
1790    #[test]
1791    fn scope_conflicts_none() {
1792        // No overlap between call site scope and body vars
1793        let source = "function main() {\n  const x = 10;\n  const y = add(x, 5);\n}";
1794        let tree = parse_source(source, LangId::TypeScript);
1795        let body_text = "const sum = a + b;";
1796        let call_byte = crate::edit::line_col_to_byte(source, 2, 0);
1797        let conflicts =
1798            detect_scope_conflicts(source, &tree, call_byte, &[], body_text, LangId::TypeScript);
1799        assert!(
1800            conflicts.is_empty(),
1801            "expected no conflicts, got: {:?}",
1802            conflicts
1803        );
1804    }
1805
1806    #[test]
1807    fn scope_conflicts_detected() {
1808        // `temp` exists at call site and inside body
1809        let source = "function main() {\n  const temp = 99;\n  const result = compute(5);\n}";
1810        let tree = parse_source(source, LangId::TypeScript);
1811        let body_text = "const temp = x * 2;\nconst result2 = temp + 10;";
1812        let call_byte = crate::edit::line_col_to_byte(source, 2, 0);
1813        let conflicts =
1814            detect_scope_conflicts(source, &tree, call_byte, &[], body_text, LangId::TypeScript);
1815        assert!(!conflicts.is_empty(), "expected conflict for 'temp'");
1816        assert!(
1817            conflicts.iter().any(|c| c.name == "temp"),
1818            "conflicts: {:?}",
1819            conflicts
1820        );
1821        assert!(
1822            conflicts.iter().any(|c| c.suggested == "temp_inlined"),
1823            "should suggest temp_inlined"
1824        );
1825    }
1826
1827    // --- substitute_params ---
1828
1829    #[test]
1830    fn substitute_params_basic() {
1831        let body = "const sum = a + b;";
1832        let mut map = std::collections::HashMap::new();
1833        map.insert("a".to_string(), "x".to_string());
1834        map.insert("b".to_string(), "y".to_string());
1835        let result = substitute_params(body, &map, LangId::TypeScript);
1836        assert_eq!(result, "const sum = x + y;");
1837    }
1838
1839    #[test]
1840    fn substitute_params_whole_word() {
1841        // Should NOT replace `i` inside `items`
1842        let body = "const result = items.filter(i => i > 0);";
1843        let mut map = std::collections::HashMap::new();
1844        map.insert("i".to_string(), "index".to_string());
1845        let result = substitute_params(body, &map, LangId::TypeScript);
1846        // `items` should be untouched, and the nested arrow's shadowing `i`
1847        // parameter/reference pair should also be left alone.
1848        assert_eq!(result, body);
1849    }
1850
1851    #[test]
1852    fn substitute_params_rewrites_outer_reference_not_shadowed_arrow_param() {
1853        let body = "return x + items.map(x => x + 1)[0];";
1854        let mut map = std::collections::HashMap::new();
1855        map.insert("x".to_string(), "5".to_string());
1856        let result = substitute_params(body, &map, LangId::TypeScript);
1857        assert_eq!(result, "return 5 + items.map(x => x + 1)[0];");
1858    }
1859
1860    #[test]
1861    fn substitute_params_noop_same_name() {
1862        let body = "const sum = x + y;";
1863        let mut map = std::collections::HashMap::new();
1864        map.insert("x".to_string(), "x".to_string());
1865        let result = substitute_params(body, &map, LangId::TypeScript);
1866        assert_eq!(result, "const sum = x + y;");
1867    }
1868
1869    #[test]
1870    fn substitute_params_empty_map() {
1871        let body = "const sum = a + b;";
1872        let map = std::collections::HashMap::new();
1873        let result = substitute_params(body, &map, LangId::TypeScript);
1874        assert_eq!(result, body);
1875    }
1876}