Skip to main content

deagle_parse/
go_parser.rs

1//! Go language parser using tree-sitter-go.
2
3use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5
6use crate::ParseResult;
7
8/// Parse a Go source file and extract definitions.
9pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
10    parse_with_edges(path, content).map(|r| r.nodes)
11}
12
13/// Parse with edge extraction.
14pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
15    let mut parser = tree_sitter::Parser::new();
16    let language = tree_sitter_go::LANGUAGE;
17    parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
18        file: path.display().to_string(),
19        message: format!("Failed to set language: {}", e),
20    })?;
21
22    let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
23        file: path.display().to_string(),
24        message: "Failed to parse file".into(),
25    })?;
26
27    let mut nodes = Vec::new();
28    let file_path = path.to_string_lossy().to_string();
29
30    nodes.push(Node {
31        id: 0,
32        name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
33        kind: NodeKind::File,
34        language: Language::Go,
35        file_path: file_path.clone(),
36        line_start: 1,
37        line_end: content.lines().count() as u32,
38        content: None,
39    });
40
41    extract_definitions(tree.root_node(), content, &file_path, &mut nodes);
42
43    let mut edges = Vec::new();
44    for i in 1..nodes.len() {
45        edges.push((0, i, EdgeKind::Contains));
46    }
47
48    Ok(ParseResult { nodes, edges })
49}
50
51fn extract_definitions(
52    node: tree_sitter::Node,
53    source: &str,
54    file_path: &str,
55    results: &mut Vec<Node>,
56) {
57    let kind = match node.kind() {
58        "function_declaration" => Some(NodeKind::Function),
59        "method_declaration" => Some(NodeKind::Method),
60        "type_declaration" => None, // handled below — contains type_spec children
61        "type_spec" => {
62            // Check if it's a struct, interface, or type alias
63            if let Some(type_node) = node.child_by_field_name("type") {
64                match type_node.kind() {
65                    "struct_type" => Some(NodeKind::Struct),
66                    "interface_type" => Some(NodeKind::Interface),
67                    _ => Some(NodeKind::TypeAlias),
68                }
69            } else {
70                None
71            }
72        }
73        "import_declaration" => Some(NodeKind::Import),
74        "const_declaration" | "var_declaration" => None, // extract individual specs
75        "const_spec" => Some(NodeKind::Constant),
76        "package_clause" => Some(NodeKind::Module),
77        _ => None,
78    };
79
80    if let Some(kind) = kind {
81        if let Some(name) = extract_name(node, source, kind) {
82            let start = node.start_position();
83            let end = node.end_position();
84            let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
85                if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
86            });
87
88            results.push(Node {
89                id: 0,
90                name,
91                kind,
92                language: Language::Go,
93                file_path: file_path.to_string(),
94                line_start: (start.row + 1) as u32,
95                line_end: (end.row + 1) as u32,
96                content,
97            });
98        }
99    }
100
101    // Recurse into children
102    let mut cursor = node.walk();
103    for child in node.children(&mut cursor) {
104        extract_definitions(child, source, file_path, results);
105    }
106}
107
108fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
109    match kind {
110        NodeKind::Import => {
111            node.utf8_text(source.as_bytes())
112                .ok()
113                .map(|s| s.trim().to_string())
114        }
115        NodeKind::Module => {
116            // package clause: "package main"
117            if let Some(n) = node.child_by_field_name("name") {
118                return n.utf8_text(source.as_bytes()).ok().map(|s| s.to_string());
119            }
120            // Fallback: find package_identifier child
121            let mut c = node.walk();
122            let children: Vec<_> = node.children(&mut c).collect();
123            children.iter()
124                .find(|n| n.kind() == "package_identifier")
125                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
126                .map(|s| s.to_string())
127        }
128        _ => {
129            node.child_by_field_name("name")
130                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
131                .map(|s| s.to_string())
132        }
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use std::path::PathBuf;
140
141    const SAMPLE_GO: &str = r#"
142package main
143
144import (
145    "fmt"
146    "net/http"
147)
148
149const MaxSize = 1024
150
151type Config struct {
152    Name   string
153    Values map[string]string
154}
155
156type Handler interface {
157    ServeHTTP(w http.ResponseWriter, r *http.Request)
158}
159
160func NewConfig(name string) *Config {
161    return &Config{Name: name, Values: make(map[string]string)}
162}
163
164func (c *Config) Get(key string) string {
165    return c.Values[key]
166}
167
168func main() {
169    config := NewConfig("test")
170    fmt.Println(config.Get("key"))
171}
172"#;
173
174    #[test]
175    fn test_parse_go_finds_all_definitions() {
176        let path = PathBuf::from("main.go");
177        let nodes = parse(&path, SAMPLE_GO).unwrap();
178        let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
179        assert!(kinds.contains(&NodeKind::Module), "should find package");
180        assert!(kinds.contains(&NodeKind::Import), "should find import");
181        assert!(kinds.contains(&NodeKind::Constant), "should find const");
182        assert!(kinds.contains(&NodeKind::Struct), "should find struct");
183        assert!(kinds.contains(&NodeKind::Interface), "should find interface");
184        assert!(kinds.contains(&NodeKind::Function), "should find function");
185        assert!(kinds.contains(&NodeKind::Method), "should find method");
186    }
187
188    #[test]
189    fn test_parse_go_struct_name() {
190        let path = PathBuf::from("main.go");
191        let nodes = parse(&path, SAMPLE_GO).unwrap();
192        let structs: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Struct).collect();
193        assert_eq!(structs.len(), 1);
194        assert_eq!(structs[0].name, "Config");
195        assert_eq!(structs[0].language, Language::Go);
196    }
197
198    #[test]
199    fn test_parse_go_methods() {
200        let path = PathBuf::from("main.go");
201        let nodes = parse(&path, SAMPLE_GO).unwrap();
202        let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
203        assert!(methods.iter().any(|m| m.name == "Get"), "should find Get method");
204    }
205
206    #[test]
207    fn test_parse_go_functions() {
208        let path = PathBuf::from("main.go");
209        let nodes = parse(&path, SAMPLE_GO).unwrap();
210        let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
211        assert!(fns.iter().any(|f| f.name == "NewConfig"));
212        assert!(fns.iter().any(|f| f.name == "main"));
213    }
214
215    #[test]
216    fn test_parse_go_interface() {
217        let path = PathBuf::from("main.go");
218        let nodes = parse(&path, SAMPLE_GO).unwrap();
219        let ifaces: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Interface).collect();
220        assert_eq!(ifaces.len(), 1);
221        assert_eq!(ifaces[0].name, "Handler");
222    }
223
224    #[test]
225    fn test_parse_go_edges() {
226        let path = PathBuf::from("main.go");
227        let result = parse_with_edges(&path, SAMPLE_GO).unwrap();
228        assert!(!result.edges.is_empty());
229        for &(from, _, ref kind) in &result.edges {
230            assert_eq!(from, 0);
231            assert_eq!(*kind, EdgeKind::Contains);
232        }
233    }
234
235    #[test]
236    fn test_parse_empty_go() {
237        let path = PathBuf::from("empty.go");
238        let nodes = parse(&path, "").unwrap();
239        assert!(nodes.len() <= 1);
240    }
241}