Skip to main content

amql_engine/resolver/
typescript.rs

1//! TypeScript/JavaScript source file resolver using tree-sitter.
2//!
3//! Parses `.ts`, `.tsx`, `.js`, `.jsx`, `.mts`, `.mjs` files into `CodeElement`
4//! trees, extracting functions, classes, methods, interfaces, type aliases,
5//! enums, modules, and constants.
6
7use super::{CodeElement, SourceLocation};
8use crate::error::AqlError;
9use crate::types::{AttrName, CodeElementName, RelativePath, TagName};
10use rustc_hash::FxHashMap;
11use std::cell::RefCell;
12use std::path::Path;
13
14/// TypeScript/JavaScript source file resolver using tree-sitter.
15pub struct TypeScriptResolver;
16
17impl super::CodeResolver for TypeScriptResolver {
18    fn resolve(&self, file_path: &Path) -> Result<CodeElement, AqlError> {
19        let source =
20            std::fs::read_to_string(file_path).map_err(|e| format!("Failed to read file: {e}"))?;
21        let is_tsx = file_path
22            .extension()
23            .is_some_and(|ext| ext == "tsx" || ext == "jsx");
24        let root = parse_typescript_source(&source, file_path, is_tsx)?;
25        Ok(root)
26    }
27
28    fn extensions(&self) -> &[&str] {
29        &[".ts", ".tsx", ".js", ".jsx", ".mts", ".mjs"]
30    }
31
32    fn code_tags(&self) -> &[&str] {
33        &[
34            "function",
35            "class",
36            "method",
37            "interface",
38            "type",
39            "enum",
40            "module",
41            "const",
42        ]
43    }
44}
45
46// Thread-local cached tree-sitter parsers to avoid re-creating on each call.
47thread_local! {
48    static TS_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
49    static TSX_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
50}
51
52fn with_ts_parser<F, R>(is_tsx: bool, f: F) -> Result<R, String>
53where
54    F: FnOnce(&mut tree_sitter::Parser) -> Result<R, String>,
55{
56    if is_tsx {
57        TSX_PARSER.with(|cell| {
58            let mut opt = cell.borrow_mut();
59            let parser = opt.get_or_insert_with(|| {
60                let mut p = tree_sitter::Parser::new();
61                p.set_language(&tree_sitter_typescript::LANGUAGE_TSX.into())
62                    .expect("Failed to set TSX language for tree-sitter");
63                p
64            });
65            f(parser)
66        })
67    } else {
68        TS_PARSER.with(|cell| {
69            let mut opt = cell.borrow_mut();
70            let parser = opt.get_or_insert_with(|| {
71                let mut p = tree_sitter::Parser::new();
72                p.set_language(&tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())
73                    .expect("Failed to set TypeScript language for tree-sitter");
74                p
75            });
76            f(parser)
77        })
78    }
79}
80
81/// Parse a TypeScript/JavaScript source string into a CodeElement tree.
82fn parse_typescript_source(
83    source: &str,
84    file_path: &Path,
85    is_tsx: bool,
86) -> Result<CodeElement, String> {
87    let tree = with_ts_parser(is_tsx, |parser| {
88        parser
89            .parse(source, None)
90            .ok_or_else(|| "Failed to parse source".to_string())
91    })?;
92
93    let root_node = tree.root_node();
94    let src = source.as_bytes();
95    let file_str = file_path.to_string_lossy().to_string();
96
97    let mut children = Vec::new();
98    let mut cursor = root_node.walk();
99    for child in root_node.named_children(&mut cursor) {
100        extract_elements(&child, src, &file_str, &mut children);
101    }
102
103    let filename = file_path
104        .file_name()
105        .map(|f| f.to_string_lossy().to_string())
106        .unwrap_or_else(|| file_str.clone());
107
108    Ok(CodeElement {
109        tag: TagName::from("module"),
110        name: CodeElementName::from(filename),
111        attrs: FxHashMap::default(),
112        children,
113        source: SourceLocation {
114            file: RelativePath::from(file_str),
115            line: 1,
116            column: 0,
117            end_line: Some(root_node.end_position().row + 1),
118            end_column: Some(root_node.end_position().column),
119            start_byte: root_node.start_byte(),
120            end_byte: root_node.end_byte(),
121        },
122    })
123}
124
125/// Extract CodeElement(s) from a tree-sitter node.
126/// Some nodes (like export_statement) are wrappers that delegate to their child.
127fn extract_elements(
128    node: &tree_sitter::Node,
129    src: &[u8],
130    file: &str,
131    result: &mut Vec<CodeElement>,
132) {
133    match node.kind() {
134        "function_declaration" | "generator_function_declaration" => {
135            result.push(extract_function(node, src, file));
136        }
137        "class_declaration" | "abstract_class_declaration" => {
138            result.push(extract_class(node, src, file));
139        }
140        "interface_declaration" => {
141            result.push(extract_interface(node, src, file));
142        }
143        "type_alias_declaration" => {
144            result.push(extract_type_alias(node, src, file));
145        }
146        "enum_declaration" => {
147            result.push(extract_enum(node, src, file));
148        }
149        "internal_module" | "module" => {
150            result.push(extract_module(node, src, file));
151        }
152        "lexical_declaration" | "variable_declaration" => {
153            extract_variable_declaration(node, src, file, result);
154        }
155        "export_statement" => {
156            extract_export_statement(node, src, file, result);
157        }
158        _ => {}
159    }
160}
161
162fn node_text<'a>(node: &tree_sitter::Node, src: &'a [u8]) -> &'a str {
163    node.utf8_text(src).unwrap_or("")
164}
165
166fn get_name(node: &tree_sitter::Node, src: &[u8]) -> CodeElementName {
167    CodeElementName::from(
168        node.child_by_field_name("name")
169            .map(|n| node_text(&n, src).to_string())
170            .unwrap_or_default(),
171    )
172}
173
174fn make_source_location(node: &tree_sitter::Node, file: &str) -> SourceLocation {
175    let start = node.start_position();
176    let end = node.end_position();
177    SourceLocation {
178        file: RelativePath::from(file),
179        line: start.row + 1,
180        column: start.column,
181        end_line: Some(end.row + 1),
182        end_column: Some(end.column),
183        start_byte: node.start_byte(),
184        end_byte: node.end_byte(),
185    }
186}
187
188fn is_async(node: &tree_sitter::Node, src: &[u8]) -> bool {
189    let mut cursor = node.walk();
190    for child in node.children(&mut cursor) {
191        if node_text(&child, src) == "async" {
192            return true;
193        }
194    }
195    false
196}
197
198fn has_keyword(node: &tree_sitter::Node, src: &[u8], keyword: &str) -> bool {
199    let mut cursor = node.walk();
200    let result = node
201        .children(&mut cursor)
202        .any(|c| node_text(&c, src) == keyword);
203    result
204}
205
206fn is_generator(node: &tree_sitter::Node) -> bool {
207    node.kind() == "generator_function_declaration" || node.kind() == "generator_function"
208}
209
210fn extract_function(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
211    let name = get_name(node, src);
212    let mut attrs = FxHashMap::default();
213    attrs.insert(
214        AttrName::from("name"),
215        serde_json::Value::String(name.to_string()),
216    );
217    if is_async(node, src) {
218        attrs.insert(AttrName::from("async"), serde_json::Value::Bool(true));
219    }
220    if is_generator(node) {
221        attrs.insert(AttrName::from("generator"), serde_json::Value::Bool(true));
222    }
223    if let Some(ret) = node.child_by_field_name("return_type") {
224        let text = node_text(&ret, src).trim_start_matches(':').trim();
225        if !text.is_empty() {
226            attrs.insert(
227                AttrName::from("returnType"),
228                serde_json::Value::String(text.to_string()),
229            );
230        }
231    }
232    CodeElement {
233        tag: TagName::from("function"),
234        name,
235        attrs,
236        children: vec![],
237        source: make_source_location(node, file),
238    }
239}
240
241fn extract_class(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
242    let name = get_name(node, src);
243    let mut attrs = FxHashMap::default();
244    attrs.insert(
245        AttrName::from("name"),
246        serde_json::Value::String(name.to_string()),
247    );
248    if node.kind() == "abstract_class_declaration" {
249        attrs.insert(AttrName::from("abstract"), serde_json::Value::Bool(true));
250    }
251
252    // Extract extends clause
253    let mut cursor = node.walk();
254    for child in node.named_children(&mut cursor) {
255        if child.kind() == "class_heritage" {
256            let heritage_text = node_text(&child, src);
257            attrs.insert(
258                AttrName::from("extends"),
259                serde_json::Value::String(heritage_text.to_string()),
260            );
261        }
262    }
263
264    // Extract methods from class body
265    let mut children = Vec::new();
266    if let Some(body) = node.child_by_field_name("body") {
267        let mut body_cursor = body.walk();
268        for child in body.named_children(&mut body_cursor) {
269            match child.kind() {
270                "method_definition" => {
271                    children.push(extract_method(&child, src, file));
272                }
273                "public_field_definition" | "property_definition" => {
274                    // Check if it's an arrow function property
275                    if let Some(value) = child.child_by_field_name("value") {
276                        if value.kind() == "arrow_function" {
277                            children.push(extract_method_from_property(&child, &value, src, file));
278                        }
279                    }
280                }
281                _ => {}
282            }
283        }
284    }
285
286    CodeElement {
287        tag: TagName::from("class"),
288        name,
289        attrs,
290        children,
291        source: make_source_location(node, file),
292    }
293}
294
295fn extract_method(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
296    let name = get_name(node, src);
297    let mut attrs = FxHashMap::default();
298    attrs.insert(
299        AttrName::from("name"),
300        serde_json::Value::String(name.to_string()),
301    );
302    if is_async(node, src) {
303        attrs.insert(AttrName::from("async"), serde_json::Value::Bool(true));
304    }
305
306    let mut cursor = node.walk();
307    for child in node.children(&mut cursor) {
308        let text = node_text(&child, src);
309        match text {
310            "static" => {
311                attrs.insert(AttrName::from("static"), serde_json::Value::Bool(true));
312            }
313            "get" if child.kind() == "property_identifier" || !child.is_named() => {
314                attrs.insert(AttrName::from("getter"), serde_json::Value::Bool(true));
315            }
316            "set" if child.kind() == "property_identifier" || !child.is_named() => {
317                attrs.insert(AttrName::from("setter"), serde_json::Value::Bool(true));
318            }
319            _ => {}
320        }
321        if child.kind() == "accessibility_modifier" {
322            attrs.insert(
323                AttrName::from("visibility"),
324                serde_json::Value::String(text.to_string()),
325            );
326        }
327    }
328
329    if let Some(ret) = node.child_by_field_name("return_type") {
330        let text = node_text(&ret, src).trim_start_matches(':').trim();
331        if !text.is_empty() {
332            attrs.insert(
333                AttrName::from("returnType"),
334                serde_json::Value::String(text.to_string()),
335            );
336        }
337    }
338
339    CodeElement {
340        tag: TagName::from("method"),
341        name,
342        attrs,
343        children: vec![],
344        source: make_source_location(node, file),
345    }
346}
347
348fn extract_method_from_property(
349    prop: &tree_sitter::Node,
350    _arrow: &tree_sitter::Node,
351    src: &[u8],
352    file: &str,
353) -> CodeElement {
354    let name = get_name(prop, src);
355    let mut attrs = FxHashMap::default();
356    attrs.insert(
357        AttrName::from("name"),
358        serde_json::Value::String(name.to_string()),
359    );
360    attrs.insert(AttrName::from("arrow"), serde_json::Value::Bool(true));
361
362    CodeElement {
363        tag: TagName::from("method"),
364        name,
365        attrs,
366        children: vec![],
367        source: make_source_location(prop, file),
368    }
369}
370
371fn extract_interface(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
372    let name = get_name(node, src);
373    let mut attrs = FxHashMap::default();
374    attrs.insert(
375        AttrName::from("name"),
376        serde_json::Value::String(name.to_string()),
377    );
378
379    // Count members
380    if let Some(body) = node.child_by_field_name("body") {
381        let mut cursor = body.walk();
382        let members = body.named_children(&mut cursor).count();
383        attrs.insert(
384            AttrName::from("members"),
385            serde_json::Value::Number(members.into()),
386        );
387    }
388
389    CodeElement {
390        tag: TagName::from("interface"),
391        name,
392        attrs,
393        children: vec![],
394        source: make_source_location(node, file),
395    }
396}
397
398fn extract_type_alias(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
399    let name = get_name(node, src);
400    let mut attrs = FxHashMap::default();
401    attrs.insert(
402        AttrName::from("name"),
403        serde_json::Value::String(name.to_string()),
404    );
405
406    CodeElement {
407        tag: TagName::from("type"),
408        name,
409        attrs,
410        children: vec![],
411        source: make_source_location(node, file),
412    }
413}
414
415fn extract_enum(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
416    let name = get_name(node, src);
417    let mut attrs = FxHashMap::default();
418    attrs.insert(
419        AttrName::from("name"),
420        serde_json::Value::String(name.to_string()),
421    );
422
423    // Count members
424    if let Some(body) = node.child_by_field_name("body") {
425        let mut cursor = body.walk();
426        let members = body.named_children(&mut cursor).count();
427        attrs.insert(
428            AttrName::from("members"),
429            serde_json::Value::Number(members.into()),
430        );
431    }
432
433    CodeElement {
434        tag: TagName::from("enum"),
435        name,
436        attrs,
437        children: vec![],
438        source: make_source_location(node, file),
439    }
440}
441
442fn extract_module(node: &tree_sitter::Node, src: &[u8], file: &str) -> CodeElement {
443    let name = get_name(node, src);
444    let mut attrs = FxHashMap::default();
445    attrs.insert(
446        AttrName::from("name"),
447        serde_json::Value::String(name.to_string()),
448    );
449
450    CodeElement {
451        tag: TagName::from("module"),
452        name,
453        attrs,
454        children: vec![],
455        source: make_source_location(node, file),
456    }
457}
458
459/// Extract from `const x = ...` or `let x = ...`.
460/// Arrow functions and function expressions become function elements.
461/// Plain constants become const elements.
462fn extract_variable_declaration(
463    node: &tree_sitter::Node,
464    src: &[u8],
465    file: &str,
466    result: &mut Vec<CodeElement>,
467) {
468    let is_const = has_keyword(node, src, "const");
469
470    let mut cursor = node.walk();
471    for child in node.named_children(&mut cursor) {
472        if child.kind() == "variable_declarator" {
473            let name = get_name(&child, src);
474            if name.is_empty() {
475                continue;
476            }
477
478            if let Some(value) = child.child_by_field_name("value") {
479                match value.kind() {
480                    "arrow_function" | "function_expression" | "generator_function" => {
481                        let mut attrs = FxHashMap::default();
482                        attrs.insert(
483                            AttrName::from("name"),
484                            serde_json::Value::String(name.to_string()),
485                        );
486                        if value.kind() == "arrow_function" {
487                            attrs.insert(AttrName::from("arrow"), serde_json::Value::Bool(true));
488                        }
489                        if is_async(&value, src) {
490                            attrs.insert(AttrName::from("async"), serde_json::Value::Bool(true));
491                        }
492                        if value.kind() == "generator_function" {
493                            attrs
494                                .insert(AttrName::from("generator"), serde_json::Value::Bool(true));
495                        }
496                        result.push(CodeElement {
497                            tag: TagName::from("function"),
498                            name,
499                            attrs,
500                            children: vec![],
501                            source: make_source_location(node, file),
502                        });
503                        return;
504                    }
505                    _ => {}
506                }
507            }
508
509            if is_const {
510                let mut attrs = FxHashMap::default();
511                attrs.insert(
512                    AttrName::from("name"),
513                    serde_json::Value::String(name.to_string()),
514                );
515                result.push(CodeElement {
516                    tag: TagName::from("const"),
517                    name,
518                    attrs,
519                    children: vec![],
520                    source: make_source_location(node, file),
521                });
522            }
523        }
524    }
525}
526
527/// Handle `export` wrapper: propagate `export` attr to the inner declaration.
528fn extract_export_statement(
529    node: &tree_sitter::Node,
530    src: &[u8],
531    file: &str,
532    result: &mut Vec<CodeElement>,
533) {
534    let has_default = has_keyword(node, src, "default");
535
536    let before = result.len();
537    let mut cursor = node.walk();
538    for child in node.named_children(&mut cursor) {
539        extract_elements(&child, src, file, result);
540    }
541
542    // Tag all newly added elements with export (and default if applicable)
543    for el in &mut result[before..] {
544        el.attrs
545            .insert(AttrName::from("export"), serde_json::Value::Bool(true));
546        if has_default {
547            el.attrs
548                .insert(AttrName::from("default"), serde_json::Value::Bool(true));
549        }
550    }
551}
552
553#[cfg(test)]
554mod tests {
555    use super::*;
556
557    fn parse_snippet(source: &str) -> CodeElement {
558        parse_typescript_source(source, Path::new("test.ts"), false).unwrap()
559    }
560
561    fn parse_tsx_snippet(source: &str) -> CodeElement {
562        parse_typescript_source(source, Path::new("test.tsx"), true).unwrap()
563    }
564
565    #[test]
566    fn parses_function_declarations() {
567        // Arrange
568        let root = parse_snippet("function foo() {}");
569        let async_root = parse_snippet("async function bar() {}");
570        let gen_root = parse_snippet("function* gen() {}");
571
572        // Act
573        let foo = &root.children[0];
574        let bar = &async_root.children[0];
575        let gen = &gen_root.children[0];
576
577        // Assert
578        assert_eq!(foo.tag, "function", "should be function");
579        assert_eq!(foo.name, "foo", "should be named foo");
580
581        assert_eq!(bar.tag, "function", "should be function");
582        assert_eq!(bar.name, "bar", "should be named bar");
583        assert_eq!(
584            bar.attrs.get("async"),
585            Some(&serde_json::Value::Bool(true)),
586            "should be async"
587        );
588
589        assert_eq!(gen.tag, "function", "should be function");
590        assert_eq!(gen.name, "gen", "should be named gen");
591        assert_eq!(
592            gen.attrs.get("generator"),
593            Some(&serde_json::Value::Bool(true)),
594            "should be generator"
595        );
596    }
597
598    #[test]
599    fn parses_arrow_functions() {
600        // Arrange
601        let root = parse_snippet("const handler = async (req: Request) => { return 1; };");
602
603        // Act
604        let handler = &root.children[0];
605
606        // Assert
607        assert_eq!(handler.tag, "function", "arrow fn should be function");
608        assert_eq!(handler.name, "handler", "should be named handler");
609        assert_eq!(
610            handler.attrs.get("arrow"),
611            Some(&serde_json::Value::Bool(true)),
612            "should be arrow"
613        );
614        assert_eq!(
615            handler.attrs.get("async"),
616            Some(&serde_json::Value::Bool(true)),
617            "should be async"
618        );
619    }
620
621    #[test]
622    fn parses_classes_with_methods() {
623        // Arrange
624        let root = parse_snippet(
625            r#"class UserService {
626                async getById(id: string): Promise<User> { return null; }
627                static create() { return new UserService(); }
628            }"#,
629        );
630
631        // Act
632        let cls = &root.children[0];
633
634        // Assert
635        assert_eq!(cls.tag, "class", "should be class");
636        assert_eq!(cls.name, "UserService", "should be named UserService");
637        assert_eq!(cls.children.len(), 2, "should have two methods");
638
639        let get_by_id = &cls.children[0];
640        assert_eq!(get_by_id.tag, "method", "should be method");
641        assert_eq!(get_by_id.name, "getById", "should be named getById");
642        assert_eq!(
643            get_by_id.attrs.get("async"),
644            Some(&serde_json::Value::Bool(true)),
645            "should be async"
646        );
647
648        let create = &cls.children[1];
649        assert_eq!(create.tag, "method", "should be method");
650        assert_eq!(create.name, "create", "should be named create");
651        assert_eq!(
652            create.attrs.get("static"),
653            Some(&serde_json::Value::Bool(true)),
654            "should be static"
655        );
656    }
657
658    #[test]
659    fn parses_interfaces_and_types() {
660        // Arrange
661        let root = parse_snippet(
662            r#"
663            interface User { id: string; name: string; }
664            type UserId = string;
665        "#,
666        );
667
668        // Act
669        let iface = &root.children[0];
670        let alias = &root.children[1];
671
672        // Assert
673        assert_eq!(iface.tag, "interface", "should be interface");
674        assert_eq!(iface.name, "User", "should be named User");
675        assert_eq!(
676            iface.attrs.get("members"),
677            Some(&serde_json::Value::Number(2.into())),
678            "should have 2 members"
679        );
680
681        assert_eq!(alias.tag, "type", "should be type");
682        assert_eq!(alias.name, "UserId", "should be named UserId");
683    }
684
685    #[test]
686    fn parses_enums() {
687        // Arrange
688        let root = parse_snippet("enum Role { Admin, User, Guest }");
689
690        // Act
691        let e = &root.children[0];
692
693        // Assert
694        assert_eq!(e.tag, "enum", "should be enum");
695        assert_eq!(e.name, "Role", "should be named Role");
696        assert_eq!(
697            e.attrs.get("members"),
698            Some(&serde_json::Value::Number(3.into())),
699            "should have 3 members"
700        );
701    }
702
703    #[test]
704    fn parses_exports() {
705        // Arrange
706        let root = parse_snippet(
707            r#"
708            export function fetchUser() {}
709            export default class App {}
710            export const MAX = 3;
711        "#,
712        );
713
714        // Act
715        let func = &root.children[0];
716        let cls = &root.children[1];
717        let konst = &root.children[2];
718
719        // Assert
720        assert_eq!(func.tag, "function", "should be function");
721        assert_eq!(
722            func.attrs.get("export"),
723            Some(&serde_json::Value::Bool(true)),
724            "should be exported"
725        );
726
727        assert_eq!(cls.tag, "class", "should be class");
728        assert_eq!(
729            cls.attrs.get("export"),
730            Some(&serde_json::Value::Bool(true)),
731            "should be exported"
732        );
733        assert_eq!(
734            cls.attrs.get("default"),
735            Some(&serde_json::Value::Bool(true)),
736            "should be default export"
737        );
738
739        assert_eq!(konst.tag, "const", "should be const");
740        assert_eq!(konst.name, "MAX", "should be named MAX");
741        assert_eq!(
742            konst.attrs.get("export"),
743            Some(&serde_json::Value::Bool(true)),
744            "should be exported"
745        );
746    }
747
748    #[test]
749    fn parses_tsx_files() {
750        // Arrange
751        let root = parse_tsx_snippet(
752            r#"
753            export function App(): JSX.Element {
754                return <div>Hello</div>;
755            }
756        "#,
757        );
758
759        // Act
760        let func = &root.children[0];
761
762        // Assert
763        assert_eq!(func.tag, "function", "should be function");
764        assert_eq!(func.name, "App", "should be named App");
765        assert_eq!(
766            func.attrs.get("export"),
767            Some(&serde_json::Value::Bool(true)),
768            "should be exported"
769        );
770    }
771
772    #[test]
773    fn parses_constants() {
774        // Arrange
775        let root = parse_snippet("const MAX_RETRIES = 3;");
776
777        // Act
778        let konst = &root.children[0];
779
780        // Assert
781        assert_eq!(konst.tag, "const", "should be const");
782        assert_eq!(konst.name, "MAX_RETRIES", "should be named MAX_RETRIES");
783    }
784}