Skip to main content

dk_engine/parser/
rust_parser.rs

1use super::LanguageParser;
2use dk_core::{CallKind, Import, RawCallEdge, Result, Span, Symbol, SymbolKind, TypeInfo, Visibility};
3use std::path::Path;
4use tree_sitter::{Node, Parser, TreeCursor};
5use uuid::Uuid;
6
7/// Rust parser backed by tree-sitter.
8///
9/// Extracts symbols, call edges, imports, and (stub) type information from
10/// Rust source files.
11pub struct RustParser;
12
13impl RustParser {
14    pub fn new() -> Self {
15        Self
16    }
17
18    /// Create a configured tree-sitter parser for Rust.
19    fn create_parser() -> Result<Parser> {
20        let mut parser = Parser::new();
21        parser
22            .set_language(&tree_sitter_rust::LANGUAGE.into())
23            .map_err(|e| dk_core::Error::ParseError(format!("Failed to load Rust grammar: {e}")))?;
24        Ok(parser)
25    }
26
27    /// Parse source bytes into a tree-sitter tree.
28    fn parse_tree(source: &[u8]) -> Result<tree_sitter::Tree> {
29        let mut parser = Self::create_parser()?;
30        parser
31            .parse(source, None)
32            .ok_or_else(|| dk_core::Error::ParseError("tree-sitter parse returned None".into()))
33    }
34
35    /// Determine the visibility of a node by checking for a `visibility_modifier` child.
36    fn node_visibility(node: &Node, source: &[u8]) -> Visibility {
37        let mut cursor = node.walk();
38        for child in node.children(&mut cursor) {
39            if child.kind() == "visibility_modifier" {
40                let text = &source[child.start_byte()..child.end_byte()];
41                let text_str = std::str::from_utf8(text).unwrap_or("");
42                if text_str.contains("crate") {
43                    return Visibility::Crate;
44                }
45                if text_str.contains("super") {
46                    return Visibility::Super;
47                }
48                return Visibility::Public;
49            }
50        }
51        Visibility::Private
52    }
53
54    /// Extract the name from a node by looking for the `name` field first,
55    /// then falling back to looking for specific identifier children.
56    fn node_name(node: &Node, source: &[u8]) -> Option<String> {
57        // For impl_item, construct a name from type + trait
58        if node.kind() == "impl_item" {
59            return Self::impl_name(node, source);
60        }
61
62        // Try the "name" field (works for function_item, struct_item, enum_item, trait_item, etc.)
63        if let Some(name_node) = node.child_by_field_name("name") {
64            let text = &source[name_node.start_byte()..name_node.end_byte()];
65            return std::str::from_utf8(text).ok().map(|s| s.to_string());
66        }
67
68        None
69    }
70
71    /// Construct a name for an impl block: "impl Trait for Type" or "impl Type".
72    fn impl_name(node: &Node, source: &[u8]) -> Option<String> {
73        let mut type_name = None;
74        let mut trait_name = None;
75
76        // Look for the type being implemented and optional trait
77        if let Some(ty) = node.child_by_field_name("type") {
78            let text = &source[ty.start_byte()..ty.end_byte()];
79            type_name = std::str::from_utf8(text).ok().map(|s| s.to_string());
80        }
81
82        if let Some(tr) = node.child_by_field_name("trait") {
83            let text = &source[tr.start_byte()..tr.end_byte()];
84            trait_name = std::str::from_utf8(text).ok().map(|s| s.to_string());
85        }
86
87        match (trait_name, type_name) {
88            (Some(tr), Some(ty)) => Some(format!("impl {tr} for {ty}")),
89            (None, Some(ty)) => Some(format!("impl {ty}")),
90            _ => Some("impl".to_string()),
91        }
92    }
93
94    /// Extract the first line of the node's source text as the signature.
95    fn node_signature(node: &Node, source: &[u8]) -> Option<String> {
96        let text = &source[node.start_byte()..node.end_byte()];
97        let text_str = std::str::from_utf8(text).ok()?;
98        let first_line = text_str.lines().next()?;
99        Some(first_line.trim().to_string())
100    }
101
102    /// Collect preceding `///` doc comments for a node.
103    fn doc_comments(node: &Node, source: &[u8]) -> Option<String> {
104        let mut comments = Vec::new();
105        let mut sibling = node.prev_sibling();
106
107        while let Some(prev) = sibling {
108            if prev.kind() == "line_comment" {
109                let text = &source[prev.start_byte()..prev.end_byte()];
110                if let Ok(s) = std::str::from_utf8(text) {
111                    let trimmed = s.trim();
112                    if trimmed.starts_with("///") {
113                        // Strip the `/// ` prefix
114                        let content = trimmed.strip_prefix("/// ").unwrap_or(
115                            trimmed.strip_prefix("///").unwrap_or(trimmed),
116                        );
117                        comments.push(content.to_string());
118                        sibling = prev.prev_sibling();
119                        continue;
120                    }
121                }
122            }
123            break;
124        }
125
126        if comments.is_empty() {
127            None
128        } else {
129            comments.reverse();
130            Some(comments.join("\n"))
131        }
132    }
133
134    /// Map a tree-sitter node kind to our SymbolKind, if applicable.
135    fn map_symbol_kind(kind: &str) -> Option<SymbolKind> {
136        match kind {
137            "function_item" => Some(SymbolKind::Function),
138            "struct_item" => Some(SymbolKind::Struct),
139            "enum_item" => Some(SymbolKind::Enum),
140            "trait_item" => Some(SymbolKind::Trait),
141            "impl_item" => Some(SymbolKind::Impl),
142            "type_item" => Some(SymbolKind::TypeAlias),
143            "const_item" => Some(SymbolKind::Const),
144            "static_item" => Some(SymbolKind::Static),
145            "mod_item" => Some(SymbolKind::Module),
146            _ => None,
147        }
148    }
149
150    /// Find the name of the enclosing function for a given node, if any.
151    fn enclosing_function_name(node: &Node, source: &[u8]) -> String {
152        let mut current = node.parent();
153        while let Some(parent) = current {
154            if parent.kind() == "function_item" {
155                if let Some(name_node) = parent.child_by_field_name("name") {
156                    let text = &source[name_node.start_byte()..name_node.end_byte()];
157                    if let Ok(name) = std::str::from_utf8(text) {
158                        return name.to_string();
159                    }
160                }
161            }
162            current = parent.parent();
163        }
164        "<module>".to_string()
165    }
166
167    /// Recursively walk the tree to extract call edges.
168    fn walk_calls(cursor: &mut TreeCursor, source: &[u8], calls: &mut Vec<RawCallEdge>) {
169        let node = cursor.node();
170
171        match node.kind() {
172            "call_expression" => {
173                // Direct function call: get the function name from "function" field
174                if let Some(func_node) = node.child_by_field_name("function") {
175                    let callee = Self::extract_callee_name(&func_node, source);
176                    if !callee.is_empty() {
177                        let caller = Self::enclosing_function_name(&node, source);
178                        calls.push(RawCallEdge {
179                            caller_name: caller,
180                            callee_name: callee,
181                            call_site: Span {
182                                start_byte: node.start_byte() as u32,
183                                end_byte: node.end_byte() as u32,
184                            },
185                            kind: CallKind::DirectCall,
186                        });
187                    }
188                }
189            }
190            "method_call_expression" => {
191                // method_call_expression has a "name" field for the method name
192                // In tree-sitter-rust, the method name is in the "name" field
193                // but it might also be the last identifier child. Let's try field first.
194                let method_name = if let Some(name_node) = node.child_by_field_name("name") {
195                    let text = &source[name_node.start_byte()..name_node.end_byte()];
196                    std::str::from_utf8(text).unwrap_or("").to_string()
197                } else {
198                    // fallback: scan for identifier children
199                    Self::last_identifier_child(&node, source)
200                };
201
202                if !method_name.is_empty() {
203                    let caller = Self::enclosing_function_name(&node, source);
204                    calls.push(RawCallEdge {
205                        caller_name: caller,
206                        callee_name: method_name,
207                        call_site: Span {
208                            start_byte: node.start_byte() as u32,
209                            end_byte: node.end_byte() as u32,
210                        },
211                        kind: CallKind::MethodCall,
212                    });
213                }
214            }
215            "macro_invocation" => {
216                // macro_invocation has a "macro" field for the macro name
217                if let Some(macro_node) = node.child_by_field_name("macro") {
218                    let text = &source[macro_node.start_byte()..macro_node.end_byte()];
219                    if let Ok(name) = std::str::from_utf8(text) {
220                        let caller = Self::enclosing_function_name(&node, source);
221                        calls.push(RawCallEdge {
222                            caller_name: caller,
223                            callee_name: name.to_string(),
224                            call_site: Span {
225                                start_byte: node.start_byte() as u32,
226                                end_byte: node.end_byte() as u32,
227                            },
228                            kind: CallKind::MacroInvocation,
229                        });
230                    }
231                }
232            }
233            _ => {}
234        }
235
236        // Recurse into children
237        if cursor.goto_first_child() {
238            loop {
239                Self::walk_calls(cursor, source, calls);
240                if !cursor.goto_next_sibling() {
241                    break;
242                }
243            }
244            cursor.goto_parent();
245        }
246    }
247
248    /// Extract callee name from a call expression's function node.
249    /// Handles identifiers, field expressions (e.g. `module::func`), and scoped identifiers.
250    fn extract_callee_name(node: &Node, source: &[u8]) -> String {
251        let text = &source[node.start_byte()..node.end_byte()];
252        std::str::from_utf8(text).unwrap_or("").to_string()
253    }
254
255    /// Get the last identifier child of a node (fallback for method names).
256    fn last_identifier_child(node: &Node, source: &[u8]) -> String {
257        let mut cursor = node.walk();
258        let mut last_ident = String::new();
259        for child in node.children(&mut cursor) {
260            if child.kind() == "identifier" || child.kind() == "field_identifier" {
261                let text = &source[child.start_byte()..child.end_byte()];
262                if let Ok(name) = std::str::from_utf8(text) {
263                    last_ident = name.to_string();
264                }
265            }
266        }
267        last_ident
268    }
269
270    /// Extract the full path from a use_declaration node.
271    fn extract_use_path(node: &Node, source: &[u8]) -> Vec<Import> {
272        let mut imports = Vec::new();
273
274        // Get the full text of the use declaration (minus `use` keyword and semicolon)
275        // We need to find the use_path/scoped_use_list within the use_declaration
276        let mut cursor = node.walk();
277        for child in node.children(&mut cursor) {
278            match child.kind() {
279                "use_as_clause" | "scoped_identifier" | "use_wildcard" | "identifier"
280                | "scoped_use_list" | "use_list" => {
281                    Self::collect_imports_from_node(&child, source, "", &mut imports);
282                }
283                _ => {}
284            }
285        }
286
287        // If we didn't extract any imports from structured children, fall back
288        // to extracting the full text path
289        if imports.is_empty() {
290            let text = &source[node.start_byte()..node.end_byte()];
291            if let Ok(full_text) = std::str::from_utf8(text) {
292                // Strip `use ` prefix and `;` suffix
293                let path = full_text
294                    .trim()
295                    .strip_prefix("use ")
296                    .unwrap_or(full_text.trim())
297                    .strip_suffix(';')
298                    .unwrap_or(full_text.trim())
299                    .trim();
300
301                if !path.is_empty() {
302                    let is_external = Self::is_external_path(path);
303                    let imported_name = path.rsplit("::").next().unwrap_or(path).to_string();
304                    imports.push(Import {
305                        module_path: path.to_string(),
306                        imported_name,
307                        alias: None,
308                        is_external,
309                    });
310                }
311            }
312        }
313
314        imports
315    }
316
317    /// Recursively collect imports from a use tree node.
318    fn collect_imports_from_node(
319        node: &Node,
320        source: &[u8],
321        prefix: &str,
322        imports: &mut Vec<Import>,
323    ) {
324        let text = &source[node.start_byte()..node.end_byte()];
325        let text_str = std::str::from_utf8(text).unwrap_or("");
326
327        match node.kind() {
328            "scoped_identifier" | "identifier" | "use_as_clause" | "use_wildcard" => {
329                let full_path = if prefix.is_empty() {
330                    text_str.to_string()
331                } else {
332                    format!("{prefix}::{text_str}")
333                };
334
335                let is_external = Self::is_external_path(&full_path);
336                let imported_name = full_path.rsplit("::").next().unwrap_or(&full_path).to_string();
337
338                // Check for alias in use_as_clause
339                let alias = if node.kind() == "use_as_clause" {
340                    node.child_by_field_name("alias").and_then(|a| {
341                        let a_text = &source[a.start_byte()..a.end_byte()];
342                        std::str::from_utf8(a_text).ok().map(|s| s.to_string())
343                    })
344                } else {
345                    None
346                };
347
348                imports.push(Import {
349                    module_path: full_path,
350                    imported_name,
351                    alias,
352                    is_external,
353                });
354            }
355            "scoped_use_list" => {
356                // Has a "path" field (the prefix) and a "list" field (the use_list)
357                let path_prefix = node.child_by_field_name("path").map(|p| {
358                    let p_text = &source[p.start_byte()..p.end_byte()];
359                    std::str::from_utf8(p_text).unwrap_or("").to_string()
360                });
361
362                let combined_prefix = match (prefix, path_prefix.as_deref()) {
363                    ("", Some(p)) => p.to_string(),
364                    (pfx, Some(p)) => format!("{pfx}::{p}"),
365                    (pfx, None) => pfx.to_string(),
366                };
367
368                if let Some(list) = node.child_by_field_name("list") {
369                    let mut cursor = list.walk();
370                    for child in list.children(&mut cursor) {
371                        Self::collect_imports_from_node(&child, source, &combined_prefix, imports);
372                    }
373                }
374            }
375            "use_list" => {
376                let mut cursor = node.walk();
377                for child in node.children(&mut cursor) {
378                    Self::collect_imports_from_node(&child, source, prefix, imports);
379                }
380            }
381            _ => {}
382        }
383    }
384
385    /// Determine if an import path is external (not starting with crate::, super::, self::).
386    fn is_external_path(path: &str) -> bool {
387        !path.starts_with("crate::")
388            && !path.starts_with("crate")
389            && !path.starts_with("super::")
390            && !path.starts_with("super")
391            && !path.starts_with("self::")
392            && !path.starts_with("self")
393    }
394}
395
396impl Default for RustParser {
397    fn default() -> Self {
398        Self::new()
399    }
400}
401
402impl LanguageParser for RustParser {
403    fn extensions(&self) -> &[&str] {
404        &["rs"]
405    }
406
407    fn extract_symbols(&self, source: &[u8], file_path: &Path) -> Result<Vec<Symbol>> {
408        if source.is_empty() {
409            return Ok(vec![]);
410        }
411
412        let tree = Self::parse_tree(source)?;
413        let root = tree.root_node();
414        let mut symbols = Vec::new();
415        let mut cursor = root.walk();
416
417        for node in root.children(&mut cursor) {
418            if let Some(kind) = Self::map_symbol_kind(node.kind()) {
419                let name = Self::node_name(&node, source).unwrap_or_default();
420                if name.is_empty() {
421                    continue;
422                }
423
424                let visibility = Self::node_visibility(&node, source);
425                let signature = Self::node_signature(&node, source);
426                let doc_comment = Self::doc_comments(&node, source);
427
428                symbols.push(Symbol {
429                    id: Uuid::new_v4(),
430                    name: name.clone(),
431                    qualified_name: name,
432                    kind,
433                    visibility,
434                    file_path: file_path.to_path_buf(),
435                    span: Span {
436                        start_byte: node.start_byte() as u32,
437                        end_byte: node.end_byte() as u32,
438                    },
439                    signature,
440                    doc_comment,
441                    parent: None,
442                    last_modified_by: None,
443                    last_modified_intent: None,
444                });
445            }
446        }
447
448        Ok(symbols)
449    }
450
451    fn extract_calls(&self, source: &[u8], _file_path: &Path) -> Result<Vec<RawCallEdge>> {
452        if source.is_empty() {
453            return Ok(vec![]);
454        }
455
456        let tree = Self::parse_tree(source)?;
457        let root = tree.root_node();
458        let mut calls = Vec::new();
459        let mut cursor = root.walk();
460
461        Self::walk_calls(&mut cursor, source, &mut calls);
462
463        Ok(calls)
464    }
465
466    fn extract_types(&self, _source: &[u8], _file_path: &Path) -> Result<Vec<TypeInfo>> {
467        // Stub: will be enhanced later
468        Ok(vec![])
469    }
470
471    fn extract_imports(&self, source: &[u8], _file_path: &Path) -> Result<Vec<Import>> {
472        if source.is_empty() {
473            return Ok(vec![]);
474        }
475
476        let tree = Self::parse_tree(source)?;
477        let root = tree.root_node();
478        let mut imports = Vec::new();
479        let mut cursor = root.walk();
480
481        for node in root.children(&mut cursor) {
482            if node.kind() == "use_declaration" {
483                imports.extend(Self::extract_use_path(&node, source));
484            }
485        }
486
487        Ok(imports)
488    }
489}