1use crate::project::ProjectRoot;
2use crate::rename::{RenameEdit, apply_edits, find_all_word_matches};
3use crate::symbols::{find_symbol, find_symbol_range};
4use anyhow::{Result, bail};
5use serde::Serialize;
6use std::fs;
7
8#[derive(Debug, Clone, Serialize)]
9pub struct InlineResult {
10 pub success: bool,
11 pub message: String,
12 pub call_sites_inlined: usize,
13 pub definition_removed: bool,
14 pub modified_files: Vec<String>,
15 pub edits: Vec<RenameEdit>,
16}
17
18pub fn inline_function(
23 project: &ProjectRoot,
24 file_path: &str,
25 function_name: &str,
26 name_path: Option<&str>,
27 dry_run: bool,
28) -> Result<InlineResult> {
29 let symbols = find_symbol(project, function_name, Some(file_path), true, true, 1)?;
31 let sym = symbols.first().ok_or_else(|| {
32 anyhow::anyhow!("Function '{}' not found in '{}'", function_name, file_path)
33 })?;
34
35 let kind_str = format!("{:?}", sym.kind).to_lowercase();
36 if kind_str != "function" && kind_str != "method" {
37 bail!(
38 "'{}' is a {}, not a function/method",
39 function_name,
40 kind_str
41 );
42 }
43
44 let resolved = project.resolve(file_path)?;
45 let source = fs::read_to_string(&resolved)?;
46
47 let (start_byte, end_byte) = find_symbol_range(project, file_path, function_name, name_path)?;
49 let full_def = &source[start_byte..end_byte];
50
51 let (params, body) = parse_function_parts(full_def, file_path)?;
53
54 let matches = find_all_word_matches(project, function_name)?;
56
57 let mut call_sites = Vec::new();
59 for (rel_path, line, col) in &matches {
60 if rel_path == file_path && *line == sym.line {
62 continue;
63 }
64 let call_file = project.resolve(rel_path)?;
65 let call_source = match fs::read_to_string(&call_file) {
66 Ok(s) => s,
67 Err(_) => continue,
68 };
69 let lines: Vec<&str> = call_source.lines().collect();
70 if *line == 0 || *line > lines.len() {
71 continue;
72 }
73 let line_text = lines[*line - 1];
74 let after_name = *col - 1 + function_name.len();
75 let rest = &line_text[after_name..].trim_start();
76 if rest.starts_with('(') {
77 if let Some(args) = extract_call_args(line_text, *col - 1) {
79 call_sites.push((rel_path.clone(), *line, *col, args));
80 }
81 }
82 }
83
84 if call_sites.is_empty() {
85 return Ok(InlineResult {
86 success: true,
87 message: format!(
88 "No call sites found for '{}'. Definition kept.",
89 function_name
90 ),
91 call_sites_inlined: 0,
92 definition_removed: false,
93 modified_files: vec![],
94 edits: vec![],
95 });
96 }
97
98 let body_lines: Vec<&str> = body.lines().collect();
100 let is_single_expression = body_lines.len() <= 1;
101
102 if !is_single_expression && call_sites.len() > 1 {
103 bail!(
104 "Cannot inline multi-statement function '{}' with {} call sites. \
105 Inline manually or reduce to a single expression.",
106 function_name,
107 call_sites.len()
108 );
109 }
110
111 let mut edits = Vec::new();
112
113 for (rel_path, line, col, args) in &call_sites {
114 let call_file = project.resolve(rel_path)?;
115 let call_source = fs::read_to_string(&call_file)?;
116 let lines_vec: Vec<&str> = call_source.lines().collect();
117 let line_text = lines_vec[*line - 1];
118
119 let call_start = *col - 1;
121 let call_end = find_call_end(line_text, call_start)?;
122 let call_text = &line_text[call_start..call_end];
123
124 let mut inlined_body = body.trim().to_string();
126 for (i, param) in params.iter().enumerate() {
127 if let Some(arg) = args.get(i) {
128 let param_re = regex::Regex::new(&format!(r"\b{}\b", regex::escape(param)))?;
129 inlined_body = param_re.replace_all(&inlined_body, arg.trim()).to_string();
130 }
131 }
132
133 let inlined_body = strip_return_keyword(&inlined_body);
135
136 edits.push(RenameEdit {
137 file_path: rel_path.clone(),
138 line: *line,
139 column: *col,
140 old_text: call_text.to_string(),
141 new_text: inlined_body,
142 });
143 }
144
145 let (start_byte_2, end_byte_2) = (start_byte, end_byte);
147 let def_start_line = source[..start_byte_2].lines().count();
148 let def_end_line = source[..end_byte_2].lines().count();
149
150 let mut modified_files: Vec<String> = edits.iter().map(|e| e.file_path.clone()).collect();
151 if !modified_files.contains(&file_path.to_string()) {
152 modified_files.push(file_path.to_string());
153 }
154 modified_files.sort();
155 modified_files.dedup();
156
157 let result = InlineResult {
158 success: true,
159 message: format!(
160 "Inlined '{}' at {} call site(s) and removed definition",
161 function_name,
162 call_sites.len()
163 ),
164 call_sites_inlined: call_sites.len(),
165 definition_removed: true,
166 modified_files,
167 edits: edits.clone(),
168 };
169
170 if !dry_run {
171 apply_edits(project, &edits)?;
173
174 let resolved = project.resolve(file_path)?;
176 let content = fs::read_to_string(&resolved)?;
177 let mut lines: Vec<String> = content.lines().map(String::from).collect();
178
179 let start_line_idx = if def_start_line > 0 {
181 def_start_line - 1
182 } else {
183 0
184 };
185 let end_line_idx = def_end_line.min(lines.len());
186
187 let drain_start = if start_line_idx > 0 && lines[start_line_idx - 1].trim().is_empty() {
189 start_line_idx - 1
190 } else {
191 start_line_idx
192 };
193 lines.drain(drain_start..end_line_idx);
194
195 let mut result_text = lines.join("\n");
196 if content.ends_with('\n') {
197 result_text.push('\n');
198 }
199 fs::write(&resolved, &result_text)?;
200 }
201
202 Ok(result)
203}
204
205fn parse_function_parts(def: &str, file_path: &str) -> Result<(Vec<String>, String)> {
207 let paren_start = def
209 .find('(')
210 .ok_or_else(|| anyhow::anyhow!("No parameter list found"))?;
211 let paren_end = find_matching_paren(def, paren_start)?;
212
213 let params_str = &def[paren_start + 1..paren_end];
214 let params: Vec<String> = if params_str.trim().is_empty() {
215 vec![]
216 } else {
217 parse_param_names(params_str, file_path)
218 };
219
220 let ext = std::path::Path::new(file_path)
222 .extension()
223 .and_then(|e| e.to_str())
224 .unwrap_or("");
225
226 let body = if ext == "py" {
227 let colon_pos = def[paren_end..].find(':').map(|p| p + paren_end);
229 if let Some(cp) = colon_pos {
230 let after_colon = &def[cp + 1..];
231 dedent_body(after_colon.trim_start_matches([' ', '\t']))
232 } else {
233 String::new()
234 }
235 } else {
236 let brace_start = def[paren_end..].find('{').map(|p| p + paren_end);
238 let brace_end = def.rfind('}');
239 match (brace_start, brace_end) {
240 (Some(bs), Some(be)) if be > bs => dedent_body(&def[bs + 1..be]),
241 _ => String::new(),
242 }
243 };
244
245 Ok((params, body))
246}
247
248fn parse_param_names(params_str: &str, file_path: &str) -> Vec<String> {
250 let ext = std::path::Path::new(file_path)
251 .extension()
252 .and_then(|e| e.to_str())
253 .unwrap_or("");
254
255 params_str
256 .split(',')
257 .filter_map(|p| {
258 let p = p.trim();
259 if p.is_empty() || p == "self" || p == "&self" || p == "&mut self" || p == "this" {
260 return None;
261 }
262 let p = p.split('=').next().unwrap_or(p).trim();
264 let name = match ext {
266 "rs" => p.split(':').next().unwrap_or(p).trim(),
267 "go" => p.split_whitespace().next().unwrap_or(p),
268 "java" | "kt" | "ts" | "tsx" | "dart" | "cs" | "scala" | "swift" => {
269 if p.contains(':') {
271 p.split(':').next().unwrap_or(p).trim()
272 } else {
273 p.split_whitespace().last().unwrap_or(p)
274 }
275 }
276 "py" => {
277 if p.contains(':') {
278 p.split(':').next().unwrap_or(p).trim()
279 } else {
280 p.trim()
281 }
282 }
283 _ => {
284 if p.contains(':') {
285 p.split(':').next().unwrap_or(p).trim()
286 } else {
287 p.split_whitespace().last().unwrap_or(p)
288 }
289 }
290 };
291 Some(name.to_string())
292 })
293 .collect()
294}
295
296fn find_matching_paren(s: &str, open_pos: usize) -> Result<usize> {
298 let mut depth = 0;
299 for (i, ch) in s[open_pos..].char_indices() {
300 match ch {
301 '(' => depth += 1,
302 ')' => {
303 depth -= 1;
304 if depth == 0 {
305 return Ok(open_pos + i);
306 }
307 }
308 _ => {}
309 }
310 }
311 bail!("Unmatched parenthesis")
312}
313
314fn extract_call_args(line: &str, name_start: usize) -> Option<Vec<String>> {
316 let rest = &line[name_start..];
318 let paren_start = rest.find('(')?;
319 let paren_end = find_matching_paren(rest, paren_start).ok()?;
320 let args_str = &rest[paren_start + 1..paren_end];
321 if args_str.trim().is_empty() {
322 return Some(vec![]);
323 }
324 Some(split_args(args_str))
325}
326
327fn split_args(s: &str) -> Vec<String> {
329 let mut args = Vec::new();
330 let mut depth = 0;
331 let mut current = String::new();
332 for ch in s.chars() {
333 match ch {
334 '(' | '[' | '{' => {
335 depth += 1;
336 current.push(ch);
337 }
338 ')' | ']' | '}' => {
339 depth -= 1;
340 current.push(ch);
341 }
342 ',' if depth == 0 => {
343 args.push(current.trim().to_string());
344 current.clear();
345 }
346 _ => current.push(ch),
347 }
348 }
349 if !current.trim().is_empty() {
350 args.push(current.trim().to_string());
351 }
352 args
353}
354
355fn find_call_end(line: &str, name_start: usize) -> Result<usize> {
357 let rest = &line[name_start..];
358 let paren_start = rest
359 .find('(')
360 .ok_or_else(|| anyhow::anyhow!("No opening paren"))?;
361 let paren_end = find_matching_paren(rest, paren_start)?;
362 Ok(name_start + paren_end + 1)
363}
364
365fn strip_return_keyword(body: &str) -> String {
367 let trimmed = body.trim();
368 if let Some(rest) = trimmed.strip_prefix("return ") {
369 rest.trim_end_matches(';').to_string()
370 } else {
371 trimmed.trim_end_matches(';').to_string()
372 }
373}
374
375fn dedent_body(body: &str) -> String {
377 let lines: Vec<&str> = body.lines().collect();
378 let non_empty: Vec<&&str> = lines.iter().filter(|l| !l.trim().is_empty()).collect();
379 if non_empty.is_empty() {
380 return String::new();
381 }
382 let min_indent = non_empty
383 .iter()
384 .map(|l| l.len() - l.trim_start().len())
385 .min()
386 .unwrap_or(0);
387 lines
388 .iter()
389 .map(|l| {
390 if l.len() >= min_indent {
391 &l[min_indent..]
392 } else {
393 l.trim()
394 }
395 })
396 .collect::<Vec<_>>()
397 .join("\n")
398 .trim()
399 .to_string()
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use crate::ProjectRoot;
406 use std::fs;
407
408 fn make_fixture() -> (std::path::PathBuf, ProjectRoot) {
409 let dir = std::env::temp_dir().join(format!(
410 "codelens-inline-fixture-{}",
411 std::time::SystemTime::now()
412 .duration_since(std::time::UNIX_EPOCH)
413 .unwrap()
414 .as_nanos()
415 ));
416 fs::create_dir_all(&dir).unwrap();
417 let project = ProjectRoot::new(dir.clone()).unwrap();
418 (dir, project)
419 }
420
421 #[test]
422 fn test_parse_function_parts_js() {
423 let def = "function add(a, b) {\n return a + b;\n}";
424 let (params, body) = parse_function_parts(def, "test.js").unwrap();
425 assert_eq!(params, vec!["a", "b"]);
426 assert!(body.contains("return a + b"));
427 }
428
429 #[test]
430 fn test_parse_function_parts_python() {
431 let def = "def add(x, y):\n return x + y";
432 let (params, body) = parse_function_parts(def, "test.py").unwrap();
433 assert_eq!(params, vec!["x", "y"]);
434 assert!(body.contains("return x + y"));
435 }
436
437 #[test]
438 fn test_parse_function_parts_rust() {
439 let def = "fn add(a: i32, b: i32) -> i32 {\n a + b\n}";
440 let (params, body) = parse_function_parts(def, "test.rs").unwrap();
441 assert_eq!(params, vec!["a", "b"]);
442 assert!(body.contains("a + b"));
443 }
444
445 #[test]
446 fn test_extract_call_args() {
447 let line = "let result = add(1, 2);";
448 let args = extract_call_args(line, 13).unwrap();
449 assert_eq!(args, vec!["1", "2"]);
450 }
451
452 #[test]
453 fn test_extract_call_args_nested() {
454 let line = "let result = add(foo(1), bar(2, 3));";
455 let args = extract_call_args(line, 13).unwrap();
456 assert_eq!(args, vec!["foo(1)", "bar(2, 3)"]);
457 }
458
459 #[test]
460 fn test_strip_return_keyword() {
461 assert_eq!(strip_return_keyword("return x + y;"), "x + y");
462 assert_eq!(strip_return_keyword("x + y"), "x + y");
463 }
464
465 #[test]
466 fn test_dedent_body() {
467 let body = " let x = 1;\n let y = 2;\n x + y";
468 let result = dedent_body(body);
469 assert_eq!(result, "let x = 1;\nlet y = 2;\nx + y");
470 }
471
472 #[test]
473 fn test_inline_dry_run() {
474 let (dir, project) = make_fixture();
475
476 let main_content = r#"function greet(name) {
477 return "Hello, " + name;
478}
479
480let msg = greet("World");
481console.log(greet("Rust"));
482"#;
483 fs::write(dir.join("main.js"), main_content).unwrap();
484
485 let result = inline_function(&project, "main.js", "greet", None, true).unwrap();
486 assert!(result.success);
487 assert_eq!(result.call_sites_inlined, 2);
488 assert!(result.definition_removed);
489
490 let after = fs::read_to_string(dir.join("main.js")).unwrap();
492 assert_eq!(after, main_content);
493
494 fs::remove_dir_all(&dir).ok();
495 }
496}