qlty_analysis/code/
filter.rs1use crate::code::{File, QUERY_MATCH_LIMIT};
2use crate::Language;
3use anyhow::Context;
4use std::sync::Arc;
5use tree_sitter::{Node, Query, Range, Tree};
6
7#[derive(Debug, Default, Clone)]
8pub struct NodeFilter {
9 ranges: Vec<Range>,
10}
11
12impl NodeFilter {
13 pub fn empty() -> Self {
14 Self::default()
15 }
16
17 pub fn add_range(&mut self, range: &Range) {
18 if !self.contains_range(range) {
19 self.ranges.push(range.to_owned());
20 }
21 }
22
23 pub fn exclude(&self, node: &Node) -> bool {
24 let range = node.range();
25 self.contains_range(&range)
26 }
27
28 fn contains_range(&self, candidate: &Range) -> bool {
29 self.ranges
30 .iter()
31 .any(|each| self.is_subrange(each, candidate))
32 }
33
34 fn is_subrange(&self, range: &Range, subrange: &Range) -> bool {
35 subrange.start_byte >= range.start_byte && subrange.end_byte <= range.end_byte
36 }
37}
38
39#[derive(Debug, Clone)]
40pub struct NodeFilterBuilder {
41 queries: Vec<Arc<Query>>,
42}
43
44impl NodeFilterBuilder {
45 #[allow(clippy::borrowed_box)]
46 pub fn for_patterns(language: &Box<dyn Language + Sync>, patterns: Vec<String>) -> Self {
47 let queries = patterns
48 .iter()
49 .map(|pattern| {
50 let query_with_capture = format!("{} @the-capture", pattern);
51 let query = Query::new(&language.tree_sitter_language(), &query_with_capture)
52 .with_context(|| {
53 format!(
54 "Failed to parse {} query: {}",
55 language.name(),
56 &query_with_capture
57 )
58 });
59
60 if query.is_err() {
61 query.as_ref().unwrap();
62 }
63
64 Arc::new(query.unwrap())
65 })
66 .collect();
67
68 Self { queries }
69 }
70
71 pub fn build(&self, source_file: &File, tree: &Tree) -> NodeFilter {
72 let mut filter = NodeFilter::empty();
73 let node = tree.root_node();
74
75 for query in self.queries.iter() {
76 let mut cursor = tree_sitter::QueryCursor::new();
77 cursor.set_match_limit(QUERY_MATCH_LIMIT as u32);
78
79 let ranges: Vec<_> = cursor
80 .matches(query, node, source_file.contents.as_bytes())
81 .flat_map(|each_match| each_match.captures)
82 .map(|capture| capture.node.range())
83 .collect();
84
85 for range in &ranges {
86 filter.add_range(range);
87 }
88 }
89
90 filter
91 }
92}
93
94#[cfg(test)]
95mod test {
96 use super::*;
97
98 #[test]
99 fn none() {
100 let source_file = File::from_string("javascript", "function foo() {}");
101 let tree = source_file.parse();
102 let filter = NodeFilter::empty();
103 assert_eq!(false, filter.exclude(&tree.root_node()))
104 }
105
106 #[test]
107 fn basic() {
108 let language = crate::lang::from_str("javascript").unwrap();
109 let patterns = vec!["(function_declaration)".to_string()];
110 let builder = NodeFilterBuilder::for_patterns(language, patterns);
111
112 let source_file = File::from_string("javascript", "function foo() {}");
113 let tree = source_file.parse();
114
115 let filter = builder.build(&source_file, &tree);
116 insta::assert_debug_snapshot!(filter, @r"
117 NodeFilter {
118 ranges: [
119 Range {
120 start_byte: 0,
121 end_byte: 17,
122 start_point: Point {
123 row: 0,
124 column: 0,
125 },
126 end_point: Point {
127 row: 0,
128 column: 17,
129 },
130 },
131 ],
132 }
133 ");
134 }
135
136 #[test]
137 fn nested() {
138 let language = crate::lang::from_str("javascript").unwrap();
139 let patterns = vec!["(function_declaration)".to_string()];
140 let builder = NodeFilterBuilder::for_patterns(language, patterns);
141
142 let source_file = File::from_string("javascript", "function foo() { function bar() {}}");
143 let tree = source_file.parse();
144
145 let filter = builder.build(&source_file, &tree);
146 insta::assert_debug_snapshot!(filter, @r"
147 NodeFilter {
148 ranges: [
149 Range {
150 start_byte: 0,
151 end_byte: 35,
152 start_point: Point {
153 row: 0,
154 column: 0,
155 },
156 end_point: Point {
157 row: 0,
158 column: 35,
159 },
160 },
161 ],
162 }
163 ");
164 }
165}