Skip to main content

codelens_engine/
inline.rs

1use crate::project::ProjectRoot;
2use crate::rename::{RenameEdit, apply_edits, find_all_word_matches};
3use crate::symbols::{find_symbol, find_symbol_range};
4use anyhow::{Result, bail};
5use serde::Serialize;
6use std::fs;
7
8#[derive(Debug, Clone, Serialize)]
9pub struct InlineResult {
10    pub success: bool,
11    pub message: String,
12    pub call_sites_inlined: usize,
13    pub definition_removed: bool,
14    pub modified_files: Vec<String>,
15    pub edits: Vec<RenameEdit>,
16}
17
18/// Inline a function: replace all call sites with the function body, then remove the definition.
19///
20/// Supports single-expression and multi-statement bodies. For multi-statement bodies,
21/// only single call-site inlining is supported (otherwise ambiguous).
22pub fn inline_function(
23    project: &ProjectRoot,
24    file_path: &str,
25    function_name: &str,
26    name_path: Option<&str>,
27    dry_run: bool,
28) -> Result<InlineResult> {
29    // 1. Find the function definition
30    let symbols = find_symbol(project, function_name, Some(file_path), true, true, 1)?;
31    let sym = symbols.first().ok_or_else(|| {
32        anyhow::anyhow!("Function '{}' not found in '{}'", function_name, file_path)
33    })?;
34
35    let kind_str = format!("{:?}", sym.kind).to_lowercase();
36    if kind_str != "function" && kind_str != "method" {
37        bail!(
38            "'{}' is a {}, not a function/method",
39            function_name,
40            kind_str
41        );
42    }
43
44    let resolved = project.resolve(file_path)?;
45    let source = fs::read_to_string(&resolved)?;
46
47    // 2. Extract function body (between the symbol range)
48    let (start_byte, end_byte) = find_symbol_range(project, file_path, function_name, name_path)?;
49    let full_def = &source[start_byte..end_byte];
50
51    // 3. Parse parameters and body from the definition
52    let (params, body) = parse_function_parts(full_def, file_path)?;
53
54    // 4. Find all call sites across the project
55    let matches = find_all_word_matches(project, function_name)?;
56
57    // Filter to actual call sites (followed by '(')
58    let mut call_sites = Vec::new();
59    for (rel_path, line, col) in &matches {
60        // Skip the definition itself
61        if rel_path == file_path && *line == sym.line {
62            continue;
63        }
64        let call_file = project.resolve(rel_path)?;
65        let call_source = match fs::read_to_string(&call_file) {
66            Ok(s) => s,
67            Err(_) => continue,
68        };
69        let lines: Vec<&str> = call_source.lines().collect();
70        if *line == 0 || *line > lines.len() {
71            continue;
72        }
73        let line_text = lines[*line - 1];
74        let after_name = *col - 1 + function_name.len();
75        let rest = &line_text[after_name..].trim_start();
76        if rest.starts_with('(') {
77            // Extract the arguments
78            if let Some(args) = extract_call_args(line_text, *col - 1) {
79                call_sites.push((rel_path.clone(), *line, *col, args));
80            }
81        }
82    }
83
84    if call_sites.is_empty() {
85        return Ok(InlineResult {
86            success: true,
87            message: format!(
88                "No call sites found for '{}'. Definition kept.",
89                function_name
90            ),
91            call_sites_inlined: 0,
92            definition_removed: false,
93            modified_files: vec![],
94            edits: vec![],
95        });
96    }
97
98    // 5. Build edits for each call site
99    let body_lines: Vec<&str> = body.lines().collect();
100    let is_single_expression = body_lines.len() <= 1;
101
102    if !is_single_expression && call_sites.len() > 1 {
103        bail!(
104            "Cannot inline multi-statement function '{}' with {} call sites. \
105             Inline manually or reduce to a single expression.",
106            function_name,
107            call_sites.len()
108        );
109    }
110
111    let mut edits = Vec::new();
112
113    for (rel_path, line, col, args) in &call_sites {
114        let call_file = project.resolve(rel_path)?;
115        let call_source = fs::read_to_string(&call_file)?;
116        let lines_vec: Vec<&str> = call_source.lines().collect();
117        let line_text = lines_vec[*line - 1];
118
119        // Find the full call expression span (name + args including parens)
120        let call_start = *col - 1;
121        let call_end = find_call_end(line_text, call_start)?;
122        let call_text = &line_text[call_start..call_end];
123
124        // Substitute parameters with arguments in the body
125        let mut inlined_body = body.trim().to_string();
126        for (i, param) in params.iter().enumerate() {
127            if let Some(arg) = args.get(i) {
128                let param_re = regex::Regex::new(&format!(r"\b{}\b", regex::escape(param)))?;
129                inlined_body = param_re.replace_all(&inlined_body, arg.trim()).to_string();
130            }
131        }
132
133        // For single-expression: strip return keyword if present
134        let inlined_body = strip_return_keyword(&inlined_body);
135
136        edits.push(RenameEdit {
137            file_path: rel_path.clone(),
138            line: *line,
139            column: *col,
140            old_text: call_text.to_string(),
141            new_text: inlined_body,
142        });
143    }
144
145    // 6. Add edit to remove the function definition
146    let (start_byte_2, end_byte_2) = (start_byte, end_byte);
147    let def_start_line = source[..start_byte_2].lines().count();
148    let def_end_line = source[..end_byte_2].lines().count();
149
150    let mut modified_files: Vec<String> = edits.iter().map(|e| e.file_path.clone()).collect();
151    if !modified_files.contains(&file_path.to_string()) {
152        modified_files.push(file_path.to_string());
153    }
154    modified_files.sort();
155    modified_files.dedup();
156
157    let result = InlineResult {
158        success: true,
159        message: format!(
160            "Inlined '{}' at {} call site(s) and removed definition",
161            function_name,
162            call_sites.len()
163        ),
164        call_sites_inlined: call_sites.len(),
165        definition_removed: true,
166        modified_files,
167        edits: edits.clone(),
168    };
169
170    if !dry_run {
171        // Apply call site edits first
172        apply_edits(project, &edits)?;
173
174        // Remove the function definition lines
175        let resolved = project.resolve(file_path)?;
176        let content = fs::read_to_string(&resolved)?;
177        let mut lines: Vec<String> = content.lines().map(String::from).collect();
178
179        // Recalculate definition line range from bytes
180        let start_line_idx = if def_start_line > 0 {
181            def_start_line - 1
182        } else {
183            0
184        };
185        let end_line_idx = def_end_line.min(lines.len());
186
187        // Remove preceding blank line if any
188        let drain_start = if start_line_idx > 0 && lines[start_line_idx - 1].trim().is_empty() {
189            start_line_idx - 1
190        } else {
191            start_line_idx
192        };
193        lines.drain(drain_start..end_line_idx);
194
195        let mut result_text = lines.join("\n");
196        if content.ends_with('\n') {
197            result_text.push('\n');
198        }
199        fs::write(&resolved, &result_text)?;
200    }
201
202    Ok(result)
203}
204
205/// Parse function parameters and body from a function definition string.
206fn parse_function_parts(def: &str, file_path: &str) -> Result<(Vec<String>, String)> {
207    // Find parameter list between first ( and matching )
208    let paren_start = def
209        .find('(')
210        .ok_or_else(|| anyhow::anyhow!("No parameter list found"))?;
211    let paren_end = find_matching_paren(def, paren_start)?;
212
213    let params_str = &def[paren_start + 1..paren_end];
214    let params: Vec<String> = if params_str.trim().is_empty() {
215        vec![]
216    } else {
217        parse_param_names(params_str, file_path)
218    };
219
220    // Find body: after '{' for brace languages, after ':' for Python
221    let ext = std::path::Path::new(file_path)
222        .extension()
223        .and_then(|e| e.to_str())
224        .unwrap_or("");
225
226    let body = if ext == "py" {
227        // Python: body is everything after the first colon+newline, de-indented
228        let colon_pos = def[paren_end..].find(':').map(|p| p + paren_end);
229        if let Some(cp) = colon_pos {
230            let after_colon = &def[cp + 1..];
231            dedent_body(after_colon.trim_start_matches([' ', '\t']))
232        } else {
233            String::new()
234        }
235    } else {
236        // Brace languages: body is between first { and last }
237        let brace_start = def[paren_end..].find('{').map(|p| p + paren_end);
238        let brace_end = def.rfind('}');
239        match (brace_start, brace_end) {
240            (Some(bs), Some(be)) if be > bs => dedent_body(&def[bs + 1..be]),
241            _ => String::new(),
242        }
243    };
244
245    Ok((params, body))
246}
247
248/// Extract just parameter names from a parameter string, handling typed params.
249fn parse_param_names(params_str: &str, file_path: &str) -> Vec<String> {
250    let ext = std::path::Path::new(file_path)
251        .extension()
252        .and_then(|e| e.to_str())
253        .unwrap_or("");
254
255    params_str
256        .split(',')
257        .filter_map(|p| {
258            let p = p.trim();
259            if p.is_empty() || p == "self" || p == "&self" || p == "&mut self" || p == "this" {
260                return None;
261            }
262            // Remove default values
263            let p = p.split('=').next().unwrap_or(p).trim();
264            // Extract name based on language
265            let name = match ext {
266                "rs" => p.split(':').next().unwrap_or(p).trim(),
267                "go" => p.split_whitespace().next().unwrap_or(p),
268                "java" | "kt" | "ts" | "tsx" | "dart" | "cs" | "scala" | "swift" => {
269                    // type name or name: type
270                    if p.contains(':') {
271                        p.split(':').next().unwrap_or(p).trim()
272                    } else {
273                        p.split_whitespace().last().unwrap_or(p)
274                    }
275                }
276                "py" => {
277                    if p.contains(':') {
278                        p.split(':').next().unwrap_or(p).trim()
279                    } else {
280                        p.trim()
281                    }
282                }
283                _ => {
284                    if p.contains(':') {
285                        p.split(':').next().unwrap_or(p).trim()
286                    } else {
287                        p.split_whitespace().last().unwrap_or(p)
288                    }
289                }
290            };
291            Some(name.to_string())
292        })
293        .collect()
294}
295
296/// Find matching closing parenthesis, handling nesting.
297fn find_matching_paren(s: &str, open_pos: usize) -> Result<usize> {
298    let mut depth = 0;
299    for (i, ch) in s[open_pos..].char_indices() {
300        match ch {
301            '(' => depth += 1,
302            ')' => {
303                depth -= 1;
304                if depth == 0 {
305                    return Ok(open_pos + i);
306                }
307            }
308            _ => {}
309        }
310    }
311    bail!("Unmatched parenthesis")
312}
313
314/// Extract arguments from a function call at the given position.
315fn extract_call_args(line: &str, name_start: usize) -> Option<Vec<String>> {
316    // Find the opening paren after the function name
317    let rest = &line[name_start..];
318    let paren_start = rest.find('(')?;
319    let paren_end = find_matching_paren(rest, paren_start).ok()?;
320    let args_str = &rest[paren_start + 1..paren_end];
321    if args_str.trim().is_empty() {
322        return Some(vec![]);
323    }
324    Some(split_args(args_str))
325}
326
327/// Split argument string by commas, respecting nested parens/brackets.
328fn split_args(s: &str) -> Vec<String> {
329    let mut args = Vec::new();
330    let mut depth = 0;
331    let mut current = String::new();
332    for ch in s.chars() {
333        match ch {
334            '(' | '[' | '{' => {
335                depth += 1;
336                current.push(ch);
337            }
338            ')' | ']' | '}' => {
339                depth -= 1;
340                current.push(ch);
341            }
342            ',' if depth == 0 => {
343                args.push(current.trim().to_string());
344                current.clear();
345            }
346            _ => current.push(ch),
347        }
348    }
349    if !current.trim().is_empty() {
350        args.push(current.trim().to_string());
351    }
352    args
353}
354
355/// Find the end of a function call expression (past the closing paren).
356fn find_call_end(line: &str, name_start: usize) -> Result<usize> {
357    let rest = &line[name_start..];
358    let paren_start = rest
359        .find('(')
360        .ok_or_else(|| anyhow::anyhow!("No opening paren"))?;
361    let paren_end = find_matching_paren(rest, paren_start)?;
362    Ok(name_start + paren_end + 1)
363}
364
365/// Strip leading 'return ' keyword from a body string.
366fn strip_return_keyword(body: &str) -> String {
367    let trimmed = body.trim();
368    if let Some(rest) = trimmed.strip_prefix("return ") {
369        rest.trim_end_matches(';').to_string()
370    } else {
371        trimmed.trim_end_matches(';').to_string()
372    }
373}
374
375/// Remove common leading whitespace from a body string.
376fn dedent_body(body: &str) -> String {
377    let lines: Vec<&str> = body.lines().collect();
378    let non_empty: Vec<&&str> = lines.iter().filter(|l| !l.trim().is_empty()).collect();
379    if non_empty.is_empty() {
380        return String::new();
381    }
382    let min_indent = non_empty
383        .iter()
384        .map(|l| l.len() - l.trim_start().len())
385        .min()
386        .unwrap_or(0);
387    lines
388        .iter()
389        .map(|l| {
390            if l.len() >= min_indent {
391                &l[min_indent..]
392            } else {
393                l.trim()
394            }
395        })
396        .collect::<Vec<_>>()
397        .join("\n")
398        .trim()
399        .to_string()
400}
401
402#[cfg(test)]
403mod tests {
404    use super::*;
405    use crate::ProjectRoot;
406    use std::fs;
407
408    fn make_fixture() -> (std::path::PathBuf, ProjectRoot) {
409        let dir = std::env::temp_dir().join(format!(
410            "codelens-inline-fixture-{}",
411            std::time::SystemTime::now()
412                .duration_since(std::time::UNIX_EPOCH)
413                .unwrap()
414                .as_nanos()
415        ));
416        fs::create_dir_all(&dir).unwrap();
417        let project = ProjectRoot::new(dir.clone()).unwrap();
418        (dir, project)
419    }
420
421    #[test]
422    fn test_parse_function_parts_js() {
423        let def = "function add(a, b) {\n  return a + b;\n}";
424        let (params, body) = parse_function_parts(def, "test.js").unwrap();
425        assert_eq!(params, vec!["a", "b"]);
426        assert!(body.contains("return a + b"));
427    }
428
429    #[test]
430    fn test_parse_function_parts_python() {
431        let def = "def add(x, y):\n    return x + y";
432        let (params, body) = parse_function_parts(def, "test.py").unwrap();
433        assert_eq!(params, vec!["x", "y"]);
434        assert!(body.contains("return x + y"));
435    }
436
437    #[test]
438    fn test_parse_function_parts_rust() {
439        let def = "fn add(a: i32, b: i32) -> i32 {\n    a + b\n}";
440        let (params, body) = parse_function_parts(def, "test.rs").unwrap();
441        assert_eq!(params, vec!["a", "b"]);
442        assert!(body.contains("a + b"));
443    }
444
445    #[test]
446    fn test_extract_call_args() {
447        let line = "let result = add(1, 2);";
448        let args = extract_call_args(line, 13).unwrap();
449        assert_eq!(args, vec!["1", "2"]);
450    }
451
452    #[test]
453    fn test_extract_call_args_nested() {
454        let line = "let result = add(foo(1), bar(2, 3));";
455        let args = extract_call_args(line, 13).unwrap();
456        assert_eq!(args, vec!["foo(1)", "bar(2, 3)"]);
457    }
458
459    #[test]
460    fn test_strip_return_keyword() {
461        assert_eq!(strip_return_keyword("return x + y;"), "x + y");
462        assert_eq!(strip_return_keyword("x + y"), "x + y");
463    }
464
465    #[test]
466    fn test_dedent_body() {
467        let body = "    let x = 1;\n    let y = 2;\n    x + y";
468        let result = dedent_body(body);
469        assert_eq!(result, "let x = 1;\nlet y = 2;\nx + y");
470    }
471
472    #[test]
473    fn test_inline_dry_run() {
474        let (dir, project) = make_fixture();
475
476        let main_content = r#"function greet(name) {
477    return "Hello, " + name;
478}
479
480let msg = greet("World");
481console.log(greet("Rust"));
482"#;
483        fs::write(dir.join("main.js"), main_content).unwrap();
484
485        let result = inline_function(&project, "main.js", "greet", None, true).unwrap();
486        assert!(result.success);
487        assert_eq!(result.call_sites_inlined, 2);
488        assert!(result.definition_removed);
489
490        // Dry run: file should be unchanged
491        let after = fs::read_to_string(dir.join("main.js")).unwrap();
492        assert_eq!(after, main_content);
493
494        fs::remove_dir_all(&dir).ok();
495    }
496}