tree_sitter_cli/
query.rs

1use std::{
2    fs,
3    io::{self, Write},
4    ops::Range,
5    path::Path,
6    time::Instant,
7};
8
9use anyhow::{Context, Result};
10use log::warn;
11use streaming_iterator::StreamingIterator;
12use tree_sitter::{Language, Parser, Point, Query, QueryCursor};
13
14use crate::{
15    query_testing::{self, to_utf8_point},
16    test::{TestInfo, TestOutcome, TestResult, TestSummary},
17};
18
19#[derive(Default)]
20pub struct QueryFileOptions {
21    pub ordered_captures: bool,
22    pub byte_range: Option<Range<usize>>,
23    pub point_range: Option<Range<Point>>,
24    pub containing_byte_range: Option<Range<usize>>,
25    pub containing_point_range: Option<Range<Point>>,
26    pub quiet: bool,
27    pub print_time: bool,
28    pub stdin: bool,
29}
30
31pub fn query_file_at_path(
32    language: &Language,
33    path: &Path,
34    name: &str,
35    query_path: &Path,
36    opts: &QueryFileOptions,
37    test_summary: Option<&mut TestSummary>,
38) -> Result<()> {
39    let stdout = io::stdout();
40    let mut stdout = stdout.lock();
41
42    let query_source = fs::read_to_string(query_path)
43        .with_context(|| format!("Error reading query file {}", query_path.display()))?;
44    let query = Query::new(language, &query_source).with_context(|| "Query compilation failed")?;
45
46    let mut query_cursor = QueryCursor::new();
47    if let Some(ref range) = opts.byte_range {
48        query_cursor.set_byte_range(range.clone());
49    }
50    if let Some(ref range) = opts.point_range {
51        query_cursor.set_point_range(range.clone());
52    }
53    if let Some(ref range) = opts.containing_byte_range {
54        query_cursor.set_containing_byte_range(range.clone());
55    }
56    if let Some(ref range) = opts.containing_point_range {
57        query_cursor.set_containing_point_range(range.clone());
58    }
59
60    let mut parser = Parser::new();
61    parser.set_language(language)?;
62
63    let mut results = Vec::new();
64    let should_test = test_summary.is_some();
65
66    if !should_test && !opts.stdin {
67        writeln!(&mut stdout, "{name}")?;
68    }
69
70    let source_code =
71        fs::read(path).with_context(|| format!("Error reading source file {}", path.display()))?;
72    let tree = parser.parse(&source_code, None).unwrap();
73
74    let start = Instant::now();
75    if opts.ordered_captures {
76        let mut captures = query_cursor.captures(&query, tree.root_node(), source_code.as_slice());
77        while let Some((mat, capture_index)) = captures.next() {
78            let capture = mat.captures[*capture_index];
79            let capture_name = &query.capture_names()[capture.index as usize];
80            if !opts.quiet && !should_test {
81                writeln!(
82                        &mut stdout,
83                        "    pattern: {:>2}, capture: {} - {capture_name}, start: {}, end: {}, text: `{}`",
84                        mat.pattern_index,
85                        capture.index,
86                        capture.node.start_position(),
87                        capture.node.end_position(),
88                        capture.node.utf8_text(&source_code).unwrap_or("")
89                    )?;
90            }
91            if should_test {
92                results.push(query_testing::CaptureInfo {
93                    name: (*capture_name).to_string(),
94                    start: to_utf8_point(capture.node.start_position(), source_code.as_slice()),
95                    end: to_utf8_point(capture.node.end_position(), source_code.as_slice()),
96                });
97            }
98        }
99    } else {
100        let mut matches = query_cursor.matches(&query, tree.root_node(), source_code.as_slice());
101        while let Some(m) = matches.next() {
102            if !opts.quiet && !should_test {
103                writeln!(&mut stdout, "  pattern: {}", m.pattern_index)?;
104            }
105            for capture in m.captures {
106                let start = capture.node.start_position();
107                let end = capture.node.end_position();
108                let capture_name = &query.capture_names()[capture.index as usize];
109                if !opts.quiet && !should_test {
110                    if end.row == start.row {
111                        writeln!(
112                                &mut stdout,
113                                "    capture: {} - {capture_name}, start: {start}, end: {end}, text: `{}`",
114                                capture.index,
115                                capture.node.utf8_text(&source_code).unwrap_or("")
116                            )?;
117                    } else {
118                        writeln!(
119                            &mut stdout,
120                            "    capture: {capture_name}, start: {start}, end: {end}",
121                        )?;
122                    }
123                }
124                if should_test {
125                    results.push(query_testing::CaptureInfo {
126                        name: (*capture_name).to_string(),
127                        start: to_utf8_point(capture.node.start_position(), source_code.as_slice()),
128                        end: to_utf8_point(capture.node.end_position(), source_code.as_slice()),
129                    });
130                }
131            }
132        }
133    }
134    if query_cursor.did_exceed_match_limit() {
135        warn!("Query exceeded maximum number of in-progress captures!");
136    }
137    if should_test {
138        let path_name = if opts.stdin {
139            "stdin"
140        } else {
141            Path::new(&path).file_name().unwrap().to_str().unwrap()
142        };
143        // Invariant: `test_summary` will always be `Some` when `should_test` is true
144        let test_summary = test_summary.unwrap();
145        match query_testing::assert_expected_captures(&results, path, &mut parser, language) {
146            Ok(assertion_count) => {
147                test_summary.query_results.add_case(TestResult {
148                    name: path_name.to_string(),
149                    info: TestInfo::AssertionTest {
150                        outcome: TestOutcome::AssertionPassed { assertion_count },
151                        test_num: test_summary.test_num,
152                    },
153                });
154            }
155            Err(e) => {
156                test_summary.query_results.add_case(TestResult {
157                    name: path_name.to_string(),
158                    info: TestInfo::AssertionTest {
159                        outcome: TestOutcome::AssertionFailed {
160                            error: e.to_string(),
161                        },
162                        test_num: test_summary.test_num,
163                    },
164                });
165                return Err(e);
166            }
167        }
168    }
169    if opts.print_time {
170        writeln!(&mut stdout, "{:?}", start.elapsed())?;
171    }
172
173    Ok(())
174}