1use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5use crate::ParseResult;
6
7pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
8 parse_with_edges(path, content).map(|r| r.nodes)
9}
10
11pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
12 let mut parser = tree_sitter::Parser::new();
13 let language = tree_sitter_cpp::LANGUAGE;
14 parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
15 file: path.display().to_string(),
16 message: format!("Failed to set language: {}", e),
17 })?;
18
19 let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
20 file: path.display().to_string(),
21 message: "Failed to parse file".into(),
22 })?;
23
24 let mut nodes = Vec::new();
25 let file_path = path.to_string_lossy().to_string();
26
27 nodes.push(Node {
28 id: 0,
29 name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
30 kind: NodeKind::File,
31 language: Language::Cpp,
32 file_path: file_path.clone(),
33 line_start: 1,
34 line_end: content.lines().count() as u32,
35 content: None,
36 });
37
38 extract_definitions(tree.root_node(), content, &file_path, &mut nodes);
39
40 let mut edges = Vec::new();
41 for i in 1..nodes.len() {
42 edges.push((0, i, EdgeKind::Contains));
43 }
44 Ok(ParseResult { nodes, edges })
45}
46
47fn extract_definitions(node: tree_sitter::Node, source: &str, file_path: &str, results: &mut Vec<Node>) {
48 let kind = match node.kind() {
49 "function_definition" => Some(NodeKind::Function),
50 "declaration" => {
51 if node.child_by_field_name("declarator")
52 .map(|d| d.kind() == "function_declarator")
53 .unwrap_or(false)
54 {
55 Some(NodeKind::Function)
56 } else {
57 None
58 }
59 }
60 "class_specifier" => {
61 if node.child_by_field_name("body").is_some() {
62 Some(NodeKind::Class)
63 } else {
64 None
65 }
66 }
67 "struct_specifier" => {
68 if node.child_by_field_name("body").is_some() {
69 Some(NodeKind::Struct)
70 } else {
71 None
72 }
73 }
74 "enum_specifier" => {
75 if node.child_by_field_name("body").is_some() {
76 Some(NodeKind::Enum)
77 } else {
78 None
79 }
80 }
81 "namespace_definition" => Some(NodeKind::Module),
82 "type_definition" => Some(NodeKind::TypeAlias),
83 "preproc_include" => Some(NodeKind::Import),
84 "preproc_def" => Some(NodeKind::Constant),
85 "template_declaration" => {
86 let mut cursor = node.walk();
88 for child in node.children(&mut cursor) {
89 match child.kind() {
90 "class_specifier" | "struct_specifier" => return extract_template(node, child, source, file_path, results, NodeKind::Class),
91 "function_definition" => return extract_template(node, child, source, file_path, results, NodeKind::Function),
92 "declaration" => return extract_template(node, child, source, file_path, results, NodeKind::Function),
93 _ => {}
94 }
95 }
96 None
97 }
98 _ => None,
99 };
100
101 if let Some(kind) = kind {
102 if let Some(name) = extract_name(node, source, kind) {
103 let start = node.start_position();
104 let end = node.end_position();
105 let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
106 crate::truncate_content(s, 500)
107 });
108 results.push(Node {
109 id: 0, name, kind, language: Language::Cpp,
110 file_path: file_path.to_string(),
111 line_start: (start.row + 1) as u32,
112 line_end: (end.row + 1) as u32,
113 content,
114 });
115 }
116 }
117
118 if node.kind() != "template_declaration" {
120 let mut cursor = node.walk();
121 for child in node.children(&mut cursor) {
122 extract_definitions(child, source, file_path, results);
123 }
124 }
125}
126
127fn extract_template(
128 template_node: tree_sitter::Node,
129 inner_node: tree_sitter::Node,
130 source: &str,
131 file_path: &str,
132 results: &mut Vec<Node>,
133 kind: NodeKind,
134) {
135 if let Some(name) = extract_name(inner_node, source, kind) {
136 let start = template_node.start_position();
137 let end = template_node.end_position();
138 let content = template_node.utf8_text(source.as_bytes()).ok().map(|s| {
139 crate::truncate_content(s, 500)
140 });
141 results.push(Node {
142 id: 0, name, kind, language: Language::Cpp,
143 file_path: file_path.to_string(),
144 line_start: (start.row + 1) as u32,
145 line_end: (end.row + 1) as u32,
146 content,
147 });
148 }
149 let mut cursor = inner_node.walk();
151 for child in inner_node.children(&mut cursor) {
152 extract_definitions(child, source, file_path, results);
153 }
154}
155
156fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
157 match kind {
158 NodeKind::Import => node.utf8_text(source.as_bytes()).ok().map(|s| s.trim().to_string()),
159 NodeKind::Constant => {
160 node.child_by_field_name("name")
161 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
162 .map(|s| s.to_string())
163 }
164 NodeKind::Function => {
165 fn find_fn_name(n: tree_sitter::Node, src: &str) -> Option<String> {
166 if n.kind() == "identifier" || n.kind() == "field_identifier" || n.kind() == "destructor_name" {
167 return n.utf8_text(src.as_bytes()).ok().map(|s| s.to_string());
168 }
169 if n.kind() == "qualified_identifier" || n.kind() == "scoped_identifier" {
171 return n.utf8_text(src.as_bytes()).ok().map(|s| s.to_string());
172 }
173 if let Some(d) = n.child_by_field_name("declarator") {
174 return find_fn_name(d, src);
175 }
176 let mut c = n.walk();
177 for child in n.children(&mut c) {
178 if let Some(name) = find_fn_name(child, src) {
179 return Some(name);
180 }
181 }
182 None
183 }
184 find_fn_name(node, source)
185 }
186 NodeKind::Class | NodeKind::Struct | NodeKind::Enum | NodeKind::Module => {
187 node.child_by_field_name("name")
188 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
189 .map(|s| s.to_string())
190 }
191 NodeKind::TypeAlias => {
192 node.child_by_field_name("declarator")
193 .and_then(|n| {
194 if n.kind() == "type_identifier" {
195 n.utf8_text(source.as_bytes()).ok().map(|s| s.to_string())
196 } else {
197 None
198 }
199 })
200 }
201 _ => node.child_by_field_name("name")
202 .and_then(|n| n.utf8_text(source.as_bytes()).ok())
203 .map(|s| s.to_string()),
204 }
205}
206
207#[cfg(test)]
208mod tests {
209 use super::*;
210 use std::path::PathBuf;
211
212 const SAMPLE_CPP: &str = r#"
213#include <iostream>
214#include <vector>
215
216#define MAX_SIZE 1024
217
218namespace math {
219
220class Vector {
221public:
222 double x, y, z;
223
224 Vector(double x, double y, double z) : x(x), y(y), z(z) {}
225
226 double magnitude() const {
227 return std::sqrt(x*x + y*y + z*z);
228 }
229
230 Vector operator+(const Vector& other) const {
231 return Vector(x + other.x, y + other.y, z + other.z);
232 }
233};
234
235struct Point {
236 int x;
237 int y;
238};
239
240enum class Color {
241 Red,
242 Green,
243 Blue
244};
245
246template<typename T>
247class Container {
248 T value;
249public:
250 Container(T v) : value(v) {}
251 T get() const { return value; }
252};
253
254template<typename T>
255T add(T a, T b) {
256 return a + b;
257}
258
259} // namespace math
260
261int main(int argc, char* argv[]) {
262 math::Vector v(1, 2, 3);
263 std::cout << v.magnitude() << std::endl;
264 return 0;
265}
266"#;
267
268 #[test]
269 fn test_parse_cpp_finds_all() {
270 let path = PathBuf::from("main.cpp");
271 let nodes = parse(&path, SAMPLE_CPP).unwrap();
272 let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
273 assert!(kinds.contains(&NodeKind::Import), "should find #include");
274 assert!(kinds.contains(&NodeKind::Constant), "should find #define");
275 assert!(kinds.contains(&NodeKind::Class), "should find class");
276 assert!(kinds.contains(&NodeKind::Struct), "should find struct");
277 assert!(kinds.contains(&NodeKind::Enum), "should find enum");
278 assert!(kinds.contains(&NodeKind::Function), "should find function");
279 assert!(kinds.contains(&NodeKind::Module), "should find namespace");
280 }
281
282 #[test]
283 fn test_parse_cpp_class() {
284 let path = PathBuf::from("main.cpp");
285 let nodes = parse(&path, SAMPLE_CPP).unwrap();
286 let classes: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
287 assert!(classes.iter().any(|c| c.name == "Vector"), "should find Vector class");
288 assert!(classes.iter().any(|c| c.name == "Container"), "should find Container template class");
289 }
290
291 #[test]
292 fn test_parse_cpp_namespace() {
293 let path = PathBuf::from("main.cpp");
294 let nodes = parse(&path, SAMPLE_CPP).unwrap();
295 let ns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Module).collect();
296 assert_eq!(ns.len(), 1);
297 assert_eq!(ns[0].name, "math");
298 }
299
300 #[test]
301 fn test_parse_cpp_functions() {
302 let path = PathBuf::from("main.cpp");
303 let nodes = parse(&path, SAMPLE_CPP).unwrap();
304 let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
305 assert!(fns.iter().any(|f| f.name == "main"), "should find main");
306 }
307
308 #[test]
309 fn test_parse_cpp_edges() {
310 let path = PathBuf::from("main.cpp");
311 let result = parse_with_edges(&path, SAMPLE_CPP).unwrap();
312 assert!(!result.edges.is_empty());
313 }
314
315 #[test]
316 fn test_parse_empty_cpp() {
317 let path = PathBuf::from("empty.cpp");
318 let nodes = parse(&path, "").unwrap();
319 assert!(nodes.len() <= 1);
320 }
321
322 #[test]
323 fn test_parse_cpp_enum_class() {
324 let path = PathBuf::from("main.cpp");
325 let nodes = parse(&path, SAMPLE_CPP).unwrap();
326 let enums: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Enum).collect();
327 assert!(enums.iter().any(|e| e.name == "Color"), "should find enum class Color");
328 }
329}