Skip to main content

forgekit_core/treesitter/
c.rs

1use crate::cfg::TestCfg;
2use crate::error::{ForgeError, Result};
3use crate::types::BlockId;
4
5use super::{CfgExtractor, FunctionInfo, SupportedLanguage};
6
7impl CfgExtractor {
8    /// Extract CFG from C source code
9    pub fn extract_c(source: &str) -> Result<Vec<FunctionInfo>> {
10        use tree_sitter::Parser;
11        use tree_sitter_c;
12
13        let mut parser = Parser::new();
14        parser
15            .set_language(&tree_sitter_c::LANGUAGE.into())
16            .map_err(|e| ForgeError::DatabaseError(format!("Failed to set C language: {:?}", e)))?;
17
18        let tree = parser
19            .parse(source, None)
20            .ok_or_else(|| ForgeError::DatabaseError("Failed to parse C code".to_string()))?;
21
22        let root = tree.root_node();
23        let mut functions = Vec::new();
24
25        Self::extract_c_functions(source, &root, &mut functions)?;
26
27        Ok(functions)
28    }
29
30    fn extract_c_functions(
31        source: &str,
32        node: &tree_sitter::Node,
33        functions: &mut Vec<FunctionInfo>,
34    ) -> Result<()> {
35        let kind = node.kind();
36
37        // Look for function definitions
38        if kind == "function_definition" {
39            if let Some(func) = Self::parse_c_function(source, node)? {
40                functions.push(func);
41            }
42        }
43
44        // Recurse into children
45        let mut cursor = node.walk();
46        for child in node.children(&mut cursor) {
47            Self::extract_c_functions(source, &child, functions)?;
48        }
49
50        Ok(())
51    }
52
53    fn parse_c_function(source: &str, node: &tree_sitter::Node) -> Result<Option<FunctionInfo>> {
54        let start_byte = node.start_byte();
55        let end_byte = node.end_byte();
56
57        // Find function name - look for identifier within function_declarator
58        let mut name = "unknown".to_string();
59        let mut cursor = node.walk();
60        for child in node.children(&mut cursor) {
61            // Direct identifier (for simple cases)
62            if child.kind() == "identifier" {
63                name = Self::node_text(source, &child);
64                break;
65            }
66            // For function declarator, look inside for the identifier
67            if child.kind() == "function_declarator" {
68                let mut inner_cursor = child.walk();
69                for inner in child.children(&mut inner_cursor) {
70                    if inner.kind() == "identifier" {
71                        name = Self::node_text(source, &inner);
72                        break;
73                    }
74                    // Handle pointer declarator
75                    if inner.kind() == "pointer_declarator" || inner.kind() == "function_declarator"
76                    {
77                        let mut ptr_cursor = inner.walk();
78                        for ptr_child in inner.children(&mut ptr_cursor) {
79                            if ptr_child.kind() == "identifier" {
80                                name = Self::node_text(source, &ptr_child);
81                                break;
82                            }
83                        }
84                    }
85                }
86                break;
87            }
88            // For pointer functions at top level
89            if child.kind() == "pointer_declarator" {
90                let mut inner_cursor = child.walk();
91                for inner in child.children(&mut inner_cursor) {
92                    if inner.kind() == "function_declarator" {
93                        let mut fn_cursor = inner.walk();
94                        for fn_child in inner.children(&mut fn_cursor) {
95                            if fn_child.kind() == "identifier" {
96                                name = Self::node_text(source, &fn_child);
97                                break;
98                            }
99                        }
100                    }
101                }
102                break;
103            }
104        }
105
106        // Find compound_statement (function body)
107        let mut body = None;
108        let mut cursor = node.walk();
109        for child in node.children(&mut cursor) {
110            if child.kind() == "compound_statement" {
111                body = Some(child);
112                break;
113            }
114        }
115
116        let cfg = if let Some(body) = body {
117            Self::build_cfg_from_body(source, &body, SupportedLanguage::C)?
118        } else {
119            // Function declaration without body
120            TestCfg::new(BlockId(0))
121        };
122
123        Ok(Some(FunctionInfo {
124            name,
125            start_byte,
126            end_byte,
127            cfg,
128        }))
129    }
130}