Skip to main content

tree_house_bindings/
parser.rs

1use std::cell::Cell;
2use std::ops::ControlFlow;
3use std::os::raw::c_void;
4use std::panic::{catch_unwind, AssertUnwindSafe};
5use std::ptr::NonNull;
6use std::time::{Duration, Instant};
7use std::{fmt, mem, ptr};
8
9use regex_cursor::Cursor;
10
11use crate::grammar::IncompatibleGrammarError;
12use crate::tree::{SyntaxTreeData, Tree};
13use crate::{Grammar, Input, IntoInput, Point, Range};
14
15// opaque data
16enum ParserData {}
17
18#[clippy::msrv = "1.76.0"]
19thread_local! {
20    static PARSER_CACHE: Cell<Option<RawParser>> = const { Cell::new(None) };
21}
22
23struct RawParser {
24    ptr: NonNull<ParserData>,
25}
26
27impl Drop for RawParser {
28    fn drop(&mut self) {
29        unsafe { ts_parser_delete(self.ptr) }
30    }
31}
32
33/// A stateful object that this is used to produce a [`Tree`] based on some
34/// source code.
35pub struct Parser {
36    ptr: NonNull<ParserData>,
37}
38
39impl Parser {
40    /// Create a new parser.
41    #[must_use]
42    pub fn new() -> Parser {
43        let ptr = match PARSER_CACHE.take() {
44            Some(cached) => {
45                let ptr = cached.ptr;
46                mem::forget(cached);
47                ptr
48            }
49            None => unsafe { ts_parser_new() },
50        };
51        Parser { ptr }
52    }
53
54    /// Set the language that the parser should use for parsing.
55    pub fn set_grammar(&mut self, grammar: Grammar) -> Result<(), IncompatibleGrammarError> {
56        if unsafe { ts_parser_set_language(self.ptr, grammar) } {
57            Ok(())
58        } else {
59            Err(IncompatibleGrammarError {
60                abi_version: grammar.abi_version(),
61            })
62        }
63    }
64
65    /// Set the ranges of text that the parser should include when parsing. By default, the parser
66    /// will always include entire documents. This function allows you to parse only a *portion*
67    /// of a document but still return a syntax tree whose ranges match up with the document as a
68    /// whole. You can also pass multiple disjoint ranges.
69    ///
70    /// `ranges` must be non-overlapping and sorted.
71    pub fn set_included_ranges(&mut self, ranges: &[Range]) -> Result<(), InvalidRangesError> {
72        // TODO: save some memory by only storing byte ranges and converting them to TS ranges in an
73        // internal buffer here. Points are not used by TS. Alternatively we can patch the TS C code
74        // to accept a simple pair (struct with two fields) of byte positions here instead of a full
75        // tree sitter range
76        let success = unsafe {
77            ts_parser_set_included_ranges(self.ptr, ranges.as_ptr(), ranges.len() as u32)
78        };
79        if success {
80            Ok(())
81        } else {
82            Err(InvalidRangesError)
83        }
84    }
85
86    #[must_use]
87    pub fn parse<I: Input>(
88        &mut self,
89        input: impl IntoInput<Input = I>,
90        old_tree: Option<&Tree>,
91    ) -> Option<Tree> {
92        let mut input = input.into_input();
93        unsafe extern "C" fn read<C: Input>(
94            payload: NonNull<c_void>,
95            byte_index: u32,
96            _position: Point,
97            bytes_read: *mut u32,
98        ) -> *const u8 {
99            let cursor = catch_unwind(AssertUnwindSafe(move || {
100                let input: &mut C = payload.cast().as_mut();
101                let cursor = input.cursor_at(byte_index);
102                let slice = cursor.chunk();
103                let offset: u32 = cursor.offset().try_into().unwrap();
104                let len: u32 = slice.len().try_into().unwrap();
105                (byte_index - offset, slice.as_ptr(), len)
106            }));
107            match cursor {
108                Ok((chunk_offset, ptr, len)) if chunk_offset < len => {
109                    *bytes_read = len - chunk_offset;
110                    ptr.add(chunk_offset as usize)
111                }
112                _ => {
113                    *bytes_read = 0;
114                    ptr::null()
115                }
116            }
117        }
118        let raw_input = ParserInputRaw {
119            payload: NonNull::from(&mut input).cast(),
120            read: read::<I>,
121            encoding: InputEncoding::Utf8,
122            decode: None,
123        };
124
125        unsafe {
126            let old_tree = old_tree.map(|tree| tree.as_raw());
127            ts_parser_parse(self.ptr, old_tree, raw_input).map(|raw| Tree::from_raw(raw))
128        }
129    }
130
131    /// Parse with a progress/cancellation callback. The callback receives the current
132    /// [`ParseState`] and returns [`ControlFlow::Break`] to cancel parsing.
133    #[must_use]
134    pub fn parse_with_options<I: Input>(
135        &mut self,
136        input: impl IntoInput<Input = I>,
137        old_tree: Option<&Tree>,
138        mut options: ParseOptions<'_>,
139    ) -> Option<Tree> {
140        let mut input = input.into_input();
141        unsafe extern "C" fn read<C: Input>(
142            payload: NonNull<c_void>,
143            byte_index: u32,
144            _position: Point,
145            bytes_read: *mut u32,
146        ) -> *const u8 {
147            let cursor = catch_unwind(AssertUnwindSafe(move || {
148                let input: &mut C = payload.cast().as_mut();
149                let cursor = input.cursor_at(byte_index);
150                let slice = cursor.chunk();
151                let offset: u32 = cursor.offset().try_into().unwrap();
152                let len: u32 = slice.len().try_into().unwrap();
153                (byte_index - offset, slice.as_ptr(), len)
154            }));
155            match cursor {
156                Ok((chunk_offset, ptr, len)) if chunk_offset < len => {
157                    *bytes_read = len - chunk_offset;
158                    ptr.add(chunk_offset as usize)
159                }
160                _ => {
161                    *bytes_read = 0;
162                    ptr::null()
163                }
164            }
165        }
166        let raw_input = ParserInputRaw {
167            payload: NonNull::from(&mut input).cast(),
168            read: read::<I>,
169            encoding: InputEncoding::Utf8,
170            decode: None,
171        };
172
173        // The payload is a thin pointer to the fat pointer stored in options.callback.
174        // The callback reconstructs the fat pointer and calls through it.
175        unsafe extern "C" fn progress_cb(raw_state: NonNull<RawParseState>) -> bool {
176            let raw_ref = raw_state.as_ref();
177            let cb: *mut &mut dyn FnMut(&ParseState) -> ControlFlow<()> =
178                raw_ref.payload.as_ptr().cast();
179            let public_state = ParseState {
180                current_byte_offset: raw_ref.current_byte_offset,
181                has_error: raw_ref.has_error,
182            };
183            (*cb)(&public_state).is_break()
184        }
185
186        let raw_options = RawParseOptions {
187            payload: unsafe {
188                Some(NonNull::new_unchecked(
189                    ptr::addr_of_mut!(options.callback).cast(),
190                ))
191            },
192            progress_callback: Some(progress_cb),
193        };
194
195        unsafe {
196            let old_tree = old_tree.map(|tree| tree.as_raw());
197            ts_parser_parse_with_options(self.ptr, old_tree, raw_input, raw_options)
198                .map(|raw| Tree::from_raw(raw))
199        }
200    }
201
202    /// Parse with a timeout. Returns `None` if parsing is not completed within `timeout`.
203    #[must_use]
204    pub fn parse_with_timeout<I: Input>(
205        &mut self,
206        input: impl IntoInput<Input = I>,
207        old_tree: Option<&Tree>,
208        timeout: Duration,
209    ) -> Option<Tree> {
210        let deadline = Instant::now() + timeout;
211        let mut check = |_: &ParseState| {
212            if Instant::now() >= deadline {
213                ControlFlow::Break(())
214            } else {
215                ControlFlow::Continue(())
216            }
217        };
218        self.parse_with_options(input, old_tree, ParseOptions::new(&mut check))
219    }
220}
221
222impl Default for Parser {
223    fn default() -> Self {
224        Self::new()
225    }
226}
227
228unsafe impl Sync for Parser {}
229unsafe impl Send for Parser {}
230
231impl Drop for Parser {
232    fn drop(&mut self) {
233        PARSER_CACHE.set(Some(RawParser { ptr: self.ptr }));
234    }
235}
236
237/// State passed to the progress callback during parsing.
238#[derive(Debug, Clone, Copy)]
239pub struct ParseState {
240    pub current_byte_offset: u32,
241    pub has_error: bool,
242}
243
244/// Options for [`Parser::parse_with_options`].
245///
246/// The callback receives the current [`ParseState`] and returns [`ControlFlow::Break`] to cancel parsing.
247pub struct ParseOptions<'a> {
248    callback: &'a mut dyn FnMut(&ParseState) -> ControlFlow<()>,
249}
250
251impl<'a> ParseOptions<'a> {
252    pub fn new(callback: &'a mut impl FnMut(&ParseState) -> ControlFlow<()>) -> ParseOptions<'a> {
253        ParseOptions { callback }
254    }
255}
256
257/// An error that occurred when trying to assign an incompatible [`Grammar`] to
258/// a [`Parser`].
259#[derive(Debug, PartialEq, Eq)]
260pub struct InvalidRangesError;
261
262impl fmt::Display for InvalidRangesError {
263    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
264        write!(f, "include ranges overlap or are not sorted",)
265    }
266}
267impl std::error::Error for InvalidRangesError {}
268
269type TreeSitterReadFn = unsafe extern "C" fn(
270    payload: NonNull<c_void>,
271    byte_index: u32,
272    position: Point,
273    bytes_read: *mut u32,
274) -> *const u8;
275
276/// A function that reads one code point from the given string, returning the number of bytes
277/// consumed.
278type DecodeInputFn =
279    unsafe extern "C" fn(string: *const u8, length: u32, code_point: *const i32) -> u32;
280
281#[repr(C)]
282#[derive(Debug)]
283pub struct ParserInputRaw {
284    pub payload: NonNull<c_void>,
285    pub read: TreeSitterReadFn,
286    pub encoding: InputEncoding,
287    /// A function to decode the the input.
288    ///
289    /// This function is only used if the encoding is `InputEncoding::Custom`.
290    pub decode: Option<DecodeInputFn>,
291}
292
293// `TSInputEncoding`
294#[repr(u32)]
295#[derive(Debug, Clone, Copy)]
296pub enum InputEncoding {
297    Utf8,
298    Utf16LE,
299    Utf16BE,
300    Custom,
301}
302
303#[repr(C)]
304#[derive(Debug)]
305struct RawParseState {
306    /// The payload passed via `RawParseOptions`' `payload` field.
307    payload: NonNull<c_void>,
308    current_byte_offset: u32,
309    has_error: bool,
310}
311
312/// A function that accepts the current parser state and returns `true` when the parse should be
313/// cancelled.
314type ProgressCallback = unsafe extern "C" fn(state: NonNull<RawParseState>) -> bool;
315
316#[repr(C)]
317#[derive(Debug, Default)]
318struct RawParseOptions {
319    payload: Option<NonNull<c_void>>,
320    progress_callback: Option<ProgressCallback>,
321}
322
323extern "C" {
324    /// Create a new parser
325    fn ts_parser_new() -> NonNull<ParserData>;
326    /// Delete the parser, freeing all of the memory that it used.
327    fn ts_parser_delete(parser: NonNull<ParserData>);
328    /// Set the language that the parser should use for parsing. Returns a boolean indicating
329    /// whether or not the language was successfully assigned. True means assignment
330    /// succeeded. False means there was a version mismatch: the language was generated with
331    /// an incompatible version of the Tree-sitter CLI. Check the language's version using
332    /// `ts_language_version` and compare it to this library's `TREE_SITTER_LANGUAGE_VERSION`
333    /// and `TREE_SITTER_MIN_COMPATIBLE_LANGUAGE_VERSION` constants.
334    fn ts_parser_set_language(parser: NonNull<ParserData>, language: Grammar) -> bool;
335    /// Set the ranges of text that the parser should include when parsing. By default, the parser
336    /// will always include entire documents. This function allows you to parse only a *portion*
337    /// of a document but still return a syntax tree whose ranges match up with the document as a
338    /// whole. You can also pass multiple disjoint ranges. The second and third parameters specify
339    /// the location and length of an array of ranges. The parser does *not* take ownership of
340    /// these ranges; it copies the data, so it doesn't matter how these ranges are allocated.
341    /// If `count` is zero, then the entire document will be parsed. Otherwise, the given ranges
342    /// must be ordered from earliest to latest in the document, and they must not overlap. That
343    /// is, the following must hold for all: `i < count - 1`: `ranges[i].end_byte <= ranges[i +
344    /// 1].start_byte` If this requirement is not satisfied, the operation will fail, the ranges
345    /// will not be assigned, and this function will return `false`. On success, this function
346    /// returns `true`
347    fn ts_parser_set_included_ranges(
348        parser: NonNull<ParserData>,
349        ranges: *const Range,
350        count: u32,
351    ) -> bool;
352
353    fn ts_parser_parse(
354        parser: NonNull<ParserData>,
355        old_tree: Option<NonNull<SyntaxTreeData>>,
356        input: ParserInputRaw,
357    ) -> Option<NonNull<SyntaxTreeData>>;
358
359    /// Use the parser to parse some source code and create a syntax tree, with some options.
360    ///
361    /// See `ts_parser_parse` for more details.
362    ///
363    /// See `TSParseOptions` for more details on the options.
364    fn ts_parser_parse_with_options(
365        parser: NonNull<ParserData>,
366        old_tree: Option<NonNull<SyntaxTreeData>>,
367        input: ParserInputRaw,
368        parse_options: RawParseOptions,
369    ) -> Option<NonNull<SyntaxTreeData>>;
370}