Skip to main content

deagle_parse/
cpp_parser.rs

1//! C++ language parser using tree-sitter-cpp.
2
3use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5use crate::ParseResult;
6
7pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
8    parse_with_edges(path, content).map(|r| r.nodes)
9}
10
11pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
12    let mut parser = tree_sitter::Parser::new();
13    let language = tree_sitter_cpp::LANGUAGE;
14    parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
15        file: path.display().to_string(),
16        message: format!("Failed to set language: {}", e),
17    })?;
18
19    let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
20        file: path.display().to_string(),
21        message: "Failed to parse file".into(),
22    })?;
23
24    let mut nodes = Vec::new();
25    let file_path = path.to_string_lossy().to_string();
26
27    nodes.push(Node {
28        id: 0,
29        name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
30        kind: NodeKind::File,
31        language: Language::Cpp,
32        file_path: file_path.clone(),
33        line_start: 1,
34        line_end: content.lines().count() as u32,
35        content: None,
36    });
37
38    extract_definitions(tree.root_node(), content, &file_path, &mut nodes);
39
40    let mut edges = Vec::new();
41    for i in 1..nodes.len() {
42        edges.push((0, i, EdgeKind::Contains));
43    }
44    Ok(ParseResult { nodes, edges })
45}
46
47fn extract_definitions(node: tree_sitter::Node, source: &str, file_path: &str, results: &mut Vec<Node>) {
48    let kind = match node.kind() {
49        "function_definition" => Some(NodeKind::Function),
50        "declaration" => {
51            if node.child_by_field_name("declarator")
52                .map(|d| d.kind() == "function_declarator")
53                .unwrap_or(false)
54            {
55                Some(NodeKind::Function)
56            } else {
57                None
58            }
59        }
60        "class_specifier" => {
61            if node.child_by_field_name("body").is_some() {
62                Some(NodeKind::Class)
63            } else {
64                None
65            }
66        }
67        "struct_specifier" => {
68            if node.child_by_field_name("body").is_some() {
69                Some(NodeKind::Struct)
70            } else {
71                None
72            }
73        }
74        "enum_specifier" => {
75            if node.child_by_field_name("body").is_some() {
76                Some(NodeKind::Enum)
77            } else {
78                None
79            }
80        }
81        "namespace_definition" => Some(NodeKind::Module),
82        "type_definition" => Some(NodeKind::TypeAlias),
83        "preproc_include" => Some(NodeKind::Import),
84        "preproc_def" => Some(NodeKind::Constant),
85        "template_declaration" => {
86            // Look inside for class or function
87            let mut cursor = node.walk();
88            for child in node.children(&mut cursor) {
89                match child.kind() {
90                    "class_specifier" | "struct_specifier" => return extract_template(node, child, source, file_path, results, NodeKind::Class),
91                    "function_definition" => return extract_template(node, child, source, file_path, results, NodeKind::Function),
92                    "declaration" => return extract_template(node, child, source, file_path, results, NodeKind::Function),
93                    _ => {}
94                }
95            }
96            None
97        }
98        _ => None,
99    };
100
101    if let Some(kind) = kind {
102        if let Some(name) = extract_name(node, source, kind) {
103            let start = node.start_position();
104            let end = node.end_position();
105            let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
106                crate::truncate_content(s, 500)
107            });
108            results.push(Node {
109                id: 0, name, kind, language: Language::Cpp,
110                file_path: file_path.to_string(),
111                line_start: (start.row + 1) as u32,
112                line_end: (end.row + 1) as u32,
113                content,
114            });
115        }
116    }
117
118    // Recurse into children (but skip template_declaration children since handled above)
119    if node.kind() != "template_declaration" {
120        let mut cursor = node.walk();
121        for child in node.children(&mut cursor) {
122            extract_definitions(child, source, file_path, results);
123        }
124    }
125}
126
127fn extract_template(
128    template_node: tree_sitter::Node,
129    inner_node: tree_sitter::Node,
130    source: &str,
131    file_path: &str,
132    results: &mut Vec<Node>,
133    kind: NodeKind,
134) {
135    if let Some(name) = extract_name(inner_node, source, kind) {
136        let start = template_node.start_position();
137        let end = template_node.end_position();
138        let content = template_node.utf8_text(source.as_bytes()).ok().map(|s| {
139            crate::truncate_content(s, 500)
140        });
141        results.push(Node {
142            id: 0, name, kind, language: Language::Cpp,
143            file_path: file_path.to_string(),
144            line_start: (start.row + 1) as u32,
145            line_end: (end.row + 1) as u32,
146            content,
147        });
148    }
149    // Also recurse into the inner node for nested definitions (e.g., methods inside template class)
150    let mut cursor = inner_node.walk();
151    for child in inner_node.children(&mut cursor) {
152        extract_definitions(child, source, file_path, results);
153    }
154}
155
156fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
157    match kind {
158        NodeKind::Import => node.utf8_text(source.as_bytes()).ok().map(|s| s.trim().to_string()),
159        NodeKind::Constant => {
160            node.child_by_field_name("name")
161                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
162                .map(|s| s.to_string())
163        }
164        NodeKind::Function => {
165            fn find_fn_name(n: tree_sitter::Node, src: &str) -> Option<String> {
166                if n.kind() == "identifier" || n.kind() == "field_identifier" || n.kind() == "destructor_name" {
167                    return n.utf8_text(src.as_bytes()).ok().map(|s| s.to_string());
168                }
169                // Handle qualified names like ClassName::method
170                if n.kind() == "qualified_identifier" || n.kind() == "scoped_identifier" {
171                    return n.utf8_text(src.as_bytes()).ok().map(|s| s.to_string());
172                }
173                if let Some(d) = n.child_by_field_name("declarator") {
174                    return find_fn_name(d, src);
175                }
176                let mut c = n.walk();
177                for child in n.children(&mut c) {
178                    if let Some(name) = find_fn_name(child, src) {
179                        return Some(name);
180                    }
181                }
182                None
183            }
184            find_fn_name(node, source)
185        }
186        NodeKind::Class | NodeKind::Struct | NodeKind::Enum | NodeKind::Module => {
187            node.child_by_field_name("name")
188                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
189                .map(|s| s.to_string())
190        }
191        NodeKind::TypeAlias => {
192            node.child_by_field_name("declarator")
193                .and_then(|n| {
194                    if n.kind() == "type_identifier" {
195                        n.utf8_text(source.as_bytes()).ok().map(|s| s.to_string())
196                    } else {
197                        None
198                    }
199                })
200        }
201        _ => node.child_by_field_name("name")
202            .and_then(|n| n.utf8_text(source.as_bytes()).ok())
203            .map(|s| s.to_string()),
204    }
205}
206
207#[cfg(test)]
208mod tests {
209    use super::*;
210    use std::path::PathBuf;
211
212    const SAMPLE_CPP: &str = r#"
213#include <iostream>
214#include <vector>
215
216#define MAX_SIZE 1024
217
218namespace math {
219
220class Vector {
221public:
222    double x, y, z;
223
224    Vector(double x, double y, double z) : x(x), y(y), z(z) {}
225
226    double magnitude() const {
227        return std::sqrt(x*x + y*y + z*z);
228    }
229
230    Vector operator+(const Vector& other) const {
231        return Vector(x + other.x, y + other.y, z + other.z);
232    }
233};
234
235struct Point {
236    int x;
237    int y;
238};
239
240enum class Color {
241    Red,
242    Green,
243    Blue
244};
245
246template<typename T>
247class Container {
248    T value;
249public:
250    Container(T v) : value(v) {}
251    T get() const { return value; }
252};
253
254template<typename T>
255T add(T a, T b) {
256    return a + b;
257}
258
259} // namespace math
260
261int main(int argc, char* argv[]) {
262    math::Vector v(1, 2, 3);
263    std::cout << v.magnitude() << std::endl;
264    return 0;
265}
266"#;
267
268    #[test]
269    fn test_parse_cpp_finds_all() {
270        let path = PathBuf::from("main.cpp");
271        let nodes = parse(&path, SAMPLE_CPP).unwrap();
272        let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
273        assert!(kinds.contains(&NodeKind::Import), "should find #include");
274        assert!(kinds.contains(&NodeKind::Constant), "should find #define");
275        assert!(kinds.contains(&NodeKind::Class), "should find class");
276        assert!(kinds.contains(&NodeKind::Struct), "should find struct");
277        assert!(kinds.contains(&NodeKind::Enum), "should find enum");
278        assert!(kinds.contains(&NodeKind::Function), "should find function");
279        assert!(kinds.contains(&NodeKind::Module), "should find namespace");
280    }
281
282    #[test]
283    fn test_parse_cpp_class() {
284        let path = PathBuf::from("main.cpp");
285        let nodes = parse(&path, SAMPLE_CPP).unwrap();
286        let classes: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
287        assert!(classes.iter().any(|c| c.name == "Vector"), "should find Vector class");
288        assert!(classes.iter().any(|c| c.name == "Container"), "should find Container template class");
289    }
290
291    #[test]
292    fn test_parse_cpp_namespace() {
293        let path = PathBuf::from("main.cpp");
294        let nodes = parse(&path, SAMPLE_CPP).unwrap();
295        let ns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Module).collect();
296        assert_eq!(ns.len(), 1);
297        assert_eq!(ns[0].name, "math");
298    }
299
300    #[test]
301    fn test_parse_cpp_functions() {
302        let path = PathBuf::from("main.cpp");
303        let nodes = parse(&path, SAMPLE_CPP).unwrap();
304        let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
305        assert!(fns.iter().any(|f| f.name == "main"), "should find main");
306    }
307
308    #[test]
309    fn test_parse_cpp_edges() {
310        let path = PathBuf::from("main.cpp");
311        let result = parse_with_edges(&path, SAMPLE_CPP).unwrap();
312        assert!(!result.edges.is_empty());
313    }
314
315    #[test]
316    fn test_parse_empty_cpp() {
317        let path = PathBuf::from("empty.cpp");
318        let nodes = parse(&path, "").unwrap();
319        assert!(nodes.len() <= 1);
320    }
321
322    #[test]
323    fn test_parse_cpp_enum_class() {
324        let path = PathBuf::from("main.cpp");
325        let nodes = parse(&path, SAMPLE_CPP).unwrap();
326        let enums: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Enum).collect();
327        assert!(enums.iter().any(|e| e.name == "Color"), "should find enum class Color");
328    }
329}