Skip to main content

cgx_engine/parsers/
rust.rs

1use tree_sitter::{Parser, Query, QueryCursor};
2
3use crate::parser::{EdgeDef, EdgeKind, LanguageParser, NodeDef, NodeKind, ParseResult};
4use crate::walker::SourceFile;
5
6pub struct RustParser {
7    language: tree_sitter::Language,
8}
9
10impl RustParser {
11    pub fn new() -> Self {
12        Self {
13            language: tree_sitter_rust::language(),
14        }
15    }
16}
17
18impl Default for RustParser {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24impl LanguageParser for RustParser {
25    fn extensions(&self) -> &[&str] {
26        &["rs"]
27    }
28
29    fn extract(&self, file: &SourceFile) -> anyhow::Result<ParseResult> {
30        let mut parser = Parser::new();
31        parser.set_language(&self.language)?;
32
33        let tree = parser
34            .parse(&file.content, None)
35            .ok_or_else(|| anyhow::anyhow!("failed to parse {}", file.relative_path))?;
36
37        let source_bytes = file.content.as_bytes();
38        let root = tree.root_node();
39        let mut nodes = Vec::new();
40        let mut edges = Vec::new();
41
42        let fp = format!("file:{}", file.relative_path);
43
44        // Function definitions
45        if let Ok(query) = Query::new(
46            &self.language,
47            "(function_item name: (identifier) @name) @fn",
48        ) {
49            let mut cursor = QueryCursor::new();
50            for m in cursor.matches(&query, root, source_bytes) {
51                let Some(name_capture) = m
52                    .captures
53                    .iter()
54                    .find(|c| query.capture_names()[c.index as usize] == "name")
55                else {
56                    continue;
57                };
58                let name = node_text(name_capture.node, source_bytes);
59                let start = name_capture.node.start_position();
60                let body_end = m
61                    .captures
62                    .iter()
63                    .find(|c| query.capture_names()[c.index as usize] == "fn")
64                    .map(|c| c.node.end_position())
65                    .unwrap_or_else(|| name_capture.node.end_position());
66                let id = format!("fn:{}:{}", file.relative_path, name);
67
68                nodes.push(NodeDef {
69                    id: id.clone(),
70                    kind: NodeKind::Function,
71                    name,
72                    path: file.relative_path.clone(),
73                    line_start: start.row as u32 + 1,
74                    line_end: body_end.row as u32 + 1,
75                    ..Default::default()
76                });
77
78                edges.push(EdgeDef {
79                    src: fp.clone(),
80                    dst: id,
81                    kind: EdgeKind::Exports,
82                    ..Default::default()
83                });
84            }
85        }
86
87        // Struct definitions
88        if let Ok(query) = Query::new(
89            &self.language,
90            "(struct_item name: (type_identifier) @name) @s",
91        ) {
92            extract_type_nodes(
93                &mut nodes,
94                &mut edges,
95                &fp,
96                file,
97                &query,
98                root,
99                source_bytes,
100                NodeKind::Class,
101                "cls",
102            );
103        }
104
105        // Enum definitions
106        if let Ok(query) = Query::new(
107            &self.language,
108            "(enum_item name: (type_identifier) @name) @e",
109        ) {
110            extract_type_nodes(
111                &mut nodes,
112                &mut edges,
113                &fp,
114                file,
115                &query,
116                root,
117                source_bytes,
118                NodeKind::Class,
119                "cls",
120            );
121        }
122
123        // Trait definitions
124        if let Ok(query) = Query::new(
125            &self.language,
126            "(trait_item name: (type_identifier) @name) @t",
127        ) {
128            extract_type_nodes(
129                &mut nodes,
130                &mut edges,
131                &fp,
132                file,
133                &query,
134                root,
135                source_bytes,
136                NodeKind::Class,
137                "cls",
138            );
139        }
140
141        // Impl blocks — add edges for impl'd struct/trait methods
142        if let Ok(query) = Query::new(
143            &self.language,
144            "(impl_item type: (type_identifier) @type body: (_) @body)",
145        ) {
146            let mut cursor = QueryCursor::new();
147            for m in cursor.matches(&query, root, source_bytes) {
148                if let Some(type_cap) = m
149                    .captures
150                    .iter()
151                    .find(|c| query.capture_names()[c.index as usize] == "type")
152                {
153                    let type_name = node_text(type_cap.node, source_bytes);
154                    edges.push(EdgeDef {
155                        src: fp.clone(),
156                        dst: format!("cls:{}:{}", file.relative_path, type_name),
157                        kind: EdgeKind::Exports,
158                        ..Default::default()
159                    });
160                }
161            }
162        }
163
164        // Use statements
165        if let Ok(query) = Query::new(
166            &self.language,
167            "(use_declaration argument: (scoped_identifier path: (_) @path name: (_)?))",
168        ) {
169            let mut cursor = QueryCursor::new();
170            for m in cursor.matches(&query, root, source_bytes) {
171                if let Some(path_cap) = m
172                    .captures
173                    .iter()
174                    .find(|c| query.capture_names()[c.index as usize] == "path")
175                {
176                    let full_path = node_text(path_cap.node, source_bytes);
177                    // Simple case: use crate::foo::bar -> file path is src/foo/bar.rs
178                    let import_path = if full_path.starts_with("crate::") {
179                        format!(
180                            "src/{}.rs",
181                            full_path.trim_start_matches("crate::").replace("::", "/")
182                        )
183                    } else {
184                        continue;
185                    };
186                    edges.push(EdgeDef {
187                        src: fp.clone(),
188                        dst: format!("file:{}", import_path),
189                        kind: EdgeKind::Imports,
190                        ..Default::default()
191                    });
192                }
193            }
194        }
195
196        // Simpler use declarations (use foo::Bar)
197        if let Ok(query) = Query::new(
198            &self.language,
199            "(use_declaration argument: (identifier) @name)",
200        ) {
201            let mut cursor = QueryCursor::new();
202            for m in cursor.matches(&query, root, source_bytes) {
203                if let Some(name_cap) = m
204                    .captures
205                    .iter()
206                    .find(|c| query.capture_names()[c.index as usize] == "name")
207                {
208                    let mod_name = node_text(name_cap.node, source_bytes);
209                    let import_path = mod_name;
210                    edges.push(EdgeDef {
211                        src: fp.clone(),
212                        dst: format!("file:{}.rs", import_path),
213                        kind: EdgeKind::Imports,
214                        ..Default::default()
215                    });
216                }
217            }
218        }
219
220        // Mark pub items as exported
221        mark_pub_exported(&mut nodes, root, source_bytes);
222
223        Ok(ParseResult {
224            nodes,
225            edges,
226            ..Default::default()
227        })
228    }
229}
230
231fn is_pub_item(node: tree_sitter::Node, source_bytes: &[u8]) -> bool {
232    for i in 0..node.child_count() {
233        if let Some(child) = node.child(i) {
234            if child.kind() == "visibility_modifier" {
235                let text = node_text(child, source_bytes);
236                if text == "pub" || text.starts_with("pub(") {
237                    return true;
238                }
239            }
240        }
241    }
242    false
243}
244
245fn mark_pub_exported(
246    nodes: &mut Vec<crate::parser::NodeDef>,
247    root: tree_sitter::Node,
248    source_bytes: &[u8],
249) {
250    walk_pub(nodes, root, source_bytes);
251}
252
253fn walk_pub(nodes: &mut Vec<crate::parser::NodeDef>, node: tree_sitter::Node, source_bytes: &[u8]) {
254    let kind = node.kind();
255    if matches!(
256        kind,
257        "function_item" | "struct_item" | "enum_item" | "trait_item" | "type_item"
258    ) && is_pub_item(node, source_bytes)
259    {
260        // Get the name of this item
261        if let Some(name_node) = node.child_by_field_name("name") {
262            let item_name = node_text(name_node, source_bytes);
263            // Mark the matching node as exported
264            for n in nodes.iter_mut() {
265                if n.name == item_name {
266                    n.metadata = serde_json::json!({"exported": true});
267                }
268            }
269        }
270    }
271
272    let mut cursor = node.walk();
273    if cursor.goto_first_child() {
274        loop {
275            walk_pub(nodes, cursor.node(), source_bytes);
276            if !cursor.goto_next_sibling() {
277                break;
278            }
279        }
280    }
281}
282
283#[allow(clippy::too_many_arguments)]
284fn extract_type_nodes(
285    nodes: &mut Vec<NodeDef>,
286    edges: &mut Vec<EdgeDef>,
287    file_id: &str,
288    file: &SourceFile,
289    query: &Query,
290    root: tree_sitter::Node,
291    source_bytes: &[u8],
292    kind: NodeKind,
293    prefix: &str,
294) {
295    let mut cursor = QueryCursor::new();
296    for m in cursor.matches(query, root, source_bytes) {
297        let Some(name_capture) = m
298            .captures
299            .iter()
300            .find(|c| query.capture_names()[c.index as usize] == "name")
301        else {
302            continue;
303        };
304        let name = node_text(name_capture.node, source_bytes);
305        let start = name_capture.node.start_position();
306        // Use the body node for end position; fall back to name node if no body capture
307        let body_end = m
308            .captures
309            .iter()
310            .find(|c| query.capture_names()[c.index as usize] != "name")
311            .map(|c| c.node.end_position())
312            .unwrap_or_else(|| name_capture.node.end_position());
313        let id = format!("{}:{}:{}", prefix, file.relative_path, name);
314
315        nodes.push(NodeDef {
316            id: id.clone(),
317            kind: kind.clone(),
318            name,
319            path: file.relative_path.clone(),
320            line_start: start.row as u32 + 1,
321            line_end: body_end.row as u32 + 1,
322            ..Default::default()
323        });
324
325        edges.push(EdgeDef {
326            src: file_id.to_string(),
327            dst: id,
328            kind: EdgeKind::Exports,
329            ..Default::default()
330        });
331    }
332}
333
334fn node_text(node: tree_sitter::Node, source: &[u8]) -> String {
335    node.utf8_text(source).unwrap_or("").to_string()
336}