Skip to main content

deagle_parse/
c_parser.rs

1//! C language parser using tree-sitter-c.
2
3use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5use crate::ParseResult;
6
7pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
8    parse_with_edges(path, content).map(|r| r.nodes)
9}
10
11pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
12    let mut parser = tree_sitter::Parser::new();
13    let language = tree_sitter_c::LANGUAGE;
14    parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
15        file: path.display().to_string(),
16        message: format!("Failed to set language: {}", e),
17    })?;
18
19    let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
20        file: path.display().to_string(),
21        message: "Failed to parse file".into(),
22    })?;
23
24    let mut nodes = Vec::new();
25    let file_path = path.to_string_lossy().to_string();
26    let lang = Language::C;
27
28    nodes.push(Node {
29        id: 0,
30        name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
31        kind: NodeKind::File,
32        language: lang,
33        file_path: file_path.clone(),
34        line_start: 1,
35        line_end: content.lines().count() as u32,
36        content: None,
37    });
38
39    extract_definitions(tree.root_node(), content, &file_path, &mut nodes);
40
41    let mut edges = Vec::new();
42    for i in 1..nodes.len() {
43        edges.push((0, i, EdgeKind::Contains));
44    }
45    Ok(ParseResult { nodes, edges })
46}
47
48fn extract_definitions(node: tree_sitter::Node, source: &str, file_path: &str, results: &mut Vec<Node>) {
49    let kind = match node.kind() {
50        "function_definition" => Some(NodeKind::Function),
51        "declaration" => {
52            // Check if it's a function declaration (prototype) or variable
53            if node.child_by_field_name("declarator").map(|d| d.kind() == "function_declarator").unwrap_or(false) {
54                Some(NodeKind::Function)
55            } else {
56                None
57            }
58        }
59        "struct_specifier" => {
60            if node.child_by_field_name("body").is_some() {
61                Some(NodeKind::Struct)
62            } else {
63                None
64            }
65        }
66        "enum_specifier" => {
67            if node.child_by_field_name("body").is_some() {
68                Some(NodeKind::Enum)
69            } else {
70                None
71            }
72        }
73        "type_definition" => Some(NodeKind::TypeAlias),
74        "preproc_include" => Some(NodeKind::Import),
75        "preproc_def" => Some(NodeKind::Constant),
76        _ => None,
77    };
78
79    if let Some(kind) = kind {
80        if let Some(name) = extract_name(node, source, kind) {
81            let start = node.start_position();
82            let end = node.end_position();
83            let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
84                crate::truncate_content(s, 500)
85            });
86            results.push(Node {
87                id: 0, name, kind, language: Language::C,
88                file_path: file_path.to_string(),
89                line_start: (start.row + 1) as u32,
90                line_end: (end.row + 1) as u32,
91                content,
92            });
93        }
94    }
95
96    let mut cursor = node.walk();
97    for child in node.children(&mut cursor) {
98        extract_definitions(child, source, file_path, results);
99    }
100}
101
102fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
103    match kind {
104        NodeKind::Import => node.utf8_text(source.as_bytes()).ok().map(|s| s.trim().to_string()),
105        NodeKind::Constant => {
106            // #define NAME ...
107            node.child_by_field_name("name")
108                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
109                .map(|s| s.to_string())
110        }
111        NodeKind::Function => {
112            // function_definition → declarator → function_declarator → declarator → identifier
113            fn find_fn_name(n: tree_sitter::Node, src: &str) -> Option<String> {
114                if n.kind() == "identifier" {
115                    return n.utf8_text(src.as_bytes()).ok().map(|s| s.to_string());
116                }
117                if let Some(d) = n.child_by_field_name("declarator") {
118                    return find_fn_name(d, src);
119                }
120                let mut c = n.walk();
121                for child in n.children(&mut c) {
122                    if let Some(name) = find_fn_name(child, src) {
123                        return Some(name);
124                    }
125                }
126                None
127            }
128            find_fn_name(node, source)
129        }
130        NodeKind::Struct | NodeKind::Enum => {
131            node.child_by_field_name("name")
132                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
133                .map(|s| s.to_string())
134        }
135        NodeKind::TypeAlias => {
136            // typedef ... name;  — last identifier before semicolon
137            node.child_by_field_name("declarator")
138                .and_then(|n| {
139                    if n.kind() == "type_identifier" {
140                        n.utf8_text(source.as_bytes()).ok().map(|s| s.to_string())
141                    } else {
142                        None
143                    }
144                })
145        }
146        _ => node.child_by_field_name("name")
147            .and_then(|n| n.utf8_text(source.as_bytes()).ok())
148            .map(|s| s.to_string()),
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::path::PathBuf;
156
157    const SAMPLE_C: &str = r#"
158#include <stdio.h>
159#include <stdlib.h>
160
161#define MAX_SIZE 1024
162#define VERSION "1.0"
163
164typedef unsigned int uint;
165
166struct Point {
167    int x;
168    int y;
169};
170
171enum Color {
172    RED,
173    GREEN,
174    BLUE
175};
176
177int add(int a, int b) {
178    return a + b;
179}
180
181void print_point(struct Point p) {
182    printf("(%d, %d)\n", p.x, p.y);
183}
184
185int main(int argc, char *argv[]) {
186    struct Point p = {1, 2};
187    print_point(p);
188    printf("%d\n", add(p.x, p.y));
189    return 0;
190}
191"#;
192
193    #[test]
194    fn test_parse_c_finds_all() {
195        let path = PathBuf::from("main.c");
196        let nodes = parse(&path, SAMPLE_C).unwrap();
197        let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
198        assert!(kinds.contains(&NodeKind::Import), "should find #include");
199        assert!(kinds.contains(&NodeKind::Constant), "should find #define");
200        assert!(kinds.contains(&NodeKind::Struct), "should find struct");
201        assert!(kinds.contains(&NodeKind::Enum), "should find enum");
202        assert!(kinds.contains(&NodeKind::Function), "should find function");
203    }
204
205    #[test]
206    fn test_parse_c_functions() {
207        let path = PathBuf::from("main.c");
208        let nodes = parse(&path, SAMPLE_C).unwrap();
209        let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
210        assert!(fns.iter().any(|f| f.name == "add"));
211        assert!(fns.iter().any(|f| f.name == "print_point"));
212        assert!(fns.iter().any(|f| f.name == "main"));
213    }
214
215    #[test]
216    fn test_parse_c_struct() {
217        let path = PathBuf::from("main.c");
218        let nodes = parse(&path, SAMPLE_C).unwrap();
219        let structs: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Struct).collect();
220        assert_eq!(structs.len(), 1);
221        assert_eq!(structs[0].name, "Point");
222    }
223
224    #[test]
225    fn test_parse_c_defines() {
226        let path = PathBuf::from("main.c");
227        let nodes = parse(&path, SAMPLE_C).unwrap();
228        let consts: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Constant).collect();
229        assert!(consts.iter().any(|c| c.name == "MAX_SIZE"));
230        assert!(consts.iter().any(|c| c.name == "VERSION"));
231    }
232
233    #[test]
234    fn test_parse_c_edges() {
235        let path = PathBuf::from("main.c");
236        let result = parse_with_edges(&path, SAMPLE_C).unwrap();
237        assert!(!result.edges.is_empty());
238    }
239
240    #[test]
241    fn test_parse_empty_c() {
242        let path = PathBuf::from("empty.c");
243        let nodes = parse(&path, "").unwrap();
244        assert!(nodes.len() <= 1);
245    }
246}