Skip to main content

cha_core/
query.rs

1//! Host-side tree-sitter query helper.
2//!
3//! Built-in plugins receive `ctx.tree` and `ctx.ts_language` and can run
4//! tree-sitter S-expression queries directly via `run_query`. WASM plugins go
5//! through the `tree_query` host import (see [`crate::wasm`]) — both paths
6//! ultimately call this helper.
7//!
8//! Lines in [`QueryMatch`] are 1-based to match `FunctionInfo` /
9//! `ClassInfo` / `CommentInfo`. Columns are 0-based byte offsets.
10
11use streaming_iterator::StreamingIterator;
12use tree_sitter::{Language, Tree};
13
14#[derive(Debug, Clone)]
15pub struct QueryMatch {
16    pub capture_name: String,
17    pub node_kind: String,
18    pub text: String,
19    pub start_line: u32,
20    pub start_col: u32,
21    pub end_line: u32,
22    pub end_col: u32,
23}
24
25/// Outer `Vec` = each pattern match, inner = captures within that match
26/// (in capture-list order). Empty on pattern compile error rather than panic
27/// — pattern strings can come from external plugins.
28pub fn run_query(
29    tree: &Tree,
30    lang: &Language,
31    source: &[u8],
32    pattern: &str,
33) -> Vec<Vec<QueryMatch>> {
34    let query = match tree_sitter::Query::new(lang, pattern) {
35        Ok(q) => q,
36        Err(_) => return vec![],
37    };
38    let capture_names: Vec<&str> = query.capture_names().to_vec();
39
40    let mut cursor = tree_sitter::QueryCursor::new();
41    let mut matches = cursor.matches(&query, tree.root_node(), source);
42    let mut results = vec![];
43    while let Some(m) = StreamingIterator::next(&mut matches) {
44        let captures: Vec<QueryMatch> = m
45            .captures
46            .iter()
47            .map(|c| {
48                let name: &str = capture_names.get(c.index as usize).copied().unwrap_or("");
49                node_to_match(&c.node, source, name)
50            })
51            .collect();
52        results.push(captures);
53    }
54    results
55}
56
57pub fn run_queries(
58    tree: &Tree,
59    lang: &Language,
60    source: &[u8],
61    patterns: &[&str],
62) -> Vec<Vec<Vec<QueryMatch>>> {
63    patterns
64        .iter()
65        .map(|p| run_query(tree, lang, source, p))
66        .collect()
67}
68
69/// 1-based lines (matches `FunctionInfo` convention).
70pub fn node_to_match(node: &tree_sitter::Node, source: &[u8], capture_name: &str) -> QueryMatch {
71    let text = node.utf8_text(source).unwrap_or("").to_string();
72    QueryMatch {
73        capture_name: capture_name.to_string(),
74        node_kind: node.kind().to_string(),
75        text,
76        start_line: (node.start_position().row as u32) + 1,
77        start_col: node.start_position().column as u32,
78        end_line: (node.end_position().row as u32) + 1,
79        end_col: node.end_position().column as u32,
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86
87    fn parse_rust(src: &str) -> (Tree, Language) {
88        let lang: Language = tree_sitter_rust::LANGUAGE.into();
89        let mut parser = tree_sitter::Parser::new();
90        parser.set_language(&lang).unwrap();
91        let tree = parser.parse(src, None).unwrap();
92        (tree, lang)
93    }
94
95    #[test]
96    fn finds_unsafe_blocks() {
97        let src = "fn main() { unsafe { let _ = 1; } }";
98        let (tree, lang) = parse_rust(src);
99        let matches = run_query(&tree, &lang, src.as_bytes(), "(unsafe_block) @b");
100        assert_eq!(matches.len(), 1);
101        assert_eq!(matches[0][0].node_kind, "unsafe_block");
102        assert_eq!(matches[0][0].start_line, 1);
103    }
104
105    #[test]
106    fn empty_for_invalid_pattern() {
107        let src = "fn main() {}";
108        let (tree, lang) = parse_rust(src);
109        let matches = run_query(&tree, &lang, src.as_bytes(), "(no_such_node_kind) @x");
110        assert!(matches.is_empty());
111    }
112
113    #[test]
114    fn captures_are_1_based() {
115        let src = "// line 1\nfn foo() {}\n";
116        let (tree, lang) = parse_rust(src);
117        let matches = run_query(&tree, &lang, src.as_bytes(), "(function_item) @f");
118        assert_eq!(matches[0][0].start_line, 2);
119    }
120}