codegraph_c/
extractor.rs

1//! AST extraction for C source code
2//!
3//! This module provides two parsing modes:
4//! - Strict mode: Fails on syntax errors (default, for clean code)
5//! - Tolerant mode: Extracts what it can even with errors (for real-world code)
6
7use codegraph_parser_api::{CodeIR, ModuleEntity, ParserConfig, ParserError};
8use std::path::Path;
9use tree_sitter::Parser;
10
11use crate::preprocessor::CPreprocessor;
12use crate::visitor::CVisitor;
13
14/// Extraction options for controlling parser behavior
15#[derive(Debug, Clone, Default)]
16pub struct ExtractionOptions {
17    /// If true, extract entities even when the AST has errors
18    pub tolerant_mode: bool,
19    /// If true, apply preprocessing to help parse kernel/system code
20    pub preprocess: bool,
21    /// If true, extract function calls for call graph
22    pub extract_calls: bool,
23}
24
25impl ExtractionOptions {
26    /// Create options optimized for kernel/system code
27    pub fn for_kernel_code() -> Self {
28        Self {
29            tolerant_mode: true,
30            preprocess: true,
31            extract_calls: true,
32        }
33    }
34
35    /// Create options for tolerant parsing of any code
36    pub fn tolerant() -> Self {
37        Self {
38            tolerant_mode: true,
39            preprocess: false,
40            extract_calls: true,
41        }
42    }
43}
44
45/// Result of extraction with additional metadata
46#[derive(Debug)]
47pub struct ExtractionResult {
48    pub ir: CodeIR,
49    /// Number of syntax errors encountered (0 = clean parse)
50    pub error_count: usize,
51    /// Whether the file was fully parsed or partially
52    pub is_partial: bool,
53    /// Macros detected in the source
54    pub detected_macros: Vec<String>,
55}
56
57/// Extract code entities and relationships from C source code (strict mode)
58pub fn extract(
59    source: &str,
60    file_path: &Path,
61    config: &ParserConfig,
62) -> Result<CodeIR, ParserError> {
63    let result = extract_with_options(source, file_path, config, &ExtractionOptions::default())?;
64
65    if result.is_partial {
66        return Err(ParserError::SyntaxError(
67            file_path.to_path_buf(),
68            0,
69            0,
70            "Syntax error".to_string(),
71        ));
72    }
73
74    Ok(result.ir)
75}
76
77/// Extract with custom options (supports tolerant mode)
78pub fn extract_with_options(
79    source: &str,
80    file_path: &Path,
81    config: &ParserConfig,
82    options: &ExtractionOptions,
83) -> Result<ExtractionResult, ParserError> {
84    // Detect macros from original source (before preprocessing)
85    let preprocessor = CPreprocessor::new();
86    let detected_macros: Vec<String> = preprocessor
87        .analyze_macros(source)
88        .iter()
89        .map(|m| m.name.clone())
90        .collect();
91
92    // Optionally preprocess the source
93    let processed_source = if options.preprocess {
94        preprocessor.preprocess(source)
95    } else {
96        source.to_string()
97    };
98
99    let mut parser = Parser::new();
100    let language = tree_sitter_c::language();
101    parser
102        .set_language(language)
103        .map_err(|e| ParserError::ParseError(file_path.to_path_buf(), e.to_string()))?;
104
105    let tree = parser.parse(&processed_source, None).ok_or_else(|| {
106        ParserError::ParseError(file_path.to_path_buf(), "Failed to parse".to_string())
107    })?;
108
109    let root_node = tree.root_node();
110    let has_error = root_node.has_error();
111    let error_count = if has_error {
112        count_errors(root_node)
113    } else {
114        0
115    };
116
117    // In strict mode, fail on errors
118    if has_error && !options.tolerant_mode {
119        return Err(ParserError::SyntaxError(
120            file_path.to_path_buf(),
121            0,
122            0,
123            format!("Syntax error ({error_count} error nodes)"),
124        ));
125    }
126
127    let mut ir = CodeIR::new(file_path.to_path_buf());
128
129    let module_name = file_path
130        .file_stem()
131        .and_then(|s| s.to_str())
132        .unwrap_or("unknown")
133        .to_string();
134    ir.module = Some(ModuleEntity {
135        name: module_name,
136        path: file_path.display().to_string(),
137        language: "c".to_string(),
138        line_count: source.lines().count(),
139        doc_comment: None,
140        attributes: Vec::new(),
141    });
142
143    // Visit the AST - the visitor will skip ERROR nodes gracefully
144    let mut visitor = CVisitor::new(processed_source.as_bytes(), config.clone());
145    visitor.set_extract_calls(options.extract_calls);
146    visitor.visit_node(root_node);
147
148    ir.functions = visitor.functions;
149    ir.classes = visitor.structs;
150    ir.imports = visitor.imports;
151
152    // Store call relationships in a custom way (we'll add this to IR later)
153    // For now, calls are stored as part of function entities
154
155    Ok(ExtractionResult {
156        ir,
157        error_count,
158        is_partial: has_error,
159        detected_macros,
160    })
161}
162
163/// Count ERROR nodes in the syntax tree
164fn count_errors(node: tree_sitter::Node) -> usize {
165    let mut count = 0;
166
167    if node.is_error() || node.is_missing() {
168        count += 1;
169    }
170
171    let mut cursor = node.walk();
172    for child in node.children(&mut cursor) {
173        count += count_errors(child);
174    }
175
176    count
177}
178
179#[cfg(test)]
180mod tests {
181    use super::*;
182
183    #[test]
184    fn test_extract_simple_function() {
185        let source = r#"
186int main() {
187    return 0;
188}
189"#;
190        let config = ParserConfig::default();
191        let result = extract(source, Path::new("test.c"), &config);
192
193        assert!(result.is_ok());
194        let ir = result.unwrap();
195        assert_eq!(ir.functions.len(), 1);
196        assert_eq!(ir.functions[0].name, "main");
197    }
198
199    #[test]
200    fn test_extract_function_with_params() {
201        let source = r#"
202int add(int a, int b) {
203    return a + b;
204}
205"#;
206        let config = ParserConfig::default();
207        let result = extract(source, Path::new("test.c"), &config);
208
209        assert!(result.is_ok());
210        let ir = result.unwrap();
211        assert_eq!(ir.functions.len(), 1);
212        assert_eq!(ir.functions[0].name, "add");
213        assert_eq!(ir.functions[0].parameters.len(), 2);
214    }
215
216    #[test]
217    fn test_extract_struct() {
218        let source = r#"
219struct Point {
220    int x;
221    int y;
222};
223"#;
224        let config = ParserConfig::default();
225        let result = extract(source, Path::new("test.c"), &config);
226
227        assert!(result.is_ok());
228        let ir = result.unwrap();
229        assert_eq!(ir.classes.len(), 1);
230        assert_eq!(ir.classes[0].name, "Point");
231    }
232
233    #[test]
234    fn test_extract_enum() {
235        let source = r#"
236enum Color {
237    RED,
238    GREEN,
239    BLUE
240};
241"#;
242        let config = ParserConfig::default();
243        let result = extract(source, Path::new("test.c"), &config);
244
245        assert!(result.is_ok());
246        let ir = result.unwrap();
247        assert_eq!(ir.classes.len(), 1);
248        assert_eq!(ir.classes[0].name, "Color");
249    }
250
251    #[test]
252    fn test_extract_include() {
253        let source = r#"
254#include <stdio.h>
255#include "myheader.h"
256"#;
257        let config = ParserConfig::default();
258        let result = extract(source, Path::new("test.c"), &config);
259
260        assert!(result.is_ok());
261        let ir = result.unwrap();
262        assert_eq!(ir.imports.len(), 2);
263    }
264
265    #[test]
266    fn test_extract_multiple_functions() {
267        let source = r#"
268int foo() { return 1; }
269int bar() { return 2; }
270int baz() { return 3; }
271"#;
272        let config = ParserConfig::default();
273        let result = extract(source, Path::new("test.c"), &config);
274
275        assert!(result.is_ok());
276        let ir = result.unwrap();
277        assert_eq!(ir.functions.len(), 3);
278    }
279
280    #[test]
281    fn test_extract_static_function() {
282        let source = r#"
283static void helper() {
284    // internal function
285}
286"#;
287        let config = ParserConfig::default();
288        let result = extract(source, Path::new("test.c"), &config);
289
290        assert!(result.is_ok());
291        let ir = result.unwrap();
292        assert_eq!(ir.functions.len(), 1);
293        assert_eq!(ir.functions[0].visibility, "private");
294    }
295
296    #[test]
297    fn test_extract_module_info() {
298        let source = r#"
299int test() {
300    return 42;
301}
302"#;
303        let config = ParserConfig::default();
304        let result = extract(source, Path::new("module.c"), &config);
305
306        assert!(result.is_ok());
307        let ir = result.unwrap();
308        assert!(ir.module.is_some());
309        let module = ir.module.unwrap();
310        assert_eq!(module.name, "module");
311        assert_eq!(module.language, "c");
312        assert!(module.line_count > 0);
313    }
314
315    #[test]
316    fn test_extract_with_syntax_error_strict() {
317        let source = r#"
318int broken( {
319    // Missing closing brace
320"#;
321        let config = ParserConfig::default();
322        let result = extract(source, Path::new("test.c"), &config);
323
324        assert!(result.is_err());
325        match result {
326            Err(ParserError::SyntaxError(..)) => (),
327            _ => panic!("Expected SyntaxError"),
328        }
329    }
330
331    #[test]
332    fn test_extract_with_syntax_error_tolerant() {
333        let source = r#"
334int valid_func() { return 0; }
335int broken( {
336int another_valid() { return 1; }
337"#;
338        let config = ParserConfig::default();
339        let options = ExtractionOptions::tolerant();
340        let result = extract_with_options(source, Path::new("test.c"), &config, &options);
341
342        assert!(result.is_ok());
343        let extraction = result.unwrap();
344        assert!(extraction.is_partial);
345        assert!(extraction.error_count > 0);
346        // Should still extract the valid functions
347        assert!(!extraction.ir.functions.is_empty());
348    }
349
350    #[test]
351    fn test_extract_kernel_code_simulation() {
352        let source = r#"
353static __init int my_module_init(void) {
354    return 0;
355}
356
357static __exit void my_module_exit(void) {
358}
359
360MODULE_LICENSE("GPL");
361"#;
362        let config = ParserConfig::default();
363        let options = ExtractionOptions::for_kernel_code();
364        let result = extract_with_options(source, Path::new("test.c"), &config, &options);
365
366        // With preprocessing, this should parse better
367        assert!(result.is_ok());
368        let extraction = result.unwrap();
369        // Check that macros were detected
370        assert!(
371            extraction.detected_macros.contains(&"__init".to_string())
372                || extraction.detected_macros.contains(&"__exit".to_string())
373        );
374    }
375
376    #[test]
377    fn test_extract_pointer_params() {
378        let source = r#"
379void process(int *arr, const char *str) {
380    // pointer parameters
381}
382"#;
383        let config = ParserConfig::default();
384        let result = extract(source, Path::new("test.c"), &config);
385
386        assert!(result.is_ok());
387        let ir = result.unwrap();
388        assert_eq!(ir.functions.len(), 1);
389        assert_eq!(ir.functions[0].parameters.len(), 2);
390    }
391
392    #[test]
393    fn test_extract_union() {
394        let source = r#"
395union Data {
396    int i;
397    float f;
398    char c;
399};
400"#;
401        let config = ParserConfig::default();
402        let result = extract(source, Path::new("test.c"), &config);
403
404        assert!(result.is_ok());
405        let ir = result.unwrap();
406        assert_eq!(ir.classes.len(), 1);
407        assert_eq!(ir.classes[0].name, "Data");
408    }
409
410    #[test]
411    fn test_extract_function_with_complexity() {
412        let source = r#"
413int complex_func(int x) {
414    if (x > 0) {
415        for (int i = 0; i < x; i++) {
416            if (i % 2 == 0) {
417                continue;
418            }
419        }
420        return 1;
421    } else if (x < 0) {
422        while (x < 0) {
423            x++;
424        }
425        return -1;
426    }
427    return 0;
428}
429"#;
430        let config = ParserConfig::default();
431        let result = extract(source, Path::new("test.c"), &config);
432
433        assert!(result.is_ok());
434        let ir = result.unwrap();
435        assert_eq!(ir.functions.len(), 1);
436        // Check that complexity metrics are populated
437        let func = &ir.functions[0];
438        assert!(func.complexity.is_some());
439        let complexity = func.complexity.as_ref().unwrap();
440        assert!(complexity.cyclomatic_complexity > 1);
441    }
442}