Skip to main content

deagle_parse/
ruby_parser.rs

1//! Ruby language parser using tree-sitter-ruby.
2
3use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5use crate::ParseResult;
6
7pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
8    parse_with_edges(path, content).map(|r| r.nodes)
9}
10
11pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
12    let mut parser = tree_sitter::Parser::new();
13    let language = tree_sitter_ruby::LANGUAGE;
14    parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
15        file: path.display().to_string(),
16        message: format!("Failed to set language: {}", e),
17    })?;
18
19    let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
20        file: path.display().to_string(),
21        message: "Failed to parse file".into(),
22    })?;
23
24    let mut nodes = Vec::new();
25    let file_path = path.to_string_lossy().to_string();
26
27    nodes.push(Node {
28        id: 0,
29        name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
30        kind: NodeKind::File,
31        language: Language::Ruby,
32        file_path: file_path.clone(),
33        line_start: 1,
34        line_end: content.lines().count() as u32,
35        content: None,
36    });
37
38    extract_definitions(tree.root_node(), content, &file_path, &mut nodes);
39
40    let mut edges = Vec::new();
41    for i in 1..nodes.len() {
42        edges.push((0, i, EdgeKind::Contains));
43    }
44    Ok(ParseResult { nodes, edges })
45}
46
47fn extract_definitions(node: tree_sitter::Node, source: &str, file_path: &str, results: &mut Vec<Node>) {
48    let kind = match node.kind() {
49        "method" | "singleton_method" => Some(NodeKind::Method),
50        "class" => Some(NodeKind::Class),
51        "module" => Some(NodeKind::Module),
52        "constant_assignment" | "casgn" => Some(NodeKind::Constant),
53        "call" => {
54            // Detect require/require_relative/include/extend
55            if let Some(method) = node.child_by_field_name("method") {
56                let method_name = method.utf8_text(source.as_bytes()).unwrap_or("");
57                match method_name {
58                    "require" | "require_relative" | "include" | "extend" | "attr_accessor"
59                    | "attr_reader" | "attr_writer" => Some(NodeKind::Import),
60                    _ => None,
61                }
62            } else {
63                None
64            }
65        }
66        _ => None,
67    };
68
69    if let Some(kind) = kind {
70        if let Some(name) = extract_name(node, source, kind) {
71            let start = node.start_position();
72            let end = node.end_position();
73            let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
74                crate::truncate_content(s, 500)
75            });
76            results.push(Node {
77                id: 0, name, kind, language: Language::Ruby,
78                file_path: file_path.to_string(),
79                line_start: (start.row + 1) as u32,
80                line_end: (end.row + 1) as u32,
81                content,
82            });
83        }
84    }
85
86    let mut cursor = node.walk();
87    for child in node.children(&mut cursor) {
88        extract_definitions(child, source, file_path, results);
89    }
90}
91
92fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
93    match kind {
94        NodeKind::Class | NodeKind::Module => {
95            node.child_by_field_name("name")
96                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
97                .map(|s| s.to_string())
98        }
99        NodeKind::Method => {
100            node.child_by_field_name("name")
101                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
102                .map(|s| s.to_string())
103        }
104        NodeKind::Constant => {
105            // constant_assignment: NAME = value
106            let mut cursor = node.walk();
107            for child in node.children(&mut cursor) {
108                if child.kind() == "constant" {
109                    return child.utf8_text(source.as_bytes()).ok().map(|s| s.to_string());
110                }
111            }
112            None
113        }
114        NodeKind::Import => {
115            // require "name" or require_relative "name"
116            node.utf8_text(source.as_bytes()).ok().map(|s| s.trim().to_string())
117        }
118        _ => node.child_by_field_name("name")
119            .and_then(|n| n.utf8_text(source.as_bytes()).ok())
120            .map(|s| s.to_string()),
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use std::path::PathBuf;
128
129    const SAMPLE_RUBY: &str = r#"
130require 'json'
131require_relative 'helpers'
132
133MAX_SIZE = 1024
134VERSION = "1.0.0"
135
136module Animals
137  class Dog
138    attr_accessor :name, :breed
139
140    def initialize(name, breed)
141      @name = name
142      @breed = breed
143    end
144
145    def bark
146      "Woof! I'm #{@name}"
147    end
148
149    def self.species
150      "Canis familiaris"
151    end
152  end
153
154  class Cat
155    def initialize(name)
156      @name = name
157    end
158
159    def meow
160      "Meow!"
161    end
162  end
163end
164
165def greet(name)
166  puts "Hello, #{name}!"
167end
168"#;
169
170    #[test]
171    fn test_parse_ruby_finds_all() {
172        let path = PathBuf::from("app.rb");
173        let nodes = parse(&path, SAMPLE_RUBY).unwrap();
174        let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
175        assert!(kinds.contains(&NodeKind::Import), "should find require");
176        assert!(kinds.contains(&NodeKind::Class), "should find class");
177        assert!(kinds.contains(&NodeKind::Module), "should find module");
178        assert!(kinds.contains(&NodeKind::Method), "should find method");
179    }
180
181    #[test]
182    fn test_parse_ruby_classes() {
183        let path = PathBuf::from("app.rb");
184        let nodes = parse(&path, SAMPLE_RUBY).unwrap();
185        let classes: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Class).collect();
186        assert!(classes.iter().any(|c| c.name == "Dog"), "should find Dog class");
187        assert!(classes.iter().any(|c| c.name == "Cat"), "should find Cat class");
188    }
189
190    #[test]
191    fn test_parse_ruby_module() {
192        let path = PathBuf::from("app.rb");
193        let nodes = parse(&path, SAMPLE_RUBY).unwrap();
194        let mods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Module).collect();
195        assert_eq!(mods.len(), 1);
196        assert_eq!(mods[0].name, "Animals");
197    }
198
199    #[test]
200    fn test_parse_ruby_methods() {
201        let path = PathBuf::from("app.rb");
202        let nodes = parse(&path, SAMPLE_RUBY).unwrap();
203        let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
204        assert!(methods.iter().any(|m| m.name == "initialize"), "should find initialize");
205        assert!(methods.iter().any(|m| m.name == "bark"), "should find bark");
206        assert!(methods.iter().any(|m| m.name == "greet"), "should find greet");
207        assert!(methods.iter().any(|m| m.name == "species"), "should find singleton method species");
208    }
209
210    #[test]
211    fn test_parse_ruby_edges() {
212        let path = PathBuf::from("app.rb");
213        let result = parse_with_edges(&path, SAMPLE_RUBY).unwrap();
214        assert!(!result.edges.is_empty());
215    }
216
217    #[test]
218    fn test_parse_empty_ruby() {
219        let path = PathBuf::from("empty.rb");
220        let nodes = parse(&path, "").unwrap();
221        assert!(nodes.len() <= 1);
222    }
223
224    #[test]
225    fn test_parse_ruby_requires() {
226        let path = PathBuf::from("app.rb");
227        let nodes = parse(&path, SAMPLE_RUBY).unwrap();
228        let imports: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Import).collect();
229        assert!(imports.len() >= 2, "should find require and require_relative, got {}", imports.len());
230    }
231}