Skip to main content

deagle_parse/
rust_parser.rs

1//! Rust language parser using tree-sitter-rust.
2
3use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5
6/// Result of parsing a file — nodes and their relationships.
7pub struct ParseResult {
8    pub nodes: Vec<Node>,
9    pub edges: Vec<(usize, usize, EdgeKind)>, // (from_idx, to_idx, kind) — indexes into nodes vec
10}
11
12/// Parse a Rust source file and extract definitions + relationships.
13pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
14    parse_with_edges(path, content).map(|r| r.nodes)
15}
16
17/// Parse with edge extraction — returns nodes and relationship tuples.
18pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
19    let mut parser = tree_sitter::Parser::new();
20    let language = tree_sitter_rust::LANGUAGE;
21    parser.set_language(&language.into()).map_err(|e| {
22        DeagleError::Parse {
23            file: path.display().to_string(),
24            message: format!("Failed to set language: {}", e),
25        }
26    })?;
27
28    let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
29        file: path.display().to_string(),
30        message: "Failed to parse file".into(),
31    })?;
32
33    let mut nodes = Vec::new();
34    let file_path = path.to_string_lossy().to_string();
35
36    // Insert file node as index 0
37    nodes.push(Node {
38        id: 0,
39        name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
40        kind: NodeKind::File,
41        language: Language::Rust,
42        file_path: file_path.clone(),
43        line_start: 1,
44        line_end: content.lines().count() as u32,
45        content: None,
46    });
47
48    extract_definitions(tree.root_node(), content, &file_path, &mut nodes);
49
50    // Build CONTAINS edges: file (idx 0) → each top-level entity
51    let mut edges = Vec::new();
52    for i in 1..nodes.len() {
53        edges.push((0, i, EdgeKind::Contains));
54    }
55
56    Ok(ParseResult { nodes, edges })
57}
58
59fn extract_definitions(
60    node: tree_sitter::Node,
61    source: &str,
62    file_path: &str,
63    results: &mut Vec<Node>,
64) {
65    let kind = match node.kind() {
66        "function_item" => Some(NodeKind::Function),
67        "struct_item" => Some(NodeKind::Struct),
68        "enum_item" => Some(NodeKind::Enum),
69        "trait_item" => Some(NodeKind::Trait),
70        "impl_item" => None, // We extract methods inside
71        "const_item" | "static_item" => Some(NodeKind::Constant),
72        "type_item" => Some(NodeKind::TypeAlias),
73        "mod_item" => Some(NodeKind::Module),
74        "use_declaration" => Some(NodeKind::Import),
75        _ => None,
76    };
77
78    if let Some(kind) = kind {
79        if let Some(name) = extract_name(node, source, kind) {
80            let start = node.start_position();
81            let end = node.end_position();
82            let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
83                // Truncate long content
84                if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
85            });
86
87            results.push(Node {
88                id: 0,
89                name,
90                kind,
91                language: Language::Rust,
92                file_path: file_path.to_string(),
93                line_start: (start.row + 1) as u32,
94                line_end: (end.row + 1) as u32,
95                content,
96            });
97        }
98    }
99
100    // Recurse into children — extract methods from impl blocks
101    if node.kind() == "impl_item" {
102        extract_impl_methods(node, source, file_path, results);
103    }
104
105    let mut cursor = node.walk();
106    for child in node.children(&mut cursor) {
107        if child.kind() != "impl_item" || node.kind() != "impl_item" {
108            extract_definitions(child, source, file_path, results);
109        }
110    }
111}
112
113fn extract_impl_methods(
114    impl_node: tree_sitter::Node,
115    source: &str,
116    file_path: &str,
117    results: &mut Vec<Node>,
118) {
119    let mut cursor = impl_node.walk();
120    for child in impl_node.children(&mut cursor) {
121        if child.kind() == "declaration_list" {
122            let mut inner = child.walk();
123            for item in child.children(&mut inner) {
124                if item.kind() == "function_item" {
125                    if let Some(name) = extract_name(item, source, NodeKind::Method) {
126                        let start = item.start_position();
127                        let end = item.end_position();
128                        let content = item.utf8_text(source.as_bytes()).ok().map(|s| {
129                            if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
130                        });
131                        results.push(Node {
132                            id: 0,
133                            name,
134                            kind: NodeKind::Method,
135                            language: Language::Rust,
136                            file_path: file_path.to_string(),
137                            line_start: (start.row + 1) as u32,
138                            line_end: (end.row + 1) as u32,
139                            content,
140                        });
141                    }
142                }
143            }
144        }
145    }
146}
147
148fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
149    match kind {
150        NodeKind::Import => {
151            // For use declarations, take the full text
152            node.utf8_text(source.as_bytes())
153                .ok()
154                .map(|s| s.trim().to_string())
155        }
156        _ => {
157            // Find the name/identifier child
158            let mut cursor = node.walk();
159            for child in node.children(&mut cursor) {
160                if child.kind() == "identifier" || child.kind() == "type_identifier" {
161                    return child.utf8_text(source.as_bytes())
162                        .ok()
163                        .map(|s| s.to_string());
164                }
165            }
166            None
167        }
168    }
169}
170
171#[cfg(test)]
172mod tests {
173    use super::*;
174    use std::path::PathBuf;
175
176    const SAMPLE_RUST: &str = r#"
177use std::collections::HashMap;
178
179const MAX_SIZE: usize = 1024;
180
181pub struct Config {
182    name: String,
183    values: HashMap<String, String>,
184}
185
186pub enum Status {
187    Active,
188    Inactive,
189}
190
191pub trait Processor {
192    fn process(&self, input: &str) -> String;
193}
194
195impl Config {
196    pub fn new(name: &str) -> Self {
197        Self { name: name.to_string(), values: HashMap::new() }
198    }
199
200    pub fn get(&self, key: &str) -> Option<&String> {
201        self.values.get(key)
202    }
203}
204
205pub fn main() {
206    let config = Config::new("test");
207    println!("{:?}", config.get("key"));
208}
209"#;
210
211    #[test]
212    fn test_parse_rust_finds_all_definitions() {
213        let path = PathBuf::from("test.rs");
214        let nodes = parse(&path, SAMPLE_RUST).unwrap();
215
216        let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
217        assert!(kinds.contains(&NodeKind::Import), "should find use declaration");
218        assert!(kinds.contains(&NodeKind::Constant), "should find const");
219        assert!(kinds.contains(&NodeKind::Struct), "should find struct");
220        assert!(kinds.contains(&NodeKind::Enum), "should find enum");
221        assert!(kinds.contains(&NodeKind::Trait), "should find trait");
222        assert!(kinds.contains(&NodeKind::Function), "should find function");
223    }
224
225    #[test]
226    fn test_parse_rust_finds_methods() {
227        let path = PathBuf::from("test.rs");
228        let nodes = parse(&path, SAMPLE_RUST).unwrap();
229
230        let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
231        assert!(methods.len() >= 2, "should find impl methods (new, get), got {}", methods.len());
232        assert!(methods.iter().any(|m| m.name == "new"));
233        assert!(methods.iter().any(|m| m.name == "get"));
234    }
235
236    #[test]
237    fn test_parse_rust_struct_name() {
238        let path = PathBuf::from("test.rs");
239        let nodes = parse(&path, SAMPLE_RUST).unwrap();
240
241        let structs: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Struct).collect();
242        assert_eq!(structs.len(), 1);
243        assert_eq!(structs[0].name, "Config");
244        assert_eq!(structs[0].language, Language::Rust);
245    }
246
247    #[test]
248    fn test_parse_rust_line_numbers() {
249        let path = PathBuf::from("test.rs");
250        let nodes = parse(&path, SAMPLE_RUST).unwrap();
251
252        let main_fn = nodes.iter().find(|n| n.name == "main" && n.kind == NodeKind::Function);
253        assert!(main_fn.is_some(), "should find main function");
254        let main_fn = main_fn.unwrap();
255        assert!(main_fn.line_start > 0, "line numbers should be 1-indexed");
256    }
257
258    #[test]
259    fn test_parse_empty_file() {
260        let path = PathBuf::from("empty.rs");
261        let nodes = parse(&path, "").unwrap();
262        // File node is always created; empty file has just the file node
263        assert!(nodes.len() <= 1);
264    }
265
266    #[test]
267    fn test_parse_rust_trait_method() {
268        let path = PathBuf::from("test.rs");
269        let nodes = parse(&path, SAMPLE_RUST).unwrap();
270
271        let traits: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Trait).collect();
272        assert_eq!(traits.len(), 1);
273        assert_eq!(traits[0].name, "Processor");
274    }
275}