Skip to main content

amql_engine/navigate/
tree.rs

1//! tree-sitter backed navigation operations.
2//!
3//! All functions are pure: source text in, node references out.
4//! Parser instances are cached per-thread to avoid re-creation overhead.
5
6use super::{NavResult, NodeRef};
7use crate::error::AqlError;
8use crate::types::{NodeKind, RelativePath};
9use std::cell::RefCell;
10
11/// Node kind selectors supported by the navigation API.
12///
13/// Selectors match tree-sitter node kinds directly (e.g. "function_declaration",
14/// "class_declaration"). An optional attribute predicate `[field=value]` filters
15/// by a named field's text content.
16struct KindSelector {
17    kind: String,
18    field: Option<String>,
19    value: Option<String>,
20}
21
22/// Parse a simple node-kind selector: `kind` or `kind[field=value]`.
23fn parse_kind_selector(selector: &str) -> KindSelector {
24    let selector = selector.trim();
25    if let Some(bracket_start) = selector.find('[') {
26        let kind = selector[..bracket_start].trim().to_string();
27        let rest = selector[bracket_start + 1..].trim_end_matches(']').trim();
28        if let Some(eq_pos) = rest.find('=') {
29            let field = rest[..eq_pos].trim().to_string();
30            let value = rest[eq_pos + 1..]
31                .trim()
32                .trim_matches('"')
33                .trim_matches('\'')
34                .to_string();
35            KindSelector {
36                kind,
37                field: Some(field),
38                value: Some(value),
39            }
40        } else {
41            KindSelector {
42                kind,
43                field: Some(rest.to_string()),
44                value: None,
45            }
46        }
47    } else {
48        KindSelector {
49            kind: selector.to_string(),
50            field: None,
51            value: None,
52        }
53    }
54}
55
56/// Check if a tree-sitter node matches a kind selector.
57fn matches_selector(node: &tree_sitter::Node, src: &[u8], selector: &KindSelector) -> bool {
58    if !selector.kind.is_empty() && node.kind() != selector.kind {
59        return false;
60    }
61    match (&selector.field, &selector.value) {
62        (Some(field), Some(value)) => node
63            .child_by_field_name(field.as_str())
64            .is_some_and(|child| child.utf8_text(src).unwrap_or("") == value.as_str()),
65        (Some(field), None) => node.child_by_field_name(field.as_str()).is_some(),
66        _ => true,
67    }
68}
69
70/// Build a NodeRef from a tree-sitter node.
71fn node_to_ref(node: &tree_sitter::Node, file: &RelativePath) -> NodeRef {
72    let start = node.start_position();
73    let end = node.end_position();
74    NodeRef {
75        file: file.clone(),
76        start_byte: node.start_byte(),
77        end_byte: node.end_byte(),
78        kind: NodeKind::from(node.kind()),
79        line: start.row + 1,
80        column: start.column,
81        end_line: end.row + 1,
82        end_column: end.column,
83    }
84}
85
86/// Build a NavResult from nodes, extracting source text for each.
87fn build_nav_result(nodes: &[tree_sitter::Node], src: &str, file: &RelativePath) -> NavResult {
88    let refs: Vec<NodeRef> = nodes.iter().map(|n| node_to_ref(n, file)).collect();
89    let source: Vec<String> = nodes
90        .iter()
91        .map(|n| {
92            src.get(n.start_byte()..n.end_byte())
93                .unwrap_or("")
94                .to_string()
95        })
96        .collect();
97    NavResult {
98        nodes: refs,
99        source,
100    }
101}
102
103// ---------------------------------------------------------------------------
104// Language detection + thread-local parser cache
105// ---------------------------------------------------------------------------
106
107/// Supported languages for tree-sitter parsing.
108#[derive(Clone, Copy, PartialEq, Eq)]
109enum Language {
110    Rust,
111    TypeScript,
112    Tsx,
113    JavaScript,
114}
115
116fn detect_language(file: &RelativePath) -> Option<Language> {
117    let path: &str = file.as_ref();
118    match std::path::Path::new(path)
119        .extension()
120        .and_then(|e| e.to_str())
121    {
122        Some("rs") => Some(Language::Rust),
123        Some("ts") | Some("mts") => Some(Language::TypeScript),
124        Some("tsx") | Some("jsx") => Some(Language::Tsx),
125        Some("js") | Some("mjs") => Some(Language::JavaScript),
126        _ => None,
127    }
128}
129
130thread_local! {
131    static RUST_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
132    static TS_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
133    static TSX_PARSER: RefCell<Option<tree_sitter::Parser>> = const { RefCell::new(None) };
134}
135
136fn parse_source(source: &str, lang: Language) -> Result<tree_sitter::Tree, AqlError> {
137    let parse_with = |cell: &RefCell<Option<tree_sitter::Parser>>,
138                      make_lang: fn() -> tree_sitter::Language|
139     -> Result<tree_sitter::Tree, AqlError> {
140        let mut opt = cell.borrow_mut();
141        let parser = opt.get_or_insert_with(|| {
142            let mut p = tree_sitter::Parser::new();
143            p.set_language(&make_lang())
144                .expect("Failed to set tree-sitter language");
145            p
146        });
147        parser
148            .parse(source, None)
149            .ok_or_else(|| AqlError::new("Failed to parse source"))
150    };
151
152    match lang {
153        Language::Rust => {
154            RUST_PARSER.with(|cell| parse_with(cell, || tree_sitter_rust::LANGUAGE.into()))
155        }
156        Language::TypeScript => TS_PARSER
157            .with(|cell| parse_with(cell, || tree_sitter_typescript::LANGUAGE_TYPESCRIPT.into())),
158        Language::Tsx | Language::JavaScript => {
159            TSX_PARSER.with(|cell| parse_with(cell, || tree_sitter_typescript::LANGUAGE_TSX.into()))
160        }
161    }
162}
163
164// ---------------------------------------------------------------------------
165// Find a node by byte range in a parsed tree
166// ---------------------------------------------------------------------------
167
168/// Find the deepest node in the tree that exactly matches the given byte range.
169fn find_node_by_range<'a>(
170    root: tree_sitter::Node<'a>,
171    start_byte: usize,
172    end_byte: usize,
173) -> Option<tree_sitter::Node<'a>> {
174    if root.start_byte() == start_byte && root.end_byte() == end_byte {
175        return Some(root);
176    }
177    let mut cursor = root.walk();
178    for child in root.named_children(&mut cursor) {
179        if child.start_byte() <= start_byte && child.end_byte() >= end_byte {
180            if let Some(found) = find_node_by_range(child, start_byte, end_byte) {
181                return Some(found);
182            }
183        }
184    }
185    // Fallback: check if root itself contains the range
186    if root.start_byte() <= start_byte && root.end_byte() >= end_byte {
187        return Some(root);
188    }
189    None
190}
191
192// ---------------------------------------------------------------------------
193// Navigation operations
194// ---------------------------------------------------------------------------
195
196/// Select all named descendants matching a selector within a scope.
197pub fn select_nodes(
198    source: &str,
199    file: &RelativePath,
200    scope: Option<&NodeRef>,
201    selector: &str,
202) -> Result<NavResult, AqlError> {
203    let lang = detect_language(file)
204        .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
205    let tree = parse_source(source, lang)?;
206    let src = source.as_bytes();
207    let sel = parse_kind_selector(selector);
208
209    let search_root = match scope {
210        Some(node_ref) => {
211            find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
212                .ok_or_else(|| {
213                    AqlError::new(format!(
214                        "Could not find node at {}..{} in {file}",
215                        node_ref.start_byte, node_ref.end_byte
216                    ))
217                })?
218        }
219        None => tree.root_node(),
220    };
221
222    let mut matches = Vec::new();
223    collect_matching_descendants(&search_root, src, &sel, &mut matches);
224    Ok(build_nav_result(&matches, source, file))
225}
226
227/// Recursively collect named descendants matching a selector.
228fn collect_matching_descendants<'a>(
229    node: &tree_sitter::Node<'a>,
230    src: &[u8],
231    selector: &KindSelector,
232    result: &mut Vec<tree_sitter::Node<'a>>,
233) {
234    let mut cursor = node.walk();
235    for child in node.named_children(&mut cursor) {
236        if matches_selector(&child, src, selector) {
237            result.push(child);
238        }
239        collect_matching_descendants(&child, src, selector, result);
240    }
241}
242
243/// Expand: return parent or nearest ancestor matching selector.
244pub fn expand_node(
245    source: &str,
246    file: &RelativePath,
247    node_ref: &NodeRef,
248    selector: Option<&str>,
249) -> Result<NavResult, AqlError> {
250    let lang = detect_language(file)
251        .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
252    let tree = parse_source(source, lang)?;
253    let src = source.as_bytes();
254
255    let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
256        .ok_or_else(|| {
257            AqlError::new(format!(
258                "Could not find node at {}..{} in {file}",
259                node_ref.start_byte, node_ref.end_byte
260            ))
261        })?;
262
263    let sel = selector.map(parse_kind_selector);
264
265    let mut current = target.parent();
266    while let Some(parent) = current {
267        match &sel {
268            Some(s) if !matches_selector(&parent, src, s) => {
269                current = parent.parent();
270            }
271            _ => {
272                return Ok(build_nav_result(&[parent], source, file));
273            }
274        }
275    }
276
277    Ok(NavResult {
278        nodes: vec![],
279        source: vec![],
280    })
281}
282
283/// Shrink: return children matching selector, or all named children.
284pub fn shrink_node(
285    source: &str,
286    file: &RelativePath,
287    node_ref: &NodeRef,
288    selector: Option<&str>,
289) -> Result<NavResult, AqlError> {
290    let lang = detect_language(file)
291        .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
292    let tree = parse_source(source, lang)?;
293    let src = source.as_bytes();
294
295    let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
296        .ok_or_else(|| {
297            AqlError::new(format!(
298                "Could not find node at {}..{} in {file}",
299                node_ref.start_byte, node_ref.end_byte
300            ))
301        })?;
302
303    let mut children = Vec::new();
304    let mut cursor = target.walk();
305    match selector.map(parse_kind_selector) {
306        Some(sel) => {
307            for child in target.named_children(&mut cursor) {
308                if matches_selector(&child, src, &sel) {
309                    children.push(child);
310                }
311            }
312        }
313        None => {
314            for child in target.named_children(&mut cursor) {
315                children.push(child);
316            }
317        }
318    }
319
320    Ok(build_nav_result(&children, source, file))
321}
322
323/// Next: return next named sibling matching selector.
324pub fn next_node(
325    source: &str,
326    file: &RelativePath,
327    node_ref: &NodeRef,
328    selector: Option<&str>,
329) -> Result<NavResult, AqlError> {
330    let lang = detect_language(file)
331        .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
332    let tree = parse_source(source, lang)?;
333    let src = source.as_bytes();
334
335    let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
336        .ok_or_else(|| {
337            AqlError::new(format!(
338                "Could not find node at {}..{} in {file}",
339                node_ref.start_byte, node_ref.end_byte
340            ))
341        })?;
342
343    let sel = selector.map(parse_kind_selector);
344    let mut current = target.next_named_sibling();
345    while let Some(sibling) = current {
346        match &sel {
347            Some(s) if !matches_selector(&sibling, src, s) => {
348                current = sibling.next_named_sibling();
349            }
350            _ => {
351                return Ok(build_nav_result(&[sibling], source, file));
352            }
353        }
354    }
355
356    Ok(NavResult {
357        nodes: vec![],
358        source: vec![],
359    })
360}
361
362/// Prev: return previous named sibling matching selector.
363pub fn prev_node(
364    source: &str,
365    file: &RelativePath,
366    node_ref: &NodeRef,
367    selector: Option<&str>,
368) -> Result<NavResult, AqlError> {
369    let lang = detect_language(file)
370        .ok_or_else(|| AqlError::new(format!("No parser for file extension: {file}")))?;
371    let tree = parse_source(source, lang)?;
372    let src = source.as_bytes();
373
374    let target = find_node_by_range(tree.root_node(), node_ref.start_byte, node_ref.end_byte)
375        .ok_or_else(|| {
376            AqlError::new(format!(
377                "Could not find node at {}..{} in {file}",
378                node_ref.start_byte, node_ref.end_byte
379            ))
380        })?;
381
382    let sel = selector.map(parse_kind_selector);
383    let mut current = target.prev_named_sibling();
384    while let Some(sibling) = current {
385        match &sel {
386            Some(s) if !matches_selector(&sibling, src, s) => {
387                current = sibling.prev_named_sibling();
388            }
389            _ => {
390                return Ok(build_nav_result(&[sibling], source, file));
391            }
392        }
393    }
394
395    Ok(NavResult {
396        nodes: vec![],
397        source: vec![],
398    })
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404
405    const TS_SOURCE: &str = r#"
406function greet(name: string): string {
407    return `Hello, ${name}!`;
408}
409
410async function fetchUser(id: number): Promise<User> {
411    const response = await fetch(`/api/users/${id}`);
412    return response.json();
413}
414
415class UserService {
416    private baseUrl: string;
417
418    constructor(baseUrl: string) {
419        this.baseUrl = baseUrl;
420    }
421
422    async getById(id: number): Promise<User> {
423        return fetchUser(id);
424    }
425
426    async create(data: UserInput): Promise<User> {
427        const response = await fetch(this.baseUrl, {
428            method: 'POST',
429            body: JSON.stringify(data),
430        });
431        return response.json();
432    }
433}
434
435export const MAX_RETRIES = 3;
436"#;
437
438    const RUST_SOURCE: &str = r#"
439pub fn parse_selector(input: &str) -> Result<SelectorAst, AqlError> {
440    let trimmed = input.trim();
441    if trimmed.is_empty() {
442        return Err(AqlError::new("Empty selector"));
443    }
444    Ok(SelectorAst { compounds: vec![] })
445}
446
447pub struct SelectorAst {
448    pub compounds: Vec<CompoundSelector>,
449}
450
451pub enum Combinator {
452    Child,
453    Descendant,
454}
455
456impl SelectorAst {
457    pub fn is_empty(&self) -> bool {
458        self.compounds.is_empty()
459    }
460}
461"#;
462
463    fn ts_file() -> RelativePath {
464        RelativePath::from("test.ts")
465    }
466
467    fn rs_file() -> RelativePath {
468        RelativePath::from("test.rs")
469    }
470
471    #[test]
472    fn select_finds_function_declarations() {
473        // Arrange and Act
474        let result = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
475
476        // Assert
477        assert_eq!(result.nodes.len(), 2, "should find 2 function declarations");
478        assert!(
479            result.source[0].contains("function greet"),
480            "first function should be greet"
481        );
482        assert!(
483            result.source[1].contains("async function fetchUser"),
484            "second function should be fetchUser"
485        );
486    }
487
488    #[test]
489    fn select_finds_class_declarations() {
490        // Arrange and Act
491        let result = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
492
493        // Assert
494        assert_eq!(result.nodes.len(), 1, "should find 1 class declaration");
495        assert!(
496            result.source[0].contains("class UserService"),
497            "should be UserService"
498        );
499    }
500
501    #[test]
502    fn select_within_scope() {
503        // Arrange
504        let classes = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
505        let class_ref = &classes.nodes[0];
506
507        // Act
508        let methods =
509            select_nodes(TS_SOURCE, &ts_file(), Some(class_ref), "method_definition").unwrap();
510
511        // Assert
512        assert_eq!(
513            methods.nodes.len(),
514            3,
515            "UserService has constructor + 2 methods"
516        );
517    }
518
519    #[test]
520    fn select_with_field_predicate() {
521        // Arrange and Act
522        let result = select_nodes(
523            TS_SOURCE,
524            &ts_file(),
525            None,
526            r#"function_declaration[name=greet]"#,
527        )
528        .unwrap();
529
530        // Assert
531        assert_eq!(result.nodes.len(), 1, "should find exactly greet");
532        assert!(
533            result.source[0].contains("function greet"),
534            "should be the greet function"
535        );
536    }
537
538    #[test]
539    fn expand_returns_parent() {
540        // Arrange
541        let methods = select_nodes(TS_SOURCE, &ts_file(), None, "method_definition").unwrap();
542        let method_ref = &methods.nodes[0];
543
544        // Act
545        let result = expand_node(TS_SOURCE, &ts_file(), method_ref, None).unwrap();
546
547        // Assert
548        assert_eq!(result.nodes.len(), 1, "should find parent");
549        // Parent of method_definition is class_body
550        assert!(
551            result.nodes[0].kind == "class_body",
552            "parent should be class_body, got: {}",
553            result.nodes[0].kind
554        );
555    }
556
557    #[test]
558    fn expand_with_selector_finds_ancestor() {
559        // Arrange
560        let methods = select_nodes(TS_SOURCE, &ts_file(), None, "method_definition").unwrap();
561        let method_ref = &methods.nodes[0];
562
563        // Act
564        let result =
565            expand_node(TS_SOURCE, &ts_file(), method_ref, Some("class_declaration")).unwrap();
566
567        // Assert
568        assert_eq!(result.nodes.len(), 1, "should find class ancestor");
569        assert_eq!(
570            result.nodes[0].kind, "class_declaration",
571            "should be class_declaration"
572        );
573    }
574
575    #[test]
576    fn shrink_returns_children() {
577        // Arrange
578        let classes = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
579        let class_ref = &classes.nodes[0];
580
581        // Act
582        let result = shrink_node(TS_SOURCE, &ts_file(), class_ref, None).unwrap();
583
584        // Assert
585        assert!(
586            result.nodes.len() >= 2,
587            "class should have at least name and body children"
588        );
589    }
590
591    #[test]
592    fn shrink_with_selector() {
593        // Arrange
594        let classes = select_nodes(TS_SOURCE, &ts_file(), None, "class_declaration").unwrap();
595        let class_ref = &classes.nodes[0];
596
597        // Act
598        let result = shrink_node(TS_SOURCE, &ts_file(), class_ref, Some("class_body")).unwrap();
599
600        // Assert
601        assert_eq!(result.nodes.len(), 1, "should find class_body child");
602        assert_eq!(result.nodes[0].kind, "class_body", "should be class_body");
603    }
604
605    #[test]
606    fn next_returns_sibling() {
607        // Arrange
608        let funcs = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
609        let first = &funcs.nodes[0];
610
611        // Act
612        let result = next_node(TS_SOURCE, &ts_file(), first, None).unwrap();
613
614        // Assert
615        assert_eq!(result.nodes.len(), 1, "should find next sibling");
616        assert!(
617            result.source[0].contains("fetchUser"),
618            "next function should be fetchUser"
619        );
620    }
621
622    #[test]
623    fn next_with_selector_skips_non_matching() {
624        // Arrange
625        let funcs = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
626        let first = &funcs.nodes[0];
627
628        // Act — skip fetchUser, find class_declaration
629        let result = next_node(TS_SOURCE, &ts_file(), first, Some("class_declaration")).unwrap();
630
631        // Assert
632        assert_eq!(result.nodes.len(), 1, "should find class");
633        assert_eq!(
634            result.nodes[0].kind, "class_declaration",
635            "should be class_declaration"
636        );
637    }
638
639    #[test]
640    fn prev_returns_sibling() {
641        // Arrange
642        let funcs = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
643        let second = &funcs.nodes[1];
644
645        // Act
646        let result = prev_node(TS_SOURCE, &ts_file(), second, None).unwrap();
647
648        // Assert
649        assert_eq!(result.nodes.len(), 1, "should find prev sibling");
650        assert!(
651            result.source[0].contains("function greet"),
652            "prev function should be greet"
653        );
654    }
655
656    #[test]
657    fn select_rust_functions() {
658        // Arrange and Act
659        let result = select_nodes(RUST_SOURCE, &rs_file(), None, "function_item").unwrap();
660
661        // Assert
662        assert_eq!(
663            result.nodes.len(),
664            2,
665            "should find parse_selector and is_empty"
666        );
667    }
668
669    #[test]
670    fn select_rust_structs() {
671        // Arrange and Act
672        let result = select_nodes(RUST_SOURCE, &rs_file(), None, "struct_item").unwrap();
673
674        // Assert
675        assert_eq!(result.nodes.len(), 1, "should find SelectorAst");
676        assert!(
677            result.source[0].contains("struct SelectorAst"),
678            "should be SelectorAst"
679        );
680    }
681
682    #[test]
683    fn empty_result_for_no_matches() {
684        // Arrange and Act
685        let result = select_nodes(TS_SOURCE, &ts_file(), None, "trait_item").unwrap();
686
687        // Assert
688        assert_eq!(result.nodes.len(), 0, "TS has no trait_item nodes");
689    }
690
691    #[test]
692    fn node_ref_byte_ranges_are_precise() {
693        // Arrange and Act
694        let result = select_nodes(TS_SOURCE, &ts_file(), None, "function_declaration").unwrap();
695        let node = &result.nodes[0];
696        let extracted = &TS_SOURCE[node.start_byte..node.end_byte];
697
698        // Assert
699        assert_eq!(
700            extracted, result.source[0],
701            "byte range should produce identical source text"
702        );
703    }
704}