qlty_analysis/code/
filter.rs

1use crate::code::{File, QUERY_MATCH_LIMIT};
2use crate::Language;
3use anyhow::Context;
4use std::sync::Arc;
5use tree_sitter::{Node, Query, Range, Tree};
6
7#[derive(Debug, Default, Clone)]
8pub struct NodeFilter {
9    ranges: Vec<Range>,
10}
11
12impl NodeFilter {
13    pub fn empty() -> Self {
14        Self::default()
15    }
16
17    pub fn add_range(&mut self, range: &Range) {
18        if !self.contains_range(range) {
19            self.ranges.push(range.to_owned());
20        }
21    }
22
23    pub fn exclude(&self, node: &Node) -> bool {
24        let range = node.range();
25        self.contains_range(&range)
26    }
27
28    fn contains_range(&self, candidate: &Range) -> bool {
29        self.ranges
30            .iter()
31            .any(|each| self.is_subrange(each, candidate))
32    }
33
34    fn is_subrange(&self, range: &Range, subrange: &Range) -> bool {
35        subrange.start_byte >= range.start_byte && subrange.end_byte <= range.end_byte
36    }
37}
38
39#[derive(Debug, Clone)]
40pub struct NodeFilterBuilder {
41    queries: Vec<Arc<Query>>,
42}
43
44impl NodeFilterBuilder {
45    #[allow(clippy::borrowed_box)]
46    pub fn for_patterns(language: &Box<dyn Language + Sync>, patterns: Vec<String>) -> Self {
47        let queries = patterns
48            .iter()
49            .map(|pattern| {
50                let query_with_capture = format!("{} @the-capture", pattern);
51                let query = Query::new(&language.tree_sitter_language(), &query_with_capture)
52                    .with_context(|| {
53                        format!(
54                            "Failed to parse {} query: {}",
55                            language.name(),
56                            &query_with_capture
57                        )
58                    });
59
60                if query.is_err() {
61                    query.as_ref().unwrap();
62                }
63
64                Arc::new(query.unwrap())
65            })
66            .collect();
67
68        Self { queries }
69    }
70
71    pub fn build(&self, source_file: &File, tree: &Tree) -> NodeFilter {
72        let mut filter = NodeFilter::empty();
73        let node = tree.root_node();
74
75        for query in self.queries.iter() {
76            let mut cursor = tree_sitter::QueryCursor::new();
77            cursor.set_match_limit(QUERY_MATCH_LIMIT as u32);
78
79            let ranges: Vec<_> = cursor
80                .matches(query, node, source_file.contents.as_bytes())
81                .flat_map(|each_match| each_match.captures)
82                .map(|capture| capture.node.range())
83                .collect();
84
85            for range in &ranges {
86                filter.add_range(range);
87            }
88        }
89
90        filter
91    }
92}
93
94#[cfg(test)]
95mod test {
96    use super::*;
97
98    #[test]
99    fn none() {
100        let source_file = File::from_string("javascript", "function foo() {}");
101        let tree = source_file.parse();
102        let filter = NodeFilter::empty();
103        assert_eq!(false, filter.exclude(&tree.root_node()))
104    }
105
106    #[test]
107    fn basic() {
108        let language = crate::lang::from_str("javascript").unwrap();
109        let patterns = vec!["(function_declaration)".to_string()];
110        let builder = NodeFilterBuilder::for_patterns(language, patterns);
111
112        let source_file = File::from_string("javascript", "function foo() {}");
113        let tree = source_file.parse();
114
115        let filter = builder.build(&source_file, &tree);
116        insta::assert_debug_snapshot!(filter, @r"
117        NodeFilter {
118            ranges: [
119                Range {
120                    start_byte: 0,
121                    end_byte: 17,
122                    start_point: Point {
123                        row: 0,
124                        column: 0,
125                    },
126                    end_point: Point {
127                        row: 0,
128                        column: 17,
129                    },
130                },
131            ],
132        }
133        ");
134    }
135
136    #[test]
137    fn nested() {
138        let language = crate::lang::from_str("javascript").unwrap();
139        let patterns = vec!["(function_declaration)".to_string()];
140        let builder = NodeFilterBuilder::for_patterns(language, patterns);
141
142        let source_file = File::from_string("javascript", "function foo() { function bar() {}}");
143        let tree = source_file.parse();
144
145        let filter = builder.build(&source_file, &tree);
146        insta::assert_debug_snapshot!(filter, @r"
147        NodeFilter {
148            ranges: [
149                Range {
150                    start_byte: 0,
151                    end_byte: 35,
152                    start_point: Point {
153                        row: 0,
154                        column: 0,
155                    },
156                    end_point: Point {
157                        row: 0,
158                        column: 35,
159                    },
160                },
161            ],
162        }
163        ");
164    }
165}