1use streaming_iterator::StreamingIterator;
12use tree_sitter::{Language, Tree};
13
14#[derive(Debug, Clone)]
15pub struct QueryMatch {
16 pub capture_name: String,
17 pub node_kind: String,
18 pub text: String,
19 pub start_line: u32,
20 pub start_col: u32,
21 pub end_line: u32,
22 pub end_col: u32,
23}
24
25pub fn run_query(
29 tree: &Tree,
30 lang: &Language,
31 source: &[u8],
32 pattern: &str,
33) -> Vec<Vec<QueryMatch>> {
34 let query = match tree_sitter::Query::new(lang, pattern) {
35 Ok(q) => q,
36 Err(_) => return vec![],
37 };
38 let capture_names: Vec<&str> = query.capture_names().to_vec();
39
40 let mut cursor = tree_sitter::QueryCursor::new();
41 let mut matches = cursor.matches(&query, tree.root_node(), source);
42 let mut results = vec![];
43 while let Some(m) = StreamingIterator::next(&mut matches) {
44 let captures: Vec<QueryMatch> = m
45 .captures
46 .iter()
47 .map(|c| {
48 let name: &str = capture_names.get(c.index as usize).copied().unwrap_or("");
49 node_to_match(&c.node, source, name)
50 })
51 .collect();
52 results.push(captures);
53 }
54 results
55}
56
57pub fn run_queries(
58 tree: &Tree,
59 lang: &Language,
60 source: &[u8],
61 patterns: &[&str],
62) -> Vec<Vec<Vec<QueryMatch>>> {
63 patterns
64 .iter()
65 .map(|p| run_query(tree, lang, source, p))
66 .collect()
67}
68
69pub fn node_to_match(node: &tree_sitter::Node, source: &[u8], capture_name: &str) -> QueryMatch {
71 let text = node.utf8_text(source).unwrap_or("").to_string();
72 QueryMatch {
73 capture_name: capture_name.to_string(),
74 node_kind: node.kind().to_string(),
75 text,
76 start_line: (node.start_position().row as u32) + 1,
77 start_col: node.start_position().column as u32,
78 end_line: (node.end_position().row as u32) + 1,
79 end_col: node.end_position().column as u32,
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86
87 fn parse_rust(src: &str) -> (Tree, Language) {
88 let lang: Language = tree_sitter_rust::LANGUAGE.into();
89 let mut parser = tree_sitter::Parser::new();
90 parser.set_language(&lang).unwrap();
91 let tree = parser.parse(src, None).unwrap();
92 (tree, lang)
93 }
94
95 #[test]
96 fn finds_unsafe_blocks() {
97 let src = "fn main() { unsafe { let _ = 1; } }";
98 let (tree, lang) = parse_rust(src);
99 let matches = run_query(&tree, &lang, src.as_bytes(), "(unsafe_block) @b");
100 assert_eq!(matches.len(), 1);
101 assert_eq!(matches[0][0].node_kind, "unsafe_block");
102 assert_eq!(matches[0][0].start_line, 1);
103 }
104
105 #[test]
106 fn empty_for_invalid_pattern() {
107 let src = "fn main() {}";
108 let (tree, lang) = parse_rust(src);
109 let matches = run_query(&tree, &lang, src.as_bytes(), "(no_such_node_kind) @x");
110 assert!(matches.is_empty());
111 }
112
113 #[test]
114 fn captures_are_1_based() {
115 let src = "// line 1\nfn foo() {}\n";
116 let (tree, lang) = parse_rust(src);
117 let matches = run_query(&tree, &lang, src.as_bytes(), "(function_item) @f");
118 assert_eq!(matches[0][0].start_line, 2);
119 }
120}