tree_sitter_cli/
query.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
use std::{
    fs,
    io::{self, Write},
    ops::Range,
    path::Path,
    time::Instant,
};

use anyhow::{Context, Result};
use tree_sitter::{Language, Parser, Point, Query, QueryCursor};

use crate::query_testing::{self, to_utf8_point};

#[allow(clippy::too_many_arguments)]
pub fn query_files_at_paths(
    language: &Language,
    paths: Vec<String>,
    query_path: &Path,
    ordered_captures: bool,
    byte_range: Option<Range<usize>>,
    point_range: Option<Range<Point>>,
    should_test: bool,
    quiet: bool,
    print_time: bool,
) -> Result<()> {
    let stdout = io::stdout();
    let mut stdout = stdout.lock();

    let query_source = fs::read_to_string(query_path)
        .with_context(|| format!("Error reading query file {query_path:?}"))?;
    let query = Query::new(language, &query_source).with_context(|| "Query compilation failed")?;

    let mut query_cursor = QueryCursor::new();
    if let Some(range) = byte_range {
        query_cursor.set_byte_range(range);
    }
    if let Some(range) = point_range {
        query_cursor.set_point_range(range);
    }

    let mut parser = Parser::new();
    parser.set_language(language)?;

    for path in paths {
        let mut results = Vec::new();

        writeln!(&mut stdout, "{path}")?;

        let source_code =
            fs::read(&path).with_context(|| format!("Error reading source file {path:?}"))?;
        let tree = parser.parse(&source_code, None).unwrap();

        let start = Instant::now();
        if ordered_captures {
            for (mat, capture_index) in
                query_cursor.captures(&query, tree.root_node(), source_code.as_slice())
            {
                let capture = mat.captures[capture_index];
                let capture_name = &query.capture_names()[capture.index as usize];
                if !quiet {
                    writeln!(
                        &mut stdout,
                        "    pattern: {:>2}, capture: {} - {capture_name}, start: {}, end: {}, text: `{}`",
                        mat.pattern_index,
                        capture.index,
                        capture.node.start_position(),
                        capture.node.end_position(),
                        capture.node.utf8_text(&source_code).unwrap_or("")
                    )?;
                }
                results.push(query_testing::CaptureInfo {
                    name: (*capture_name).to_string(),
                    start: to_utf8_point(capture.node.start_position(), source_code.as_slice()),
                    end: to_utf8_point(capture.node.end_position(), source_code.as_slice()),
                });
            }
        } else {
            for m in query_cursor.matches(&query, tree.root_node(), source_code.as_slice()) {
                if !quiet {
                    writeln!(&mut stdout, "  pattern: {}", m.pattern_index)?;
                }
                for capture in m.captures {
                    let start = capture.node.start_position();
                    let end = capture.node.end_position();
                    let capture_name = &query.capture_names()[capture.index as usize];
                    if !quiet {
                        if end.row == start.row {
                            writeln!(
                                &mut stdout,
                                "    capture: {} - {capture_name}, start: {start}, end: {end}, text: `{}`",
                                capture.index,
                                capture.node.utf8_text(&source_code).unwrap_or("")
                            )?;
                        } else {
                            writeln!(
                                &mut stdout,
                                "    capture: {capture_name}, start: {start}, end: {end}",
                            )?;
                        }
                    }
                    results.push(query_testing::CaptureInfo {
                        name: (*capture_name).to_string(),
                        start: to_utf8_point(capture.node.start_position(), source_code.as_slice()),
                        end: to_utf8_point(capture.node.end_position(), source_code.as_slice()),
                    });
                }
            }
        }
        if query_cursor.did_exceed_match_limit() {
            writeln!(
                &mut stdout,
                "  WARNING: Query exceeded maximum number of in-progress captures!"
            )?;
        }
        if should_test {
            query_testing::assert_expected_captures(&results, path, &mut parser, language)?;
        }
        if print_time {
            writeln!(&mut stdout, "{:?}", start.elapsed())?;
        }
    }

    Ok(())
}