1use codegraph_parser_api::{CodeIR, ModuleEntity, ParserConfig, ParserError};
8use std::path::Path;
9use tree_sitter::Parser;
10
11use crate::preprocessor::CPreprocessor;
12use crate::visitor::CVisitor;
13
14#[derive(Debug, Clone, Default)]
16pub struct ExtractionOptions {
17 pub tolerant_mode: bool,
19 pub preprocess: bool,
21 pub extract_calls: bool,
23}
24
25impl ExtractionOptions {
26 pub fn for_kernel_code() -> Self {
28 Self {
29 tolerant_mode: true,
30 preprocess: true,
31 extract_calls: true,
32 }
33 }
34
35 pub fn tolerant() -> Self {
37 Self {
38 tolerant_mode: true,
39 preprocess: false,
40 extract_calls: true,
41 }
42 }
43}
44
45#[derive(Debug)]
47pub struct ExtractionResult {
48 pub ir: CodeIR,
49 pub error_count: usize,
51 pub is_partial: bool,
53 pub detected_macros: Vec<String>,
55}
56
57pub fn extract(
59 source: &str,
60 file_path: &Path,
61 config: &ParserConfig,
62) -> Result<CodeIR, ParserError> {
63 let result = extract_with_options(source, file_path, config, &ExtractionOptions::default())?;
64
65 if result.is_partial {
66 return Err(ParserError::SyntaxError(
67 file_path.to_path_buf(),
68 0,
69 0,
70 "Syntax error".to_string(),
71 ));
72 }
73
74 Ok(result.ir)
75}
76
77pub fn extract_with_options(
79 source: &str,
80 file_path: &Path,
81 config: &ParserConfig,
82 options: &ExtractionOptions,
83) -> Result<ExtractionResult, ParserError> {
84 let preprocessor = CPreprocessor::new();
86 let detected_macros: Vec<String> = preprocessor
87 .analyze_macros(source)
88 .iter()
89 .map(|m| m.name.clone())
90 .collect();
91
92 let processed_source = if options.preprocess {
94 preprocessor.preprocess(source)
95 } else {
96 source.to_string()
97 };
98
99 let mut parser = Parser::new();
100 let language = tree_sitter_c::language();
101 parser
102 .set_language(language)
103 .map_err(|e| ParserError::ParseError(file_path.to_path_buf(), e.to_string()))?;
104
105 let tree = parser.parse(&processed_source, None).ok_or_else(|| {
106 ParserError::ParseError(file_path.to_path_buf(), "Failed to parse".to_string())
107 })?;
108
109 let root_node = tree.root_node();
110 let has_error = root_node.has_error();
111 let error_count = if has_error {
112 count_errors(root_node)
113 } else {
114 0
115 };
116
117 if has_error && !options.tolerant_mode {
119 return Err(ParserError::SyntaxError(
120 file_path.to_path_buf(),
121 0,
122 0,
123 format!("Syntax error ({error_count} error nodes)"),
124 ));
125 }
126
127 let mut ir = CodeIR::new(file_path.to_path_buf());
128
129 let module_name = file_path
130 .file_stem()
131 .and_then(|s| s.to_str())
132 .unwrap_or("unknown")
133 .to_string();
134 ir.module = Some(ModuleEntity {
135 name: module_name,
136 path: file_path.display().to_string(),
137 language: "c".to_string(),
138 line_count: source.lines().count(),
139 doc_comment: None,
140 attributes: Vec::new(),
141 });
142
143 let mut visitor = CVisitor::new(processed_source.as_bytes(), config.clone());
145 visitor.set_extract_calls(options.extract_calls);
146 visitor.visit_node(root_node);
147
148 ir.functions = visitor.functions;
149 ir.classes = visitor.structs;
150 ir.imports = visitor.imports;
151
152 Ok(ExtractionResult {
156 ir,
157 error_count,
158 is_partial: has_error,
159 detected_macros,
160 })
161}
162
163fn count_errors(node: tree_sitter::Node) -> usize {
165 let mut count = 0;
166
167 if node.is_error() || node.is_missing() {
168 count += 1;
169 }
170
171 let mut cursor = node.walk();
172 for child in node.children(&mut cursor) {
173 count += count_errors(child);
174 }
175
176 count
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn test_extract_simple_function() {
185 let source = r#"
186int main() {
187 return 0;
188}
189"#;
190 let config = ParserConfig::default();
191 let result = extract(source, Path::new("test.c"), &config);
192
193 assert!(result.is_ok());
194 let ir = result.unwrap();
195 assert_eq!(ir.functions.len(), 1);
196 assert_eq!(ir.functions[0].name, "main");
197 }
198
199 #[test]
200 fn test_extract_function_with_params() {
201 let source = r#"
202int add(int a, int b) {
203 return a + b;
204}
205"#;
206 let config = ParserConfig::default();
207 let result = extract(source, Path::new("test.c"), &config);
208
209 assert!(result.is_ok());
210 let ir = result.unwrap();
211 assert_eq!(ir.functions.len(), 1);
212 assert_eq!(ir.functions[0].name, "add");
213 assert_eq!(ir.functions[0].parameters.len(), 2);
214 }
215
216 #[test]
217 fn test_extract_struct() {
218 let source = r#"
219struct Point {
220 int x;
221 int y;
222};
223"#;
224 let config = ParserConfig::default();
225 let result = extract(source, Path::new("test.c"), &config);
226
227 assert!(result.is_ok());
228 let ir = result.unwrap();
229 assert_eq!(ir.classes.len(), 1);
230 assert_eq!(ir.classes[0].name, "Point");
231 }
232
233 #[test]
234 fn test_extract_enum() {
235 let source = r#"
236enum Color {
237 RED,
238 GREEN,
239 BLUE
240};
241"#;
242 let config = ParserConfig::default();
243 let result = extract(source, Path::new("test.c"), &config);
244
245 assert!(result.is_ok());
246 let ir = result.unwrap();
247 assert_eq!(ir.classes.len(), 1);
248 assert_eq!(ir.classes[0].name, "Color");
249 }
250
251 #[test]
252 fn test_extract_include() {
253 let source = r#"
254#include <stdio.h>
255#include "myheader.h"
256"#;
257 let config = ParserConfig::default();
258 let result = extract(source, Path::new("test.c"), &config);
259
260 assert!(result.is_ok());
261 let ir = result.unwrap();
262 assert_eq!(ir.imports.len(), 2);
263 }
264
265 #[test]
266 fn test_extract_multiple_functions() {
267 let source = r#"
268int foo() { return 1; }
269int bar() { return 2; }
270int baz() { return 3; }
271"#;
272 let config = ParserConfig::default();
273 let result = extract(source, Path::new("test.c"), &config);
274
275 assert!(result.is_ok());
276 let ir = result.unwrap();
277 assert_eq!(ir.functions.len(), 3);
278 }
279
280 #[test]
281 fn test_extract_static_function() {
282 let source = r#"
283static void helper() {
284 // internal function
285}
286"#;
287 let config = ParserConfig::default();
288 let result = extract(source, Path::new("test.c"), &config);
289
290 assert!(result.is_ok());
291 let ir = result.unwrap();
292 assert_eq!(ir.functions.len(), 1);
293 assert_eq!(ir.functions[0].visibility, "private");
294 }
295
296 #[test]
297 fn test_extract_module_info() {
298 let source = r#"
299int test() {
300 return 42;
301}
302"#;
303 let config = ParserConfig::default();
304 let result = extract(source, Path::new("module.c"), &config);
305
306 assert!(result.is_ok());
307 let ir = result.unwrap();
308 assert!(ir.module.is_some());
309 let module = ir.module.unwrap();
310 assert_eq!(module.name, "module");
311 assert_eq!(module.language, "c");
312 assert!(module.line_count > 0);
313 }
314
315 #[test]
316 fn test_extract_with_syntax_error_strict() {
317 let source = r#"
318int broken( {
319 // Missing closing brace
320"#;
321 let config = ParserConfig::default();
322 let result = extract(source, Path::new("test.c"), &config);
323
324 assert!(result.is_err());
325 match result {
326 Err(ParserError::SyntaxError(..)) => (),
327 _ => panic!("Expected SyntaxError"),
328 }
329 }
330
331 #[test]
332 fn test_extract_with_syntax_error_tolerant() {
333 let source = r#"
334int valid_func() { return 0; }
335int broken( {
336int another_valid() { return 1; }
337"#;
338 let config = ParserConfig::default();
339 let options = ExtractionOptions::tolerant();
340 let result = extract_with_options(source, Path::new("test.c"), &config, &options);
341
342 assert!(result.is_ok());
343 let extraction = result.unwrap();
344 assert!(extraction.is_partial);
345 assert!(extraction.error_count > 0);
346 assert!(!extraction.ir.functions.is_empty());
348 }
349
350 #[test]
351 fn test_extract_kernel_code_simulation() {
352 let source = r#"
353static __init int my_module_init(void) {
354 return 0;
355}
356
357static __exit void my_module_exit(void) {
358}
359
360MODULE_LICENSE("GPL");
361"#;
362 let config = ParserConfig::default();
363 let options = ExtractionOptions::for_kernel_code();
364 let result = extract_with_options(source, Path::new("test.c"), &config, &options);
365
366 assert!(result.is_ok());
368 let extraction = result.unwrap();
369 assert!(
371 extraction.detected_macros.contains(&"__init".to_string())
372 || extraction.detected_macros.contains(&"__exit".to_string())
373 );
374 }
375
376 #[test]
377 fn test_extract_pointer_params() {
378 let source = r#"
379void process(int *arr, const char *str) {
380 // pointer parameters
381}
382"#;
383 let config = ParserConfig::default();
384 let result = extract(source, Path::new("test.c"), &config);
385
386 assert!(result.is_ok());
387 let ir = result.unwrap();
388 assert_eq!(ir.functions.len(), 1);
389 assert_eq!(ir.functions[0].parameters.len(), 2);
390 }
391
392 #[test]
393 fn test_extract_union() {
394 let source = r#"
395union Data {
396 int i;
397 float f;
398 char c;
399};
400"#;
401 let config = ParserConfig::default();
402 let result = extract(source, Path::new("test.c"), &config);
403
404 assert!(result.is_ok());
405 let ir = result.unwrap();
406 assert_eq!(ir.classes.len(), 1);
407 assert_eq!(ir.classes[0].name, "Data");
408 }
409
410 #[test]
411 fn test_extract_function_with_complexity() {
412 let source = r#"
413int complex_func(int x) {
414 if (x > 0) {
415 for (int i = 0; i < x; i++) {
416 if (i % 2 == 0) {
417 continue;
418 }
419 }
420 return 1;
421 } else if (x < 0) {
422 while (x < 0) {
423 x++;
424 }
425 return -1;
426 }
427 return 0;
428}
429"#;
430 let config = ParserConfig::default();
431 let result = extract(source, Path::new("test.c"), &config);
432
433 assert!(result.is_ok());
434 let ir = result.unwrap();
435 assert_eq!(ir.functions.len(), 1);
436 let func = &ir.functions[0];
438 assert!(func.complexity.is_some());
439 let complexity = func.complexity.as_ref().unwrap();
440 assert!(complexity.cyclomatic_complexity > 1);
441 }
442}