cs/trace/
call_extractor.rs

1use crate::error::Result;
2use crate::search::TextSearcher;
3use regex::Regex;
4use std::collections::HashSet;
5use std::fs;
6use std::path::PathBuf;
7
8use super::FunctionDef;
9
10/// Information about a function that calls another function
11#[derive(Debug, Clone)]
12pub struct CallerInfo {
13    pub caller_name: String,
14    pub file: PathBuf,
15    pub line: usize,
16}
17
18/// Extracts function calls from code
19pub struct CallExtractor {
20    searcher: TextSearcher,
21    call_patterns: Vec<Regex>,
22    pub keywords: HashSet<String>,
23}
24
25impl CallExtractor {
26    /// Create a new CallExtractor
27    ///
28    /// # Arguments
29    /// * `base_dir` - The base directory of the project to search in
30    pub fn new(base_dir: PathBuf) -> Self {
31        Self {
32            searcher: TextSearcher::new(base_dir),
33            call_patterns: Self::default_call_patterns(),
34            keywords: Self::common_keywords(),
35        }
36    }
37
38    /// Default patterns for finding function calls across languages
39    fn default_call_patterns() -> Vec<Regex> {
40        vec![
41            // JavaScript/TypeScript - direct function calls
42            Regex::new(r"\b(\w+)\s*\(").unwrap(),
43            // JavaScript/TypeScript - method calls
44            Regex::new(r"\.(\w+)\s*\(").unwrap(),
45            // JavaScript/TypeScript - chained calls
46            Regex::new(r"\.(\w+)\s*\([^)]*\)\.(\w+)").unwrap(),
47            // Ruby - method calls
48            Regex::new(r"\b(\w+)\s*\(").unwrap(),
49            // Ruby - method calls without parentheses
50            Regex::new(r"\b(\w+)\s+\w+").unwrap(),
51        ]
52    }
53
54    /// Common language keywords to filter out (not function calls)
55    fn common_keywords() -> HashSet<String> {
56        let keywords = vec![
57            // JavaScript/TypeScript keywords
58            "if",
59            "for",
60            "while",
61            "switch",
62            "catch",
63            "typeof",
64            "instanceof",
65            "const",
66            "let",
67            "var",
68            "function",
69            "class",
70            "extends",
71            "import",
72            "export",
73            "from",
74            "async",
75            "await",
76            "try",
77            "finally",
78            "else",
79            "break",
80            "continue",
81            "case",
82            "default",
83            "do",
84            "in",
85            "of",
86            // JavaScript/TypeScript built-ins
87            "console",
88            "window",
89            "document",
90            "setTimeout",
91            "setInterval",
92            "parseInt",
93            "parseFloat",
94            "isNaN",
95            "Object",
96            "Array",
97            "String",
98            "Number",
99            "Boolean",
100            "Date",
101            "Math",
102            "JSON",
103            "Promise",
104            // TypeScript specific
105            "interface",
106            "type",
107            "enum",
108            "namespace",
109            "declare",
110            "abstract",
111            "implements",
112            "public",
113            "private",
114            "protected",
115            "readonly",
116            // Ruby keywords
117            "if",
118            "unless",
119            "case",
120            "when",
121            "while",
122            "until",
123            "for",
124            "in",
125            "begin",
126            "rescue",
127            "ensure",
128            "end",
129            "class",
130            "module",
131            "def",
132            "puts",
133            "print",
134            "p",
135            "require",
136            "include",
137            "extend",
138            "attr_reader",
139            "attr_writer",
140            "attr_accessor",
141            "private",
142            "protected",
143            "public",
144            // Ruby built-ins
145            "Array",
146            "Hash",
147            "String",
148            "Integer",
149            "Float",
150            "Numeric",
151            "File",
152            // Common programming constructs
153            "return",
154            "new",
155            "delete",
156            "throw",
157            "raise",
158            "yield",
159            "super",
160        ];
161        keywords.into_iter().map(String::from).collect()
162    }
163
164    /// Extract function calls from a function body
165    ///
166    /// Reads the function definition and extracts all function calls within its body.
167    /// Filters out language keywords and built-in functions.
168    pub fn extract_calls(&self, func: &FunctionDef) -> Result<Vec<String>> {
169        // Read the file
170        let content = fs::read_to_string(&func.file)?;
171        let lines: Vec<&str> = content.lines().collect();
172
173        // Find the function body - be smarter about detecting function boundaries
174        let start_line = func.line.saturating_sub(1);
175        let end_line = self.find_function_end(&lines, start_line).min(lines.len());
176
177        let mut calls = HashSet::new();
178
179        for line in &lines[start_line..end_line] {
180            // Skip comments and strings
181            if self.is_comment_or_string(line) {
182                continue;
183            }
184
185            // Find all function calls using multiple patterns
186            for pattern in &self.call_patterns {
187                for cap in pattern.captures_iter(line) {
188                    // Try each capture group (patterns may have different group structures)
189                    for i in 1..cap.len() {
190                        if let Some(name_match) = cap.get(i) {
191                            let name = name_match.as_str();
192
193                            // Filter out invalid function names
194                            if self.is_valid_function_name(name)
195                                && !self.keywords.contains(name)
196                                && name != func.name
197                            {
198                                calls.insert(name.to_string());
199                            }
200                        }
201                    }
202                }
203            }
204        }
205
206        Ok(calls.into_iter().collect())
207    }
208
209    /// Find the end of a function definition
210    fn find_function_end(&self, lines: &[&str], start_line: usize) -> usize {
211        if start_line >= lines.len() {
212            return lines.len();
213        }
214
215        let start_content = lines[start_line].trim();
216
217        // Check for brace-based languages (JS, Rust, etc.)
218        if start_content.contains('{')
219            || (start_line + 1 < lines.len() && lines[start_line + 1].trim().contains('{'))
220        {
221            return self.find_brace_end(lines, start_line);
222        }
223
224        // Check for Python (indentation) - must come before Ruby check
225        if start_content.starts_with("def ") && start_content.ends_with(':') {
226            return self.find_python_end(lines, start_line);
227        }
228
229        // Check for Ruby (def ... end)
230        if start_content.starts_with("def ") {
231            return self.find_ruby_end(lines, start_line);
232        }
233        // Default fallback
234        (start_line + 30).min(lines.len())
235    }
236
237    fn find_brace_end(&self, lines: &[&str], start_line: usize) -> usize {
238        let mut brace_count = 0;
239        let mut found_opening = false;
240
241        for (i, line) in lines.iter().enumerate().skip(start_line) {
242            for ch in line.chars() {
243                match ch {
244                    '{' => {
245                        brace_count += 1;
246                        found_opening = true;
247                    }
248                    '}' => {
249                        brace_count -= 1;
250                        if found_opening && brace_count == 0 {
251                            return i + 1;
252                        }
253                    }
254                    _ => {}
255                }
256            }
257        }
258        (start_line + 30).min(lines.len())
259    }
260
261    fn find_ruby_end(&self, lines: &[&str], start_line: usize) -> usize {
262        let mut depth = 0;
263        let mut found_start = false;
264
265        for (i, line) in lines.iter().enumerate().skip(start_line) {
266            let trimmed = line.trim();
267            // Simple heuristic for Ruby blocks
268            if trimmed.starts_with("def ")
269                || trimmed.starts_with("class ")
270                || trimmed.starts_with("module ")
271                || trimmed.starts_with("if ")
272                || trimmed.starts_with("do ")
273                || trimmed.starts_with("begin ")
274            {
275                depth += 1;
276                found_start = true;
277            }
278
279            if trimmed == "end" || trimmed.starts_with("end ") {
280                depth -= 1;
281                if found_start && depth == 0 {
282                    return i + 1;
283                }
284            }
285        }
286        (start_line + 30).min(lines.len())
287    }
288
289    fn find_python_end(&self, lines: &[&str], start_line: usize) -> usize {
290        // Get indentation of the function definition
291        let def_indent = lines[start_line]
292            .chars()
293            .take_while(|c| c.is_whitespace())
294            .count();
295
296        for (i, line) in lines.iter().enumerate().skip(start_line + 1) {
297            let trimmed = line.trim();
298            if trimmed.is_empty() || trimmed.starts_with('#') {
299                continue;
300            }
301
302            let current_indent = line.chars().take_while(|c| c.is_whitespace()).count();
303            if current_indent <= def_indent {
304                return i;
305            }
306        }
307        lines.len()
308    }
309
310    /// Check if a line is a comment or inside a string literal
311    fn is_comment_or_string(&self, line: &str) -> bool {
312        let trimmed = line.trim();
313        // JavaScript/TypeScript comments
314        trimmed.starts_with("//") || trimmed.starts_with("/*") ||
315        // Ruby/Python comments
316        trimmed.starts_with("#")
317    }
318
319    /// Check if a string is a valid function name
320    fn is_valid_function_name(&self, name: &str) -> bool {
321        // Must be a valid identifier
322        !name.is_empty()
323            && name.chars().all(|c| c.is_alphanumeric() || c == '_')
324            && !name.chars().next().unwrap().is_numeric()
325    }
326
327    /// Find all functions that call the given function
328    ///
329    /// Searches the codebase for all calls to `func_name` and identifies
330    /// the calling function for each occurrence. Uses case variants for cross-language support.
331    pub fn find_callers(&self, func_name: &str) -> Result<Vec<CallerInfo>> {
332        let mut callers = Vec::new();
333
334        // Generate case variants for cross-case searching
335        let variants = Self::generate_case_variants(func_name);
336
337        // Search for each variant
338        for variant in variants {
339            let matches = self.searcher.search(&variant)?;
340
341            for m in matches {
342                // Skip comment lines (JavaScript //, Ruby/Python #)
343                let trimmed = m.content.trim();
344                if trimmed.starts_with("//") || trimmed.starts_with("#") {
345                    continue;
346                }
347
348                // Ensure it's a function call (variant followed by '(') with word boundary
349                let call_regex =
350                    Regex::new(&format!(r"\b{}\s*\(", regex::escape(&variant))).unwrap();
351                if !call_regex.is_match(&m.content) {
352                    continue;
353                }
354
355                // Skip function definition lines where the variant is being defined
356                if (trimmed.starts_with("function ")
357                    || trimmed.starts_with("def ")
358                    || trimmed.starts_with("fn "))
359                    && trimmed.contains(&variant)
360                {
361                    continue;
362                }
363
364                // Determine the calling function
365                let caller_name = self.find_containing_function(&m.file, m.line)?;
366
367                // Avoid duplicates (same caller, file, line)
368                if !callers.iter().any(|existing: &CallerInfo| {
369                    existing.caller_name == caller_name
370                        && existing.file == m.file
371                        && existing.line == m.line
372                }) {
373                    callers.push(CallerInfo {
374                        caller_name,
375                        file: m.file.clone(),
376                        line: m.line,
377                    });
378                }
379            }
380        }
381
382        Ok(callers)
383    }
384
385    /// Find the function that contains a given line (simplified implementation)
386    ///
387    /// Searches backwards from the given line to find the most recent function definition.
388    fn find_containing_function(&self, file: &PathBuf, line: usize) -> Result<String> {
389        let content = fs::read_to_string(file)?;
390
391        let lines: Vec<&str> = content.lines().collect();
392
393        // Search backwards from the line to find a function definition
394        let function_patterns = vec![
395            // JavaScript/TypeScript patterns
396            Regex::new(r"function\s+(\w+)").unwrap(),
397            Regex::new(r"(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>").unwrap(),
398            Regex::new(r"export\s+(?:const|let|var)\s+(\w+)\s*=\s*(?:async\s+)?\([^)]*\)\s*=>")
399                .unwrap(),
400            Regex::new(r"export\s+function\s+(\w+)").unwrap(),
401            Regex::new(
402                r"^\s*(?:public|private|protected|static)?\s*(?:async\s+)?(\w+)\s*\([^)]*\)\s*[:{]",
403            )
404            .unwrap(),
405            // Ruby patterns
406            Regex::new(r"def\s+(\w+)").unwrap(),
407            Regex::new(r"def\s+self\.(\w+)").unwrap(),
408            // Generic method pattern
409            Regex::new(r"^\s*(\w+)\s*\([^)]*\)\s*\{").unwrap(),
410            // Rust pattern (for completeness)
411            Regex::new(r"fn\s+(\w+)").unwrap(),
412        ];
413
414        // Search backwards up to 100 lines or start of file
415        let start = line.saturating_sub(100);
416        for i in (start..line.saturating_sub(1)).rev() {
417            if i >= lines.len() {
418                continue;
419            }
420
421            let line_content = lines[i];
422            for pattern in &function_patterns {
423                if let Some(captures) = pattern.captures(line_content) {
424                    if let Some(name_match) = captures.get(1) {
425                        return Ok(name_match.as_str().to_string());
426                    }
427                }
428            }
429        }
430
431        // If no containing function found, it might be top-level code
432        Ok("<top-level>".to_string())
433    }
434
435    /// Generate case variants of a function name for cross-case searching
436    ///
437    /// For input "createUser" generates: ["createUser", "create_user", "CreateUser"]
438    /// For input "user_profile" generates: ["user_profile", "userProfile", "UserProfile"]
439    fn generate_case_variants(func_name: &str) -> Vec<String> {
440        let mut variants = std::collections::HashSet::new();
441
442        // Always include the original
443        variants.insert(func_name.to_string());
444
445        // Generate snake_case variant
446        let snake_case = Self::to_snake_case(func_name);
447        variants.insert(snake_case.clone());
448
449        // Generate camelCase variant
450        let camel_case = Self::to_camel_case(&snake_case);
451        variants.insert(camel_case.clone());
452
453        // Generate PascalCase variant
454        let pascal_case = Self::to_pascal_case(&snake_case);
455        variants.insert(pascal_case);
456
457        variants.into_iter().collect()
458    }
459
460    /// Convert to snake_case
461    fn to_snake_case(input: &str) -> String {
462        let mut result = String::new();
463
464        for (i, ch) in input.chars().enumerate() {
465            if ch.is_uppercase() && i > 0 {
466                result.push('_');
467            }
468            result.push(ch.to_lowercase().next().unwrap());
469        }
470
471        result
472    }
473
474    /// Convert snake_case to camelCase
475    fn to_camel_case(input: &str) -> String {
476        let parts: Vec<&str> = input.split('_').collect();
477        if parts.is_empty() {
478            return String::new();
479        }
480
481        let mut result = parts[0].to_lowercase();
482        for part in parts.iter().skip(1) {
483            if !part.is_empty() {
484                let mut chars = part.chars();
485                if let Some(first) = chars.next() {
486                    result.push(first.to_uppercase().next().unwrap());
487                    result.push_str(&chars.as_str().to_lowercase());
488                }
489            }
490        }
491
492        result
493    }
494
495    /// Convert snake_case to PascalCase
496    fn to_pascal_case(input: &str) -> String {
497        let parts: Vec<&str> = input.split('_').collect();
498        let mut result = String::new();
499
500        for part in parts {
501            if !part.is_empty() {
502                let mut chars = part.chars();
503                if let Some(first) = chars.next() {
504                    result.push(first.to_uppercase().next().unwrap());
505                    result.push_str(&chars.as_str().to_lowercase());
506                }
507            }
508        }
509
510        result
511    }
512}
513
514impl Default for CallExtractor {
515    fn default() -> Self {
516        Self::new(std::env::current_dir().unwrap())
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523
524    #[test]
525    fn test_call_extractor_creation() {
526        let extractor = CallExtractor::new(std::env::current_dir().unwrap());
527        assert!(!extractor.keywords.is_empty());
528    }
529
530    #[test]
531    fn test_call_patterns() {
532        let extractor = CallExtractor::new(std::env::current_dir().unwrap());
533        let test_line = "result = processData(x, y);";
534
535        let mut found_calls = false;
536        for pattern in &extractor.call_patterns {
537            if pattern.is_match(test_line) {
538                found_calls = true;
539                break;
540            }
541        }
542        assert!(found_calls);
543    }
544
545    #[test]
546    fn test_keywords_filter() {
547        let extractor = CallExtractor::new(std::env::current_dir().unwrap());
548        assert!(extractor.keywords.contains("if"));
549        assert!(extractor.keywords.contains("for"));
550        assert!(extractor.keywords.contains("while"));
551    }
552}