use std::collections::HashSet;
use tree_sitter::{Node, Tree};
use crate::indent::IndentStyle;
use crate::parser::{grammar_for, node_text, LangId};
#[derive(Debug)]
pub struct FreeVariableResult {
pub parameters: Vec<String>,
pub has_this_or_self: bool,
}
pub fn detect_free_variables(
source: &str,
tree: &Tree,
start_byte: usize,
end_byte: usize,
lang: LangId,
) -> FreeVariableResult {
let root = tree.root_node();
let mut references: Vec<String> = Vec::new();
collect_identifier_refs(&root, source, start_byte, end_byte, lang, &mut references);
let mut local_decls: HashSet<String> = HashSet::new();
collect_declarations_in_range(&root, source, start_byte, end_byte, lang, &mut local_decls);
let enclosing_fn = find_enclosing_function(&root, start_byte, lang);
let mut enclosing_decls: HashSet<String> = HashSet::new();
if let Some(fn_node) = enclosing_fn {
collect_declarations_in_range(
&fn_node,
source,
fn_node.start_byte(),
start_byte, lang,
&mut enclosing_decls,
);
collect_function_params(&fn_node, source, lang, &mut enclosing_decls);
}
let has_this_or_self = check_this_or_self(&root, source, start_byte, end_byte, lang);
let mut seen = HashSet::new();
let mut parameters = Vec::new();
for name in &references {
if local_decls.contains(name) {
continue;
}
if !seen.insert(name.clone()) {
continue; }
if enclosing_decls.contains(name) {
parameters.push(name.clone());
}
}
FreeVariableResult {
parameters,
has_this_or_self,
}
}
fn collect_identifier_refs(
node: &Node,
source: &str,
start_byte: usize,
end_byte: usize,
lang: LangId,
out: &mut Vec<String>,
) {
if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
return;
}
let kind = node.kind();
if kind == "identifier" && node.start_byte() >= start_byte && node.end_byte() <= end_byte {
if !is_property_access(node, lang) {
let name = node_text(source, node).to_string();
if !is_keyword(&name, lang) {
out.push(name);
}
}
}
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
collect_identifier_refs(&cursor.node(), source, start_byte, end_byte, lang, out);
if !cursor.goto_next_sibling() {
break;
}
}
}
}
fn is_property_access(node: &Node, lang: LangId) -> bool {
if let Some(parent) = node.parent() {
let pk = parent.kind();
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
if pk == "member_expression" {
if let Some(prop) = parent.child_by_field_name("property") {
return prop.id() == node.id();
}
}
}
LangId::Python => {
if pk == "attribute" {
if let Some(attr) = parent.child_by_field_name("attribute") {
return attr.id() == node.id();
}
}
}
_ => {}
}
}
false
}
fn is_keyword(name: &str, lang: LangId) -> bool {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => matches!(
name,
"undefined" | "null" | "true" | "false" | "NaN" | "Infinity" | "console" | "require"
),
LangId::Python => matches!(
name,
"None"
| "True"
| "False"
| "print"
| "len"
| "range"
| "str"
| "int"
| "float"
| "list"
| "dict"
| "set"
| "tuple"
| "type"
| "super"
| "isinstance"
| "enumerate"
| "zip"
| "map"
| "filter"
| "sorted"
| "reversed"
| "any"
| "all"
| "min"
| "max"
| "sum"
| "abs"
| "open"
| "input"
| "format"
| "hasattr"
| "getattr"
| "setattr"
| "delattr"
| "repr"
| "iter"
| "next"
| "ValueError"
| "TypeError"
| "KeyError"
| "IndexError"
| "Exception"
| "RuntimeError"
| "StopIteration"
| "NotImplementedError"
| "AttributeError"
| "ImportError"
| "OSError"
| "IOError"
| "FileNotFoundError"
),
_ => false,
}
}
fn collect_declarations_in_range(
node: &Node,
source: &str,
start_byte: usize,
end_byte: usize,
lang: LangId,
out: &mut HashSet<String>,
) {
if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
return;
}
let kind = node.kind();
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
if kind == "variable_declarator" {
if let Some(name_node) = node.child_by_field_name("name") {
if name_node.start_byte() >= start_byte && name_node.end_byte() <= end_byte {
out.insert(node_text(source, &name_node).to_string());
}
}
}
}
LangId::Python => {
if kind == "assignment" {
if let Some(left) = node.child_by_field_name("left") {
if left.kind() == "identifier"
&& left.start_byte() >= start_byte
&& left.end_byte() <= end_byte
{
out.insert(node_text(source, &left).to_string());
}
}
}
}
_ => {}
}
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
collect_declarations_in_range(&cursor.node(), source, start_byte, end_byte, lang, out);
if !cursor.goto_next_sibling() {
break;
}
}
}
}
fn collect_function_params(fn_node: &Node, source: &str, lang: LangId, out: &mut HashSet<String>) {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
if let Some(params) = fn_node.child_by_field_name("parameters") {
collect_param_identifiers(¶ms, source, lang, out);
}
let mut cursor = fn_node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if child.kind() == "variable_declarator" {
if let Some(value) = child.child_by_field_name("value") {
if value.kind() == "arrow_function" {
if let Some(params) = value.child_by_field_name("parameters") {
collect_param_identifiers(¶ms, source, lang, out);
}
}
}
}
if !cursor.goto_next_sibling() {
break;
}
}
}
}
LangId::Python => {
if let Some(params) = fn_node.child_by_field_name("parameters") {
collect_param_identifiers(¶ms, source, lang, out);
}
}
_ => {}
}
}
fn collect_param_identifiers(
params_node: &Node,
source: &str,
lang: LangId,
out: &mut HashSet<String>,
) {
let mut cursor = params_node.walk();
if cursor.goto_first_child() {
loop {
let child = cursor.node();
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
if child.kind() == "required_parameter" || child.kind() == "optional_parameter"
{
if let Some(pattern) = child.child_by_field_name("pattern") {
if pattern.kind() == "identifier" {
out.insert(node_text(source, &pattern).to_string());
}
}
} else if child.kind() == "identifier" {
out.insert(node_text(source, &child).to_string());
}
}
LangId::Python => {
if child.kind() == "identifier" {
let name = node_text(source, &child).to_string();
if name != "self" {
out.insert(name);
}
}
}
_ => {}
}
if !cursor.goto_next_sibling() {
break;
}
}
}
}
fn find_enclosing_function<'a>(root: &Node<'a>, byte_pos: usize, lang: LangId) -> Option<Node<'a>> {
let fn_kinds: &[&str] = match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
&[
"function_declaration",
"method_definition",
"arrow_function",
"lexical_declaration", ]
}
LangId::Python => &["function_definition"],
_ => &[],
};
find_deepest_ancestor(root, byte_pos, fn_kinds)
}
fn find_deepest_ancestor<'a>(node: &Node<'a>, byte_pos: usize, kinds: &[&str]) -> Option<Node<'a>> {
let mut result: Option<Node<'a>> = None;
if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
result = Some(*node);
}
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
if let Some(deeper) = find_deepest_ancestor(&child, byte_pos, kinds) {
result = Some(deeper);
}
}
}
}
result
}
fn check_this_or_self(
node: &Node,
source: &str,
start_byte: usize,
end_byte: usize,
lang: LangId,
) -> bool {
if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
return false;
}
if node.start_byte() >= start_byte && node.end_byte() <= end_byte {
let kind = node.kind();
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
if kind == "this" {
return true;
}
}
LangId::Python => {
if kind == "identifier" && node_text(source, node) == "self" {
if let Some(parent) = node.parent() {
if parent.kind() == "parameters" {
return false;
}
}
return true;
}
}
_ => {}
}
}
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
if check_this_or_self(&cursor.node(), source, start_byte, end_byte, lang) {
return true;
}
if !cursor.goto_next_sibling() {
break;
}
}
}
false
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ReturnKind {
Expression(String),
Variable(String),
Void,
}
pub fn detect_return_value(
source: &str,
tree: &Tree,
start_byte: usize,
end_byte: usize,
enclosing_fn_end_byte: Option<usize>,
lang: LangId,
) -> ReturnKind {
let root = tree.root_node();
if let Some(expr) = find_return_in_range(&root, source, start_byte, end_byte) {
return ReturnKind::Expression(expr);
}
let mut in_range_decls: HashSet<String> = HashSet::new();
collect_declarations_in_range(
&root,
source,
start_byte,
end_byte,
lang,
&mut in_range_decls,
);
if let Some(fn_end) = enclosing_fn_end_byte {
let post_range_end = fn_end.min(source.len());
if end_byte < post_range_end {
let mut post_refs: Vec<String> = Vec::new();
collect_identifier_refs(
&root,
source,
end_byte,
post_range_end,
lang,
&mut post_refs,
);
for decl in &in_range_decls {
if post_refs.contains(decl) {
return ReturnKind::Variable(decl.clone());
}
}
}
}
ReturnKind::Void
}
fn find_return_in_range(
node: &Node,
source: &str,
start_byte: usize,
end_byte: usize,
) -> Option<String> {
if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
return None;
}
if node.kind() == "return_statement"
&& node.start_byte() >= start_byte
&& node.end_byte() <= end_byte
{
let text = node_text(source, node).trim().to_string();
let expr = text
.strip_prefix("return")
.unwrap_or("")
.trim()
.trim_end_matches(';')
.trim()
.to_string();
if !expr.is_empty() {
return Some(expr);
}
}
let mut cursor = node.walk();
if cursor.goto_first_child() {
loop {
if let Some(result) = find_return_in_range(&cursor.node(), source, start_byte, end_byte)
{
return Some(result);
}
if !cursor.goto_next_sibling() {
break;
}
}
}
None
}
pub fn generate_extracted_function(
name: &str,
params: &[String],
return_kind: &ReturnKind,
body_text: &str,
base_indent: &str,
lang: LangId,
indent_style: IndentStyle,
) -> String {
let indent_unit = indent_style.as_str();
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => generate_ts_function(
name,
params,
return_kind,
body_text,
base_indent,
indent_unit,
),
LangId::Python => generate_py_function(
name,
params,
return_kind,
body_text,
base_indent,
indent_unit,
),
_ => {
generate_ts_function(
name,
params,
return_kind,
body_text,
base_indent,
indent_unit,
)
}
}
}
fn generate_ts_function(
name: &str,
params: &[String],
return_kind: &ReturnKind,
body_text: &str,
base_indent: &str,
indent_unit: &str,
) -> String {
let params_str = params.join(", ");
let mut lines = Vec::new();
lines.push(format!(
"{}function {}({}) {{",
base_indent, name, params_str
));
for line in body_text.lines() {
if line.trim().is_empty() {
lines.push(String::new());
} else {
lines.push(format!("{}{}{}", base_indent, indent_unit, line.trim()));
}
}
match return_kind {
ReturnKind::Variable(var) => {
lines.push(format!("{}{}return {};", base_indent, indent_unit, var));
}
ReturnKind::Expression(_) => {
}
ReturnKind::Void => {}
}
lines.push(format!("{}}}", base_indent));
lines.join("\n")
}
fn generate_py_function(
name: &str,
params: &[String],
return_kind: &ReturnKind,
body_text: &str,
base_indent: &str,
indent_unit: &str,
) -> String {
let params_str = params.join(", ");
let mut lines = Vec::new();
lines.push(format!("{}def {}({}):", base_indent, name, params_str));
for line in body_text.lines() {
if line.trim().is_empty() {
lines.push(String::new());
} else {
lines.push(format!("{}{}{}", base_indent, indent_unit, line.trim()));
}
}
match return_kind {
ReturnKind::Variable(var) => {
lines.push(format!("{}{}return {}", base_indent, indent_unit, var));
}
ReturnKind::Expression(_) => {
}
ReturnKind::Void => {}
}
lines.join("\n")
}
pub fn generate_call_site(
name: &str,
params: &[String],
return_kind: &ReturnKind,
indent: &str,
lang: LangId,
) -> String {
let args_str = params.join(", ");
match return_kind {
ReturnKind::Variable(var) => match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
format!("{}const {} = {}({});", indent, var, name, args_str)
}
LangId::Python => {
format!("{}{} = {}({})", indent, var, name, args_str)
}
_ => format!("{}const {} = {}({});", indent, var, name, args_str),
},
ReturnKind::Expression(_expr) => match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
format!("{}return {}({});", indent, name, args_str)
}
LangId::Python => {
format!("{}return {}({})", indent, name, args_str)
}
_ => format!("{}return {}({});", indent, name, args_str),
},
ReturnKind::Void => match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
format!("{}{}({});", indent, name, args_str)
}
LangId::Python => {
format!("{}{}({})", indent, name, args_str)
}
_ => format!("{}{}({});", indent, name, args_str),
},
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ScopeConflict {
pub name: String,
pub suggested: String,
}
pub fn detect_scope_conflicts(
source: &str,
tree: &Tree,
insertion_byte: usize,
_param_names: &[String],
body_text: &str,
lang: LangId,
) -> Vec<ScopeConflict> {
let root = tree.root_node();
let enclosing_fn = find_enclosing_function(&root, insertion_byte, lang);
let mut scope_decls: HashSet<String> = HashSet::new();
if let Some(fn_node) = enclosing_fn {
collect_declarations_in_range(
&fn_node,
source,
fn_node.start_byte(),
fn_node.end_byte(),
lang,
&mut scope_decls,
);
collect_function_params(&fn_node, source, lang, &mut scope_decls);
} else {
collect_declarations_in_range(
&root,
source,
root.start_byte(),
root.end_byte(),
lang,
&mut scope_decls,
);
}
let mut body_decls: HashSet<String> = HashSet::new();
let body_grammar = grammar_for(lang);
let mut body_parser = tree_sitter::Parser::new();
if body_parser.set_language(&body_grammar).is_ok() {
if let Some(body_tree) = body_parser.parse(body_text.as_bytes(), None) {
let body_root = body_tree.root_node();
collect_declarations_in_range(
&body_root,
body_text,
0,
body_text.len(),
lang,
&mut body_decls,
);
}
}
let mut conflicts = Vec::new();
for decl in &body_decls {
if scope_decls.contains(decl) {
conflicts.push(ScopeConflict {
name: decl.clone(),
suggested: format!("{}_inlined", decl),
});
}
}
conflicts.sort_by(|a, b| a.name.cmp(&b.name));
conflicts
}
pub fn validate_single_return(
source: &str,
_tree: &Tree,
fn_node: &Node,
lang: LangId,
) -> Result<(), usize> {
if lang != LangId::Python && fn_node.kind() == "arrow_function" {
if let Some(body) = fn_node.child_by_field_name("body") {
if body.kind() != "statement_block" {
return Ok(());
}
}
}
let count = count_return_statements(fn_node, source);
if count > 1 {
Err(count)
} else {
Ok(())
}
}
fn count_return_statements(node: &Node, source: &str) -> usize {
let _ = source;
let mut count = 0;
let nested_fn_kinds = [
"function_declaration",
"function_definition",
"arrow_function",
"method_definition",
];
let kind = node.kind();
if kind == "return_statement" {
return 1;
}
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
if nested_fn_kinds.contains(&child.kind()) {
continue;
}
count += count_return_statements(&child, source);
}
}
count
}
pub fn substitute_params(
body_text: &str,
param_to_arg: &std::collections::HashMap<String, String>,
lang: LangId,
) -> String {
if param_to_arg.is_empty() {
return body_text.to_string();
}
let grammar = grammar_for(lang);
let mut parser = tree_sitter::Parser::new();
if parser.set_language(&grammar).is_err() {
return body_text.to_string();
}
let tree = match parser.parse(body_text.as_bytes(), None) {
Some(t) => t,
None => return body_text.to_string(),
};
let mut replacements: Vec<(usize, usize, String)> = Vec::new();
collect_param_replacements(
&tree.root_node(),
body_text,
param_to_arg,
lang,
&mut replacements,
);
replacements.sort_by(|a, b| b.0.cmp(&a.0));
let mut result = body_text.to_string();
for (start, end, replacement) in replacements {
result = format!("{}{}{}", &result[..start], replacement, &result[end..]);
}
result
}
fn collect_param_replacements(
node: &Node,
source: &str,
param_to_arg: &std::collections::HashMap<String, String>,
lang: LangId,
out: &mut Vec<(usize, usize, String)>,
) {
let kind = node.kind();
if kind == "identifier" {
if !is_property_access(node, lang) {
let name = node_text(source, node);
if let Some(replacement) = param_to_arg.get(name) {
out.push((node.start_byte(), node.end_byte(), replacement.clone()));
}
}
}
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
collect_param_replacements(&child, source, param_to_arg, lang, out);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parser::grammar_for;
use std::path::PathBuf;
use tree_sitter::Parser;
fn fixture_path(name: &str) -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures")
.join("extract_function")
.join(name)
}
fn parse_source(source: &str, lang: LangId) -> Tree {
let grammar = grammar_for(lang);
let mut parser = Parser::new();
parser.set_language(&grammar).unwrap();
parser.parse(source.as_bytes(), None).unwrap()
}
#[test]
fn free_vars_detects_enclosing_function_params() {
let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
let tree = parse_source(&source, LangId::TypeScript);
let line5_start = crate::edit::line_col_to_byte(&source, 5, 0);
let line6_end = crate::edit::line_col_to_byte(&source, 7, 0);
let result =
detect_free_variables(&source, &tree, line5_start, line6_end, LangId::TypeScript);
assert!(
result.parameters.contains(&"items".to_string()),
"should detect 'items' as parameter, got: {:?}",
result.parameters
);
assert!(
result.parameters.contains(&"prefix".to_string()),
"should detect 'prefix' as parameter, got: {:?}",
result.parameters
);
assert!(!result.has_this_or_self);
}
#[test]
fn free_vars_filters_property_identifiers() {
let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
let tree = parse_source(&source, LangId::TypeScript);
let line5_start = crate::edit::line_col_to_byte(&source, 5, 0);
let line6_end = crate::edit::line_col_to_byte(&source, 7, 0);
let result =
detect_free_variables(&source, &tree, line5_start, line6_end, LangId::TypeScript);
assert!(
!result.parameters.contains(&"filter".to_string()),
"property 'filter' should not be a free variable"
);
assert!(
!result.parameters.contains(&"length".to_string()),
"property 'length' should not be a free variable"
);
assert!(
!result.parameters.contains(&"map".to_string()),
"property 'map' should not be a free variable"
);
}
#[test]
fn free_vars_skips_module_level_refs() {
let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
let tree = parse_source(&source, LangId::TypeScript);
let start = crate::edit::line_col_to_byte(&source, 5, 0);
let end = crate::edit::line_col_to_byte(&source, 10, 0);
let result = detect_free_variables(&source, &tree, start, end, LangId::TypeScript);
assert!(
!result.parameters.contains(&"BASE_URL".to_string()),
"module-level 'BASE_URL' should not be a parameter, got: {:?}",
result.parameters
);
assert!(
!result.parameters.contains(&"console".to_string()),
"'console' should not be a parameter, got: {:?}",
result.parameters
);
}
#[test]
fn free_vars_detects_this_in_ts() {
let source = std::fs::read_to_string(fixture_path("sample_this.ts")).unwrap();
let tree = parse_source(&source, LangId::TypeScript);
let start = crate::edit::line_col_to_byte(&source, 4, 0);
let end = crate::edit::line_col_to_byte(&source, 7, 0);
let result = detect_free_variables(&source, &tree, start, end, LangId::TypeScript);
assert!(result.has_this_or_self, "should detect 'this' reference");
}
#[test]
fn free_vars_detects_self_in_python() {
let source = r#"
class UserService:
def get_user(self, id):
key = id.lower()
user = self.users.get(key)
return user
"#;
let tree = parse_source(source, LangId::Python);
let start = crate::edit::line_col_to_byte(source, 4, 0);
let end = crate::edit::line_col_to_byte(source, 5, 0);
let result = detect_free_variables(source, &tree, start, end, LangId::Python);
assert!(result.has_this_or_self, "should detect 'self' reference");
}
#[test]
fn return_value_explicit_return() {
let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
let tree = parse_source(&source, LangId::TypeScript);
let start = crate::edit::line_col_to_byte(&source, 14, 0);
let end = crate::edit::line_col_to_byte(&source, 17, 0);
let result = detect_return_value(&source, &tree, start, end, None, LangId::TypeScript);
assert_eq!(result, ReturnKind::Expression("added".to_string()));
}
#[test]
fn return_value_post_range_usage() {
let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
let tree = parse_source(&source, LangId::TypeScript);
let start = crate::edit::line_col_to_byte(&source, 5, 0);
let end = crate::edit::line_col_to_byte(&source, 6, 0);
let fn_end = crate::edit::line_col_to_byte(&source, 10, 0);
let result =
detect_return_value(&source, &tree, start, end, Some(fn_end), LangId::TypeScript);
assert_eq!(result, ReturnKind::Variable("filtered".to_string()));
}
#[test]
fn return_value_void() {
let source = std::fs::read_to_string(fixture_path("sample.ts")).unwrap();
let tree = parse_source(&source, LangId::TypeScript);
let start = crate::edit::line_col_to_byte(&source, 20, 0);
let end = crate::edit::line_col_to_byte(&source, 22, 0);
let result = detect_return_value(
&source,
&tree,
start,
end,
Some(crate::edit::line_col_to_byte(&source, 23, 0)),
LangId::TypeScript,
);
assert_eq!(result, ReturnKind::Void);
}
#[test]
fn generate_ts_function_with_params() {
let body = "const doubled = x * 2;\nconst added = doubled + 10;";
let result = generate_extracted_function(
"compute",
&["x".to_string()],
&ReturnKind::Variable("added".to_string()),
body,
"",
LangId::TypeScript,
IndentStyle::Spaces(2),
);
assert!(result.contains("function compute(x)"));
assert!(result.contains("return added;"));
assert!(result.contains("}"));
}
#[test]
fn generate_py_function_with_params() {
let body = "doubled = x * 2\nadded = doubled + 10";
let result = generate_extracted_function(
"compute",
&["x".to_string()],
&ReturnKind::Variable("added".to_string()),
body,
"",
LangId::Python,
IndentStyle::Spaces(4),
);
assert!(result.contains("def compute(x):"));
assert!(result.contains("return added"));
}
#[test]
fn generate_call_site_with_return_var() {
let call = generate_call_site(
"compute",
&["x".to_string()],
&ReturnKind::Variable("result".to_string()),
" ",
LangId::TypeScript,
);
assert_eq!(call, " const result = compute(x);");
}
#[test]
fn generate_call_site_void() {
let call = generate_call_site(
"doWork",
&["a".to_string(), "b".to_string()],
&ReturnKind::Void,
" ",
LangId::TypeScript,
);
assert_eq!(call, " doWork(a, b);");
}
#[test]
fn generate_call_site_return_expression() {
let call = generate_call_site(
"compute",
&["x".to_string()],
&ReturnKind::Expression("x * 2".to_string()),
" ",
LangId::TypeScript,
);
assert_eq!(call, " return compute(x);");
}
#[test]
fn free_vars_python_function_params() {
let source = std::fs::read_to_string(fixture_path("sample.py")).unwrap();
let tree = parse_source(&source, LangId::Python);
let start = crate::edit::line_col_to_byte(&source, 5, 0);
let end = crate::edit::line_col_to_byte(&source, 7, 0);
let result = detect_free_variables(&source, &tree, start, end, LangId::Python);
assert!(
result.parameters.contains(&"items".to_string()),
"should detect 'items': {:?}",
result.parameters
);
assert!(
result.parameters.contains(&"prefix".to_string()),
"should detect 'prefix': {:?}",
result.parameters
);
assert!(!result.has_this_or_self);
}
#[test]
fn validate_single_return_single() {
let source =
"function add(a: number, b: number): number {\n const sum = a + b;\n return sum;\n}";
let tree = parse_source(source, LangId::TypeScript);
let root = tree.root_node();
let fn_node = root.child(0).unwrap(); assert!(validate_single_return(source, &tree, &fn_node, LangId::TypeScript).is_ok());
}
#[test]
fn validate_single_return_void() {
let source = "function greet(name: string): void {\n console.log(name);\n}";
let tree = parse_source(source, LangId::TypeScript);
let root = tree.root_node();
let fn_node = root.child(0).unwrap();
assert!(validate_single_return(source, &tree, &fn_node, LangId::TypeScript).is_ok());
}
#[test]
fn validate_single_return_expression_body() {
let source = "const double = (n: number): number => n * 2;";
let tree = parse_source(source, LangId::TypeScript);
let root = tree.root_node();
let lex_decl = root.child(0).unwrap();
let var_decl = lex_decl.child(1).unwrap(); let arrow = var_decl.child_by_field_name("value").unwrap();
assert_eq!(arrow.kind(), "arrow_function");
assert!(validate_single_return(source, &tree, &arrow, LangId::TypeScript).is_ok());
}
#[test]
fn validate_single_return_multiple() {
let source = "function abs(x: number): number {\n if (x > 0) {\n return x;\n }\n return -x;\n}";
let tree = parse_source(source, LangId::TypeScript);
let root = tree.root_node();
let fn_node = root.child(0).unwrap();
let result = validate_single_return(source, &tree, &fn_node, LangId::TypeScript);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), 2);
}
#[test]
fn scope_conflicts_none() {
let source = "function main() {\n const x = 10;\n const y = add(x, 5);\n}";
let tree = parse_source(source, LangId::TypeScript);
let body_text = "const sum = a + b;";
let call_byte = crate::edit::line_col_to_byte(source, 2, 0);
let conflicts =
detect_scope_conflicts(source, &tree, call_byte, &[], body_text, LangId::TypeScript);
assert!(
conflicts.is_empty(),
"expected no conflicts, got: {:?}",
conflicts
);
}
#[test]
fn scope_conflicts_detected() {
let source = "function main() {\n const temp = 99;\n const result = compute(5);\n}";
let tree = parse_source(source, LangId::TypeScript);
let body_text = "const temp = x * 2;\nconst result2 = temp + 10;";
let call_byte = crate::edit::line_col_to_byte(source, 2, 0);
let conflicts =
detect_scope_conflicts(source, &tree, call_byte, &[], body_text, LangId::TypeScript);
assert!(!conflicts.is_empty(), "expected conflict for 'temp'");
assert!(
conflicts.iter().any(|c| c.name == "temp"),
"conflicts: {:?}",
conflicts
);
assert!(
conflicts.iter().any(|c| c.suggested == "temp_inlined"),
"should suggest temp_inlined"
);
}
#[test]
fn substitute_params_basic() {
let body = "const sum = a + b;";
let mut map = std::collections::HashMap::new();
map.insert("a".to_string(), "x".to_string());
map.insert("b".to_string(), "y".to_string());
let result = substitute_params(body, &map, LangId::TypeScript);
assert_eq!(result, "const sum = x + y;");
}
#[test]
fn substitute_params_whole_word() {
let body = "const result = items.filter(i => i > 0);";
let mut map = std::collections::HashMap::new();
map.insert("i".to_string(), "index".to_string());
let result = substitute_params(body, &map, LangId::TypeScript);
assert!(
!result.contains("items") || result.contains("items"),
"items should be preserved"
);
assert!(
result.contains("index"),
"should contain 'index': {}",
result
);
}
#[test]
fn substitute_params_noop_same_name() {
let body = "const sum = x + y;";
let mut map = std::collections::HashMap::new();
map.insert("x".to_string(), "x".to_string());
let result = substitute_params(body, &map, LangId::TypeScript);
assert_eq!(result, "const sum = x + y;");
}
#[test]
fn substitute_params_empty_map() {
let body = "const sum = a + b;";
let map = std::collections::HashMap::new();
let result = substitute_params(body, &map, LangId::TypeScript);
assert_eq!(result, body);
}
}