cs/trace/
call_extractor.rs

1use crate::error::Result;
2use crate::search::TextSearcher;
3use regex::Regex;
4use std::collections::HashSet;
5use std::fs;
6use std::path::PathBuf;
7
8use super::FunctionDef;
9
10/// Information about a function that calls another function
11#[derive(Debug, Clone)]
12pub struct CallerInfo {
13    pub caller_name: String,
14    pub file: PathBuf,
15    pub line: usize,
16}
17
18/// Extracts function calls from code
19pub struct CallExtractor {
20    searcher: TextSearcher,
21    call_patterns: Vec<Regex>,
22    pub keywords: HashSet<String>,
23}
24
25impl CallExtractor {
26    /// Create a new CallExtractor
27    ///
28    /// # Arguments
29    /// * `base_dir` - The base directory of the project to search in
30    pub fn new(base_dir: PathBuf) -> Self {
31        Self {
32            searcher: TextSearcher::new(base_dir),
33            call_patterns: Self::default_call_patterns(),
34            keywords: Self::common_keywords(),
35        }
36    }
37
38    /// Default patterns for finding function calls across languages
39    fn default_call_patterns() -> Vec<Regex> {
40        vec![
41            // JavaScript/TypeScript - direct function calls
42            Regex::new(r"\b(\w+)\s*\(").unwrap(),
43            // JavaScript/TypeScript - method calls
44            Regex::new(r"\.(\w+)\s*\(").unwrap(),
45            // JavaScript/TypeScript - chained calls
46            Regex::new(r"\.(\w+)\s*\([^)]*\)\.(\w+)").unwrap(),
47            // Ruby - method calls
48            Regex::new(r"\b(\w+)\s*\(").unwrap(),
49            // Ruby - method calls without parentheses
50            Regex::new(r"\b(\w+)\s+\w+").unwrap(),
51        ]
52    }
53
54    /// Common language keywords to filter out (not function calls)
55    fn common_keywords() -> HashSet<String> {
56        let keywords = vec![
57            // JavaScript/TypeScript keywords
58            "if",
59            "for",
60            "while",
61            "switch",
62            "catch",
63            "typeof",
64            "instanceof",
65            "const",
66            "let",
67            "var",
68            "function",
69            "class",
70            "extends",
71            "import",
72            "export",
73            "from",
74            "async",
75            "await",
76            "try",
77            "finally",
78            "else",
79            "break",
80            "continue",
81            "case",
82            "default",
83            "do",
84            "in",
85            "of",
86            // JavaScript/TypeScript built-ins
87            "console",
88            "window",
89            "document",
90            "setTimeout",
91            "setInterval",
92            "parseInt",
93            "parseFloat",
94            "isNaN",
95            "Object",
96            "Array",
97            "String",
98            "Number",
99            "Boolean",
100            "Date",
101            "Math",
102            "JSON",
103            "Promise",
104            // TypeScript specific
105            "interface",
106            "type",
107            "enum",
108            "namespace",
109            "declare",
110            "abstract",
111            "implements",
112            "public",
113            "private",
114            "protected",
115            "readonly",
116            // Ruby keywords
117            "if",
118            "unless",
119            "case",
120            "when",
121            "while",
122            "until",
123            "for",
124            "in",
125            "begin",
126            "rescue",
127            "ensure",
128            "end",
129            "class",
130            "module",
131            "def",
132            "puts",
133            "print",
134            "p",
135            "require",
136            "include",
137            "extend",
138            "attr_reader",
139            "attr_writer",
140            "attr_accessor",
141            "private",
142            "protected",
143            "public",
144            // Ruby built-ins
145            "Array",
146            "Hash",
147            "String",
148            "Integer",
149            "Float",
150            "Numeric",
151            "File",
152            // Common programming constructs
153            "return",
154            "new",
155            "delete",
156            "throw",
157            "raise",
158            "yield",
159            "super",
160        ];
161        keywords.into_iter().map(String::from).collect()
162    }
163
164    /// Extract function calls from a function body
165    ///
166    /// Reads the function definition and extracts all function calls within its body.
167    /// Filters out language keywords and built-in functions.
168    pub fn extract_calls(&self, func: &FunctionDef) -> Result<Vec<String>> {
169        // Read the file
170        let content = fs::read_to_string(&func.file)?;
171        let lines: Vec<&str> = content.lines().collect();
172
173        // Find the function body - be smarter about detecting function boundaries
174        let start_line = func.line.saturating_sub(1);
175        let end_line = self.find_function_end(&lines, start_line).min(lines.len());
176
177        let mut calls = HashSet::new();
178
179        for line in &lines[start_line..end_line] {
180            // Skip comments and strings
181            if self.is_comment_or_string(line) {
182                continue;
183            }
184
185            // Find all function calls using multiple patterns
186            for pattern in &self.call_patterns {
187                for cap in pattern.captures_iter(line) {
188                    // Try each capture group (patterns may have different group structures)
189                    for i in 1..cap.len() {
190                        if let Some(name_match) = cap.get(i) {
191                            let name = name_match.as_str();
192
193                            // Filter out invalid function names
194                            if self.is_valid_function_name(name)
195                                && !self.keywords.contains(name)
196                                && name != func.name
197                            {
198                                calls.insert(name.to_string());
199                            }
200                        }
201                    }
202                }
203            }
204        }
205
206        Ok(calls.into_iter().collect())
207    }
208
209    /// Find the end of a function definition
210    fn find_function_end(&self, lines: &[&str], start_line: usize) -> usize {
211        if start_line >= lines.len() {
212            return lines.len();
213        }
214
215        let start_content = lines[start_line].trim();
216
217        // Check for brace-based languages (JS, Rust, etc.)
218        if start_content.contains('{')
219            || (start_line + 1 < lines.len() && lines[start_line + 1].trim().contains('{'))
220        {
221            return self.find_brace_end(lines, start_line);
222        }
223
224        // Check for Python (indentation) - must come before Ruby check
225        if start_content.starts_with("def ") && start_content.ends_with(':') {
226            return self.find_python_end(lines, start_line);
227        }
228
229        // Check for Ruby (def ... end)
230        if start_content.starts_with("def ") {
231            return self.find_ruby_end(lines, start_line);
232        }
233        // Default fallback
234        (start_line + 30).min(lines.len())
235    }
236
237    fn find_brace_end(&self, lines: &[&str], start_line: usize) -> usize {
238        let mut brace_count = 0;
239        let mut found_opening = false;
240
241        for (i, line) in lines.iter().enumerate().skip(start_line) {
242            for ch in line.chars() {
243                match ch {
244                    '{' => {
245                        brace_count += 1;
246                        found_opening = true;
247                    }
248                    '}' => {
249                        brace_count -= 1;
250                        if found_opening && brace_count == 0 {
251                            return i + 1;
252                        }
253                    }
254                    _ => {}
255                }
256            }
257        }
258        (start_line + 30).min(lines.len())
259    }
260
261    fn find_ruby_end(&self, lines: &[&str], start_line: usize) -> usize {
262        let mut depth = 0;
263        let mut found_start = false;
264
265        for (i, line) in lines.iter().enumerate().skip(start_line) {
266            let trimmed = line.trim();
267            // Simple heuristic for Ruby blocks
268            if trimmed.starts_with("def ")
269                || trimmed.starts_with("class ")
270                || trimmed.starts_with("module ")
271                || trimmed.starts_with("if ")
272                || trimmed.starts_with("do ")
273                || trimmed.starts_with("begin ")
274            {
275                depth += 1;
276                found_start = true;
277            }
278
279            if trimmed == "end" || trimmed.starts_with("end ") {
280                depth -= 1;
281                if found_start && depth == 0 {
282                    return i + 1;
283                }
284            }
285        }
286        (start_line + 30).min(lines.len())
287    }
288
289    fn find_python_end(&self, lines: &[&str], start_line: usize) -> usize {
290        // Get indentation of the function definition
291        let def_indent = lines[start_line]
292            .chars()
293            .take_while(|c| c.is_whitespace())
294            .count();
295
296        for (i, line) in lines.iter().enumerate().skip(start_line + 1) {
297            let trimmed = line.trim();
298            if trimmed.is_empty() || trimmed.starts_with('#') {
299                continue;
300            }
301
302            let current_indent = line.chars().take_while(|c| c.is_whitespace()).count();
303            if current_indent <= def_indent {
304                return i;
305            }
306        }
307        lines.len()
308    }
309
310    /// Check if a line is a comment or inside a string literal
311    fn is_comment_or_string(&self, line: &str) -> bool {
312        let trimmed = line.trim();
313        // JavaScript/TypeScript comments
314        trimmed.starts_with("//") || trimmed.starts_with("/*") ||
315        // Ruby/Python comments
316        trimmed.starts_with("#")
317    }
318
319    /// Check if a string is a valid function name
320    fn is_valid_function_name(&self, name: &str) -> bool {
321        // Must be a valid identifier
322        !name.is_empty()
323            && name.chars().all(|c| c.is_alphanumeric() || c == '_')
324            && !name.chars().next().unwrap().is_numeric()
325    }
326
327    /// Find all functions that call the given function
328    ///
329    /// Searches the codebase for all calls to `func_name` and identifies
330    /// the calling function for each occurrence. Uses case variants for cross-language support.
331    pub fn find_callers(&self, func_name: &str) -> Result<Vec<CallerInfo>> {
332        let mut callers = Vec::new();
333
334        // Generate case variants for cross-case searching
335        let variants = Self::generate_case_variants(func_name);
336
337        // Search for each variant
338        for variant in variants {
339            let matches = self.searcher.search(&variant)?;
340
341            for m in matches {
342                // Skip comment lines (JavaScript //, Ruby/Python #)
343                let trimmed = m.content.trim();
344                if trimmed.starts_with("//") || trimmed.starts_with("#") {
345                    continue;
346                }
347
348                // Ensure it's a function call (variant followed by '(') with word boundary
349                let call_regex =
350                    Regex::new(&format!(r"\b{}\s*\(", regex::escape(&variant))).unwrap();
351                if !call_regex.is_match(&m.content) {
352                    continue;
353                }
354
355                // Skip function definition lines where the variant is being defined
356                if trimmed.starts_with("function ")
357                    || trimmed.starts_with("def ")
358                    || trimmed.starts_with("fn ")
359                {
360                    if trimmed.contains(&variant) {
361                        continue;
362                    }
363                }
364
365                // Determine the calling function
366                let caller_name = self.find_containing_function(&m.file, m.line)?;
367
368                // Avoid duplicates (same caller, file, line)
369                if !callers.iter().any(|existing: &CallerInfo| {
370                    existing.caller_name == caller_name
371                        && existing.file == m.file
372                        && existing.line == m.line
373                }) {
374                    callers.push(CallerInfo {
375                        caller_name,
376                        file: m.file.clone(),
377                        line: m.line,
378                    });
379                }
380            }
381        }
382
383        Ok(callers)
384    }
385
386    /// Find the function that contains a given line (simplified implementation)
387    ///
388    /// Searches backwards from the given line to find the most recent function definition.
389    fn find_containing_function(&self, file: &PathBuf, line: usize) -> Result<String> {
390        let content = fs::read_to_string(file)?;
391
392        let lines: Vec<&str> = content.lines().collect();
393
394        // Search backwards from the line to find a function definition
395        let function_patterns = vec![
396            // JavaScript/TypeScript patterns
397            Regex::new(r"function\s+(\w+)").unwrap(),
398            Regex::new(r"(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>").unwrap(),
399            Regex::new(r"export\s+(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>")
400                .unwrap(),
401            Regex::new(r"export\s+function\s+(\w+)").unwrap(),
402            Regex::new(
403                r"^\s*(?:public|private|protected|static)?\s*(?:async\s+)?(\w+)\s*\([^)]*\)\s*[:{]",
404            )
405            .unwrap(),
406            // Ruby patterns
407            Regex::new(r"def\s+(\w+)").unwrap(),
408            Regex::new(r"def\s+self\.(\w+)").unwrap(),
409            // Generic method pattern
410            Regex::new(r"^\s*(\w+)\s*\([^)]*\)\s*\{").unwrap(),
411            // Rust pattern (for completeness)
412            Regex::new(r"fn\s+(\w+)").unwrap(),
413        ];
414
415        // Search backwards up to 100 lines or start of file
416        let start = line.saturating_sub(100);
417        for i in (start..line.saturating_sub(1)).rev() {
418            if i >= lines.len() {
419                continue;
420            }
421
422            let line_content = lines[i];
423            for pattern in &function_patterns {
424                if let Some(captures) = pattern.captures(line_content) {
425                    if let Some(name_match) = captures.get(1) {
426                        return Ok(name_match.as_str().to_string());
427                    }
428                }
429            }
430        }
431
432        // If no containing function found, it might be top-level code
433        Ok("<top-level>".to_string())
434    }
435
436    /// Generate case variants of a function name for cross-case searching
437    ///
438    /// For input "createUser" generates: ["createUser", "create_user", "CreateUser"]
439    /// For input "user_profile" generates: ["user_profile", "userProfile", "UserProfile"]
440    fn generate_case_variants(func_name: &str) -> Vec<String> {
441        let mut variants = std::collections::HashSet::new();
442
443        // Always include the original
444        variants.insert(func_name.to_string());
445
446        // Generate snake_case variant
447        let snake_case = Self::to_snake_case(func_name);
448        variants.insert(snake_case.clone());
449
450        // Generate camelCase variant
451        let camel_case = Self::to_camel_case(&snake_case);
452        variants.insert(camel_case.clone());
453
454        // Generate PascalCase variant
455        let pascal_case = Self::to_pascal_case(&snake_case);
456        variants.insert(pascal_case);
457
458        variants.into_iter().collect()
459    }
460
461    /// Convert to snake_case
462    fn to_snake_case(input: &str) -> String {
463        let mut result = String::new();
464
465        for (i, ch) in input.chars().enumerate() {
466            if ch.is_uppercase() && i > 0 {
467                result.push('_');
468            }
469            result.push(ch.to_lowercase().next().unwrap());
470        }
471
472        result
473    }
474
475    /// Convert snake_case to camelCase
476    fn to_camel_case(input: &str) -> String {
477        let parts: Vec<&str> = input.split('_').collect();
478        if parts.is_empty() {
479            return String::new();
480        }
481
482        let mut result = parts[0].to_lowercase();
483        for part in parts.iter().skip(1) {
484            if !part.is_empty() {
485                let mut chars = part.chars();
486                if let Some(first) = chars.next() {
487                    result.push(first.to_uppercase().next().unwrap());
488                    result.push_str(&chars.as_str().to_lowercase());
489                }
490            }
491        }
492
493        result
494    }
495
496    /// Convert snake_case to PascalCase
497    fn to_pascal_case(input: &str) -> String {
498        let parts: Vec<&str> = input.split('_').collect();
499        let mut result = String::new();
500
501        for part in parts {
502            if !part.is_empty() {
503                let mut chars = part.chars();
504                if let Some(first) = chars.next() {
505                    result.push(first.to_uppercase().next().unwrap());
506                    result.push_str(&chars.as_str().to_lowercase());
507                }
508            }
509        }
510
511        result
512    }
513}
514
515impl Default for CallExtractor {
516    fn default() -> Self {
517        Self::new(std::env::current_dir().unwrap())
518    }
519}
520
521#[cfg(test)]
522mod tests {
523    use super::*;
524
525    #[test]
526    fn test_call_extractor_creation() {
527        let extractor = CallExtractor::new(std::env::current_dir().unwrap());
528        assert!(!extractor.keywords.is_empty());
529    }
530
531    #[test]
532    fn test_call_patterns() {
533        let extractor = CallExtractor::new(std::env::current_dir().unwrap());
534        let test_line = "result = processData(x, y);";
535
536        let mut found_calls = false;
537        for pattern in &extractor.call_patterns {
538            if pattern.is_match(test_line) {
539                found_calls = true;
540                break;
541            }
542        }
543        assert!(found_calls);
544    }
545
546    #[test]
547    fn test_keywords_filter() {
548        let extractor = CallExtractor::new(std::env::current_dir().unwrap());
549        assert!(extractor.keywords.contains("if"));
550        assert!(extractor.keywords.contains("for"));
551        assert!(extractor.keywords.contains("while"));
552    }
553}