ad_editor/
ts.rs

1//! Support for tree-sitter incremental parsing, querying and highlighting of Buffers
2//!
3//! For a given language the user needs to provide a .so file containing the compiled
4//! tree-sitter parser and a highlights .scm file for driving the highlighting.
5//!
6//! Producing the token stream for a given buffer is handled in a multi-step process in
7//! order to support caching of tokens per-line and not baking in an explicit rendered
8//! representation (e.g. ANSI terminal escape codes) to the output.
9//!   - The file as a whole is tokenized via tree-sitter using a user provided query
10//!   - Tokens are obtained per-line using a [LineIter] which may be efficiently started
11//!     at a non-zero line offset when needed
12//!   - The [TokenIter] type returned by [LineIter] yields [RangeToken]s containing the
13//!     tags provided by the user in their query
14//!   - [TK_DEFAULT] tokens are injected between those identified by the user's query so
15//!     that the full token stream from a [TokenIter] will always contain the complete
16//!     text of the raw buffer line
17//!   - [RangeToken]s are tagged byte offsets within the parent [Buffer] which may be used
18//!     to extract and render sub-regions of text. In order to implement horizontal scrolling
19//!     and clamping of text based on the available screen columns, a UI implementation will
20//!     need to make use of [unicode_width::UnicodeWidthChar] in order to determine whether
21//!     none, part or all of any given token should be rendered.
22use crate::{
23    buffer::{GapBuffer, Slice, SliceIter},
24    dot::Range,
25};
26use libloading::{Library, Symbol};
27use std::{
28    cmp::{max, min, Ord, Ordering, PartialOrd},
29    collections::HashSet,
30    fmt, fs,
31    iter::{repeat_n, Peekable},
32    ops::{Deref, DerefMut},
33    path::Path,
34    slice,
35};
36use streaming_iterator::StreamingIterator;
37use tracing::{error, info};
38use tree_sitter::{self as ts, ffi::TSLanguage};
39
40pub const TK_DEFAULT: &str = "default";
41pub const TK_DOT: &str = "dot";
42pub const TK_LOAD: &str = "load";
43pub const TK_EXEC: &str = "exec";
44pub const SUPPORTED_PREDICATES: [&str; 0] = [];
45
46/// Buffer level tree-sitter state for parsing and highlighting
47#[derive(Debug)]
48pub struct TsState {
49    tree: ts::Tree,
50    p: Parser,
51    t: Tokenizer,
52}
53
54impl TsState {
55    pub fn try_new(
56        lang: &str,
57        so_dir: &str,
58        query_dir: &str,
59        gb: &GapBuffer,
60    ) -> Result<Self, String> {
61        let query_path = Path::new(query_dir).join(lang).join("highlights.scm");
62        let query = match fs::read_to_string(query_path) {
63            Ok(s) => s,
64            Err(e) => return Err(format!("unable to read tree-sitter query file: {e}")),
65        };
66
67        let p = Parser::try_new(so_dir, lang)?;
68
69        Self::try_new_explicit(p, &query, gb)
70    }
71
72    #[cfg(test)]
73    pub(crate) fn try_new_from_language(
74        lang_name: &str,
75        lang: ts::Language,
76        query: &str,
77        gb: &GapBuffer,
78    ) -> Result<Self, String> {
79        let p = Parser::try_new_from_language(lang_name, lang)?;
80
81        Self::try_new_explicit(p, query, gb)
82    }
83
84    fn try_new_explicit(mut p: Parser, query: &str, gb: &GapBuffer) -> Result<Self, String> {
85        let tree = p.parse_with(
86            &mut |byte_offset, _| gb.maximal_slice_from_offset(byte_offset),
87            None,
88        );
89
90        match tree {
91            Some(tree) => {
92                let mut t = p.new_tokenizer(query)?;
93                t.update(tree.root_node(), gb, 0, usize::MAX - 1);
94                info!("TS loaded for {}", p.lang_name);
95
96                Ok(Self { p, t, tree })
97            }
98            None => Err("failed to parse file".to_owned()),
99        }
100    }
101
102    pub fn edit(&mut self, ch_start: usize, ch_old_end: usize, ch_new_end: usize, gb: &GapBuffer) {
103        self.tree.edit(&ts::InputEdit {
104            start_byte: gb.char_to_byte(ch_start),
105            old_end_byte: gb.char_to_byte(ch_old_end),
106            new_end_byte: gb.char_to_byte(ch_new_end),
107            // See https://github.com/tree-sitter/tree-sitter/discussions/1793 for why this OK
108            start_position: ts::Point::new(0, 0),
109            old_end_position: ts::Point::new(0, 0),
110            new_end_position: ts::Point::new(0, 0),
111        });
112
113        let new_tree = self.p.parse_with(
114            &mut |byte_offset, _| gb.maximal_slice_from_offset(byte_offset),
115            Some(&self.tree),
116        );
117
118        if let Some(tree) = new_tree {
119            // TODO: it might be looking at self.tree.changed_ranges(&tree) to optimise being able
120            // to only tokenize regions we're missing
121            self.tree = tree;
122        }
123
124        self.t.ranges.clear();
125    }
126
127    pub fn update(&mut self, gb: &GapBuffer, from: usize, n_rows: usize) {
128        let byte_from = gb.char_to_byte(gb.line_to_char(from));
129        let byte_to = if from + n_rows + 1 < gb.len_lines() {
130            gb.char_to_byte(gb.line_to_char(from + n_rows + 1))
131        } else {
132            gb.len()
133        };
134        let need_tokens = self.t.ranges.is_empty()
135            || self.t.ranges.first().unwrap().r.from > byte_from
136            || self.t.ranges.last().unwrap().r.to < byte_to;
137
138        if need_tokens {
139            self.t.update(self.tree.root_node(), gb, from, n_rows);
140        }
141    }
142
143    #[inline]
144    pub fn iter_tokenized_lines_from(
145        &self,
146        line: usize,
147        gb: &GapBuffer,
148        dot_range: Range,
149        load_exec_range: Option<(bool, Range)>,
150    ) -> LineIter<'_> {
151        self.t
152            .iter_tokenized_lines_from(line, gb, dot_range, load_exec_range)
153    }
154
155    pub fn pretty_print_tree(&self) -> String {
156        let sexp = self.tree.root_node().to_sexp();
157        let mut buf = String::with_capacity(sexp.len()); // better starting point than default
158        let mut has_field = false;
159        let mut indent = 0;
160
161        for s in sexp.split([' ', ')']) {
162            if s.is_empty() {
163                indent -= 1;
164                buf.push(')');
165            } else if s.starts_with('(') {
166                if has_field {
167                    has_field = false;
168                } else {
169                    if indent > 0 {
170                        buf.push('\n');
171                        buf.extend(repeat_n(' ', indent * 2));
172                    }
173                    indent += 1;
174                }
175
176                buf.push_str(s); // "(node_name"
177            } else if s.ends_with(':') {
178                buf.push('\n');
179                buf.extend(repeat_n(' ', indent * 2));
180                buf.push_str(s); // "field:"
181                buf.push(' ');
182                has_field = true;
183                indent += 1;
184            }
185        }
186
187        buf
188    }
189}
190
191// Required for us to be able to pass GapBuffers to the tree-sitter API
192impl<'a> ts::TextProvider<&'a [u8]> for &'a GapBuffer {
193    type I = SliceIter<'a>;
194
195    fn text(&mut self, node: ts::Node<'_>) -> Self::I {
196        let ts::Range {
197            start_byte,
198            end_byte,
199            ..
200        } = node.range();
201        let char_from = self.raw_byte_to_char(self.byte_to_raw_byte(start_byte));
202        let char_to = self.raw_byte_to_char(self.byte_to_raw_byte(end_byte));
203
204        self.slice(char_from, char_to).slice_iter()
205    }
206}
207
208/// A dynamically loaded tree-sitter parser backed by an on disk .so file
209pub struct Parser {
210    lang_name: String,
211    inner: ts::Parser,
212    lang: ts::Language,
213    // Need to prevent drop while the parser is in use
214    // Stored as an Option to allow for crate-based parsers that are not backed by a .so file
215    _lib: Option<Library>,
216}
217
218impl Deref for Parser {
219    type Target = ts::Parser;
220
221    fn deref(&self) -> &Self::Target {
222        &self.inner
223    }
224}
225
226impl DerefMut for Parser {
227    fn deref_mut(&mut self) -> &mut Self::Target {
228        &mut self.inner
229    }
230}
231
232impl fmt::Debug for Parser {
233    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234        write!(f, "Parser({})", self.lang_name)
235    }
236}
237
238impl Parser {
239    /// Error values returned by this function are intended as status messages to be
240    /// presented to the user.
241    pub fn try_new<P: AsRef<Path>>(so_dir: P, lang_name: &str) -> Result<Self, String> {
242        let p = so_dir.as_ref().join(format!("{lang_name}.so"));
243        let lang_fn = format!("tree_sitter_{lang_name}");
244
245        // SAFETY: if the library loads and contains the target symbol we expect the
246        //         given .so file to be a valid tree-sitter parser
247        unsafe {
248            let lib = Library::new(p).map_err(|e| e.to_string())?;
249            let func: Symbol<'_, unsafe extern "C" fn() -> *const TSLanguage> =
250                lib.get(lang_fn.as_bytes()).map_err(|e| e.to_string())?;
251
252            let lang = ts::Language::from_raw(func());
253            if lang.version() < ts::MIN_COMPATIBLE_LANGUAGE_VERSION {
254                return Err(format!(
255                    "incompatible .so tree-sitter parser version: {} < {}",
256                    lang.version(),
257                    ts::MIN_COMPATIBLE_LANGUAGE_VERSION
258                ));
259            }
260
261            let mut inner = ts::Parser::new();
262            inner.set_language(&lang).map_err(|e| e.to_string())?;
263
264            Ok(Self {
265                lang_name: lang_name.to_owned(),
266                inner,
267                lang,
268                _lib: Some(lib),
269            })
270        }
271    }
272
273    /// Construct a new tokenizer directly from a ts::Language provided by a crate
274    #[cfg(test)]
275    fn try_new_from_language(lang_name: &str, lang: ts::Language) -> Result<Self, String> {
276        let mut inner = ts::Parser::new();
277        inner.set_language(&lang).map_err(|e| e.to_string())?;
278
279        Ok(Self {
280            lang_name: lang_name.to_owned(),
281            inner,
282            lang,
283            _lib: None,
284        })
285    }
286
287    pub fn new_tokenizer(&self, query: &str) -> Result<Tokenizer, String> {
288        let q = ts::Query::new(&self.lang, query).map_err(|e| format!("{e:?}"))?;
289        let cur = ts::QueryCursor::new();
290
291        // If a query has been copied from another text editor then there is a chance that
292        // it makes use of custom predicates that we don't know how to handle. The highlights
293        // as a whole won't behave as the user expects in this instance so we error out the
294        // setup of syntax-highlighting as a whole in this case and log an error
295        let mut unsupported_predicates = HashSet::new();
296        for i in 0..q.pattern_count() {
297            for p in q.general_predicates(i) {
298                if !SUPPORTED_PREDICATES.contains(&p.operator.as_ref()) {
299                    unsupported_predicates.insert(p.operator.clone());
300                }
301            }
302        }
303
304        if !unsupported_predicates.is_empty() {
305            error!("Unsupported custom tree-sitter predicates found: {unsupported_predicates:?}");
306            info!("Supported custom tree-sitter predicates: {SUPPORTED_PREDICATES:?}");
307            info!("Please modify the highlights.scm file to remove the unsupported predicates");
308
309            return Err(format!(
310                "{} highlights query contained unsupported custom predicates",
311                self.lang_name
312            ));
313        }
314
315        Ok(Tokenizer {
316            q,
317            cur,
318            ranges: Vec::new(),
319        })
320    }
321}
322
323pub struct Tokenizer {
324    // Tree-sitter state
325    q: ts::Query,
326    cur: ts::QueryCursor,
327    // Cache of computed syntax tokens for passing to LineIter
328    ranges: Vec<SyntaxRange>,
329}
330
331impl fmt::Debug for Tokenizer {
332    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333        write!(f, "Tokenizer")
334    }
335}
336
337impl Tokenizer {
338    pub fn update(&mut self, root: ts::Node<'_>, gb: &GapBuffer, from: usize, n_rows: usize) {
339        self.cur.set_point_range(
340            ts::Point {
341                row: from,
342                column: 0,
343            }..ts::Point {
344                row: from + n_rows + 1,
345                column: 0,
346            },
347        );
348
349        // This is a streaming-iterator not an interator, hence the odd while-let that follows
350        let mut it = self.cur.captures(&self.q, root, gb);
351
352        while let Some((m, idx)) = it.next() {
353            let cap = m.captures[*idx];
354            let r = ByteRange::from(cap.node.range());
355            if let Some(prev) = self.ranges.last_mut() {
356                if r == prev.r {
357                    // prefering the the last capture found so that precedence ordering
358                    // in query files matches Neovim & the treesitter-cli
359                    prev.cap_idx = Some(cap.index as usize);
360                    continue;
361                } else if r.from < prev.r.to && prev.r.from < r.to {
362                    continue;
363                }
364            }
365            self.ranges.push(SyntaxRange {
366                r,
367                cap_idx: Some(cap.index as usize),
368            });
369        }
370
371        self.ranges.sort_unstable();
372        self.ranges.dedup();
373    }
374
375    #[inline]
376    pub fn iter_tokenized_lines_from(
377        &self,
378        line: usize,
379        gb: &GapBuffer,
380        dot_range: Range,
381        load_exec_range: Option<(bool, Range)>,
382    ) -> LineIter<'_> {
383        LineIter::new(
384            line,
385            gb,
386            dot_range,
387            load_exec_range,
388            self.q.capture_names(),
389            &self.ranges,
390        )
391    }
392
393    #[cfg(test)]
394    fn range_tokens(&self) -> Vec<RangeToken<'_>> {
395        let names = self.q.capture_names();
396
397        self.ranges
398            .iter()
399            .map(|sr| RangeToken {
400                tag: sr.cap_idx.map(|i| names[i]).unwrap_or(TK_DEFAULT),
401                r: sr.r,
402            })
403            .collect()
404    }
405}
406
407/// Byte offsets within a Buffer
408#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
409pub(crate) struct ByteRange {
410    pub(crate) from: usize,
411    pub(crate) to: usize,
412}
413
414impl ByteRange {
415    fn from_range(r: Range, gb: &GapBuffer) -> Self {
416        let Range { start, end, .. } = r;
417
418        Self {
419            from: gb.char_to_byte(start.idx),
420            to: gb.char_to_byte(end.idx),
421        }
422    }
423
424    #[inline]
425    fn intersects(&self, start_byte: usize, end_byte: usize) -> bool {
426        self.from <= end_byte && start_byte <= self.to
427    }
428
429    #[inline]
430    fn contains(&self, start_byte: usize, end_byte: usize) -> bool {
431        self.from <= start_byte && self.to >= end_byte
432    }
433
434    /// Convert this [ByteRange] into a [RangeToken] if it intersects with the provided
435    /// start and end point.
436    fn try_as_token<'a>(
437        &self,
438        ty: &'a str,
439        start_byte: usize,
440        end_byte: usize,
441    ) -> Option<RangeToken<'a>> {
442        if self.intersects(start_byte, end_byte) {
443            Some(RangeToken {
444                tag: ty,
445                r: ByteRange {
446                    from: max(self.from, start_byte),
447                    to: min(self.to, end_byte),
448                },
449            })
450        } else {
451            None
452        }
453    }
454}
455
456impl From<ts::Range> for ByteRange {
457    fn from(r: ts::Range) -> Self {
458        Self {
459            from: r.start_byte,
460            to: r.end_byte,
461        }
462    }
463}
464
465/// A tagged [ByteRange] denoting which tree-sitter capture index from our scheme query
466/// matched this range within the buffer. A cap_idx of [None] indicates that this is a
467/// default range for the purposes of syntax highlighting
468#[derive(Debug, Clone, Copy, PartialEq, Eq)]
469pub(crate) struct SyntaxRange {
470    cap_idx: Option<usize>,
471    r: ByteRange,
472}
473
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub struct RangeToken<'a> {
476    pub(crate) tag: &'a str,
477    pub(crate) r: ByteRange,
478}
479
480impl RangeToken<'_> {
481    pub fn tag(&self) -> &str {
482        self.tag
483    }
484
485    pub fn as_slice<'a>(&self, gb: &'a GapBuffer) -> Slice<'a> {
486        gb.slice_from_byte_offsets(self.r.from, self.r.to)
487    }
488
489    #[inline]
490    fn split(self, at: usize) -> (Self, Self) {
491        (
492            RangeToken {
493                tag: self.tag,
494                r: ByteRange {
495                    from: self.r.from,
496                    to: at,
497                },
498            },
499            RangeToken {
500                tag: self.tag,
501                r: ByteRange {
502                    from: at,
503                    to: self.r.to,
504                },
505            },
506        )
507    }
508}
509
510impl PartialOrd for SyntaxRange {
511    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
512        Some(self.cmp(other))
513    }
514}
515
516impl Ord for SyntaxRange {
517    fn cmp(&self, other: &Self) -> Ordering {
518        self.r.cmp(&other.r)
519    }
520}
521
522/// Yield sub-iterators of tokens per-line in a file.
523///
524/// Any given [SyntaxRange] coming from the underlying [Tokenizer] may be
525/// used by multiple [TokenIter]s coming from this iterator if the range
526/// in question spans multiple lines
527#[derive(Debug)]
528pub struct LineIter<'a> {
529    /// capture names to be used as the token types
530    names: &'a [&'a str],
531    /// byte offsets for the position of each newline in the input
532    line_endings: Vec<usize>,
533    /// full set of syntax ranges for the input
534    ranges: &'a [SyntaxRange],
535    start_byte: usize,
536    /// the next line to yeild
537    line: usize,
538    dot_range: ByteRange,
539    load_exec_range: Option<(bool, ByteRange)>,
540}
541
542impl<'a> LineIter<'a> {
543    pub(crate) fn new(
544        line: usize,
545        gb: &GapBuffer,
546        dot_range: Range,
547        load_exec_range: Option<(bool, Range)>,
548        names: &'a [&'a str],
549        ranges: &'a [SyntaxRange],
550    ) -> LineIter<'a> {
551        let line_endings = gb.byte_line_endings();
552        let start_byte = if line == 0 {
553            0
554        } else {
555            line_endings[line - 1] + 1
556        };
557
558        let dot_range = ByteRange::from_range(dot_range, gb);
559        let load_exec_range =
560            load_exec_range.map(|(is_load, r)| (is_load, ByteRange::from_range(r, gb)));
561
562        LineIter {
563            names,
564            line_endings,
565            ranges,
566            start_byte,
567            line,
568            dot_range,
569            load_exec_range,
570        }
571    }
572}
573
574impl<'a> Iterator for LineIter<'a> {
575    type Item = TokenIter<'a>;
576
577    fn next(&mut self) -> Option<Self::Item> {
578        if self.line == self.line_endings.len() {
579            return None;
580        }
581
582        let start_byte = self.start_byte;
583        let end_byte = self.line_endings[self.line];
584
585        self.line += 1;
586        self.start_byte = end_byte + 1;
587
588        // Determine tokens required for the next line
589        let held: Option<RangeToken<'_>>;
590        let ranges: Peekable<slice::Iter<'_, SyntaxRange>>;
591
592        let dot_range = self.dot_range.try_as_token(TK_DOT, start_byte, end_byte);
593        let load_exec_range = self.load_exec_range.and_then(|(is_load, br)| {
594            let ty = if is_load { TK_LOAD } else { TK_EXEC };
595            br.try_as_token(ty, start_byte, end_byte)
596        });
597
598        loop {
599            match self.ranges.first() {
600                // Advance to the next range
601                Some(sr) if sr.r.to < start_byte => {
602                    self.ranges = &self.ranges[1..];
603                }
604
605                // End of known tokens so everything else is just TK_DEFAULT
606                None => {
607                    held = Some(RangeToken {
608                        tag: TK_DEFAULT,
609                        r: ByteRange {
610                            from: start_byte,
611                            to: end_byte,
612                        },
613                    });
614                    ranges = [].iter().peekable();
615                    break;
616                }
617
618                // The next range is beyond this line
619                Some(sr) if sr.r.from >= end_byte => {
620                    held = Some(RangeToken {
621                        tag: TK_DEFAULT,
622                        r: ByteRange {
623                            from: start_byte,
624                            to: end_byte,
625                        },
626                    });
627                    ranges = [].iter().peekable();
628                    break;
629                }
630
631                // The next range fully contains the line
632                Some(sr) if sr.r.contains(start_byte, end_byte) => {
633                    held = Some(RangeToken {
634                        tag: sr.cap_idx.map(|i| self.names[i]).unwrap_or(TK_DEFAULT),
635                        r: ByteRange {
636                            from: start_byte,
637                            to: end_byte,
638                        },
639                    });
640                    ranges = [].iter().peekable();
641                    break;
642                }
643
644                // The next range starts at the beginning of the line or ends within the line
645                Some(sr) => {
646                    assert!(sr.r.from < end_byte);
647                    if sr.r.from > start_byte {
648                        held = Some(RangeToken {
649                            tag: TK_DEFAULT,
650                            r: ByteRange {
651                                from: start_byte,
652                                to: sr.r.from,
653                            },
654                        });
655                    } else {
656                        held = None;
657                    }
658                    ranges = self.ranges.iter().peekable();
659                    break;
660                }
661            }
662        }
663
664        Some(TokenIter {
665            start_byte,
666            end_byte,
667            names: self.names,
668            ranges,
669            held,
670            dot_held: None,
671            dot_range,
672            load_exec_range,
673        })
674    }
675}
676
677type Rt<'a> = RangeToken<'a>;
678
679#[derive(Debug, PartialEq, Eq)]
680enum Held<'a> {
681    One(Rt<'a>),
682    Two(Rt<'a>, Rt<'a>),
683    Three(Rt<'a>, Rt<'a>, Rt<'a>),
684    Four(Rt<'a>, Rt<'a>, Rt<'a>, Rt<'a>),
685    Five(Rt<'a>, Rt<'a>, Rt<'a>, Rt<'a>, Rt<'a>),
686}
687
688impl Held<'_> {
689    fn byte_from_to(&self) -> (usize, usize) {
690        match self {
691            Held::One(a) => (a.r.from, a.r.to),
692            Held::Two(a, b) => (a.r.from, b.r.to),
693            Held::Three(a, _, b) => (a.r.from, b.r.to),
694            Held::Four(a, _, _, b) => (a.r.from, b.r.to),
695            Held::Five(a, _, _, _, b) => (a.r.from, b.r.to),
696        }
697    }
698
699    fn split(self, at: usize) -> (Self, Self) {
700        use Held::*;
701
702        match self {
703            One(a) => {
704                let (l, r) = a.split(at);
705                (One(l), One(r))
706            }
707
708            Two(a, b) => {
709                if at == a.r.to {
710                    (One(a), One(b))
711                } else if a.r.contains(at, at) {
712                    let (l, r) = a.split(at);
713                    (One(l), Two(r, b))
714                } else {
715                    let (l, r) = b.split(at);
716                    (Two(a, l), One(r))
717                }
718            }
719
720            Three(a, b, c) => {
721                if at == a.r.to {
722                    (One(a), Two(b, c))
723                } else if at == b.r.to {
724                    (Two(a, b), One(c))
725                } else if a.r.contains(at, at) {
726                    let (l, r) = a.split(at);
727                    (One(l), Three(r, b, c))
728                } else if b.r.contains(at, at) {
729                    let (l, r) = b.split(at);
730                    (Two(a, l), Two(r, c))
731                } else {
732                    let (l, r) = c.split(at);
733                    (Three(a, b, l), One(r))
734                }
735            }
736
737            Four(_, _, _, _) => unreachable!("only called for 1-3"),
738            Five(_, _, _, _, _) => unreachable!("only called for 1-3"),
739        }
740    }
741
742    fn join(self, other: Self) -> Self {
743        use Held::*;
744
745        match (self, other) {
746            (One(a), One(b)) => Two(a, b),
747            (One(a), Two(b, c)) => Three(a, b, c),
748            (One(a), Three(b, c, d)) => Four(a, b, c, d),
749            (One(a), Four(b, c, d, e)) => Five(a, b, c, d, e),
750
751            (Two(a, b), One(c)) => Three(a, b, c),
752            (Two(a, b), Two(c, d)) => Four(a, b, c, d),
753            (Two(a, b), Three(c, d, e)) => Five(a, b, c, d, e),
754
755            (Three(a, b, c), One(d)) => Four(a, b, c, d),
756            (Three(a, b, c), Two(d, e)) => Five(a, b, c, d, e),
757
758            (Four(a, b, c, d), One(e)) => Five(a, b, c, d, e),
759
760            _ => unreachable!("only have a max of 5 held"),
761        }
762    }
763}
764
765/// An iterator of tokens for a single line.
766///
767/// "default" ranges will be injected in-between the known syntax regions
768/// so a consumer may treat the output of this iterator as a continous,
769/// non-overlapping set of sub-regions spanning a single line within a
770/// given buffer.
771#[derive(Debug)]
772pub struct TokenIter<'a> {
773    /// byte offset for the start of this line
774    start_byte: usize,
775    /// byte offset for the end of this line
776    end_byte: usize,
777    /// Capture names to be used as the token types
778    names: &'a [&'a str],
779    /// The set of ranges applicable to this line
780    ranges: Peekable<slice::Iter<'a, SyntaxRange>>,
781    /// When yielding a dot range we may end up partially consuming
782    /// the following range so we need to stash a Token for yielding
783    /// on the next call to .next()
784    held: Option<RangeToken<'a>>,
785    dot_held: Option<Held<'a>>,
786    dot_range: Option<RangeToken<'a>>,
787    load_exec_range: Option<RangeToken<'a>>,
788}
789
790impl<'a> TokenIter<'a> {
791    fn next_without_selections(&mut self) -> Option<RangeToken<'a>> {
792        let held = self.held.take();
793        if held.is_some() {
794            return held;
795        }
796
797        let next = self.ranges.next()?;
798
799        if next.r.from > self.end_byte {
800            // Next available token is after this line and any 'default' held token will
801            // have been emitted above before we hit this point, so we're done.
802            return None;
803        } else if next.r.to >= self.end_byte {
804            // Last token runs until at least the end of this line so we just need to truncate
805            // to the end of the line and ensure that the following call to .next() returns None.
806            self.ranges = [].iter().peekable();
807
808            return Some(RangeToken {
809                tag: next.cap_idx.map(|i| self.names[i]).unwrap_or(TK_DEFAULT),
810                r: ByteRange {
811                    from: max(next.r.from, self.start_byte),
812                    to: self.end_byte,
813                },
814            });
815        }
816
817        match self.ranges.peek() {
818            Some(sr) if sr.r.from > self.end_byte => {
819                self.ranges = [].iter().peekable();
820
821                self.held = Some(RangeToken {
822                    tag: TK_DEFAULT,
823                    r: ByteRange {
824                        from: next.r.to,
825                        to: self.end_byte,
826                    },
827                });
828            }
829
830            Some(sr) if sr.r.from > next.r.to => {
831                self.held = Some(RangeToken {
832                    tag: TK_DEFAULT,
833                    r: ByteRange {
834                        from: next.r.to,
835                        to: sr.r.from,
836                    },
837                });
838            }
839
840            None if next.r.to < self.end_byte => {
841                self.held = Some(RangeToken {
842                    tag: TK_DEFAULT,
843                    r: ByteRange {
844                        from: next.r.to,
845                        to: self.end_byte,
846                    },
847                });
848            }
849
850            _ => (),
851        }
852
853        Some(RangeToken {
854            tag: next.cap_idx.map(|i| self.names[i]).unwrap_or(TK_DEFAULT),
855            r: ByteRange {
856                from: max(next.r.from, self.start_byte),
857                to: next.r.to,
858            },
859        })
860    }
861
862    fn update_held(&mut self, mut held: Held<'a>, rt: RangeToken<'a>) -> Held<'a> {
863        let (self_from, self_to) = held.byte_from_to();
864        let (from, to) = (rt.r.from, rt.r.to);
865
866        match (from.cmp(&self_from), to.cmp(&self_to)) {
867            (Ordering::Less, _) => unreachable!("only called when rt >= self"),
868
869            (Ordering::Equal, Ordering::Less) => {
870                // hold rt then remaining of held
871                let (_, r) = held.split(to);
872                held = Held::One(rt).join(r);
873            }
874
875            (Ordering::Greater, Ordering::Less) => {
876                // hold held up to rt, rt & held from rt
877                let (l, r) = held.split(from);
878                let (_, r) = r.split(to);
879                held = l.join(Held::One(rt)).join(r);
880            }
881
882            (Ordering::Equal, Ordering::Equal) => {
883                // replace held with rt
884                held = Held::One(rt);
885            }
886
887            (Ordering::Greater, Ordering::Equal) => {
888                // hold held to rt & rt
889                let (l, _) = held.split(from);
890                held = l.join(Held::One(rt));
891            }
892
893            (Ordering::Equal, Ordering::Greater) => {
894                // hold rt, consume to find other held tokens (if any)
895                held = self.find_end_of_selection(Held::One(rt), to);
896            }
897
898            (Ordering::Greater, Ordering::Greater) => {
899                // hold held to rt & rt, consume to find other held tokens (if any)
900                let (l, _) = held.split(from);
901                held = self.find_end_of_selection(l.join(Held::One(rt)), to);
902            }
903        }
904
905        held
906    }
907
908    fn find_end_of_selection(&mut self, mut held: Held<'a>, to: usize) -> Held<'a> {
909        loop {
910            let mut next = match self.next_without_selections() {
911                None => break,
912                Some(next) => next,
913            };
914            if next.r.to <= to {
915                continue; // token is entirely within rt
916            }
917            next.r.from = to;
918            held = held.join(Held::One(next));
919            break;
920        }
921
922        held
923    }
924
925    fn pop(&mut self) -> Option<RangeToken<'a>> {
926        match self.dot_held {
927            None => None,
928            Some(Held::One(a)) => {
929                self.dot_held = None;
930                Some(a)
931            }
932            Some(Held::Two(a, b)) => {
933                self.dot_held = Some(Held::One(b));
934                Some(a)
935            }
936            Some(Held::Three(a, b, c)) => {
937                self.dot_held = Some(Held::Two(b, c));
938                Some(a)
939            }
940            Some(Held::Four(a, b, c, d)) => {
941                self.dot_held = Some(Held::Three(b, c, d));
942                Some(a)
943            }
944            Some(Held::Five(a, b, c, d, e)) => {
945                self.dot_held = Some(Held::Four(b, c, d, e));
946                Some(a)
947            }
948        }
949    }
950}
951
952impl<'a> Iterator for TokenIter<'a> {
953    type Item = RangeToken<'a>;
954
955    fn next(&mut self) -> Option<Self::Item> {
956        // Emit pre-computed held tokens first
957        let next = self.pop();
958        if next.is_some() {
959            return next;
960        }
961
962        // Determine the next token we would emit in the absense of any user selections and then
963        // apply the selections in priority order:
964        //   - dot overwrites original syntax highlighting
965        //   - load/exec overwrite dot
966        #[inline]
967        fn intersects(opt: &Option<RangeToken<'_>>, from: usize, to: usize) -> bool {
968            opt.as_ref()
969                .map(|rt| rt.r.intersects(from, to))
970                .unwrap_or(false)
971        }
972
973        let next = self.next_without_selections()?;
974        let (from, to) = (next.r.from, next.r.to);
975        let mut held = Held::One(next);
976
977        if intersects(&self.dot_range, from, to) {
978            let r = self.dot_range.take().unwrap();
979            held = self.update_held(held, r);
980        }
981
982        let (from, to) = held.byte_from_to();
983        if intersects(&self.load_exec_range, from, to) {
984            let r = self.load_exec_range.take().unwrap();
985            held = self.update_held(held, r);
986        }
987
988        if let Held::One(rt) = held {
989            Some(rt) // held_dot is None so just return the token directly
990        } else {
991            self.dot_held = Some(held);
992            self.pop()
993        }
994    }
995}
996
997#[cfg(test)]
998mod tests {
999    use super::*;
1000    use crate::{
1001        buffer::Buffer,
1002        dot::{Cur, Dot},
1003        editor::Action,
1004    };
1005    use ad_event::Source;
1006    use simple_test_case::test_case;
1007
1008    fn sr(from: usize, to: usize) -> SyntaxRange {
1009        SyntaxRange {
1010            cap_idx: Some(0),
1011            r: ByteRange { from, to },
1012        }
1013    }
1014
1015    fn rt_def(from: usize, to: usize) -> RangeToken<'static> {
1016        RangeToken {
1017            tag: TK_DEFAULT,
1018            r: ByteRange { from, to },
1019        }
1020    }
1021
1022    fn rt_dot(from: usize, to: usize) -> RangeToken<'static> {
1023        RangeToken {
1024            tag: TK_DOT,
1025            r: ByteRange { from, to },
1026        }
1027    }
1028
1029    fn rt_exe(from: usize, to: usize) -> RangeToken<'static> {
1030        RangeToken {
1031            tag: TK_EXEC,
1032            r: ByteRange { from, to },
1033        }
1034    }
1035
1036    fn rt_str(from: usize, to: usize) -> RangeToken<'static> {
1037        RangeToken {
1038            tag: "string",
1039            r: ByteRange { from, to },
1040        }
1041    }
1042
1043    // range at start of single token
1044    #[test_case(
1045        Held::One(rt_str(0, 5)),
1046        None,
1047        rt_dot(0, 5),
1048        &[sr(10, 15)],
1049        Held::One(rt_dot(0, 5));
1050        "held one range matches held"
1051    )]
1052    #[test_case(
1053        Held::One(rt_str(0, 5)),
1054        None,
1055        rt_dot(0, 3),
1056        &[sr(10, 15)],
1057        Held::Two(rt_dot(0, 3), rt_str(3, 5));
1058        "held one range start to within held"
1059    )]
1060    #[test_case(
1061        Held::One(rt_str(0, 5)),
1062        Some(rt_def(5, 10)),
1063        rt_dot(0, 7),
1064        &[sr(10, 15), sr(20, 30)],
1065        Held::Two(rt_dot(0, 7), rt_def(7, 10));
1066        "held one range start to past held but before next token"
1067    )]
1068    #[test_case(
1069        Held::One(rt_str(0, 5)),
1070        Some(rt_def(5, 10)),
1071        rt_dot(0, 13),
1072        &[sr(10, 15), sr(20, 30)],
1073        Held::Two(rt_dot(0, 13), rt_str(13, 15));
1074        "held one range start to into next token"
1075    )]
1076    #[test_case(
1077        Held::One(rt_str(0, 5)),
1078        Some(rt_def(5, 10)),
1079        rt_dot(0, 16),
1080        &[sr(10, 15), sr(20, 30)],
1081        Held::Two(rt_dot(0, 16), rt_def(16, 20));
1082        "held one range start to past next token"
1083    )]
1084    // range within single token
1085    #[test_case(
1086        Held::One(rt_str(0, 5)),
1087        None,
1088        rt_dot(3, 5),
1089        &[sr(10, 15)],
1090        Held::Two(rt_str(0, 3), rt_dot(3, 5));
1091        "held one range from within to end of held"
1092    )]
1093    #[test_case(
1094        Held::One(rt_str(0, 5)),
1095        None,
1096        rt_dot(2, 4),
1097        &[sr(10, 15)],
1098        Held::Three(rt_str(0, 2), rt_dot(2, 4), rt_str(4, 5));
1099        "held one range with to within held"
1100    )]
1101    #[test_case(
1102        Held::One(rt_str(0, 5)),
1103        Some(rt_def(5, 10)),
1104        rt_dot(3, 7),
1105        &[sr(10, 15), sr(20, 30)],
1106        Held::Three(rt_str(0, 3), rt_dot(3, 7), rt_def(7, 10));
1107        "held one range within to past held but before next token"
1108    )]
1109    #[test_case(
1110        Held::One(rt_str(0, 5)),
1111        Some(rt_def(5, 10)),
1112        rt_dot(3, 13),
1113        &[sr(10, 15), sr(20, 30)],
1114        Held::Three(rt_str(0, 3), rt_dot(3, 13), rt_str(13, 15));
1115        "held one range within to into next token"
1116    )]
1117    #[test_case(
1118        Held::One(rt_str(0, 5)),
1119        Some(rt_def(5, 10)),
1120        rt_dot(3, 16),
1121        &[sr(10, 15), sr(20, 30)],
1122        Held::Three(rt_str(0, 3), rt_dot(3, 16), rt_def(16, 20));
1123        "held one range within to past next token"
1124    )]
1125    // held 2 tokens
1126    #[test_case(
1127        Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1128        None,
1129        rt_exe(0, 5),
1130        &[sr(10, 15)],
1131        Held::One(rt_exe(0, 5));
1132        "held two range matches all held"
1133    )]
1134    #[test_case(
1135        Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1136        None,
1137        rt_exe(2, 5),
1138        &[sr(10, 15)],
1139        Held::Two(rt_str(0, 2), rt_exe(2, 5));
1140        "held two range from within first to end of held"
1141    )]
1142    #[test_case(
1143        Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1144        None,
1145        rt_exe(4, 5),
1146        &[sr(10, 15)],
1147        Held::Three(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 5));
1148        "held two range from within second to end of held"
1149    )]
1150    #[test_case(
1151        Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1152        Some(rt_def(5, 10)),
1153        rt_exe(4, 8),
1154        &[sr(10, 15)],
1155        Held::Four(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 8), rt_def(8, 10));
1156        "held two range from within second past end of held"
1157    )]
1158    // held 3 tokens
1159    #[test_case(
1160        Held::Three(rt_str(0, 3), rt_dot(3, 5), rt_str(5, 8)),
1161        None,
1162        rt_exe(0, 8),
1163        &[sr(10, 15)],
1164        Held::One(rt_exe(0, 8));
1165        "held three range matches all held"
1166    )]
1167    #[test_case(
1168        Held::Three(rt_str(0, 3), rt_dot(3, 5), rt_str(5, 8)),
1169        None,
1170        rt_exe(2, 8),
1171        &[sr(10, 15)],
1172        Held::Two(rt_str(0, 2), rt_exe(2, 8));
1173        "held three range from within first to end of held"
1174    )]
1175    #[test_case(
1176        Held::Three(rt_str(0, 3), rt_dot(3, 5), rt_str(5, 8)),
1177        None,
1178        rt_exe(4, 8),
1179        &[sr(10, 15)],
1180        Held::Three(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 8));
1181        "held three range from within second to end of held"
1182    )]
1183    #[test_case(
1184        Held::Three(rt_str(0, 3), rt_dot(3, 6), rt_str(6, 9)),
1185        None,
1186        rt_exe(4, 5),
1187        &[sr(10, 15)],
1188        Held::Five(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 5), rt_dot(5, 6), rt_str(6, 9));
1189        "held three range from within second"
1190    )]
1191    #[test]
1192    fn update_held(
1193        initial: Held<'static>,
1194        held: Option<RangeToken<'static>>,
1195        r: RangeToken<'static>,
1196        ranges: &[SyntaxRange],
1197        expected: Held<'static>,
1198    ) {
1199        let mut it = TokenIter {
1200            start_byte: 0,
1201            end_byte: 42,
1202            names: &["string"],
1203            ranges: ranges.iter().peekable(),
1204            held,
1205            dot_held: None,
1206            dot_range: None,
1207            load_exec_range: None,
1208        };
1209
1210        let held = it.update_held(initial, r);
1211
1212        assert_eq!(held, expected);
1213    }
1214
1215    fn rt(tag: &str, from: usize, to: usize) -> RangeToken<'_> {
1216        RangeToken {
1217            tag,
1218            r: ByteRange { from, to },
1219        }
1220    }
1221
1222    #[test]
1223    fn char_delete_correctly_update_state() {
1224        // minimal query for the fn keyword and parens
1225        let query = r#"
1226"fn" @keyword
1227
1228[ "(" ")" "{" "}" ] @punctuation"#;
1229
1230        let s = "fn main() {}";
1231        let mut b = Buffer::new_unnamed(0, s);
1232        let gb = &b.txt;
1233        b.ts_state = Some(
1234            TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb)
1235                .unwrap(),
1236        );
1237
1238        assert_eq!(b.str_contents(), "fn main() {}\n");
1239        assert_eq!(
1240            b.ts_state.as_ref().unwrap().t.range_tokens(),
1241            vec![
1242                rt("keyword", 0, 2),       // fn
1243                rt("punctuation", 7, 8),   // (
1244                rt("punctuation", 8, 9),   // )
1245                rt("punctuation", 10, 11), // {
1246                rt("punctuation", 11, 12), // }
1247            ]
1248        );
1249
1250        b.dot = Dot::Cur { c: Cur { idx: 9 } };
1251        b.handle_action(Action::Delete, Source::Fsys);
1252        b.ts_state
1253            .as_mut()
1254            .unwrap()
1255            .update(&b.txt, 0, usize::MAX - 1);
1256        let ranges = b.ts_state.as_ref().unwrap().t.range_tokens();
1257
1258        assert_eq!(b.str_contents(), "fn main(){}\n");
1259        assert_eq!(ranges.len(), 5);
1260
1261        // these two should have moved left one character
1262        assert_eq!(ranges[3], rt("punctuation", 9, 10), "opening curly");
1263        assert_eq!(ranges[4], rt("punctuation", 10, 11), "closing curly");
1264    }
1265
1266    #[test]
1267    fn overlapping_tokens_prefer_previous_matches() {
1268        // Minimal query extracted from the full query in gh#88 that resulted in
1269        // overlapping tokens being produced
1270        let query = r#"
1271(identifier) @variable
1272
1273(import_statement
1274  name: (dotted_name
1275    (identifier) @module))
1276
1277(import_statement
1278  name: (aliased_import
1279    name: (dotted_name
1280      (identifier) @module)
1281    alias: (identifier) @module))
1282
1283(import_from_statement
1284  module_name: (dotted_name
1285    (identifier) @module))"#;
1286
1287        let s = "import builtins as _builtins";
1288        let b = Buffer::new_unnamed(0, s);
1289        let gb = &b.txt;
1290        let ts = TsState::try_new_from_language(
1291            "python",
1292            tree_sitter_python::LANGUAGE.into(),
1293            query,
1294            gb,
1295        )
1296        .unwrap();
1297
1298        assert_eq!(
1299            ts.t.range_tokens(),
1300            vec![
1301                rt("module", 7, 15),  // builtins
1302                rt("module", 19, 28)  // _builtins
1303            ]
1304        );
1305    }
1306
1307    #[test]
1308    fn built_in_predicates_work() {
1309        let query = r#"
1310(identifier) @variable
1311
1312; Assume all-caps names are constants
1313((identifier) @constant
1314  (#match? @constant "^[A-Z][A-Z%d_]*$"))
1315
1316((identifier) @constant.builtin
1317  (#any-of? @constant.builtin "Some" "None" "Ok" "Err"))
1318
1319[ "(" ")" "{" "}" ] @punctuation"#;
1320
1321        let s = "Ok(Some(42)) foo BAR";
1322        let b = Buffer::new_unnamed(0, s);
1323        let gb = &b.txt;
1324        let ts =
1325            TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb)
1326                .unwrap();
1327
1328        assert_eq!(
1329            ts.t.range_tokens(),
1330            vec![
1331                rt("constant.builtin", 0, 2), // Ok
1332                rt("punctuation", 2, 3),      // (
1333                rt("constant.builtin", 3, 7), // Some
1334                rt("punctuation", 7, 8),      // (
1335                rt("punctuation", 10, 11),    // )
1336                rt("punctuation", 11, 12),    // )
1337                rt("variable", 13, 16),       // foo
1338                rt("constant", 17, 20),       // BAR
1339            ]
1340        );
1341    }
1342}