Skip to main content

deagle_parse/
typescript_parser.rs

1//! TypeScript/TSX language parser using tree-sitter-typescript.
2
3use deagle_core::{DeagleError, EdgeKind, Language, Node, NodeKind, Result};
4use std::path::Path;
5
6use crate::ParseResult;
7
8/// Parse a TypeScript source file and extract definitions.
9pub fn parse(path: &Path, content: &str) -> Result<Vec<Node>> {
10    parse_with_edges(path, content).map(|r| r.nodes)
11}
12
13/// Parse with edge extraction.
14pub fn parse_with_edges(path: &Path, content: &str) -> Result<ParseResult> {
15    let mut parser = tree_sitter::Parser::new();
16
17    // Use TSX parser (superset of TS — handles both .ts and .tsx)
18    let language = tree_sitter_typescript::LANGUAGE_TSX;
19    parser.set_language(&language.into()).map_err(|e| DeagleError::Parse {
20        file: path.display().to_string(),
21        message: format!("Failed to set language: {}", e),
22    })?;
23
24    let tree = parser.parse(content, None).ok_or_else(|| DeagleError::Parse {
25        file: path.display().to_string(),
26        message: "Failed to parse file".into(),
27    })?;
28
29    let mut nodes = Vec::new();
30    let file_path = path.to_string_lossy().to_string();
31    let lang = Language::TypeScript;
32
33    nodes.push(Node {
34        id: 0,
35        name: path.file_name().and_then(|n| n.to_str()).unwrap_or("unknown").to_string(),
36        kind: NodeKind::File,
37        language: lang,
38        file_path: file_path.clone(),
39        line_start: 1,
40        line_end: content.lines().count() as u32,
41        content: None,
42    });
43
44    extract_definitions(tree.root_node(), content, &file_path, lang, &mut nodes, false);
45
46    let mut edges = Vec::new();
47    for i in 1..nodes.len() {
48        edges.push((0, i, EdgeKind::Contains));
49    }
50
51    Ok(ParseResult { nodes, edges })
52}
53
54fn extract_definitions(
55    node: tree_sitter::Node,
56    source: &str,
57    file_path: &str,
58    lang: Language,
59    results: &mut Vec<Node>,
60    inside_class: bool,
61) {
62    let kind = match node.kind() {
63        "function_declaration" => Some(NodeKind::Function),
64        "method_definition" => Some(NodeKind::Method),
65        "class_declaration" => Some(NodeKind::Class),
66        "interface_declaration" => Some(NodeKind::Interface),
67        "type_alias_declaration" => Some(NodeKind::TypeAlias),
68        "enum_declaration" => Some(NodeKind::Enum),
69        "import_statement" => Some(NodeKind::Import),
70        "export_statement" => None, // recurse into children
71        "lexical_declaration" => {
72            // const/let/var — check for arrow functions or UPPER_CASE constants
73            if !inside_class {
74                extract_lexical(node, source, file_path, lang, results);
75            }
76            None
77        }
78        _ => None,
79    };
80
81    if let Some(kind) = kind {
82        if let Some(name) = extract_name(node, source, kind) {
83            let start = node.start_position();
84            let end = node.end_position();
85            let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
86                if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
87            });
88
89            results.push(Node {
90                id: 0,
91                name,
92                kind,
93                language: lang,
94                file_path: file_path.to_string(),
95                line_start: (start.row + 1) as u32,
96                line_end: (end.row + 1) as u32,
97                content,
98            });
99
100            // Recurse into class body for methods
101            if kind == NodeKind::Class {
102                if let Some(body) = node.child_by_field_name("body") {
103                    let mut cursor = body.walk();
104                    for child in body.children(&mut cursor) {
105                        extract_definitions(child, source, file_path, lang, results, true);
106                    }
107                }
108                return;
109            }
110        }
111    }
112
113    // Recurse
114    if node.kind() != "class_declaration" {
115        let mut cursor = node.walk();
116        for child in node.children(&mut cursor) {
117            extract_definitions(child, source, file_path, lang, results, inside_class);
118        }
119    }
120}
121
122fn extract_lexical(
123    node: tree_sitter::Node,
124    source: &str,
125    file_path: &str,
126    lang: Language,
127    results: &mut Vec<Node>,
128) {
129    let mut cursor = node.walk();
130    for child in node.children(&mut cursor) {
131        if child.kind() == "variable_declarator" {
132            if let Some(name_node) = child.child_by_field_name("name") {
133                let name = name_node.utf8_text(source.as_bytes()).unwrap_or_default().to_string();
134                // Check if value is an arrow function
135                let is_arrow = child.child_by_field_name("value")
136                    .map(|v| v.kind() == "arrow_function")
137                    .unwrap_or(false);
138
139                let kind = if is_arrow {
140                    NodeKind::Function
141                } else if name.chars().all(|c| c.is_uppercase() || c == '_' || c.is_ascii_digit()) && !name.is_empty() {
142                    NodeKind::Constant
143                } else {
144                    return; // skip regular variables
145                };
146
147                let start = node.start_position();
148                let end = node.end_position();
149                let content = node.utf8_text(source.as_bytes()).ok().map(|s| {
150                    if s.len() > 500 { format!("{}...", &s[..500]) } else { s.to_string() }
151                });
152
153                results.push(Node {
154                    id: 0,
155                    name,
156                    kind,
157                    language: lang,
158                    file_path: file_path.to_string(),
159                    line_start: (start.row + 1) as u32,
160                    line_end: (end.row + 1) as u32,
161                    content,
162                });
163            }
164        }
165    }
166}
167
168fn extract_name(node: tree_sitter::Node, source: &str, kind: NodeKind) -> Option<String> {
169    match kind {
170        NodeKind::Import => {
171            node.utf8_text(source.as_bytes())
172                .ok()
173                .map(|s| s.trim().to_string())
174        }
175        _ => {
176            node.child_by_field_name("name")
177                .and_then(|n| n.utf8_text(source.as_bytes()).ok())
178                .map(|s| s.to_string())
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186    use std::path::PathBuf;
187
188    const SAMPLE_TS: &str = r#"
189import { Router } from 'express';
190import type { Request, Response } from 'express';
191
192const MAX_SIZE = 1024;
193
194interface Config {
195    name: string;
196    values: Record<string, string>;
197}
198
199type Status = 'active' | 'inactive';
200
201enum Direction {
202    Up,
203    Down,
204    Left,
205    Right,
206}
207
208class Server {
209    private config: Config;
210
211    constructor(config: Config) {
212        this.config = config;
213    }
214
215    start(): void {
216        console.log('starting');
217    }
218
219    getConfig(): Config {
220        return this.config;
221    }
222}
223
224function createServer(name: string): Server {
225    return new Server({ name, values: {} });
226}
227
228const handler = (req: Request, res: Response) => {
229    res.send('ok');
230};
231
232export function main() {
233    const server = createServer('test');
234    server.start();
235}
236"#;
237
238    #[test]
239    fn test_parse_ts_finds_all_definitions() {
240        let path = PathBuf::from("app.ts");
241        let nodes = parse(&path, SAMPLE_TS).unwrap();
242        let kinds: Vec<_> = nodes.iter().map(|n| n.kind).collect();
243        assert!(kinds.contains(&NodeKind::Import), "should find import");
244        assert!(kinds.contains(&NodeKind::Constant), "should find constant");
245        assert!(kinds.contains(&NodeKind::Interface), "should find interface");
246        assert!(kinds.contains(&NodeKind::TypeAlias), "should find type alias");
247        assert!(kinds.contains(&NodeKind::Enum), "should find enum");
248        assert!(kinds.contains(&NodeKind::Class), "should find class");
249        assert!(kinds.contains(&NodeKind::Function), "should find function");
250    }
251
252    #[test]
253    fn test_parse_ts_class_methods() {
254        let path = PathBuf::from("app.ts");
255        let nodes = parse(&path, SAMPLE_TS).unwrap();
256        let methods: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Method).collect();
257        assert!(methods.iter().any(|m| m.name == "start"));
258        assert!(methods.iter().any(|m| m.name == "getConfig"));
259        assert!(methods.iter().any(|m| m.name == "constructor"));
260    }
261
262    #[test]
263    fn test_parse_ts_arrow_function() {
264        let path = PathBuf::from("app.ts");
265        let nodes = parse(&path, SAMPLE_TS).unwrap();
266        let fns: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Function).collect();
267        assert!(fns.iter().any(|f| f.name == "handler"), "arrow function should be captured");
268        assert!(fns.iter().any(|f| f.name == "createServer"));
269        assert!(fns.iter().any(|f| f.name == "main"));
270    }
271
272    #[test]
273    fn test_parse_ts_interface() {
274        let path = PathBuf::from("app.ts");
275        let nodes = parse(&path, SAMPLE_TS).unwrap();
276        let ifaces: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Interface).collect();
277        assert_eq!(ifaces.len(), 1);
278        assert_eq!(ifaces[0].name, "Config");
279    }
280
281    #[test]
282    fn test_parse_ts_enum() {
283        let path = PathBuf::from("app.ts");
284        let nodes = parse(&path, SAMPLE_TS).unwrap();
285        let enums: Vec<_> = nodes.iter().filter(|n| n.kind == NodeKind::Enum).collect();
286        assert_eq!(enums.len(), 1);
287        assert_eq!(enums[0].name, "Direction");
288    }
289
290    #[test]
291    fn test_parse_ts_edges() {
292        let path = PathBuf::from("app.ts");
293        let result = parse_with_edges(&path, SAMPLE_TS).unwrap();
294        assert!(!result.edges.is_empty());
295        for &(from, _, ref kind) in &result.edges {
296            assert_eq!(from, 0);
297            assert_eq!(*kind, EdgeKind::Contains);
298        }
299    }
300
301    #[test]
302    fn test_parse_empty_ts() {
303        let path = PathBuf::from("empty.ts");
304        let nodes = parse(&path, "").unwrap();
305        assert!(nodes.len() <= 1);
306    }
307}