1use std::{
2 fs,
3 io::{self, Write},
4 ops::Range,
5 path::Path,
6 time::Instant,
7};
8
9use anyhow::{Context, Result};
10use log::warn;
11use streaming_iterator::StreamingIterator;
12use tree_sitter::{Language, Parser, Point, Query, QueryCursor};
13
14use crate::{
15 query_testing::{self, to_utf8_point},
16 test::{TestInfo, TestOutcome, TestResult, TestSummary},
17};
18
19#[derive(Default)]
20pub struct QueryFileOptions {
21 pub ordered_captures: bool,
22 pub byte_range: Option<Range<usize>>,
23 pub point_range: Option<Range<Point>>,
24 pub containing_byte_range: Option<Range<usize>>,
25 pub containing_point_range: Option<Range<Point>>,
26 pub quiet: bool,
27 pub print_time: bool,
28 pub stdin: bool,
29}
30
31pub fn query_file_at_path(
32 language: &Language,
33 path: &Path,
34 name: &str,
35 query_path: &Path,
36 opts: &QueryFileOptions,
37 test_summary: Option<&mut TestSummary>,
38) -> Result<()> {
39 let stdout = io::stdout();
40 let mut stdout = stdout.lock();
41
42 let query_source = fs::read_to_string(query_path)
43 .with_context(|| format!("Error reading query file {}", query_path.display()))?;
44 let query = Query::new(language, &query_source).with_context(|| "Query compilation failed")?;
45
46 let mut query_cursor = QueryCursor::new();
47 if let Some(ref range) = opts.byte_range {
48 query_cursor.set_byte_range(range.clone());
49 }
50 if let Some(ref range) = opts.point_range {
51 query_cursor.set_point_range(range.clone());
52 }
53 if let Some(ref range) = opts.containing_byte_range {
54 query_cursor.set_containing_byte_range(range.clone());
55 }
56 if let Some(ref range) = opts.containing_point_range {
57 query_cursor.set_containing_point_range(range.clone());
58 }
59
60 let mut parser = Parser::new();
61 parser.set_language(language)?;
62
63 let mut results = Vec::new();
64 let should_test = test_summary.is_some();
65
66 if !should_test && !opts.stdin {
67 writeln!(&mut stdout, "{name}")?;
68 }
69
70 let source_code =
71 fs::read(path).with_context(|| format!("Error reading source file {}", path.display()))?;
72 let tree = parser.parse(&source_code, None).unwrap();
73
74 let start = Instant::now();
75 if opts.ordered_captures {
76 let mut captures = query_cursor.captures(&query, tree.root_node(), source_code.as_slice());
77 while let Some((mat, capture_index)) = captures.next() {
78 let capture = mat.captures[*capture_index];
79 let capture_name = &query.capture_names()[capture.index as usize];
80 if !opts.quiet && !should_test {
81 writeln!(
82 &mut stdout,
83 " pattern: {:>2}, capture: {} - {capture_name}, start: {}, end: {}, text: `{}`",
84 mat.pattern_index,
85 capture.index,
86 capture.node.start_position(),
87 capture.node.end_position(),
88 capture.node.utf8_text(&source_code).unwrap_or("")
89 )?;
90 }
91 if should_test {
92 results.push(query_testing::CaptureInfo {
93 name: (*capture_name).to_string(),
94 start: to_utf8_point(capture.node.start_position(), source_code.as_slice()),
95 end: to_utf8_point(capture.node.end_position(), source_code.as_slice()),
96 });
97 }
98 }
99 } else {
100 let mut matches = query_cursor.matches(&query, tree.root_node(), source_code.as_slice());
101 while let Some(m) = matches.next() {
102 if !opts.quiet && !should_test {
103 writeln!(&mut stdout, " pattern: {}", m.pattern_index)?;
104 }
105 for capture in m.captures {
106 let start = capture.node.start_position();
107 let end = capture.node.end_position();
108 let capture_name = &query.capture_names()[capture.index as usize];
109 if !opts.quiet && !should_test {
110 if end.row == start.row {
111 writeln!(
112 &mut stdout,
113 " capture: {} - {capture_name}, start: {start}, end: {end}, text: `{}`",
114 capture.index,
115 capture.node.utf8_text(&source_code).unwrap_or("")
116 )?;
117 } else {
118 writeln!(
119 &mut stdout,
120 " capture: {capture_name}, start: {start}, end: {end}",
121 )?;
122 }
123 }
124 if should_test {
125 results.push(query_testing::CaptureInfo {
126 name: (*capture_name).to_string(),
127 start: to_utf8_point(capture.node.start_position(), source_code.as_slice()),
128 end: to_utf8_point(capture.node.end_position(), source_code.as_slice()),
129 });
130 }
131 }
132 }
133 }
134 if query_cursor.did_exceed_match_limit() {
135 warn!("Query exceeded maximum number of in-progress captures!");
136 }
137 if should_test {
138 let path_name = if opts.stdin {
139 "stdin"
140 } else {
141 Path::new(&path).file_name().unwrap().to_str().unwrap()
142 };
143 let test_summary = test_summary.unwrap();
145 match query_testing::assert_expected_captures(&results, path, &mut parser, language) {
146 Ok(assertion_count) => {
147 test_summary.query_results.add_case(TestResult {
148 name: path_name.to_string(),
149 info: TestInfo::AssertionTest {
150 outcome: TestOutcome::AssertionPassed { assertion_count },
151 test_num: test_summary.test_num,
152 },
153 });
154 }
155 Err(e) => {
156 test_summary.query_results.add_case(TestResult {
157 name: path_name.to_string(),
158 info: TestInfo::AssertionTest {
159 outcome: TestOutcome::AssertionFailed {
160 error: e.to_string(),
161 },
162 test_num: test_summary.test_num,
163 },
164 });
165 return Err(e);
166 }
167 }
168 }
169 if opts.print_time {
170 writeln!(&mut stdout, "{:?}", start.elapsed())?;
171 }
172
173 Ok(())
174}