Skip to main content

aft/
extract.rs

1//! Shared extraction utilities for `extract_function` (and future `inline_symbol`).
2//!
3//! Provides:
4//! - `detect_free_variables` — classify identifier references in a byte range
5//! - `detect_return_value` — infer what the extracted function should return
6//! - `generate_extracted_function` — produce function text for TS/JS or Python
7
8use std::collections::HashSet;
9
10use tree_sitter::{Node, Tree};
11
12use crate::indent::IndentStyle;
13use crate::parser::{grammar_for, node_text, LangId};
14
15// ---------------------------------------------------------------------------
16// Free variable detection
17// ---------------------------------------------------------------------------
18
19/// Classification result for free variables in a selected byte range.
20#[derive(Debug)]
21pub struct FreeVariableResult {
22    /// Identifiers declared in an enclosing function scope that the range
23    /// references — these become parameters of the extracted function.
24    pub parameters: Vec<String>,
25    /// Whether `this` (JS/TS) or `self` (Python) appears in the range.
26    pub has_this_or_self: bool,
27}
28
29/// Walk the AST for a byte range and classify every identifier reference.
30///
31/// Classification rules:
32/// 1. Declared-in-range (local variable) → skip
33/// 2. Declared in enclosing function scope → parameter
34/// 3. Module-level or import → skip
35/// 4. `this` / `self` keyword → flag (error for extract_function)
36/// 5. `property_identifier` / `field_identifier` on the right side of `.` → skip
37///    (these are member accesses, not free variables)
38pub fn detect_free_variables(
39    source: &str,
40    tree: &Tree,
41    start_byte: usize,
42    end_byte: usize,
43    lang: LangId,
44) -> FreeVariableResult {
45    let root = tree.root_node();
46
47    // 1. Collect all identifiers referenced in the range (excluding property access)
48    let mut references: Vec<String> = Vec::new();
49    collect_identifier_refs(&root, source, start_byte, end_byte, lang, &mut references);
50
51    // 2. Collect declarations within the range (these are locals, not free)
52    let mut local_decls: HashSet<String> = HashSet::new();
53    collect_declarations_in_range(&root, source, start_byte, end_byte, lang, &mut local_decls);
54
55    // 3. Find the enclosing function scope boundary
56    let enclosing_fn = find_enclosing_function(&root, start_byte, lang);
57
58    // 4. Collect declarations in the enclosing function but outside the range
59    let mut enclosing_decls: HashSet<String> = HashSet::new();
60    if let Some(fn_node) = enclosing_fn {
61        collect_declarations_in_range(
62            &fn_node,
63            source,
64            fn_node.start_byte(),
65            start_byte, // only before the range
66            lang,
67            &mut enclosing_decls,
68        );
69        // Also collect function parameters
70        collect_function_params(&fn_node, source, lang, &mut enclosing_decls);
71    }
72
73    // 5. Check for this/self
74    let has_this_or_self = check_this_or_self(&root, source, start_byte, end_byte, lang);
75
76    // 6. Classify: a reference is a parameter if it's not a local decl,
77    //    IS declared in the enclosing function scope, and is not module-level.
78    let mut seen = HashSet::new();
79    let mut parameters = Vec::new();
80    for name in &references {
81        if local_decls.contains(name) {
82            continue;
83        }
84        if !seen.insert(name.clone()) {
85            continue; // dedup
86        }
87        if enclosing_decls.contains(name) {
88            parameters.push(name.clone());
89        }
90        // If not in enclosing_decls, it's module-level or global — skip
91    }
92
93    FreeVariableResult {
94        parameters,
95        has_this_or_self,
96    }
97}
98
99/// Collect all `identifier` nodes in [start_byte, end_byte) that are genuine
100/// references (not property accesses on the right side of `.`).
101fn collect_identifier_refs(
102    node: &Node,
103    source: &str,
104    start_byte: usize,
105    end_byte: usize,
106    lang: LangId,
107    out: &mut Vec<String>,
108) {
109    // Skip nodes entirely outside the range
110    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
111        return;
112    }
113
114    let kind = node.kind();
115
116    // An `identifier` node in the range that is NOT a property/field access
117    if kind == "identifier" && node.start_byte() >= start_byte && node.end_byte() <= end_byte {
118        // Check parent: if parent is member_expression and this is the "property" field,
119        // it's a property access, not a free variable.
120        if !is_property_access(node, lang) {
121            let name = node_text(source, node).to_string();
122            // Skip language keywords that parse as identifiers
123            if !is_keyword(&name, lang) {
124                out.push(name);
125            }
126        }
127    }
128
129    // Recurse into children
130    let mut cursor = node.walk();
131    if cursor.goto_first_child() {
132        loop {
133            collect_identifier_refs(&cursor.node(), source, start_byte, end_byte, lang, out);
134            if !cursor.goto_next_sibling() {
135                break;
136            }
137        }
138    }
139}
140
141/// Check if an identifier node is a property access (right side of `.`).
142fn is_property_access(node: &Node, lang: LangId) -> bool {
143    // property_identifier and field_identifier are separate node kinds in TS/JS,
144    // so they won't even reach here. But for Python `attribute` access the child
145    // is still `identifier`.
146    if let Some(parent) = node.parent() {
147        let pk = parent.kind();
148        match lang {
149            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
150                // member_expression: object.property — the "property" child
151                if pk == "member_expression" {
152                    if let Some(prop) = parent.child_by_field_name("property") {
153                        return prop.id() == node.id();
154                    }
155                }
156            }
157            LangId::Python => {
158                // attribute: object.attr — the "attribute" child
159                if pk == "attribute" {
160                    if let Some(attr) = parent.child_by_field_name("attribute") {
161                        return attr.id() == node.id();
162                    }
163                }
164            }
165            _ => {}
166        }
167    }
168    false
169}
170
171/// Identifiers that are language keywords and should not be treated as free variables.
172fn is_keyword(name: &str, lang: LangId) -> bool {
173    match lang {
174        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => matches!(
175            name,
176            "undefined" | "null" | "true" | "false" | "NaN" | "Infinity" | "console" | "require"
177        ),
178        LangId::Python => matches!(
179            name,
180            "None"
181                | "True"
182                | "False"
183                | "print"
184                | "len"
185                | "range"
186                | "str"
187                | "int"
188                | "float"
189                | "list"
190                | "dict"
191                | "set"
192                | "tuple"
193                | "type"
194                | "super"
195                | "isinstance"
196                | "enumerate"
197                | "zip"
198                | "map"
199                | "filter"
200                | "sorted"
201                | "reversed"
202                | "any"
203                | "all"
204                | "min"
205                | "max"
206                | "sum"
207                | "abs"
208                | "open"
209                | "input"
210                | "format"
211                | "hasattr"
212                | "getattr"
213                | "setattr"
214                | "delattr"
215                | "repr"
216                | "iter"
217                | "next"
218                | "ValueError"
219                | "TypeError"
220                | "KeyError"
221                | "IndexError"
222                | "Exception"
223                | "RuntimeError"
224                | "StopIteration"
225                | "NotImplementedError"
226                | "AttributeError"
227                | "ImportError"
228                | "OSError"
229                | "IOError"
230                | "FileNotFoundError"
231        ),
232        _ => false,
233    }
234}
235
236/// Collect names declared (via variable declarations) within a byte range.
237fn collect_declarations_in_range(
238    node: &Node,
239    source: &str,
240    start_byte: usize,
241    end_byte: usize,
242    lang: LangId,
243    out: &mut HashSet<String>,
244) {
245    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
246        return;
247    }
248
249    let kind = node.kind();
250
251    match lang {
252        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
253            // variable_declarator has a "name" child that is the declared identifier
254            if kind == "variable_declarator" {
255                if let Some(name_node) = node.child_by_field_name("name") {
256                    if name_node.start_byte() >= start_byte && name_node.end_byte() <= end_byte {
257                        out.insert(node_text(source, &name_node).to_string());
258                    }
259                }
260            }
261        }
262        LangId::Python => {
263            // assignment: left side
264            if kind == "assignment" {
265                if let Some(left) = node.child_by_field_name("left") {
266                    if left.kind() == "identifier"
267                        && left.start_byte() >= start_byte
268                        && left.end_byte() <= end_byte
269                    {
270                        out.insert(node_text(source, &left).to_string());
271                    }
272                }
273            }
274        }
275        _ => {}
276    }
277
278    // Recurse
279    let mut cursor = node.walk();
280    if cursor.goto_first_child() {
281        loop {
282            collect_declarations_in_range(&cursor.node(), source, start_byte, end_byte, lang, out);
283            if !cursor.goto_next_sibling() {
284                break;
285            }
286        }
287    }
288}
289
290/// Collect parameter names from a function node.
291fn collect_function_params(fn_node: &Node, source: &str, lang: LangId, out: &mut HashSet<String>) {
292    match lang {
293        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
294            // function_declaration / arrow_function have "parameters" field
295            if let Some(params) = fn_node.child_by_field_name("parameters") {
296                collect_param_identifiers(&params, source, lang, out);
297            }
298            // For arrow functions inside lexical_declaration, drill down
299            let mut cursor = fn_node.walk();
300            if cursor.goto_first_child() {
301                loop {
302                    let child = cursor.node();
303                    if child.kind() == "variable_declarator" {
304                        if let Some(value) = child.child_by_field_name("value") {
305                            if value.kind() == "arrow_function" {
306                                if let Some(params) = value.child_by_field_name("parameters") {
307                                    collect_param_identifiers(&params, source, lang, out);
308                                }
309                            }
310                        }
311                    }
312                    if !cursor.goto_next_sibling() {
313                        break;
314                    }
315                }
316            }
317        }
318        LangId::Python => {
319            if let Some(params) = fn_node.child_by_field_name("parameters") {
320                collect_param_identifiers(&params, source, lang, out);
321            }
322        }
323        _ => {}
324    }
325}
326
327/// Walk a parameter list node and collect identifier names.
328fn collect_param_identifiers(
329    params_node: &Node,
330    source: &str,
331    lang: LangId,
332    out: &mut HashSet<String>,
333) {
334    let mut cursor = params_node.walk();
335    if cursor.goto_first_child() {
336        loop {
337            let child = cursor.node();
338            match lang {
339                LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
340                    // required_parameter, optional_parameter have pattern child,
341                    // or directly identifier
342                    if child.kind() == "required_parameter" || child.kind() == "optional_parameter"
343                    {
344                        if let Some(pattern) = child.child_by_field_name("pattern") {
345                            if pattern.kind() == "identifier" {
346                                out.insert(node_text(source, &pattern).to_string());
347                            }
348                        }
349                    } else if child.kind() == "identifier" {
350                        out.insert(node_text(source, &child).to_string());
351                    }
352                }
353                LangId::Python => {
354                    if child.kind() == "identifier" {
355                        let name = node_text(source, &child).to_string();
356                        // Skip `self` parameter
357                        if name != "self" {
358                            out.insert(name);
359                        }
360                    }
361                }
362                _ => {}
363            }
364            if !cursor.goto_next_sibling() {
365                break;
366            }
367        }
368    }
369}
370
371/// Find the innermost function node that encloses `byte_pos`.
372fn find_enclosing_function<'a>(root: &Node<'a>, byte_pos: usize, lang: LangId) -> Option<Node<'a>> {
373    let fn_kinds: &[&str] = match lang {
374        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
375            &[
376                "function_declaration",
377                "method_definition",
378                "arrow_function",
379                "lexical_declaration", // for const foo = () => ...
380            ]
381        }
382        LangId::Python => &["function_definition"],
383        _ => &[],
384    };
385
386    find_deepest_ancestor(root, byte_pos, fn_kinds)
387}
388
389/// Find the deepest ancestor node (of the given kinds) that contains `byte_pos`.
390fn find_deepest_ancestor<'a>(node: &Node<'a>, byte_pos: usize, kinds: &[&str]) -> Option<Node<'a>> {
391    let mut result: Option<Node<'a>> = None;
392    if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
393        result = Some(*node);
394    }
395
396    let child_count = node.child_count();
397    for i in 0..child_count {
398        if let Some(child) = node.child(i as u32) {
399            if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
400                if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
401                    result = Some(deeper);
402                }
403            }
404        }
405    }
406
407    result
408}
409
410/// Check if `this` (JS/TS) or `self` (Python) appears in the byte range.
411fn check_this_or_self(
412    node: &Node,
413    source: &str,
414    start_byte: usize,
415    end_byte: usize,
416    lang: LangId,
417) -> bool {
418    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
419        return false;
420    }
421
422    if node.start_byte() >= start_byte && node.end_byte() <= end_byte {
423        let kind = node.kind();
424        match lang {
425            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
426                if kind == "this" {
427                    return true;
428                }
429            }
430            LangId::Python => {
431                if kind == "identifier" && node_text(source, node) == "self" {
432                    // Check it's not a parameter declaration (like `def foo(self):`)
433                    if let Some(parent) = node.parent() {
434                        if parent.kind() == "parameters" {
435                            return false;
436                        }
437                    }
438                    return true;
439                }
440            }
441            _ => {}
442        }
443    }
444
445    let mut cursor = node.walk();
446    if cursor.goto_first_child() {
447        loop {
448            if check_this_or_self(&cursor.node(), source, start_byte, end_byte, lang) {
449                return true;
450            }
451            if !cursor.goto_next_sibling() {
452                break;
453            }
454        }
455    }
456
457    false
458}
459
460// ---------------------------------------------------------------------------
461// Return value detection
462// ---------------------------------------------------------------------------
463
464/// What the extracted function should return.
465#[derive(Debug, Clone, PartialEq, Eq)]
466pub enum ReturnKind {
467    /// The range contains an explicit `return expr;` → use that expression
468    Expression(String),
469    /// A variable declared in-range is used after the range in the enclosing function
470    Variable(String),
471    /// Nothing needs to be returned (void)
472    Void,
473}
474
475/// Detect what the extracted code range should return.
476///
477/// 1. If there's an explicit `return` statement in the range, use its expression.
478/// 2. If a variable declared in-range is referenced after the range (but within
479///    the enclosing function), that variable becomes the return value.
480/// 3. Otherwise, void.
481pub fn detect_return_value(
482    source: &str,
483    tree: &Tree,
484    start_byte: usize,
485    end_byte: usize,
486    enclosing_fn_end_byte: Option<usize>,
487    lang: LangId,
488) -> ReturnKind {
489    let root = tree.root_node();
490
491    // Check for explicit return statements in the range
492    if let Some(expr) = find_return_in_range(&root, source, start_byte, end_byte) {
493        return ReturnKind::Expression(expr);
494    }
495
496    // Collect declarations in the range
497    let mut in_range_decls: HashSet<String> = HashSet::new();
498    collect_declarations_in_range(
499        &root,
500        source,
501        start_byte,
502        end_byte,
503        lang,
504        &mut in_range_decls,
505    );
506
507    // Check if any in-range declaration is used after the range in the enclosing fn
508    if let Some(fn_end) = enclosing_fn_end_byte {
509        let post_range_end = fn_end.min(source.len());
510        if end_byte < post_range_end {
511            let mut post_refs: Vec<String> = Vec::new();
512            collect_identifier_refs(
513                &root,
514                source,
515                end_byte,
516                post_range_end,
517                lang,
518                &mut post_refs,
519            );
520
521            for decl in &in_range_decls {
522                if post_refs.contains(decl) {
523                    return ReturnKind::Variable(decl.clone());
524                }
525            }
526        }
527    }
528
529    ReturnKind::Void
530}
531
532/// Find an explicit `return` statement in the byte range and return its expression text.
533fn find_return_in_range(
534    node: &Node,
535    source: &str,
536    start_byte: usize,
537    end_byte: usize,
538) -> Option<String> {
539    if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
540        return None;
541    }
542
543    if node.kind() == "return_statement"
544        && node.start_byte() >= start_byte
545        && node.end_byte() <= end_byte
546    {
547        // Get the expression after "return"
548        let text = node_text(source, node).trim().to_string();
549        let expr = text
550            .strip_prefix("return")
551            .unwrap_or("")
552            .trim()
553            .trim_end_matches(';')
554            .trim()
555            .to_string();
556        if !expr.is_empty() {
557            return Some(expr);
558        }
559    }
560
561    let mut cursor = node.walk();
562    if cursor.goto_first_child() {
563        loop {
564            if let Some(result) = find_return_in_range(&cursor.node(), source, start_byte, end_byte)
565            {
566                return Some(result);
567            }
568            if !cursor.goto_next_sibling() {
569                break;
570            }
571        }
572    }
573
574    None
575}
576
577// ---------------------------------------------------------------------------
578// Function generation
579// ---------------------------------------------------------------------------
580
581/// Generate the text for an extracted function.
582pub fn generate_extracted_function(
583    name: &str,
584    params: &[String],
585    return_kind: &ReturnKind,
586    body_text: &str,
587    base_indent: &str,
588    lang: LangId,
589    indent_style: IndentStyle,
590) -> String {
591    let indent_unit = indent_style.as_str();
592
593    match lang {
594        LangId::TypeScript | LangId::Tsx | LangId::JavaScript => generate_ts_function(
595            name,
596            params,
597            return_kind,
598            body_text,
599            base_indent,
600            indent_unit,
601        ),
602        LangId::Python => generate_py_function(
603            name,
604            params,
605            return_kind,
606            body_text,
607            base_indent,
608            indent_unit,
609        ),
610        _ => {
611            // Shouldn't reach here due to language guard, but produce something reasonable
612            generate_ts_function(
613                name,
614                params,
615                return_kind,
616                body_text,
617                base_indent,
618                indent_unit,
619            )
620        }
621    }
622}
623
624fn generate_ts_function(
625    name: &str,
626    params: &[String],
627    return_kind: &ReturnKind,
628    body_text: &str,
629    base_indent: &str,
630    indent_unit: &str,
631) -> String {
632    let params_str = params.join(", ");
633    let mut lines = Vec::new();
634
635    lines.push(format!(
636        "{}function {}({}) {{",
637        base_indent, name, params_str
638    ));
639
640    // Re-indent body to be inside the function
641    for line in body_text.lines() {
642        if line.trim().is_empty() {
643            lines.push(String::new());
644        } else {
645            lines.push(format!("{}{}{}", base_indent, indent_unit, line.trim()));
646        }
647    }
648
649    // Add return statement if needed
650    match return_kind {
651        ReturnKind::Variable(var) => {
652            lines.push(format!("{}{}return {};", base_indent, indent_unit, var));
653        }
654        ReturnKind::Expression(_) => {
655            // The return is already in the body text
656        }
657        ReturnKind::Void => {}
658    }
659
660    lines.push(format!("{}}}", base_indent));
661    lines.join("\n")
662}
663
664fn generate_py_function(
665    name: &str,
666    params: &[String],
667    return_kind: &ReturnKind,
668    body_text: &str,
669    base_indent: &str,
670    indent_unit: &str,
671) -> String {
672    let params_str = params.join(", ");
673    let mut lines = Vec::new();
674
675    lines.push(format!("{}def {}({}):", base_indent, name, params_str));
676
677    // Re-indent body
678    for line in body_text.lines() {
679        if line.trim().is_empty() {
680            lines.push(String::new());
681        } else {
682            lines.push(format!("{}{}{}", base_indent, indent_unit, line.trim()));
683        }
684    }
685
686    // Add return statement if needed
687    match return_kind {
688        ReturnKind::Variable(var) => {
689            lines.push(format!("{}{}return {}", base_indent, indent_unit, var));
690        }
691        ReturnKind::Expression(_) => {
692            // Already in body
693        }
694        ReturnKind::Void => {}
695    }
696
697    lines.join("\n")
698}
699
700/// Generate the call site text that replaces the extracted range.
701pub fn generate_call_site(
702    name: &str,
703    params: &[String],
704    return_kind: &ReturnKind,
705    indent: &str,
706    lang: LangId,
707) -> String {
708    let args_str = params.join(", ");
709
710    match return_kind {
711        ReturnKind::Variable(var) => match lang {
712            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
713                format!("{}const {} = {}({});", indent, var, name, args_str)
714            }
715            LangId::Python => {
716                format!("{}{} = {}({})", indent, var, name, args_str)
717            }
718            _ => format!("{}const {} = {}({});", indent, var, name, args_str),
719        },
720        ReturnKind::Expression(_expr) => match lang {
721            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
722                format!("{}return {}({});", indent, name, args_str)
723            }
724            LangId::Python => {
725                format!("{}return {}({})", indent, name, args_str)
726            }
727            _ => format!("{}return {}({});", indent, name, args_str),
728        },
729        ReturnKind::Void => match lang {
730            LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
731                format!("{}{}({});", indent, name, args_str)
732            }
733            LangId::Python => {
734                format!("{}{}({})", indent, name, args_str)
735            }
736            _ => format!("{}{}({});", indent, name, args_str),
737        },
738    }
739}
740
741// ---------------------------------------------------------------------------
742// Inline symbol utilities
743// ---------------------------------------------------------------------------
744
745/// A detected scope conflict when inlining a function body at a call site.
746#[derive(Debug, Clone, PartialEq, Eq)]
747pub struct ScopeConflict {
748    /// The variable name that conflicts.
749    pub name: String,
750    /// Suggested alternative name to avoid the conflict.
751    pub suggested: String,
752}
753
754/// Detect scope conflicts between the call site scope and the function body
755/// being inlined.
756///
757/// Collects all variable declarations at the call site's scope level
758/// (surrounding function body), then checks for collisions with variables
759/// declared in `body_text`.
760pub fn detect_scope_conflicts(
761    source: &str,
762    tree: &Tree,
763    insertion_byte: usize,
764    _param_names: &[String],
765    body_text: &str,
766    lang: LangId,
767) -> Vec<ScopeConflict> {
768    let root = tree.root_node();
769
770    // 1. Find the enclosing function at the call site
771    let enclosing_fn = find_enclosing_function(&root, insertion_byte, lang);
772
773    // 2. Collect all declarations in the call site's scope
774    let mut scope_decls: HashSet<String> = HashSet::new();
775    if let Some(fn_node) = enclosing_fn {
776        collect_declarations_in_range(
777            &fn_node,
778            source,
779            fn_node.start_byte(),
780            fn_node.end_byte(),
781            lang,
782            &mut scope_decls,
783        );
784        collect_function_params(&fn_node, source, lang, &mut scope_decls);
785    } else {
786        // Module-level: collect all top-level declarations
787        collect_declarations_in_range(
788            &root,
789            source,
790            root.start_byte(),
791            root.end_byte(),
792            lang,
793            &mut scope_decls,
794        );
795    }
796
797    // 3. Collect declarations in the body being inlined
798    let mut body_decls: HashSet<String> = HashSet::new();
799    let body_grammar = grammar_for(lang);
800    let mut body_parser = tree_sitter::Parser::new();
801    if body_parser.set_language(&body_grammar).is_ok() {
802        if let Some(body_tree) = body_parser.parse(body_text.as_bytes(), None) {
803            let body_root = body_tree.root_node();
804            collect_declarations_in_range(
805                &body_root,
806                body_text,
807                0,
808                body_text.len(),
809                lang,
810                &mut body_decls,
811            );
812        }
813    }
814
815    // 4. Find collisions
816    let mut conflicts = Vec::new();
817    for decl in &body_decls {
818        if scope_decls.contains(decl) {
819            conflicts.push(ScopeConflict {
820                name: decl.clone(),
821                suggested: format!("{}_inlined", decl),
822            });
823        }
824    }
825
826    // Sort for deterministic output
827    conflicts.sort_by(|a, b| a.name.cmp(&b.name));
828    conflicts
829}
830
831/// Validate that a function has at most one return statement (suitable for inlining).
832///
833/// - Arrow functions with expression bodies (no `return` keyword) → valid (single-return)
834/// - Functions with 0 returns (void) → valid
835/// - Functions with exactly 1 return → valid
836/// - Functions with >1 return → invalid, returns the count
837pub fn validate_single_return(
838    source: &str,
839    _tree: &Tree,
840    fn_node: &Node,
841    lang: LangId,
842) -> Result<(), usize> {
843    // Arrow functions with expression bodies are always single-return
844    if lang != LangId::Python && fn_node.kind() == "arrow_function" {
845        if let Some(body) = fn_node.child_by_field_name("body") {
846            if body.kind() != "statement_block" {
847                // Expression body — implicitly single-return
848                return Ok(());
849            }
850        }
851    }
852
853    let count = count_return_statements(fn_node, source);
854    if count > 1 {
855        Err(count)
856    } else {
857        Ok(())
858    }
859}
860
861/// Count `return_statement` nodes in a function body (non-recursive into nested functions).
862fn count_return_statements(node: &Node, source: &str) -> usize {
863    let _ = source;
864    let mut count = 0;
865
866    // Don't count returns in nested function bodies
867    let nested_fn_kinds = [
868        "function_declaration",
869        "function_definition",
870        "arrow_function",
871        "method_definition",
872    ];
873
874    let kind = node.kind();
875    if kind == "return_statement" {
876        return 1;
877    }
878
879    let child_count = node.child_count();
880    for i in 0..child_count {
881        if let Some(child) = node.child(i as u32) {
882            // Skip nested function definitions
883            if nested_fn_kinds.contains(&child.kind()) {
884                continue;
885            }
886            count += count_return_statements(&child, source);
887        }
888    }
889
890    count
891}
892
893/// Substitute parameter names with argument expressions in a function body.
894///
895/// Uses tree-sitter to find `identifier` nodes matching parameter names,
896/// replacing from end to start to preserve byte offsets. Only replaces
897/// whole-word matches (identifiers, not substrings).
898pub fn substitute_params(
899    body_text: &str,
900    param_to_arg: &std::collections::HashMap<String, String>,
901    lang: LangId,
902) -> String {
903    if param_to_arg.is_empty() {
904        return body_text.to_string();
905    }
906
907    let grammar = grammar_for(lang);
908    let mut parser = tree_sitter::Parser::new();
909    if parser.set_language(&grammar).is_err() {
910        return body_text.to_string();
911    }
912
913    let tree = match parser.parse(body_text.as_bytes(), None) {
914        Some(t) => t,
915        None => return body_text.to_string(),
916    };
917
918    // Collect all identifier nodes that match parameter names
919    let mut replacements: Vec<(usize, usize, String)> = Vec::new();
920    collect_param_replacements(
921        &tree.root_node(),
922        body_text,
923        param_to_arg,
924        lang,
925        &mut replacements,
926    );
927
928    // Sort by start position descending so replacements don't shift offsets
929    replacements.sort_by(|a, b| b.0.cmp(&a.0));
930
931    let mut result = body_text.to_string();
932    for (start, end, replacement) in replacements {
933        result = format!("{}{}{}", &result[..start], replacement, &result[end..]);
934    }
935
936    result
937}
938
939/// Collect identifier nodes that match parameter names for substitution.
940fn collect_param_replacements(
941    node: &Node,
942    source: &str,
943    param_to_arg: &std::collections::HashMap<String, String>,
944    lang: LangId,
945    out: &mut Vec<(usize, usize, String)>,
946) {
947    let kind = node.kind();
948
949    if kind == "identifier" {
950        // Check it's not a property access
951        if !is_property_access(node, lang) {
952            let name = node_text(source, node);
953            if let Some(replacement) = param_to_arg.get(name) {
954                out.push((node.start_byte(), node.end_byte(), replacement.clone()));
955            }
956        }
957    }
958
959    // Also handle Python-specific name node
960    // Recurse into children
961    let child_count = node.child_count();
962    for i in 0..child_count {
963        if let Some(child) = node.child(i as u32) {
964            collect_param_replacements(&child, source, param_to_arg, lang, out);
965        }
966    }
967}
968
969// ---------------------------------------------------------------------------
970// Tests
971// ---------------------------------------------------------------------------
972
973#[cfg(test)]
974mod tests {
975    use super::*;
976    use crate::parser::grammar_for;
977    use std::path::PathBuf;
978    use tree_sitter::Parser;
979
980    fn fixture_path(name: &str) -> PathBuf {
981        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
982            .join("tests")
983            .join("fixtures")
984            .join("extract_function")
985            .join(name)
986    }
987
988    fn parse_source(source: &str, lang: LangId) -> Tree {
989        let grammar = grammar_for(lang);
990        let mut parser = Parser::new();
991        parser.set_language(&grammar).unwrap();
992        parser.parse(source.as_bytes(), None).unwrap()
993    }
994
995    // --- Free variable detection: simple identifiers ---
996
997    #[test]
998    fn free_vars_detects_enclosing_function_params() {
999        // `items` and `prefix` are function params → should be detected as free variables
1000        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1001        let tree = parse_source(&source, LangId::TypeScript);
1002
1003        // Lines 5-8 (0-indexed): the body of processData that uses `items` and `prefix`
1004        // "  const filtered = items.filter(item => item.length > 0);"
1005        // "  const mapped = filtered.map(item => prefix + item);"
1006        // These lines reference `items` and `prefix` from the function params.
1007        let line5_start = crate::edit::line_col_to_byte(&source, 5, 0);
1008        let line6_end = crate::edit::line_col_to_byte(&source, 7, 0);
1009
1010        let result =
1011            detect_free_variables(&source, &tree, line5_start, line6_end, LangId::TypeScript);
1012        assert!(
1013            result.parameters.contains(&"items".to_string()),
1014            "should detect 'items' as parameter, got: {:?}",
1015            result.parameters
1016        );
1017        assert!(
1018            result.parameters.contains(&"prefix".to_string()),
1019            "should detect 'prefix' as parameter, got: {:?}",
1020            result.parameters
1021        );
1022        assert!(!result.has_this_or_self);
1023    }
1024
1025    // --- Property access filtering ---
1026
1027    #[test]
1028    fn free_vars_filters_property_identifiers() {
1029        // In `items.filter(...)`, `filter` should NOT be a free variable.
1030        // In `item.length`, `length` should NOT be a free variable.
1031        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1032        let tree = parse_source(&source, LangId::TypeScript);
1033
1034        let line5_start = crate::edit::line_col_to_byte(&source, 5, 0);
1035        let line6_end = crate::edit::line_col_to_byte(&source, 7, 0);
1036
1037        let result =
1038            detect_free_variables(&source, &tree, line5_start, line6_end, LangId::TypeScript);
1039        // "filter", "map", "length" should NOT appear
1040        assert!(
1041            !result.parameters.contains(&"filter".to_string()),
1042            "property 'filter' should not be a free variable"
1043        );
1044        assert!(
1045            !result.parameters.contains(&"length".to_string()),
1046            "property 'length' should not be a free variable"
1047        );
1048        assert!(
1049            !result.parameters.contains(&"map".to_string()),
1050            "property 'map' should not be a free variable"
1051        );
1052    }
1053
1054    // --- Module-level vs function-level classification ---
1055
1056    #[test]
1057    fn free_vars_skips_module_level_refs() {
1058        // `BASE_URL` is module-level → should NOT be a parameter
1059        // `console` is a global → should NOT be a parameter
1060        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1061        let tree = parse_source(&source, LangId::TypeScript);
1062
1063        // processData body: lines 5-9
1064        let start = crate::edit::line_col_to_byte(&source, 5, 0);
1065        let end = crate::edit::line_col_to_byte(&source, 10, 0);
1066
1067        let result = detect_free_variables(&source, &tree, start, end, LangId::TypeScript);
1068        assert!(
1069            !result.parameters.contains(&"BASE_URL".to_string()),
1070            "module-level 'BASE_URL' should not be a parameter, got: {:?}",
1071            result.parameters
1072        );
1073        assert!(
1074            !result.parameters.contains(&"console".to_string()),
1075            "'console' should not be a parameter, got: {:?}",
1076            result.parameters
1077        );
1078    }
1079
1080    // --- this/self detection ---
1081
1082    #[test]
1083    fn free_vars_detects_this_in_ts() {
1084        let source = std::fs::read_to_string(fixture_path("sample_this.ts")).unwrap();
1085        let tree = parse_source(&source, LangId::TypeScript);
1086
1087        // getUser method body lines 4-6 contain `this.users.get(key)`
1088        let start = crate::edit::line_col_to_byte(&source, 4, 0);
1089        let end = crate::edit::line_col_to_byte(&source, 7, 0);
1090
1091        let result = detect_free_variables(&source, &tree, start, end, LangId::TypeScript);
1092        assert!(result.has_this_or_self, "should detect 'this' reference");
1093    }
1094
1095    #[test]
1096    fn free_vars_detects_self_in_python() {
1097        let source = r#"
1098class UserService:
1099    def get_user(self, id):
1100        key = id.lower()
1101        user = self.users.get(key)
1102        return user
1103"#;
1104        let tree = parse_source(source, LangId::Python);
1105
1106        // Lines 3-4 (0-indexed) contain `self.users.get(key)`
1107        let start = crate::edit::line_col_to_byte(source, 4, 0);
1108        let end = crate::edit::line_col_to_byte(source, 5, 0);
1109
1110        let result = detect_free_variables(source, &tree, start, end, LangId::Python);
1111        assert!(result.has_this_or_self, "should detect 'self' reference");
1112    }
1113
1114    // --- Return value detection ---
1115
1116    #[test]
1117    fn return_value_explicit_return() {
1118        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1119        let tree = parse_source(&source, LangId::TypeScript);
1120
1121        // simpleHelper: lines 13-16 — contains "return added;"
1122        let start = crate::edit::line_col_to_byte(&source, 14, 0);
1123        let end = crate::edit::line_col_to_byte(&source, 17, 0);
1124
1125        let result = detect_return_value(&source, &tree, start, end, None, LangId::TypeScript);
1126        assert_eq!(result, ReturnKind::Expression("added".to_string()));
1127    }
1128
1129    #[test]
1130    fn return_value_post_range_usage() {
1131        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1132        let tree = parse_source(&source, LangId::TypeScript);
1133
1134        // processData lines 5-7: declares `filtered`, `mapped`
1135        // Lines 7-9 use `result` (which comes after mapped), but line 7 declares `result`
1136        // Let's extract lines 5-6 only: `filtered` and `mapped` are declared
1137        // and `filtered` is used on line 6, `mapped` is used on line 7
1138        let start = crate::edit::line_col_to_byte(&source, 5, 0);
1139        let end = crate::edit::line_col_to_byte(&source, 6, 0);
1140
1141        // Enclosing function ends around line 10
1142        let fn_end = crate::edit::line_col_to_byte(&source, 10, 0);
1143
1144        let result =
1145            detect_return_value(&source, &tree, start, end, Some(fn_end), LangId::TypeScript);
1146        // `filtered` is declared in-range and used after the range
1147        assert_eq!(result, ReturnKind::Variable("filtered".to_string()));
1148    }
1149
1150    #[test]
1151    fn return_value_void() {
1152        let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
1153        let tree = parse_source(&source, LangId::TypeScript);
1154
1155        // voidWork lines 20-21: no return, `greeting` is only used within
1156        let start = crate::edit::line_col_to_byte(&source, 20, 0);
1157        let end = crate::edit::line_col_to_byte(&source, 22, 0);
1158
1159        let result = detect_return_value(
1160            &source,
1161            &tree,
1162            start,
1163            end,
1164            Some(crate::edit::line_col_to_byte(&source, 23, 0)),
1165            LangId::TypeScript,
1166        );
1167        assert_eq!(result, ReturnKind::Void);
1168    }
1169
1170    // --- Function generation ---
1171
1172    #[test]
1173    fn generate_ts_function_with_params() {
1174        let body = "const doubled = x * 2;\nconst added = doubled + 10;";
1175        let result = generate_extracted_function(
1176            "compute",
1177            &["x".to_string()],
1178            &ReturnKind::Variable("added".to_string()),
1179            body,
1180            "",
1181            LangId::TypeScript,
1182            IndentStyle::Spaces(2),
1183        );
1184        assert!(result.contains("function compute(x)"));
1185        assert!(result.contains("return added;"));
1186        assert!(result.contains("}"));
1187    }
1188
1189    #[test]
1190    fn generate_py_function_with_params() {
1191        let body = "doubled = x * 2\nadded = doubled + 10";
1192        let result = generate_extracted_function(
1193            "compute",
1194            &["x".to_string()],
1195            &ReturnKind::Variable("added".to_string()),
1196            body,
1197            "",
1198            LangId::Python,
1199            IndentStyle::Spaces(4),
1200        );
1201        assert!(result.contains("def compute(x):"));
1202        assert!(result.contains("return added"));
1203    }
1204
1205    #[test]
1206    fn generate_call_site_with_return_var() {
1207        let call = generate_call_site(
1208            "compute",
1209            &["x".to_string()],
1210            &ReturnKind::Variable("result".to_string()),
1211            "  ",
1212            LangId::TypeScript,
1213        );
1214        assert_eq!(call, "  const result = compute(x);");
1215    }
1216
1217    #[test]
1218    fn generate_call_site_void() {
1219        let call = generate_call_site(
1220            "doWork",
1221            &["a".to_string(), "b".to_string()],
1222            &ReturnKind::Void,
1223            "  ",
1224            LangId::TypeScript,
1225        );
1226        assert_eq!(call, "  doWork(a, b);");
1227    }
1228
1229    #[test]
1230    fn generate_call_site_return_expression() {
1231        let call = generate_call_site(
1232            "compute",
1233            &["x".to_string()],
1234            &ReturnKind::Expression("x * 2".to_string()),
1235            "  ",
1236            LangId::TypeScript,
1237        );
1238        assert_eq!(call, "  return compute(x);");
1239    }
1240
1241    // --- Python free variables ---
1242
1243    #[test]
1244    fn free_vars_python_function_params() {
1245        let source = std::fs::read_to_string(fixture_path("sample.py")).unwrap();
1246        let tree = parse_source(&source, LangId::Python);
1247
1248        // process_data body: lines 5-8 reference `items` and `prefix`
1249        let start = crate::edit::line_col_to_byte(&source, 5, 0);
1250        let end = crate::edit::line_col_to_byte(&source, 7, 0);
1251
1252        let result = detect_free_variables(&source, &tree, start, end, LangId::Python);
1253        assert!(
1254            result.parameters.contains(&"items".to_string()),
1255            "should detect 'items': {:?}",
1256            result.parameters
1257        );
1258        assert!(
1259            result.parameters.contains(&"prefix".to_string()),
1260            "should detect 'prefix': {:?}",
1261            result.parameters
1262        );
1263        assert!(!result.has_this_or_self);
1264    }
1265
1266    // --- validate_single_return ---
1267
1268    #[test]
1269    fn validate_single_return_single() {
1270        let source =
1271            "function add(a: number, b: number): number {\n  const sum = a + b;\n  return sum;\n}";
1272        let tree = parse_source(source, LangId::TypeScript);
1273        let root = tree.root_node();
1274        let fn_node = root.child(0).unwrap(); // function_declaration
1275        assert!(validate_single_return(source, &tree, &fn_node, LangId::TypeScript).is_ok());
1276    }
1277
1278    #[test]
1279    fn validate_single_return_void() {
1280        let source = "function greet(name: string): void {\n  console.log(name);\n}";
1281        let tree = parse_source(source, LangId::TypeScript);
1282        let root = tree.root_node();
1283        let fn_node = root.child(0).unwrap();
1284        assert!(validate_single_return(source, &tree, &fn_node, LangId::TypeScript).is_ok());
1285    }
1286
1287    #[test]
1288    fn validate_single_return_expression_body() {
1289        let source = "const double = (n: number): number => n * 2;";
1290        let tree = parse_source(source, LangId::TypeScript);
1291        let root = tree.root_node();
1292        // lexical_declaration > variable_declarator > arrow_function
1293        let lex_decl = root.child(0).unwrap();
1294        let var_decl = lex_decl.child(1).unwrap(); // variable_declarator
1295        let arrow = var_decl.child_by_field_name("value").unwrap();
1296        assert_eq!(arrow.kind(), "arrow_function");
1297        assert!(validate_single_return(source, &tree, &arrow, LangId::TypeScript).is_ok());
1298    }
1299
1300    #[test]
1301    fn validate_single_return_multiple() {
1302        let source = "function abs(x: number): number {\n  if (x > 0) {\n    return x;\n  }\n  return -x;\n}";
1303        let tree = parse_source(source, LangId::TypeScript);
1304        let root = tree.root_node();
1305        let fn_node = root.child(0).unwrap();
1306        let result = validate_single_return(source, &tree, &fn_node, LangId::TypeScript);
1307        assert!(result.is_err());
1308        assert_eq!(result.unwrap_err(), 2);
1309    }
1310
1311    // --- detect_scope_conflicts ---
1312
1313    #[test]
1314    fn scope_conflicts_none() {
1315        // No overlap between call site scope and body vars
1316        let source = "function main() {\n  const x = 10;\n  const y = add(x, 5);\n}";
1317        let tree = parse_source(source, LangId::TypeScript);
1318        let body_text = "const sum = a + b;";
1319        let call_byte = crate::edit::line_col_to_byte(source, 2, 0);
1320        let conflicts =
1321            detect_scope_conflicts(source, &tree, call_byte, &[], body_text, LangId::TypeScript);
1322        assert!(
1323            conflicts.is_empty(),
1324            "expected no conflicts, got: {:?}",
1325            conflicts
1326        );
1327    }
1328
1329    #[test]
1330    fn scope_conflicts_detected() {
1331        // `temp` exists at call site and inside body
1332        let source = "function main() {\n  const temp = 99;\n  const result = compute(5);\n}";
1333        let tree = parse_source(source, LangId::TypeScript);
1334        let body_text = "const temp = x * 2;\nconst result2 = temp + 10;";
1335        let call_byte = crate::edit::line_col_to_byte(source, 2, 0);
1336        let conflicts =
1337            detect_scope_conflicts(source, &tree, call_byte, &[], body_text, LangId::TypeScript);
1338        assert!(!conflicts.is_empty(), "expected conflict for 'temp'");
1339        assert!(
1340            conflicts.iter().any(|c| c.name == "temp"),
1341            "conflicts: {:?}",
1342            conflicts
1343        );
1344        assert!(
1345            conflicts.iter().any(|c| c.suggested == "temp_inlined"),
1346            "should suggest temp_inlined"
1347        );
1348    }
1349
1350    // --- substitute_params ---
1351
1352    #[test]
1353    fn substitute_params_basic() {
1354        let body = "const sum = a + b;";
1355        let mut map = std::collections::HashMap::new();
1356        map.insert("a".to_string(), "x".to_string());
1357        map.insert("b".to_string(), "y".to_string());
1358        let result = substitute_params(body, &map, LangId::TypeScript);
1359        assert_eq!(result, "const sum = x + y;");
1360    }
1361
1362    #[test]
1363    fn substitute_params_whole_word() {
1364        // Should NOT replace `i` inside `items`
1365        let body = "const result = items.filter(i => i > 0);";
1366        let mut map = std::collections::HashMap::new();
1367        map.insert("i".to_string(), "index".to_string());
1368        let result = substitute_params(body, &map, LangId::TypeScript);
1369        // `items` should be untouched, but the arrow param `i` and its use `i` should be replaced
1370        assert!(
1371            !result.contains("items") || result.contains("items"),
1372            "items should be preserved"
1373        );
1374        // The `i` in `i => i > 0` should be replaced
1375        assert!(
1376            result.contains("index"),
1377            "should contain 'index': {}",
1378            result
1379        );
1380    }
1381
1382    #[test]
1383    fn substitute_params_noop_same_name() {
1384        let body = "const sum = x + y;";
1385        let mut map = std::collections::HashMap::new();
1386        map.insert("x".to_string(), "x".to_string());
1387        let result = substitute_params(body, &map, LangId::TypeScript);
1388        assert_eq!(result, "const sum = x + y;");
1389    }
1390
1391    #[test]
1392    fn substitute_params_empty_map() {
1393        let body = "const sum = a + b;";
1394        let map = std::collections::HashMap::new();
1395        let result = substitute_params(body, &map, LangId::TypeScript);
1396        assert_eq!(result, body);
1397    }
1398}