use std::path::Path;
use normalize_languages::parsers::parse_with_grammar;
use normalize_languages::support_for_path;
use crate::{PlannedEdit, RefactoringContext, RefactoringPlan};
pub struct InlineFunctionOutcome {
pub plan: RefactoringPlan,
pub function_name: String,
pub call_site_line: usize,
}
pub fn plan_inline_function(
_ctx: &RefactoringContext,
file_abs: &Path,
content: &str,
line: usize,
col: usize,
force: bool,
) -> Result<InlineFunctionOutcome, String> {
let support = support_for_path(file_abs).ok_or_else(|| {
let ext = file_abs
.extension()
.and_then(|e| e.to_str())
.unwrap_or("<unknown>");
format!("inline-function: no language support for .{ext} files")
})?;
let grammar = support.grammar_name();
let tree = parse_with_grammar(grammar, content).ok_or_else(|| {
format!(
"inline-function: grammar for {grammar} not loaded — run `normalize grammars install`"
)
})?;
let cursor_byte = line_col_to_byte(content, line, col)?;
let root = tree.root_node();
let function_name =
resolve_function_name_at(&root, content, cursor_byte, grammar).ok_or_else(|| {
format!("inline-function: no function definition or call found at {line}:{col}")
})?;
let def = find_function_def(&root, content, &function_name, grammar).ok_or_else(|| {
format!("inline-function: definition of '{function_name}' not found in this file")
})?;
let body_text = extract_body_text(content, &def)?;
let call_sites = find_call_sites(&root, content, &function_name, grammar);
match call_sites.len() {
0 => {
return Err(format!(
"inline-function: '{function_name}' has no call sites in this file"
));
}
1 => {} _ if force => {} n => {
return Err(format!(
"inline-function: '{function_name}' is called {n} times; use --force to inline anyway (or inline the specific call manually)"
));
}
}
let call_site = &call_sites[0];
let inlined = substitute_call(content, &def, call_site, &body_text)?;
let call_site_line = call_site.line;
let final_content = remove_function_def(&inlined, &def, content)?;
let plan = RefactoringPlan {
operation: "inline-function".to_string(),
edits: vec![PlannedEdit {
file: file_abs.to_path_buf(),
original: content.to_string(),
new_content: final_content,
description: format!("inline {function_name}"),
}],
warnings: vec![],
};
Ok(InlineFunctionOutcome {
plan,
function_name,
call_site_line,
})
}
struct FunctionDef {
name: String,
params: Vec<String>,
def_start_byte: usize,
def_end_byte: usize,
body_start_byte: usize,
body_end_byte: usize,
}
struct CallSite {
args: Vec<String>,
call_start_byte: usize,
call_end_byte: usize,
line: usize,
}
fn line_col_to_byte(content: &str, line: usize, col: usize) -> Result<usize, String> {
if line == 0 {
return Err("inline-function: line is 1-based; 0 is invalid".to_string());
}
let mut current_line = 1usize;
let mut line_start = 0usize;
for (i, ch) in content.char_indices() {
if current_line == line {
line_start = i;
break;
}
if ch == '\n' {
current_line += 1;
}
if current_line > line {
return Err(format!(
"inline-function: line {line} is beyond end of file ({current_line} lines)"
));
}
}
if current_line < line {
return Err(format!(
"inline-function: line {line} is beyond end of file"
));
}
let col_offset = col.saturating_sub(1); let byte = line_start
+ col_offset.min(
content[line_start..]
.find('\n')
.unwrap_or(content[line_start..].len()),
);
Ok(byte.min(content.len()))
}
fn byte_to_line(content: &str, byte: usize) -> usize {
content[..byte.min(content.len())]
.chars()
.filter(|&c| c == '\n')
.count()
+ 1
}
fn resolve_function_name_at<'a>(
root: &tree_sitter::Node<'a>,
content: &str,
cursor_byte: usize,
_grammar: &str,
) -> Option<String> {
let node = root.descendant_for_byte_range(cursor_byte, cursor_byte + 1)?;
let mut n = node;
loop {
let kind = n.kind();
if is_function_def_kind(kind) {
if let Some(name_node) = find_name_child(&n, content) {
return Some(name_node);
}
}
if (kind == "lexical_declaration" || kind == "variable_declaration")
&& let Some(name) = extract_arrow_def_name(&n, content)
{
return Some(name);
}
if (kind == "call_expression" || kind == "call")
&& let Some(callee) = n
.child_by_field_name("function")
.or_else(|| n.child_by_field_name("callee"))
{
let callee_text = &content[callee.start_byte()..callee.end_byte()];
if !callee_text.contains('.') && !callee_text.contains(':') {
return Some(callee_text.to_string());
}
}
match n.parent() {
Some(p) if p.id() != root.id() => n = p,
_ => break,
}
}
None
}
fn is_function_def_kind(kind: &str) -> bool {
matches!(
kind,
"function_declaration"
| "function_definition" | "function_item" | "method_definition"
| "generator_function_declaration"
)
}
fn is_arrow_or_func_expr_kind(kind: &str) -> bool {
matches!(
kind,
"arrow_function" | "function_expression" | "generator_function"
)
}
fn find_params_child<'a>(node: &tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
let mut c = node.walk();
let mut found = None;
if c.goto_first_child() {
loop {
let n = c.node();
if matches!(
n.kind(),
"formal_parameters" | "parameters" | "parameter_list"
) {
found = Some(n);
break;
}
if !c.goto_next_sibling() {
break;
}
}
}
found
}
fn find_body_child<'a>(node: &tree_sitter::Node<'a>) -> Option<tree_sitter::Node<'a>> {
let mut c = node.walk();
if c.goto_first_child() {
loop {
let n = c.node();
if matches!(n.kind(), "statement_block" | "block" | "function_body") {
return Some(n);
}
if !c.goto_next_sibling() {
break;
}
}
}
let mut c = node.walk();
let mut past_arrow = false;
if c.goto_first_child() {
loop {
let n = c.node();
if past_arrow && n.is_named() {
return Some(n);
}
if n.kind() == "=>" {
past_arrow = true;
}
if !c.goto_next_sibling() {
break;
}
}
}
None
}
fn find_name_child(node: &tree_sitter::Node<'_>, content: &str) -> Option<String> {
if let Some(name_node) = node.child_by_field_name("name") {
return Some(content[name_node.start_byte()..name_node.end_byte()].to_string());
}
None
}
fn extract_arrow_def_name(node: &tree_sitter::Node<'_>, content: &str) -> Option<String> {
let mut c = node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
if child.kind() == "variable_declarator"
&& let Some(name) = arrow_declarator_name(&child, content)
{
return Some(name);
}
if !c.goto_next_sibling() {
break;
}
}
}
None
}
fn arrow_declarator_name(decl: &tree_sitter::Node<'_>, content: &str) -> Option<String> {
let name_via_field = decl.child_by_field_name("name").or_else(|| {
let mut c = decl.walk();
let mut found = None;
if c.goto_first_child() {
loop {
let n = c.node();
if n.kind() == "identifier" {
found = Some(n);
break;
}
if !c.goto_next_sibling() {
break;
}
}
}
found
});
let name_text = name_via_field.map(|n| content[n.start_byte()..n.end_byte()].to_string())?;
let has_func_value = decl
.child_by_field_name("value")
.map(|v| is_arrow_or_func_expr_kind(v.kind()))
.unwrap_or_else(|| {
let mut c = decl.walk();
let mut found = false;
if c.goto_first_child() {
loop {
let n = c.node();
if is_arrow_or_func_expr_kind(n.kind()) {
found = true;
break;
}
if !c.goto_next_sibling() {
break;
}
}
}
found
});
if has_func_value {
Some(name_text)
} else {
None
}
}
fn find_function_def(
root: &tree_sitter::Node<'_>,
content: &str,
name: &str,
_grammar: &str,
) -> Option<FunctionDef> {
let mut cursor = root.walk();
find_function_def_recursive(&mut cursor, *root, content, name)
}
fn find_function_def_recursive(
cursor: &mut tree_sitter::TreeCursor<'_>,
node: tree_sitter::Node<'_>,
content: &str,
name: &str,
) -> Option<FunctionDef> {
let kind = node.kind();
if is_function_def_kind(kind)
&& let Some(found_name) = find_name_child(&node, content)
&& found_name == name
{
return extract_function_def(&node, content, name, true);
}
if (kind == "lexical_declaration" || kind == "variable_declaration")
&& let Some(def) = try_extract_arrow_def(&node, content, name)
{
return Some(def);
}
if cursor.goto_first_child() {
loop {
let child = cursor.node();
if let Some(result) = find_function_def_recursive(cursor, child, content, name) {
return Some(result);
}
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
None
}
fn try_extract_arrow_def(
decl_node: &tree_sitter::Node<'_>,
content: &str,
name: &str,
) -> Option<FunctionDef> {
let mut decl_cursor = decl_node.walk();
if decl_cursor.goto_first_child() {
loop {
let child = decl_cursor.node();
if child.kind() == "variable_declarator"
&& arrow_declarator_name(&child, content).as_deref() == Some(name)
{
return extract_function_def(decl_node, content, name, true);
}
if !decl_cursor.goto_next_sibling() {
break;
}
}
}
None
}
fn extract_function_def(
node: &tree_sitter::Node<'_>,
content: &str,
name: &str,
_is_statement: bool,
) -> Option<FunctionDef> {
let (param_node, body_node) =
if node.kind() == "lexical_declaration" || node.kind() == "variable_declaration" {
let mut c = node.walk();
let mut found_decl: Option<tree_sitter::Node<'_>> = None;
if c.goto_first_child() {
loop {
let child = c.node();
if child.kind() == "variable_declarator" {
let vname = child
.child_by_field_name("name")
.map(|n| &content[n.start_byte()..n.end_byte()]);
if vname == Some(name) {
found_decl = Some(child);
break;
}
}
if !c.goto_next_sibling() {
break;
}
}
}
let decl = found_decl?;
let value = decl.child_by_field_name("value").or_else(|| {
let mut cc = decl.walk();
let mut found = None;
if cc.goto_first_child() {
loop {
let n = cc.node();
if is_arrow_or_func_expr_kind(n.kind()) {
found = Some(n);
break;
}
if !cc.goto_next_sibling() {
break;
}
}
}
found
})?;
let params = value
.child_by_field_name("parameters")
.or_else(|| value.child_by_field_name("formal_parameters"))
.or_else(|| find_params_child(&value))?;
let body = value
.child_by_field_name("body")
.or_else(|| find_body_child(&value))?;
(params, body)
} else {
let params = node
.child_by_field_name("parameters")
.or_else(|| node.child_by_field_name("formal_parameters"))
.or_else(|| find_params_child(node))?;
let body = node
.child_by_field_name("body")
.or_else(|| find_body_child(node))?;
(params, body)
};
let params = extract_parameter_names(¶m_node, content);
let (body_start_byte, body_end_byte) =
if body_node.kind() == "statement_block" || body_node.kind() == "block" {
let inner_start = body_node.start_byte() + 1;
let inner_end = body_node.end_byte() - 1;
(inner_start, inner_end)
} else {
(body_node.start_byte(), body_node.end_byte())
};
let def_start_byte = {
let raw = node.start_byte();
content[..raw].rfind('\n').map(|i| i + 1).unwrap_or(0)
};
let def_end_byte = {
let raw = node.end_byte();
if raw < content.len() && content.as_bytes()[raw] == b'\n' {
raw + 1
} else {
raw
}
};
Some(FunctionDef {
name: name.to_string(),
params,
def_start_byte,
def_end_byte,
body_start_byte,
body_end_byte,
})
}
fn extract_parameter_names(params_node: &tree_sitter::Node<'_>, content: &str) -> Vec<String> {
let mut names = vec![];
let mut c = params_node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
let kind = child.kind();
let param_name = match kind {
"identifier" => Some(&content[child.start_byte()..child.end_byte()]),
"required_parameter" | "optional_parameter" => child
.child_by_field_name("pattern")
.or_else(|| {
let mut cc = child.walk();
if cc.goto_first_child() {
loop {
let n = cc.node();
if n.kind() == "identifier" {
return Some(n);
}
if !cc.goto_next_sibling() {
break;
}
}
}
None
})
.map(|n| &content[n.start_byte()..n.end_byte()]),
"typed_parameter" | "default_parameter" => {
let mut cc = child.walk();
let mut found = None;
if cc.goto_first_child() {
loop {
let n = cc.node();
if n.kind() == "identifier" {
found = Some(&content[n.start_byte()..n.end_byte()]);
break;
}
if !cc.goto_next_sibling() {
break;
}
}
}
found
}
"parameter" => child
.child_by_field_name("pattern")
.map(|n| &content[n.start_byte()..n.end_byte()]),
_ => None,
};
if let Some(n) = param_name
&& !n.is_empty()
{
names.push(n.to_string());
}
if !c.goto_next_sibling() {
break;
}
}
}
names
}
fn extract_body_text(content: &str, def: &FunctionDef) -> Result<String, String> {
let raw_body = &content[def.body_start_byte..def.body_end_byte];
let return_count = count_return_statements(raw_body);
if return_count > 1 {
return Err(format!(
"inline-function: '{}' has {} return statements; inlining would require control-flow analysis — aborting (too complex)",
def.name, return_count
));
}
Ok(raw_body.to_string())
}
fn count_return_statements(text: &str) -> usize {
let mut count = 0usize;
let mut i = 0usize;
let bytes = text.as_bytes();
while i + 6 <= bytes.len() {
if &bytes[i..i + 6] == b"return" {
let before_ok = i == 0 || !bytes[i - 1].is_ascii_alphanumeric() && bytes[i - 1] != b'_';
let after = bytes.get(i + 6).copied();
let after_ok = after.is_none_or(|b| !b.is_ascii_alphanumeric() && b != b'_');
if before_ok && after_ok {
count += 1;
}
}
i += 1;
}
count
}
fn find_call_sites(
root: &tree_sitter::Node<'_>,
content: &str,
name: &str,
_grammar: &str,
) -> Vec<CallSite> {
let mut sites = vec![];
let mut cursor = root.walk();
find_call_sites_recursive(&mut cursor, *root, content, name, &mut sites);
sites
}
fn find_call_sites_recursive(
cursor: &mut tree_sitter::TreeCursor<'_>,
node: tree_sitter::Node<'_>,
content: &str,
name: &str,
sites: &mut Vec<CallSite>,
) {
let kind = node.kind();
if kind == "call_expression" || kind == "call" {
let callee = node
.child_by_field_name("function")
.or_else(|| node.child_by_field_name("callee"));
if let Some(callee_node) = callee {
let callee_text = &content[callee_node.start_byte()..callee_node.end_byte()];
if callee_text == name {
let args = extract_call_args(&node, content);
let line = byte_to_line(content, node.start_byte());
sites.push(CallSite {
args,
call_start_byte: node.start_byte(),
call_end_byte: node.end_byte(),
line,
});
}
}
}
if cursor.goto_first_child() {
loop {
let child = cursor.node();
find_call_sites_recursive(cursor, child, content, name, sites);
if !cursor.goto_next_sibling() {
break;
}
}
cursor.goto_parent();
}
}
fn extract_call_args(call_node: &tree_sitter::Node<'_>, content: &str) -> Vec<String> {
let mut args = vec![];
let args_node = call_node.child_by_field_name("arguments").or_else(|| {
let mut c = call_node.walk();
let mut found = None;
if c.goto_first_child() {
loop {
let n = c.node();
if matches!(n.kind(), "argument_list" | "arguments") {
found = Some(n);
break;
}
if !c.goto_next_sibling() {
break;
}
}
}
found
});
let Some(args_node) = args_node else {
return args;
};
let mut c = args_node.walk();
if c.goto_first_child() {
loop {
let child = c.node();
let kind = child.kind();
if kind != "," && kind != "(" && kind != ")" && child.is_named() {
args.push(content[child.start_byte()..child.end_byte()].to_string());
}
if !c.goto_next_sibling() {
break;
}
}
}
args
}
fn substitute_call(
content: &str,
def: &FunctionDef,
call: &CallSite,
body_text: &str,
) -> Result<String, String> {
if call.args.len() != def.params.len() {
return Err(format!(
"inline-function: '{}' expects {} arguments but call site provides {} — aborting",
def.name,
def.params.len(),
call.args.len()
));
}
let trimmed = body_text.trim();
let stripped = strip_single_return(trimmed);
let mut result = stripped.to_string();
for (param, arg) in def.params.iter().zip(call.args.iter()) {
result = normalize_edit::replace_all_words(&result, param, arg);
}
let replacement = result.trim().to_string();
let mut new_content = String::new();
new_content.push_str(&content[..call.call_start_byte]);
new_content.push_str(&replacement);
new_content.push_str(&content[call.call_end_byte..]);
Ok(new_content)
}
fn strip_single_return(s: &str) -> &str {
let s = s.trim_start();
if let Some(rest) = s.strip_prefix("return") {
let after = rest.trim_start_matches([' ', '\t']);
after.strip_suffix(';').unwrap_or(after).trim()
} else {
s
}
}
fn remove_function_def(inlined: &str, def: &FunctionDef, original: &str) -> Result<String, String> {
let orig_def_text = &original[def.def_start_byte..def.def_end_byte];
if let Some(pos) = inlined.find(orig_def_text) {
let mut result = String::new();
result.push_str(&inlined[..pos]);
result.push_str(&inlined[pos + orig_def_text.len()..]);
Ok(collapse_triple_newlines(result))
} else {
Err(format!(
"inline-function: could not locate definition of '{}' in modified content — aborting",
def.name
))
}
}
fn collapse_triple_newlines(s: String) -> String {
let mut result = s;
loop {
let before = result.len();
result = result.replace("\n\n\n", "\n\n");
if result.len() == before {
break;
}
}
result
}