use std::collections::HashMap;
use std::path::Path;
use tree_sitter::Parser;
use crate::context::AppContext;
use crate::edit;
use crate::extract::{detect_scope_conflicts, substitute_params, validate_single_return};
use crate::lsp_hints;
use crate::parser::{detect_language, grammar_for, node_text, LangId};
use crate::protocol::{RawRequest, Response};
pub fn handle_inline_symbol(req: &RawRequest, ctx: &AppContext) -> Response {
let op_id = crate::backup::new_op_id();
let file = match req.params.get("file").and_then(|v| v.as_str()) {
Some(f) => f,
None => {
return Response::error(
&req.id,
"invalid_request",
"inline_symbol: missing required param 'file'",
);
}
};
let symbol = match req.params.get("symbol").and_then(|v| v.as_str()) {
Some(s) => s,
None => {
return Response::error(
&req.id,
"invalid_request",
"inline_symbol: missing required param 'symbol'",
);
}
};
let call_site_line = match req.params.get("call_site_line").and_then(|v| v.as_u64()) {
Some(l) if l >= 1 => (l - 1) as u32,
Some(_) => {
return Response::error(
&req.id,
"invalid_request",
"inline_symbol: 'call_site_line' must be >= 1 (1-based)",
);
}
None => {
return Response::error(
&req.id,
"invalid_request",
"inline_symbol: missing required param 'call_site_line'",
);
}
};
let path = match ctx.validate_path(&req.id, Path::new(file)) {
Ok(path) => path,
Err(resp) => return resp,
};
if !path.exists() {
return Response::error(
&req.id,
"file_not_found",
format!("inline_symbol: file not found: {}", file),
);
}
let lang = match detect_language(&path) {
Some(l) => l,
None => {
return Response::error(
&req.id,
"unsupported_language",
"inline_symbol: unsupported file type",
);
}
};
if !matches!(
lang,
LangId::TypeScript | LangId::Tsx | LangId::JavaScript | LangId::Python
) {
return Response::error(
&req.id,
"unsupported_language",
format!(
"inline_symbol: only TypeScript/JavaScript/Python files are supported, got {:?}",
lang
),
);
}
let source = match std::fs::read_to_string(&path) {
Ok(s) => s,
Err(e) => {
return Response::error(
&req.id,
"file_not_found",
format!("inline_symbol: {}: {}", file, e),
);
}
};
let grammar = grammar_for(lang);
let mut parser = Parser::new();
if parser.set_language(&grammar).is_err() {
return Response::error(
&req.id,
"parse_error",
"inline_symbol: failed to initialize parser",
);
}
let tree = match parser.parse(source.as_bytes(), None) {
Some(t) => t,
None => {
return Response::error(
&req.id,
"parse_error",
"inline_symbol: failed to parse file",
);
}
};
let matches = match ctx.provider().resolve_symbol(&path, symbol) {
Ok(m) => m,
Err(e) => {
return Response::error(&req.id, e.code(), e.to_string());
}
};
let matches = if let Some(hints) = lsp_hints::parse_lsp_hints(req) {
lsp_hints::apply_lsp_disambiguation(matches, &hints)
} else {
matches
};
let sym = match matches.iter().find(|m| {
matches!(
m.symbol.kind,
crate::symbols::SymbolKind::Function | crate::symbols::SymbolKind::Method
)
}) {
Some(m) => &m.symbol,
None => {
return Response::error(
&req.id,
"symbol_not_found",
format!("inline_symbol: no function '{}' found in {}", symbol, file),
);
}
};
let fn_start_byte = edit::line_col_to_byte(&source, sym.range.start_line, sym.range.start_col);
let fn_node = match find_function_node_at(&tree.root_node(), fn_start_byte, lang) {
Some(n) => n,
None => {
return Response::error(
&req.id,
"symbol_not_found",
format!(
"inline_symbol: could not locate function node for '{}' in AST",
symbol
),
);
}
};
if let Err(count) = validate_single_return(&source, &tree, &fn_node, lang) {
return Response::error_with_data(
&req.id,
"multiple_returns",
format!(
"inline_symbol: function '{}' has {} return statements (max 1 for inlining)",
symbol, count
),
serde_json::json!({
"return_count": count,
"symbol": symbol,
}),
);
}
let (param_patterns, body_text) = extract_fn_params_and_body(&fn_node, &source, lang);
let param_names = param_patterns
.iter()
.flat_map(InlineParam::binding_names)
.collect::<Vec<_>>();
let call_line_start = edit::line_col_to_byte(&source, call_site_line, 0);
let call_line_end = if (call_site_line + 1) as usize <= source.lines().count() {
edit::line_col_to_byte(&source, call_site_line + 1, 0)
} else {
source.len()
};
let call_node = match find_call_node_at_line(
&tree.root_node(),
symbol,
&source,
call_line_start,
call_line_end,
lang,
) {
Some(n) => n,
None => {
return Response::error(
&req.id,
"call_not_found",
format!(
"inline_symbol: no call to '{}' found at line {} in {}",
symbol, call_site_line, file
),
);
}
};
let (call_context, replacement_node, assignment_var) =
detect_call_context(&call_node, &source, lang);
let args = extract_call_arguments(&call_node, &source, lang);
let param_to_arg = match build_param_to_arg_map(¶m_patterns, &args) {
Ok(map) => map,
Err(reason) => {
return Response::error_with_data(
&req.id,
"param_mismatch",
format!("inline_symbol: cannot inline '{}': {}", symbol, reason),
serde_json::json!({
"symbol": symbol,
"reason": reason,
}),
);
}
};
let conflicts = detect_scope_conflicts(
&source,
&tree,
replacement_node.start_byte(),
¶m_names,
&body_text,
lang,
);
if !conflicts.is_empty() {
let conflicting_names: Vec<&str> = conflicts.iter().map(|c| c.name.as_str()).collect();
let suggestions: Vec<serde_json::Value> = conflicts
.iter()
.map(|c| {
serde_json::json!({
"original": c.name,
"suggested": c.suggested,
})
})
.collect();
return Response::error_with_data(
&req.id,
"scope_conflict",
format!(
"inline_symbol: scope conflicts detected when inlining '{}': [{}]",
symbol,
conflicting_names.join(", ")
),
serde_json::json!({
"conflicting_names": conflicting_names,
"suggestions": suggestions,
"symbol": symbol,
}),
);
}
let substituted_body = substitute_params(&body_text, ¶m_to_arg, lang);
let substitution_count = param_to_arg.len();
let replacement_indent = get_line_indent(&source, call_site_line as usize);
let replacement_text = build_inline_replacement(
&substituted_body,
&call_context,
&replacement_indent,
lang,
assignment_var.as_deref(),
);
let replace_start = {
let bytes = source.as_bytes();
let mut s = replacement_node.start_byte();
while s > 0 && matches!(bytes[s - 1], b' ' | b'\t') {
s -= 1;
}
s
};
let new_source = match edit::replace_byte_range(
&source,
replace_start,
replacement_node.end_byte(),
&replacement_text,
) {
Ok(s) => s,
Err(e) => return Response::error(&req.id, e.code(), e.to_string()),
};
let backup_id = match edit::auto_backup(
ctx,
req.session(),
&path,
&format!("inline_symbol: {}", symbol),
Some(&op_id),
) {
Ok(id) => id,
Err(e) => {
return Response::error(&req.id, e.code(), e.to_string());
}
};
let mut write_result =
match edit::write_format_validate(&path, &new_source, &ctx.config(), &req.params) {
Ok(r) => r,
Err(e) => {
return Response::error(&req.id, e.code(), e.to_string());
}
};
if let Ok(final_content) = std::fs::read_to_string(&path) {
write_result.lsp_outcome = ctx.lsp_post_write(&path, &final_content, &req.params);
}
log::debug!("inline_symbol: {} at {}:{}", symbol, file, call_site_line);
let mut result = serde_json::json!({
"file": file,
"symbol": symbol,
"call_context": call_context,
"substitutions": substitution_count,
"conflicts": [],
"formatted": write_result.formatted,
});
if let Some(valid) = write_result.syntax_valid {
result["syntax_valid"] = serde_json::json!(valid);
}
if let Some(ref reason) = write_result.format_skipped_reason {
result["format_skipped_reason"] = serde_json::json!(reason);
}
if write_result.validate_requested {
result["validation_errors"] = serde_json::json!(write_result.validation_errors);
}
if let Some(ref reason) = write_result.validate_skipped_reason {
result["validate_skipped_reason"] = serde_json::json!(reason);
}
if let Some(ref id) = backup_id {
result["backup_id"] = serde_json::json!(id);
}
write_result.append_lsp_diagnostics_to(&mut result);
Response::success(&req.id, result)
}
fn find_function_node_at<'a>(
root: &tree_sitter::Node<'a>,
byte_pos: usize,
lang: LangId,
) -> Option<tree_sitter::Node<'a>> {
let fn_kinds: &[&str] = match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => &[
"function_declaration",
"method_definition",
"arrow_function",
],
LangId::Python => &["function_definition"],
_ => &[],
};
find_node_at(root, byte_pos, fn_kinds)
}
fn find_node_at<'a>(
node: &tree_sitter::Node<'a>,
byte_pos: usize,
kinds: &[&str],
) -> Option<tree_sitter::Node<'a>> {
if node.end_byte() <= byte_pos {
return None;
}
if kinds.contains(&node.kind()) && node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
return Some(*node);
}
if node.kind() == "export_statement"
&& node.start_byte() <= byte_pos
&& byte_pos < node.end_byte()
{
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
if kinds.contains(&child.kind()) {
return Some(child);
}
if kinds.contains(&"arrow_function") && child.kind() == "lexical_declaration" {
let grand_count = child.child_count();
for j in 0..grand_count {
if let Some(grand) = child.child(j as u32) {
if grand.kind() == "variable_declarator" {
if let Some(value) = grand.child_by_field_name("value") {
if value.kind() == "arrow_function" {
return Some(value);
}
}
}
}
}
}
}
}
}
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
if child.start_byte() <= byte_pos && byte_pos < child.end_byte() {
if let Some(found) = find_node_at(&child, byte_pos, kinds) {
return Some(found);
}
}
}
}
if kinds.contains(&"arrow_function") && node.kind() == "lexical_declaration" {
if node.start_byte() <= byte_pos && byte_pos < node.end_byte() {
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
if child.kind() == "variable_declarator" {
if let Some(value) = child.child_by_field_name("value") {
if value.kind() == "arrow_function" {
return Some(value);
}
}
}
}
}
}
}
None
}
#[derive(Debug, Clone)]
enum InlineParam {
Simple {
name: String,
default_value: Option<String>,
},
Destructured {
bindings: Vec<DestructuredBinding>,
},
Rest {
name: String,
},
Unsupported {
description: String,
},
}
#[derive(Debug, Clone)]
struct DestructuredBinding {
name: String,
access_path: String,
}
impl InlineParam {
fn binding_names(&self) -> Vec<String> {
match self {
InlineParam::Simple { name, .. } | InlineParam::Rest { name } => vec![name.clone()],
InlineParam::Destructured { bindings } => bindings
.iter()
.map(|binding| binding.name.clone())
.collect(),
InlineParam::Unsupported { .. } => Vec::new(),
}
}
}
fn extract_fn_params_and_body(
fn_node: &tree_sitter::Node,
source: &str,
lang: LangId,
) -> (Vec<InlineParam>, String) {
let mut param_patterns = Vec::new();
let params_node = fn_node.child_by_field_name("parameters");
if let Some(params) = params_node {
let child_count = params.child_count();
for i in 0..child_count {
if let Some(child) = params.child(i as u32) {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
if let Some(param) = extract_ts_param_pattern(&child, source) {
param_patterns.push(param);
}
}
LangId::Python => {
if child.kind() == "identifier" {
let name = node_text(source, &child).to_string();
if name != "self" {
param_patterns.push(InlineParam::Simple {
name,
default_value: None,
});
}
}
}
_ => {}
}
}
}
}
let body_text = if let Some(body) = fn_node.child_by_field_name("body") {
let raw = node_text(source, &body);
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
if body.kind() == "statement_block" {
strip_braces(raw)
} else {
raw.to_string()
}
}
LangId::Python => {
raw.to_string()
}
_ => raw.to_string(),
}
} else {
String::new()
};
(param_patterns, body_text)
}
fn extract_ts_param_pattern(node: &tree_sitter::Node, source: &str) -> Option<InlineParam> {
match node.kind() {
"required_parameter" | "optional_parameter" => {
let raw = node_text(source, node).trim_start();
if raw.starts_with("...") {
return node
.child_by_field_name("pattern")
.and_then(|pattern| first_identifier_name(&pattern, source))
.map(|name| InlineParam::Rest { name });
}
if let Some(default_value) = node.child_by_field_name("value") {
if let Some(pattern) = node.child_by_field_name("pattern") {
if pattern.kind() == "identifier" {
return Some(InlineParam::Simple {
name: node_text(source, &pattern).to_string(),
default_value: Some(
node_text(source, &default_value).trim().to_string(),
),
});
}
}
}
node.child_by_field_name("pattern")
.and_then(|pattern| extract_ts_param_pattern(&pattern, source))
}
"identifier" => Some(InlineParam::Simple {
name: node_text(source, node).to_string(),
default_value: None,
}),
"assignment_pattern" => {
let left = node
.child_by_field_name("left")
.or_else(|| node.child_by_field_name("pattern"));
let right = node
.child_by_field_name("right")
.or_else(|| node.child_by_field_name("value"));
match (left, right) {
(Some(left), Some(right)) if left.kind() == "identifier" => {
Some(InlineParam::Simple {
name: node_text(source, &left).to_string(),
default_value: Some(node_text(source, &right).trim().to_string()),
})
}
(Some(left), _) => extract_ts_param_pattern(&left, source),
_ => Some(InlineParam::Unsupported {
description: format!(
"unsupported default parameter `{}`",
node_text(source, node)
),
}),
}
}
"rest_pattern" => {
first_identifier_name(node, source).map(|name| InlineParam::Rest { name })
}
"object_pattern" | "array_pattern" => {
let mut bindings = Vec::new();
if collect_destructured_bindings(node, source, "", &mut bindings).is_some() {
Some(InlineParam::Destructured { bindings })
} else {
Some(InlineParam::Unsupported {
description: format!(
"unsupported destructured parameter `{}`",
node_text(source, node)
),
})
}
}
"(" | ")" | "," => None,
_ if !node.is_named() => None,
_ => None,
}
}
fn first_identifier_name(node: &tree_sitter::Node, source: &str) -> Option<String> {
if matches!(
node.kind(),
"identifier" | "shorthand_property_identifier_pattern"
) {
return Some(node_text(source, node).to_string());
}
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
if let Some(name) = first_identifier_name(&child, source) {
return Some(name);
}
}
}
None
}
fn collect_destructured_bindings(
node: &tree_sitter::Node,
source: &str,
access_prefix: &str,
out: &mut Vec<DestructuredBinding>,
) -> Option<()> {
match node.kind() {
"identifier" | "shorthand_property_identifier_pattern" => {
out.push(DestructuredBinding {
name: node_text(source, node).to_string(),
access_path: access_prefix.to_string(),
});
Some(())
}
"assignment_pattern" => {
let left = node
.child_by_field_name("left")
.or_else(|| node.child_by_field_name("pattern"))?;
collect_destructured_bindings(&left, source, access_prefix, out)
}
"object_pattern" => collect_object_pattern_bindings(node, source, access_prefix, out),
"array_pattern" => collect_array_pattern_bindings(node, source, access_prefix, out),
_ => None,
}
}
fn collect_object_pattern_bindings(
node: &tree_sitter::Node,
source: &str,
access_prefix: &str,
out: &mut Vec<DestructuredBinding>,
) -> Option<()> {
let child_count = node.child_count();
for i in 0..child_count {
let Some(child) = node.child(i as u32) else {
continue;
};
match child.kind() {
"{" | "}" | "," => {}
"shorthand_property_identifier_pattern" | "identifier" => {
let name = node_text(source, &child);
out.push(DestructuredBinding {
name: name.to_string(),
access_path: format!("{}.{name}", access_prefix),
});
}
"pair_pattern" | "pair" => {
let key = child.child_by_field_name("key")?;
let value = child
.child_by_field_name("value")
.or_else(|| child.child_by_field_name("pattern"))?;
let property = property_access_for_key(&key, source)?;
let nested_access = format!("{}{}", access_prefix, property);
collect_destructured_bindings(&value, source, &nested_access, out)?;
}
"object_assignment_pattern" => {
let name = first_identifier_name(&child, source)?;
out.push(DestructuredBinding {
access_path: format!("{}.{name}", access_prefix),
name,
});
}
"rest_pattern" => return None,
_ if child.is_named() => return None,
_ => {}
}
}
Some(())
}
fn collect_array_pattern_bindings(
node: &tree_sitter::Node,
source: &str,
access_prefix: &str,
out: &mut Vec<DestructuredBinding>,
) -> Option<()> {
let mut element_index = 0usize;
let child_count = node.child_count();
for i in 0..child_count {
let Some(child) = node.child(i as u32) else {
continue;
};
match child.kind() {
"[" | "]" | "," => {}
"rest_pattern" => return None,
_ if child.is_named() => {
let nested_access = format!("{}[{element_index}]", access_prefix);
collect_destructured_bindings(&child, source, &nested_access, out)?;
element_index += 1;
}
_ => {}
}
}
Some(())
}
fn property_access_for_key(node: &tree_sitter::Node, source: &str) -> Option<String> {
let key = node_text(source, node).trim();
if is_simple_identifier(key) {
Some(format!(".{key}"))
} else {
None
}
}
fn build_param_to_arg_map(
params: &[InlineParam],
args: &[String],
) -> Result<HashMap<String, String>, String> {
let mut param_to_arg = HashMap::new();
for (i, param) in params.iter().enumerate() {
match param {
InlineParam::Simple {
name,
default_value,
} => {
if let Some(arg) = args.get(i) {
param_to_arg.insert(name.clone(), arg.clone());
} else if let Some(default_value) = default_value {
param_to_arg.insert(name.clone(), default_value.clone());
}
}
InlineParam::Destructured { bindings } => {
let Some(arg) = args.get(i) else {
return Err(format!(
"missing argument for destructured parameter {}",
i + 1
));
};
if !is_simple_identifier(arg.trim()) {
return Err(format!(
"destructured parameter {} requires a simple variable argument, got `{}`",
i + 1,
arg.trim()
));
}
for binding in bindings {
param_to_arg.insert(
binding.name.clone(),
format!("{}{}", arg.trim(), binding.access_path),
);
}
}
InlineParam::Rest { name } => {
let rest_args = if i < args.len() { &args[i..] } else { &[] };
param_to_arg.insert(name.clone(), format!("[{}]", rest_args.join(", ")));
}
InlineParam::Unsupported { description } => return Err(description.clone()),
}
}
Ok(param_to_arg)
}
fn is_simple_identifier(value: &str) -> bool {
let mut chars = value.chars();
let Some(first) = chars.next() else {
return false;
};
if !(first == '_' || first == '$' || first.is_ascii_alphabetic()) {
return false;
}
chars.all(|c| c == '_' || c == '$' || c.is_ascii_alphanumeric())
}
fn strip_braces(text: &str) -> String {
let trimmed = text.trim();
if trimmed.starts_with('{') && trimmed.ends_with('}') {
let inner = &trimmed[1..trimmed.len() - 1];
let inner = inner.trim_start_matches('\n').trim_end_matches('\n');
let min_indent = inner
.lines()
.filter(|l| !l.trim().is_empty())
.map(|l| l.len() - l.trim_start().len())
.min()
.unwrap_or(0);
inner
.lines()
.map(|l| {
if l.trim().is_empty() {
String::new()
} else if l.len() >= min_indent {
l[min_indent..].to_string()
} else {
l.to_string()
}
})
.collect::<Vec<_>>()
.join("\n")
} else {
text.to_string()
}
}
fn find_call_node_at_line<'a>(
root: &tree_sitter::Node<'a>,
symbol: &str,
source: &str,
start_byte: usize,
end_byte: usize,
lang: LangId,
) -> Option<tree_sitter::Node<'a>> {
let call_kinds: Vec<&str> = match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => vec!["call_expression"],
LangId::Python => vec!["call"],
_ => vec![],
};
find_call_recursive(root, symbol, source, start_byte, end_byte, &call_kinds)
}
fn find_call_recursive<'a>(
node: &tree_sitter::Node<'a>,
symbol: &str,
source: &str,
start_byte: usize,
end_byte: usize,
call_kinds: &[&str],
) -> Option<tree_sitter::Node<'a>> {
if node.end_byte() <= start_byte || node.start_byte() >= end_byte {
return None;
}
if call_kinds.contains(&node.kind())
&& node.start_byte() >= start_byte
&& node.start_byte() < end_byte
&& node.end_byte() <= source.len()
{
if let Some(callee_name) = crate::calls::extract_callee_name(node, source) {
if callee_name == symbol {
return Some(*node);
}
}
}
let child_count = node.child_count();
for i in 0..child_count {
if let Some(child) = node.child(i as u32) {
if let Some(found) =
find_call_recursive(&child, symbol, source, start_byte, end_byte, call_kinds)
{
return Some(found);
}
}
}
None
}
fn detect_call_context<'a>(
call_node: &tree_sitter::Node<'a>,
source: &str,
_lang: LangId,
) -> (String, tree_sitter::Node<'a>, Option<String>) {
if let Some(parent) = call_node.parent() {
let pk = parent.kind();
if pk == "variable_declarator" || pk == "assignment" || pk == "assignment_expression" {
let var_name = if pk == "variable_declarator" {
parent
.child_by_field_name("name")
.map(|n| node_text(source, &n).to_string())
} else {
parent
.child_by_field_name("left")
.map(|n| node_text(source, &n).to_string())
};
if let Some(grandparent) = parent.parent() {
let gpk = grandparent.kind();
if gpk == "lexical_declaration"
|| gpk == "variable_declaration"
|| gpk == "expression_statement"
{
return ("assignment".to_string(), grandparent, var_name);
}
}
return ("assignment".to_string(), parent, var_name);
}
if pk == "expression_statement" {
return ("standalone".to_string(), parent, None);
}
if pk == "return_statement" {
return ("return".to_string(), parent, None);
}
}
("standalone".to_string(), *call_node, None)
}
fn extract_call_arguments(
call_node: &tree_sitter::Node,
source: &str,
lang: LangId,
) -> Vec<String> {
let mut args = Vec::new();
let args_node = match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
call_node.child_by_field_name("arguments")
}
LangId::Python => call_node.child_by_field_name("arguments"),
_ => None,
};
if let Some(args_parent) = args_node {
let child_count = args_parent.child_count();
for i in 0..child_count {
if let Some(child) = args_parent.child(i as u32) {
if child.kind() != "("
&& child.kind() != ")"
&& child.kind() != ","
&& !child.kind().is_empty()
{
args.push(node_text(source, &child).to_string());
}
}
}
}
args
}
fn get_line_indent(source: &str, line: usize) -> String {
source
.lines()
.nth(line)
.map(|l| {
let trimmed = l.trim_start();
l[..l.len() - trimmed.len()].to_string()
})
.unwrap_or_default()
}
fn build_inline_replacement(
body: &str,
call_context: &str,
indent: &str,
lang: LangId,
assignment_var: Option<&str>,
) -> String {
let body_trimmed = body.trim();
let lines: Vec<&str> = body_trimmed.lines().collect();
match call_context {
"assignment" => {
if lines.len() == 1 {
let line = lines[0].trim();
if let Some(expr) = strip_return_prefix(line) {
if let Some(var) = assignment_var {
build_assignment_line(var, expr, indent, lang)
} else {
format!("{}{}", indent, expr)
}
} else {
if let Some(var) = assignment_var {
build_assignment_line(var, line, indent, lang)
} else {
format!("{}{}", indent, line)
}
}
} else {
build_multiline_assignment(&lines, indent, lang, assignment_var)
}
}
"standalone" => {
build_multiline_standalone(&lines, indent, lang)
}
"return" => {
build_multiline_replacement(&lines, indent, lang)
}
_ => build_multiline_replacement(&lines, indent, lang),
}
}
fn strip_return_prefix(line: &str) -> Option<&str> {
parse_return_statement(line).and_then(|expr| expr)
}
fn parse_return_statement(line: &str) -> Option<Option<&str>> {
let trimmed = line.trim_start();
if trimmed == "return" || trimmed == "return;" {
return Some(None);
}
let expr = trimmed.strip_prefix("return ")?;
let expr = expr.trim_end_matches(';').trim();
if expr.is_empty() {
Some(None)
} else {
Some(Some(expr))
}
}
fn build_assignment_line(var: &str, expr: &str, indent: &str, lang: LangId) -> String {
match lang {
LangId::TypeScript | LangId::Tsx | LangId::JavaScript => {
format!("{}const {} = {};", indent, var, expr)
}
LangId::Python => {
format!("{}{} = {}", indent, var, expr)
}
_ => format!("{}const {} = {};", indent, var, expr),
}
}
fn build_multiline_assignment(
lines: &[&str],
indent: &str,
lang: LangId,
assignment_var: Option<&str>,
) -> String {
let normalized = strip_common_indent(lines);
let mut result = Vec::new();
for line in &normalized {
let trimmed_start = line.trim_start();
if trimmed_start.is_empty() {
result.push(String::new());
} else if let Some(Some(expr)) = parse_return_statement(trimmed_start) {
if let Some(var) = assignment_var {
let line_indent = leading_whitespace(line);
let assignment_indent = format!("{}{}", indent, line_indent);
result.push(build_assignment_line(var, expr, &assignment_indent, lang));
}
} else if parse_return_statement(trimmed_start).is_some() {
} else {
result.push(format!("{}{}", indent, line));
}
}
result.join("\n")
}
fn build_multiline_standalone(lines: &[&str], indent: &str, lang: LangId) -> String {
let normalized = strip_common_indent(lines);
let mut result = Vec::new();
for line in &normalized {
let trimmed_start = line.trim_start();
if trimmed_start.is_empty() {
result.push(String::new());
} else if let Some(return_expr) = parse_return_statement(trimmed_start) {
if let Some(expr) = return_expr {
let line_indent = leading_whitespace(line);
match lang {
LangId::Python => result.push(format!("{}{}{}", indent, line_indent, expr)),
_ => result.push(format!("{}{}{};", indent, line_indent, expr)),
}
}
} else {
result.push(format!("{}{}", indent, line));
}
}
result.join("\n")
}
fn build_multiline_replacement(lines: &[&str], indent: &str, _lang: LangId) -> String {
strip_common_indent(lines)
.iter()
.map(|line| {
if line.trim().is_empty() {
String::new()
} else {
format!("{}{}", indent, line)
}
})
.collect::<Vec<_>>()
.join("\n")
}
fn strip_common_indent(lines: &[&str]) -> Vec<String> {
let common_indent = lines
.iter()
.filter(|line| !line.trim().is_empty())
.map(|line| leading_whitespace(line).len())
.min()
.unwrap_or(0);
lines
.iter()
.map(|line| {
if line.trim().is_empty() {
String::new()
} else {
line.get(common_indent..).unwrap_or(line).to_string()
}
})
.collect()
}
fn leading_whitespace(line: &str) -> &str {
let trimmed = line.trim_start();
&line[..line.len() - trimmed.len()]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::protocol::RawRequest;
fn make_request(id: &str, command: &str, params: serde_json::Value) -> RawRequest {
RawRequest {
id: id.to_string(),
command: command.to_string(),
params,
lsp_hints: None,
session_id: None,
}
}
fn fixture_path(name: &str) -> std::path::PathBuf {
std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests")
.join("fixtures")
.join("inline_symbol")
.join(name)
}
#[test]
fn inline_symbol_missing_file() {
let req = make_request("1", "inline_symbol", serde_json::json!({}));
let ctx = crate::context::AppContext::new(
Box::new(crate::parser::TreeSitterProvider::new()),
crate::config::Config::default(),
);
let resp = handle_inline_symbol(&req, &ctx);
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["success"], false);
assert_eq!(json["code"], "invalid_request");
let msg = json["message"].as_str().unwrap();
assert!(
msg.contains("file"),
"message should mention 'file': {}",
msg
);
}
#[test]
fn inline_symbol_missing_symbol() {
let req = make_request(
"2",
"inline_symbol",
serde_json::json!({"file": "/tmp/test.ts"}),
);
let ctx = crate::context::AppContext::new(
Box::new(crate::parser::TreeSitterProvider::new()),
crate::config::Config::default(),
);
let resp = handle_inline_symbol(&req, &ctx);
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["success"], false);
assert_eq!(json["code"], "invalid_request");
let msg = json["message"].as_str().unwrap();
assert!(
msg.contains("symbol"),
"message should mention 'symbol': {}",
msg
);
}
#[test]
fn inline_symbol_missing_call_site_line() {
let req = make_request(
"3",
"inline_symbol",
serde_json::json!({"file": "/tmp/test.ts", "symbol": "foo"}),
);
let ctx = crate::context::AppContext::new(
Box::new(crate::parser::TreeSitterProvider::new()),
crate::config::Config::default(),
);
let resp = handle_inline_symbol(&req, &ctx);
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["success"], false);
assert_eq!(json["code"], "invalid_request");
}
#[test]
fn inline_symbol_unsupported_language() {
let dir = std::env::temp_dir().join("aft_test_inline_lang");
std::fs::create_dir_all(&dir).ok();
let file = dir.join("test.rs");
std::fs::write(&file, "fn main() {}").unwrap();
let req = make_request(
"4",
"inline_symbol",
serde_json::json!({
"file": file.display().to_string(),
"symbol": "main",
"call_site_line": 1,
}),
);
let ctx = crate::context::AppContext::new(
Box::new(crate::parser::TreeSitterProvider::new()),
crate::config::Config::default(),
);
let resp = handle_inline_symbol(&req, &ctx);
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["success"], false);
assert_eq!(json["code"], "unsupported_language");
std::fs::remove_dir_all(&dir).ok();
}
#[test]
fn inline_symbol_multiple_returns_rejected() {
let fixture = fixture_path("sample_multi.ts");
let req = make_request(
"5",
"inline_symbol",
serde_json::json!({
"file": fixture.display().to_string(),
"symbol": "multiReturn",
"call_site_line": 9,
}),
);
let ctx = crate::context::AppContext::new(
Box::new(crate::parser::TreeSitterProvider::new()),
crate::config::Config::default(),
);
let resp = handle_inline_symbol(&req, &ctx);
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["success"], false);
assert_eq!(json["code"], "multiple_returns");
assert!(json["return_count"].as_u64().unwrap() >= 2);
}
#[test]
fn inline_symbol_scope_conflict_reported() {
let fixture = fixture_path("sample_conflict.ts");
let req = make_request(
"6",
"inline_symbol",
serde_json::json!({
"file": fixture.display().to_string(),
"symbol": "compute",
"call_site_line": 9,
}),
);
let ctx = crate::context::AppContext::new(
Box::new(crate::parser::TreeSitterProvider::new()),
crate::config::Config::default(),
);
let resp = handle_inline_symbol(&req, &ctx);
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(
json["success"], false,
"should fail with scope_conflict: {:?}",
json
);
assert_eq!(json["code"], "scope_conflict");
let conflicting = json["conflicting_names"].as_array().unwrap();
assert!(
!conflicting.is_empty(),
"should report at least one conflict: {:?}",
conflicting
);
}
}