neural_shared/parser/
typescript.rs

1//! TypeScript/JavaScript parser using tree-sitter
2
3use super::{Location, ParsedFile, Parser, Symbol, SymbolKind};
4use crate::Result;
5use std::path::Path;
6use tree_sitter::{Node, Parser as TSParser, Tree};
7
8pub struct TypeScriptParser;
9
10impl TypeScriptParser {
11    pub fn new() -> Result<Self> {
12        Ok(Self)
13    }
14
15    fn extract_definitions(&self, tree: &Tree, source: &str, file_path: &str) -> Vec<Symbol> {
16        let mut definitions = Vec::new();
17        let root = tree.root_node();
18
19        self.traverse_for_definitions(root, source, file_path, &mut definitions, None);
20
21        definitions
22    }
23
24    fn traverse_for_definitions(
25        &self,
26        node: Node,
27        source: &str,
28        file_path: &str,
29        definitions: &mut Vec<Symbol>,
30        current_class: Option<String>,
31    ) {
32        let kind = node.kind();
33
34        match kind {
35            "function_declaration" | "function" => {
36                // Extract function name
37                if let Some(name_node) = node.child_by_field_name("name") {
38                    let name = name_node
39                        .utf8_text(source.as_bytes())
40                        .unwrap_or("")
41                        .to_string();
42                    if !name.is_empty() {
43                        let pos = name_node.start_position();
44
45                        definitions.push(Symbol::new(
46                            name,
47                            SymbolKind::Function,
48                            Location {
49                                file: file_path.to_string(),
50                                line: pos.row + 1,
51                                column: pos.column,
52                            },
53                        ));
54                    }
55                }
56            }
57            "arrow_function" => {
58                // Arrow functions assigned to variables
59                // We'll handle this when we process variable declarations
60            }
61            "method_definition" => {
62                // Extract method name (inside class)
63                if let Some(name_node) = node.child_by_field_name("name") {
64                    let name = name_node
65                        .utf8_text(source.as_bytes())
66                        .unwrap_or("")
67                        .to_string();
68                    if !name.is_empty() {
69                        let pos = name_node.start_position();
70
71                        let symbol_kind = if let Some(ref class_name) = current_class {
72                            SymbolKind::Method {
73                                class_name: class_name.clone(),
74                            }
75                        } else {
76                            SymbolKind::Function
77                        };
78
79                        definitions.push(Symbol::new(
80                            name,
81                            symbol_kind,
82                            Location {
83                                file: file_path.to_string(),
84                                line: pos.row + 1,
85                                column: pos.column,
86                            },
87                        ));
88                    }
89                }
90            }
91            "class_declaration" | "class" => {
92                // Extract class name
93                if let Some(name_node) = node.child_by_field_name("name") {
94                    let name = name_node
95                        .utf8_text(source.as_bytes())
96                        .unwrap_or("")
97                        .to_string();
98                    if !name.is_empty() {
99                        let pos = name_node.start_position();
100
101                        definitions.push(Symbol::new(
102                            name.clone(),
103                            SymbolKind::Class,
104                            Location {
105                                file: file_path.to_string(),
106                                line: pos.row + 1,
107                                column: pos.column,
108                            },
109                        ));
110
111                        // Traverse class body with class context
112                        let mut cursor = node.walk();
113                        for child in node.children(&mut cursor) {
114                            self.traverse_for_definitions(
115                                child,
116                                source,
117                                file_path,
118                                definitions,
119                                Some(name.clone()),
120                            );
121                        }
122                        return; // Don't traverse children again below
123                    }
124                }
125            }
126            "variable_declarator" => {
127                // Handle const foo = function() {} or const foo = () => {}
128                if let Some(name_node) = node.child_by_field_name("name") {
129                    if let Some(value_node) = node.child_by_field_name("value") {
130                        let value_kind = value_node.kind();
131                        if value_kind == "function" || value_kind == "arrow_function" {
132                            let name = name_node
133                                .utf8_text(source.as_bytes())
134                                .unwrap_or("")
135                                .to_string();
136                            if !name.is_empty() {
137                                let pos = name_node.start_position();
138
139                                definitions.push(Symbol::new(
140                                    name,
141                                    SymbolKind::Function,
142                                    Location {
143                                        file: file_path.to_string(),
144                                        line: pos.row + 1,
145                                        column: pos.column,
146                                    },
147                                ));
148                            }
149                        }
150                    }
151                }
152            }
153            _ => {}
154        }
155
156        // Traverse children
157        let mut cursor = node.walk();
158        for child in node.children(&mut cursor) {
159            self.traverse_for_definitions(
160                child,
161                source,
162                file_path,
163                definitions,
164                current_class.clone(),
165            );
166        }
167    }
168
169    fn extract_usages(&self, tree: &Tree, source: &str, file_path: &str) -> Vec<Symbol> {
170        let mut usages = Vec::new();
171        let root = tree.root_node();
172
173        self.traverse_for_usages(root, source, file_path, &mut usages);
174
175        usages
176    }
177
178    fn traverse_for_usages(
179        &self,
180        node: Node,
181        source: &str,
182        file_path: &str,
183        usages: &mut Vec<Symbol>,
184    ) {
185        let kind = node.kind();
186
187        match kind {
188            "call_expression" => {
189                // Extract function name being called
190                if let Some(func_node) = node.child_by_field_name("function") {
191                    let name = self.extract_call_name(func_node, source);
192                    if !name.is_empty() {
193                        let pos = func_node.start_position();
194                        usages.push(Symbol::new(
195                            name,
196                            SymbolKind::Function,
197                            Location {
198                                file: file_path.to_string(),
199                                line: pos.row + 1,
200                                column: pos.column,
201                            },
202                        ));
203                    }
204                }
205            }
206            "new_expression" => {
207                // Track class instantiation
208                if let Some(class_node) = node.child_by_field_name("constructor") {
209                    let name = class_node
210                        .utf8_text(source.as_bytes())
211                        .unwrap_or("")
212                        .to_string();
213                    if !name.is_empty() {
214                        let pos = class_node.start_position();
215                        usages.push(Symbol::new(
216                            name,
217                            SymbolKind::Class,
218                            Location {
219                                file: file_path.to_string(),
220                                line: pos.row + 1,
221                                column: pos.column,
222                            },
223                        ));
224                    }
225                }
226            }
227            _ => {}
228        }
229
230        // Traverse children
231        let mut cursor = node.walk();
232        for child in node.children(&mut cursor) {
233            self.traverse_for_usages(child, source, file_path, usages);
234        }
235    }
236
237    fn extract_call_name(&self, node: Node, source: &str) -> String {
238        match node.kind() {
239            "identifier" => node.utf8_text(source.as_bytes()).unwrap_or("").to_string(),
240            "member_expression" => {
241                // For obj.method() calls, extract the method name
242                if let Some(prop_node) = node.child_by_field_name("property") {
243                    prop_node
244                        .utf8_text(source.as_bytes())
245                        .unwrap_or("")
246                        .to_string()
247                } else {
248                    String::new()
249                }
250            }
251            _ => String::new(),
252        }
253    }
254
255    fn extract_entry_points(&self, tree: &Tree, source: &str) -> Vec<String> {
256        let mut entry_points = Vec::new();
257        let root = tree.root_node();
258
259        self.traverse_for_entry_points(root, source, &mut entry_points);
260
261        entry_points
262    }
263
264    fn traverse_for_entry_points(&self, node: Node, source: &str, entry_points: &mut Vec<String>) {
265        let kind = node.kind();
266
267        // Detect top-level call expressions (like main())
268        if kind == "expression_statement" {
269            if let Some(expr) = node.child(0) {
270                if expr.kind() == "call_expression" {
271                    if let Some(func_node) = expr.child_by_field_name("function") {
272                        let name = self.extract_call_name(func_node, source);
273                        if !name.is_empty() {
274                            entry_points.push(name);
275                        }
276                    }
277                }
278            }
279        }
280
281        // Detect exported functions as entry points
282        if kind == "export_statement" {
283            // Find function/class being exported
284            let mut cursor = node.walk();
285            for child in node.children(&mut cursor) {
286                if child.kind() == "function_declaration" || child.kind() == "class_declaration" {
287                    if let Some(name_node) = child.child_by_field_name("name") {
288                        let name = name_node.utf8_text(source.as_bytes()).unwrap_or("");
289                        if !name.is_empty() {
290                            entry_points.push(name.to_string());
291                        }
292                    }
293                }
294            }
295        }
296
297        // Detect test functions (describe, it, test)
298        if kind == "call_expression" {
299            if let Some(func_node) = node.child_by_field_name("function") {
300                let func_name = func_node.utf8_text(source.as_bytes()).unwrap_or("");
301                if func_name == "describe" || func_name == "it" || func_name == "test" {
302                    // Mark this as an entry point - extract the callback function
303                    if let Some(args) = node.child_by_field_name("arguments") {
304                        // The test callback is usually the second argument
305                        let mut cursor = args.walk();
306                        for child in args.children(&mut cursor) {
307                            if child.kind() == "arrow_function" || child.kind() == "function" {
308                                // This is a test entry point - for now, we'll just mark the test function itself
309                                entry_points
310                                    .push(format!("__test_callback_{}", entry_points.len()));
311                            }
312                        }
313                    }
314                }
315            }
316        }
317
318        // Traverse children (but not into function bodies to avoid recursion)
319        if kind != "statement_block" {
320            let mut cursor = node.walk();
321            for child in node.children(&mut cursor) {
322                self.traverse_for_entry_points(child, source, entry_points);
323            }
324        }
325    }
326}
327
328impl Parser for TypeScriptParser {
329    fn parse(&self, source: &str, file_path: &Path) -> Result<ParsedFile> {
330        // Parser needs to be mutable, so we need to use interior mutability
331        // For now, we'll create a new parser each time (not ideal but works for MVP)
332        let mut parser = TSParser::new();
333        parser.set_language(tree_sitter_typescript::language_typescript())?;
334
335        let tree = parser
336            .parse(source, None)
337            .ok_or_else(|| anyhow::anyhow!("Failed to parse TypeScript file"))?;
338
339        let file_path_str = file_path.to_string_lossy().to_string();
340
341        let definitions = self.extract_definitions(&tree, source, &file_path_str);
342        let usages = self.extract_usages(&tree, source, &file_path_str);
343        let entry_points = self.extract_entry_points(&tree, source);
344
345        Ok(ParsedFile {
346            path: file_path_str,
347            definitions,
348            usages,
349            entry_points,
350        })
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use super::*;
357
358    #[test]
359    fn test_parse_simple_function() {
360        let parser = TypeScriptParser::new().unwrap();
361        let source = r#"
362function hello() {
363    console.log("Hello, world!");
364}
365"#;
366        let result = parser.parse(source, Path::new("test.ts"));
367        assert!(result.is_ok());
368
369        let parsed = result.unwrap();
370        assert_eq!(parsed.definitions.len(), 1);
371        assert_eq!(parsed.definitions[0].name, "hello");
372    }
373
374    #[test]
375    fn test_parse_arrow_function() {
376        let parser = TypeScriptParser::new().unwrap();
377        let source = r#"
378const greet = () => {
379    console.log("Hello!");
380};
381"#;
382        let result = parser.parse(source, Path::new("test.ts"));
383        assert!(result.is_ok());
384
385        let parsed = result.unwrap();
386        assert_eq!(parsed.definitions.len(), 1);
387        assert_eq!(parsed.definitions[0].name, "greet");
388    }
389
390    #[test]
391    fn test_parse_class_with_methods() {
392        let parser = TypeScriptParser::new().unwrap();
393        let source = r#"
394class Calculator {
395    add(a: number, b: number) {
396        return a + b;
397    }
398    
399    subtract(a: number, b: number) {
400        return a - b;
401    }
402}
403"#;
404        let result = parser.parse(source, Path::new("test.ts"));
405        assert!(result.is_ok());
406
407        let parsed = result.unwrap();
408        // Should have 1 class + 2 methods = 3 definitions
409        assert_eq!(parsed.definitions.len(), 3);
410    }
411
412    #[test]
413    fn test_parse_function_calls() {
414        let parser = TypeScriptParser::new().unwrap();
415        let source = r#"
416function foo() {
417    return 42;
418}
419
420function bar() {
421    foo();
422    console.log("test");
423}
424"#;
425        let result = parser.parse(source, Path::new("test.ts"));
426        assert!(result.is_ok());
427
428        let parsed = result.unwrap();
429        assert_eq!(parsed.definitions.len(), 2); // foo, bar
430        assert!(!parsed.usages.is_empty()); // At least foo() call
431    }
432}