1use crate::{
23 buffer::{GapBuffer, Slice, SliceIter},
24 dot::Range,
25};
26use libloading::{Library, Symbol};
27use std::{
28 cmp::{max, min, Ord, Ordering, PartialOrd},
29 collections::HashSet,
30 fmt, fs,
31 iter::{repeat_n, Peekable},
32 ops::{Deref, DerefMut},
33 path::Path,
34 slice,
35};
36use streaming_iterator::StreamingIterator;
37use tracing::{error, info};
38use tree_sitter::{self as ts, ffi::TSLanguage};
39
40pub const TK_DEFAULT: &str = "default";
41pub const TK_DOT: &str = "dot";
42pub const TK_LOAD: &str = "load";
43pub const TK_EXEC: &str = "exec";
44pub const SUPPORTED_PREDICATES: [&str; 0] = [];
45
46#[derive(Debug)]
48pub struct TsState {
49 tree: ts::Tree,
50 p: Parser,
51 t: Tokenizer,
52}
53
54impl TsState {
55 pub fn try_new(
56 lang: &str,
57 so_dir: &str,
58 query_dir: &str,
59 gb: &GapBuffer,
60 ) -> Result<Self, String> {
61 let query_path = Path::new(query_dir).join(lang).join("highlights.scm");
62 let query = match fs::read_to_string(query_path) {
63 Ok(s) => s,
64 Err(e) => return Err(format!("unable to read tree-sitter query file: {e}")),
65 };
66
67 let p = Parser::try_new(so_dir, lang)?;
68
69 Self::try_new_explicit(p, &query, gb)
70 }
71
72 #[cfg(test)]
73 pub(crate) fn try_new_from_language(
74 lang_name: &str,
75 lang: ts::Language,
76 query: &str,
77 gb: &GapBuffer,
78 ) -> Result<Self, String> {
79 let p = Parser::try_new_from_language(lang_name, lang)?;
80
81 Self::try_new_explicit(p, query, gb)
82 }
83
84 fn try_new_explicit(mut p: Parser, query: &str, gb: &GapBuffer) -> Result<Self, String> {
85 let tree = p.parse_with(
86 &mut |byte_offset, _| gb.maximal_slice_from_offset(byte_offset),
87 None,
88 );
89
90 match tree {
91 Some(tree) => {
92 let mut t = p.new_tokenizer(query)?;
93 t.update(tree.root_node(), gb, 0, usize::MAX - 1);
94 info!("TS loaded for {}", p.lang_name);
95
96 Ok(Self { p, t, tree })
97 }
98 None => Err("failed to parse file".to_owned()),
99 }
100 }
101
102 pub fn edit(&mut self, ch_start: usize, ch_old_end: usize, ch_new_end: usize, gb: &GapBuffer) {
103 self.tree.edit(&ts::InputEdit {
104 start_byte: gb.char_to_byte(ch_start),
105 old_end_byte: gb.char_to_byte(ch_old_end),
106 new_end_byte: gb.char_to_byte(ch_new_end),
107 start_position: ts::Point::new(0, 0),
109 old_end_position: ts::Point::new(0, 0),
110 new_end_position: ts::Point::new(0, 0),
111 });
112
113 let new_tree = self.p.parse_with(
114 &mut |byte_offset, _| gb.maximal_slice_from_offset(byte_offset),
115 Some(&self.tree),
116 );
117
118 if let Some(tree) = new_tree {
119 self.tree = tree;
122 }
123
124 self.t.ranges.clear();
125 }
126
127 pub fn update(&mut self, gb: &GapBuffer, from: usize, n_rows: usize) {
128 let byte_from = gb.char_to_byte(gb.line_to_char(from));
129 let byte_to = if from + n_rows + 1 < gb.len_lines() {
130 gb.char_to_byte(gb.line_to_char(from + n_rows + 1))
131 } else {
132 gb.len()
133 };
134 let need_tokens = self.t.ranges.is_empty()
135 || self.t.ranges.first().unwrap().r.from > byte_from
136 || self.t.ranges.last().unwrap().r.to < byte_to;
137
138 if need_tokens {
139 self.t.update(self.tree.root_node(), gb, from, n_rows);
140 }
141 }
142
143 #[inline]
144 pub fn iter_tokenized_lines_from(
145 &self,
146 line: usize,
147 gb: &GapBuffer,
148 dot_range: Range,
149 load_exec_range: Option<(bool, Range)>,
150 ) -> LineIter<'_> {
151 self.t
152 .iter_tokenized_lines_from(line, gb, dot_range, load_exec_range)
153 }
154
155 pub fn pretty_print_tree(&self) -> String {
156 let sexp = self.tree.root_node().to_sexp();
157 let mut buf = String::with_capacity(sexp.len()); let mut has_field = false;
159 let mut indent = 0;
160
161 for s in sexp.split([' ', ')']) {
162 if s.is_empty() {
163 indent -= 1;
164 buf.push(')');
165 } else if s.starts_with('(') {
166 if has_field {
167 has_field = false;
168 } else {
169 if indent > 0 {
170 buf.push('\n');
171 buf.extend(repeat_n(' ', indent * 2));
172 }
173 indent += 1;
174 }
175
176 buf.push_str(s); } else if s.ends_with(':') {
178 buf.push('\n');
179 buf.extend(repeat_n(' ', indent * 2));
180 buf.push_str(s); buf.push(' ');
182 has_field = true;
183 indent += 1;
184 }
185 }
186
187 buf
188 }
189}
190
191impl<'a> ts::TextProvider<&'a [u8]> for &'a GapBuffer {
193 type I = SliceIter<'a>;
194
195 fn text(&mut self, node: ts::Node<'_>) -> Self::I {
196 let ts::Range {
197 start_byte,
198 end_byte,
199 ..
200 } = node.range();
201 let char_from = self.raw_byte_to_char(self.byte_to_raw_byte(start_byte));
202 let char_to = self.raw_byte_to_char(self.byte_to_raw_byte(end_byte));
203
204 self.slice(char_from, char_to).slice_iter()
205 }
206}
207
208pub struct Parser {
210 lang_name: String,
211 inner: ts::Parser,
212 lang: ts::Language,
213 _lib: Option<Library>,
216}
217
218impl Deref for Parser {
219 type Target = ts::Parser;
220
221 fn deref(&self) -> &Self::Target {
222 &self.inner
223 }
224}
225
226impl DerefMut for Parser {
227 fn deref_mut(&mut self) -> &mut Self::Target {
228 &mut self.inner
229 }
230}
231
232impl fmt::Debug for Parser {
233 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
234 write!(f, "Parser({})", self.lang_name)
235 }
236}
237
238impl Parser {
239 pub fn try_new<P: AsRef<Path>>(so_dir: P, lang_name: &str) -> Result<Self, String> {
242 let p = so_dir.as_ref().join(format!("{lang_name}.so"));
243 let lang_fn = format!("tree_sitter_{lang_name}");
244
245 unsafe {
248 let lib = Library::new(p).map_err(|e| e.to_string())?;
249 let func: Symbol<'_, unsafe extern "C" fn() -> *const TSLanguage> =
250 lib.get(lang_fn.as_bytes()).map_err(|e| e.to_string())?;
251
252 let lang = ts::Language::from_raw(func());
253 if lang.version() < ts::MIN_COMPATIBLE_LANGUAGE_VERSION {
254 return Err(format!(
255 "incompatible .so tree-sitter parser version: {} < {}",
256 lang.version(),
257 ts::MIN_COMPATIBLE_LANGUAGE_VERSION
258 ));
259 }
260
261 let mut inner = ts::Parser::new();
262 inner.set_language(&lang).map_err(|e| e.to_string())?;
263
264 Ok(Self {
265 lang_name: lang_name.to_owned(),
266 inner,
267 lang,
268 _lib: Some(lib),
269 })
270 }
271 }
272
273 #[cfg(test)]
275 fn try_new_from_language(lang_name: &str, lang: ts::Language) -> Result<Self, String> {
276 let mut inner = ts::Parser::new();
277 inner.set_language(&lang).map_err(|e| e.to_string())?;
278
279 Ok(Self {
280 lang_name: lang_name.to_owned(),
281 inner,
282 lang,
283 _lib: None,
284 })
285 }
286
287 pub fn new_tokenizer(&self, query: &str) -> Result<Tokenizer, String> {
288 let q = ts::Query::new(&self.lang, query).map_err(|e| format!("{e:?}"))?;
289 let cur = ts::QueryCursor::new();
290
291 let mut unsupported_predicates = HashSet::new();
296 for i in 0..q.pattern_count() {
297 for p in q.general_predicates(i) {
298 if !SUPPORTED_PREDICATES.contains(&p.operator.as_ref()) {
299 unsupported_predicates.insert(p.operator.clone());
300 }
301 }
302 }
303
304 if !unsupported_predicates.is_empty() {
305 error!("Unsupported custom tree-sitter predicates found: {unsupported_predicates:?}");
306 info!("Supported custom tree-sitter predicates: {SUPPORTED_PREDICATES:?}");
307 info!("Please modify the highlights.scm file to remove the unsupported predicates");
308
309 return Err(format!(
310 "{} highlights query contained unsupported custom predicates",
311 self.lang_name
312 ));
313 }
314
315 Ok(Tokenizer {
316 q,
317 cur,
318 ranges: Vec::new(),
319 })
320 }
321}
322
323pub struct Tokenizer {
324 q: ts::Query,
326 cur: ts::QueryCursor,
327 ranges: Vec<SyntaxRange>,
329}
330
331impl fmt::Debug for Tokenizer {
332 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
333 write!(f, "Tokenizer")
334 }
335}
336
337impl Tokenizer {
338 pub fn update(&mut self, root: ts::Node<'_>, gb: &GapBuffer, from: usize, n_rows: usize) {
339 self.cur.set_point_range(
340 ts::Point {
341 row: from,
342 column: 0,
343 }..ts::Point {
344 row: from + n_rows + 1,
345 column: 0,
346 },
347 );
348
349 let mut it = self.cur.captures(&self.q, root, gb);
351
352 while let Some((m, idx)) = it.next() {
353 let cap = m.captures[*idx];
354 let r = ByteRange::from(cap.node.range());
355 if let Some(prev) = self.ranges.last_mut() {
356 if r == prev.r {
357 prev.cap_idx = Some(cap.index as usize);
360 continue;
361 } else if r.from < prev.r.to && prev.r.from < r.to {
362 continue;
363 }
364 }
365 self.ranges.push(SyntaxRange {
366 r,
367 cap_idx: Some(cap.index as usize),
368 });
369 }
370
371 self.ranges.sort_unstable();
372 self.ranges.dedup();
373 }
374
375 #[inline]
376 pub fn iter_tokenized_lines_from(
377 &self,
378 line: usize,
379 gb: &GapBuffer,
380 dot_range: Range,
381 load_exec_range: Option<(bool, Range)>,
382 ) -> LineIter<'_> {
383 LineIter::new(
384 line,
385 gb,
386 dot_range,
387 load_exec_range,
388 self.q.capture_names(),
389 &self.ranges,
390 )
391 }
392
393 #[cfg(test)]
394 fn range_tokens(&self) -> Vec<RangeToken<'_>> {
395 let names = self.q.capture_names();
396
397 self.ranges
398 .iter()
399 .map(|sr| RangeToken {
400 tag: sr.cap_idx.map(|i| names[i]).unwrap_or(TK_DEFAULT),
401 r: sr.r,
402 })
403 .collect()
404 }
405}
406
407#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
409pub(crate) struct ByteRange {
410 pub(crate) from: usize,
411 pub(crate) to: usize,
412}
413
414impl ByteRange {
415 fn from_range(r: Range, gb: &GapBuffer) -> Self {
416 let Range { start, end, .. } = r;
417
418 Self {
419 from: gb.char_to_byte(start.idx),
420 to: gb.char_to_byte(end.idx),
421 }
422 }
423
424 #[inline]
425 fn intersects(&self, start_byte: usize, end_byte: usize) -> bool {
426 self.from <= end_byte && start_byte <= self.to
427 }
428
429 #[inline]
430 fn contains(&self, start_byte: usize, end_byte: usize) -> bool {
431 self.from <= start_byte && self.to >= end_byte
432 }
433
434 fn try_as_token<'a>(
437 &self,
438 ty: &'a str,
439 start_byte: usize,
440 end_byte: usize,
441 ) -> Option<RangeToken<'a>> {
442 if self.intersects(start_byte, end_byte) {
443 Some(RangeToken {
444 tag: ty,
445 r: ByteRange {
446 from: max(self.from, start_byte),
447 to: min(self.to, end_byte),
448 },
449 })
450 } else {
451 None
452 }
453 }
454}
455
456impl From<ts::Range> for ByteRange {
457 fn from(r: ts::Range) -> Self {
458 Self {
459 from: r.start_byte,
460 to: r.end_byte,
461 }
462 }
463}
464
465#[derive(Debug, Clone, Copy, PartialEq, Eq)]
469pub(crate) struct SyntaxRange {
470 cap_idx: Option<usize>,
471 r: ByteRange,
472}
473
474#[derive(Debug, Clone, Copy, PartialEq, Eq)]
475pub struct RangeToken<'a> {
476 pub(crate) tag: &'a str,
477 pub(crate) r: ByteRange,
478}
479
480impl RangeToken<'_> {
481 pub fn tag(&self) -> &str {
482 self.tag
483 }
484
485 pub fn as_slice<'a>(&self, gb: &'a GapBuffer) -> Slice<'a> {
486 gb.slice_from_byte_offsets(self.r.from, self.r.to)
487 }
488
489 #[inline]
490 fn split(self, at: usize) -> (Self, Self) {
491 (
492 RangeToken {
493 tag: self.tag,
494 r: ByteRange {
495 from: self.r.from,
496 to: at,
497 },
498 },
499 RangeToken {
500 tag: self.tag,
501 r: ByteRange {
502 from: at,
503 to: self.r.to,
504 },
505 },
506 )
507 }
508}
509
510impl PartialOrd for SyntaxRange {
511 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
512 Some(self.cmp(other))
513 }
514}
515
516impl Ord for SyntaxRange {
517 fn cmp(&self, other: &Self) -> Ordering {
518 self.r.cmp(&other.r)
519 }
520}
521
522#[derive(Debug)]
528pub struct LineIter<'a> {
529 names: &'a [&'a str],
531 line_endings: Vec<usize>,
533 ranges: &'a [SyntaxRange],
535 start_byte: usize,
536 line: usize,
538 dot_range: ByteRange,
539 load_exec_range: Option<(bool, ByteRange)>,
540}
541
542impl<'a> LineIter<'a> {
543 pub(crate) fn new(
544 line: usize,
545 gb: &GapBuffer,
546 dot_range: Range,
547 load_exec_range: Option<(bool, Range)>,
548 names: &'a [&'a str],
549 ranges: &'a [SyntaxRange],
550 ) -> LineIter<'a> {
551 let line_endings = gb.byte_line_endings();
552 let start_byte = if line == 0 {
553 0
554 } else {
555 line_endings[line - 1] + 1
556 };
557
558 let dot_range = ByteRange::from_range(dot_range, gb);
559 let load_exec_range =
560 load_exec_range.map(|(is_load, r)| (is_load, ByteRange::from_range(r, gb)));
561
562 LineIter {
563 names,
564 line_endings,
565 ranges,
566 start_byte,
567 line,
568 dot_range,
569 load_exec_range,
570 }
571 }
572}
573
574impl<'a> Iterator for LineIter<'a> {
575 type Item = TokenIter<'a>;
576
577 fn next(&mut self) -> Option<Self::Item> {
578 if self.line == self.line_endings.len() {
579 return None;
580 }
581
582 let start_byte = self.start_byte;
583 let end_byte = self.line_endings[self.line];
584
585 self.line += 1;
586 self.start_byte = end_byte + 1;
587
588 let held: Option<RangeToken<'_>>;
590 let ranges: Peekable<slice::Iter<'_, SyntaxRange>>;
591
592 let dot_range = self.dot_range.try_as_token(TK_DOT, start_byte, end_byte);
593 let load_exec_range = self.load_exec_range.and_then(|(is_load, br)| {
594 let ty = if is_load { TK_LOAD } else { TK_EXEC };
595 br.try_as_token(ty, start_byte, end_byte)
596 });
597
598 loop {
599 match self.ranges.first() {
600 Some(sr) if sr.r.to < start_byte => {
602 self.ranges = &self.ranges[1..];
603 }
604
605 None => {
607 held = Some(RangeToken {
608 tag: TK_DEFAULT,
609 r: ByteRange {
610 from: start_byte,
611 to: end_byte,
612 },
613 });
614 ranges = [].iter().peekable();
615 break;
616 }
617
618 Some(sr) if sr.r.from >= end_byte => {
620 held = Some(RangeToken {
621 tag: TK_DEFAULT,
622 r: ByteRange {
623 from: start_byte,
624 to: end_byte,
625 },
626 });
627 ranges = [].iter().peekable();
628 break;
629 }
630
631 Some(sr) if sr.r.contains(start_byte, end_byte) => {
633 held = Some(RangeToken {
634 tag: sr.cap_idx.map(|i| self.names[i]).unwrap_or(TK_DEFAULT),
635 r: ByteRange {
636 from: start_byte,
637 to: end_byte,
638 },
639 });
640 ranges = [].iter().peekable();
641 break;
642 }
643
644 Some(sr) => {
646 assert!(sr.r.from < end_byte);
647 if sr.r.from > start_byte {
648 held = Some(RangeToken {
649 tag: TK_DEFAULT,
650 r: ByteRange {
651 from: start_byte,
652 to: sr.r.from,
653 },
654 });
655 } else {
656 held = None;
657 }
658 ranges = self.ranges.iter().peekable();
659 break;
660 }
661 }
662 }
663
664 Some(TokenIter {
665 start_byte,
666 end_byte,
667 names: self.names,
668 ranges,
669 held,
670 dot_held: None,
671 dot_range,
672 load_exec_range,
673 })
674 }
675}
676
677type Rt<'a> = RangeToken<'a>;
678
679#[derive(Debug, PartialEq, Eq)]
680enum Held<'a> {
681 One(Rt<'a>),
682 Two(Rt<'a>, Rt<'a>),
683 Three(Rt<'a>, Rt<'a>, Rt<'a>),
684 Four(Rt<'a>, Rt<'a>, Rt<'a>, Rt<'a>),
685 Five(Rt<'a>, Rt<'a>, Rt<'a>, Rt<'a>, Rt<'a>),
686}
687
688impl Held<'_> {
689 fn byte_from_to(&self) -> (usize, usize) {
690 match self {
691 Held::One(a) => (a.r.from, a.r.to),
692 Held::Two(a, b) => (a.r.from, b.r.to),
693 Held::Three(a, _, b) => (a.r.from, b.r.to),
694 Held::Four(a, _, _, b) => (a.r.from, b.r.to),
695 Held::Five(a, _, _, _, b) => (a.r.from, b.r.to),
696 }
697 }
698
699 fn split(self, at: usize) -> (Self, Self) {
700 use Held::*;
701
702 match self {
703 One(a) => {
704 let (l, r) = a.split(at);
705 (One(l), One(r))
706 }
707
708 Two(a, b) => {
709 if at == a.r.to {
710 (One(a), One(b))
711 } else if a.r.contains(at, at) {
712 let (l, r) = a.split(at);
713 (One(l), Two(r, b))
714 } else {
715 let (l, r) = b.split(at);
716 (Two(a, l), One(r))
717 }
718 }
719
720 Three(a, b, c) => {
721 if at == a.r.to {
722 (One(a), Two(b, c))
723 } else if at == b.r.to {
724 (Two(a, b), One(c))
725 } else if a.r.contains(at, at) {
726 let (l, r) = a.split(at);
727 (One(l), Three(r, b, c))
728 } else if b.r.contains(at, at) {
729 let (l, r) = b.split(at);
730 (Two(a, l), Two(r, c))
731 } else {
732 let (l, r) = c.split(at);
733 (Three(a, b, l), One(r))
734 }
735 }
736
737 Four(_, _, _, _) => unreachable!("only called for 1-3"),
738 Five(_, _, _, _, _) => unreachable!("only called for 1-3"),
739 }
740 }
741
742 fn join(self, other: Self) -> Self {
743 use Held::*;
744
745 match (self, other) {
746 (One(a), One(b)) => Two(a, b),
747 (One(a), Two(b, c)) => Three(a, b, c),
748 (One(a), Three(b, c, d)) => Four(a, b, c, d),
749 (One(a), Four(b, c, d, e)) => Five(a, b, c, d, e),
750
751 (Two(a, b), One(c)) => Three(a, b, c),
752 (Two(a, b), Two(c, d)) => Four(a, b, c, d),
753 (Two(a, b), Three(c, d, e)) => Five(a, b, c, d, e),
754
755 (Three(a, b, c), One(d)) => Four(a, b, c, d),
756 (Three(a, b, c), Two(d, e)) => Five(a, b, c, d, e),
757
758 (Four(a, b, c, d), One(e)) => Five(a, b, c, d, e),
759
760 _ => unreachable!("only have a max of 5 held"),
761 }
762 }
763}
764
765#[derive(Debug)]
772pub struct TokenIter<'a> {
773 start_byte: usize,
775 end_byte: usize,
777 names: &'a [&'a str],
779 ranges: Peekable<slice::Iter<'a, SyntaxRange>>,
781 held: Option<RangeToken<'a>>,
785 dot_held: Option<Held<'a>>,
786 dot_range: Option<RangeToken<'a>>,
787 load_exec_range: Option<RangeToken<'a>>,
788}
789
790impl<'a> TokenIter<'a> {
791 fn next_without_selections(&mut self) -> Option<RangeToken<'a>> {
792 let held = self.held.take();
793 if held.is_some() {
794 return held;
795 }
796
797 let next = self.ranges.next()?;
798
799 if next.r.from > self.end_byte {
800 return None;
803 } else if next.r.to >= self.end_byte {
804 self.ranges = [].iter().peekable();
807
808 return Some(RangeToken {
809 tag: next.cap_idx.map(|i| self.names[i]).unwrap_or(TK_DEFAULT),
810 r: ByteRange {
811 from: max(next.r.from, self.start_byte),
812 to: self.end_byte,
813 },
814 });
815 }
816
817 match self.ranges.peek() {
818 Some(sr) if sr.r.from > self.end_byte => {
819 self.ranges = [].iter().peekable();
820
821 self.held = Some(RangeToken {
822 tag: TK_DEFAULT,
823 r: ByteRange {
824 from: next.r.to,
825 to: self.end_byte,
826 },
827 });
828 }
829
830 Some(sr) if sr.r.from > next.r.to => {
831 self.held = Some(RangeToken {
832 tag: TK_DEFAULT,
833 r: ByteRange {
834 from: next.r.to,
835 to: sr.r.from,
836 },
837 });
838 }
839
840 None if next.r.to < self.end_byte => {
841 self.held = Some(RangeToken {
842 tag: TK_DEFAULT,
843 r: ByteRange {
844 from: next.r.to,
845 to: self.end_byte,
846 },
847 });
848 }
849
850 _ => (),
851 }
852
853 Some(RangeToken {
854 tag: next.cap_idx.map(|i| self.names[i]).unwrap_or(TK_DEFAULT),
855 r: ByteRange {
856 from: max(next.r.from, self.start_byte),
857 to: next.r.to,
858 },
859 })
860 }
861
862 fn update_held(&mut self, mut held: Held<'a>, rt: RangeToken<'a>) -> Held<'a> {
863 let (self_from, self_to) = held.byte_from_to();
864 let (from, to) = (rt.r.from, rt.r.to);
865
866 match (from.cmp(&self_from), to.cmp(&self_to)) {
867 (Ordering::Less, _) => unreachable!("only called when rt >= self"),
868
869 (Ordering::Equal, Ordering::Less) => {
870 let (_, r) = held.split(to);
872 held = Held::One(rt).join(r);
873 }
874
875 (Ordering::Greater, Ordering::Less) => {
876 let (l, r) = held.split(from);
878 let (_, r) = r.split(to);
879 held = l.join(Held::One(rt)).join(r);
880 }
881
882 (Ordering::Equal, Ordering::Equal) => {
883 held = Held::One(rt);
885 }
886
887 (Ordering::Greater, Ordering::Equal) => {
888 let (l, _) = held.split(from);
890 held = l.join(Held::One(rt));
891 }
892
893 (Ordering::Equal, Ordering::Greater) => {
894 held = self.find_end_of_selection(Held::One(rt), to);
896 }
897
898 (Ordering::Greater, Ordering::Greater) => {
899 let (l, _) = held.split(from);
901 held = self.find_end_of_selection(l.join(Held::One(rt)), to);
902 }
903 }
904
905 held
906 }
907
908 fn find_end_of_selection(&mut self, mut held: Held<'a>, to: usize) -> Held<'a> {
909 loop {
910 let mut next = match self.next_without_selections() {
911 None => break,
912 Some(next) => next,
913 };
914 if next.r.to <= to {
915 continue; }
917 next.r.from = to;
918 held = held.join(Held::One(next));
919 break;
920 }
921
922 held
923 }
924
925 fn pop(&mut self) -> Option<RangeToken<'a>> {
926 match self.dot_held {
927 None => None,
928 Some(Held::One(a)) => {
929 self.dot_held = None;
930 Some(a)
931 }
932 Some(Held::Two(a, b)) => {
933 self.dot_held = Some(Held::One(b));
934 Some(a)
935 }
936 Some(Held::Three(a, b, c)) => {
937 self.dot_held = Some(Held::Two(b, c));
938 Some(a)
939 }
940 Some(Held::Four(a, b, c, d)) => {
941 self.dot_held = Some(Held::Three(b, c, d));
942 Some(a)
943 }
944 Some(Held::Five(a, b, c, d, e)) => {
945 self.dot_held = Some(Held::Four(b, c, d, e));
946 Some(a)
947 }
948 }
949 }
950}
951
952impl<'a> Iterator for TokenIter<'a> {
953 type Item = RangeToken<'a>;
954
955 fn next(&mut self) -> Option<Self::Item> {
956 let next = self.pop();
958 if next.is_some() {
959 return next;
960 }
961
962 #[inline]
967 fn intersects(opt: &Option<RangeToken<'_>>, from: usize, to: usize) -> bool {
968 opt.as_ref()
969 .map(|rt| rt.r.intersects(from, to))
970 .unwrap_or(false)
971 }
972
973 let next = self.next_without_selections()?;
974 let (from, to) = (next.r.from, next.r.to);
975 let mut held = Held::One(next);
976
977 if intersects(&self.dot_range, from, to) {
978 let r = self.dot_range.take().unwrap();
979 held = self.update_held(held, r);
980 }
981
982 let (from, to) = held.byte_from_to();
983 if intersects(&self.load_exec_range, from, to) {
984 let r = self.load_exec_range.take().unwrap();
985 held = self.update_held(held, r);
986 }
987
988 if let Held::One(rt) = held {
989 Some(rt) } else {
991 self.dot_held = Some(held);
992 self.pop()
993 }
994 }
995}
996
997#[cfg(test)]
998mod tests {
999 use super::*;
1000 use crate::{
1001 buffer::Buffer,
1002 dot::{Cur, Dot},
1003 editor::Action,
1004 };
1005 use ad_event::Source;
1006 use simple_test_case::test_case;
1007
1008 fn sr(from: usize, to: usize) -> SyntaxRange {
1009 SyntaxRange {
1010 cap_idx: Some(0),
1011 r: ByteRange { from, to },
1012 }
1013 }
1014
1015 fn rt_def(from: usize, to: usize) -> RangeToken<'static> {
1016 RangeToken {
1017 tag: TK_DEFAULT,
1018 r: ByteRange { from, to },
1019 }
1020 }
1021
1022 fn rt_dot(from: usize, to: usize) -> RangeToken<'static> {
1023 RangeToken {
1024 tag: TK_DOT,
1025 r: ByteRange { from, to },
1026 }
1027 }
1028
1029 fn rt_exe(from: usize, to: usize) -> RangeToken<'static> {
1030 RangeToken {
1031 tag: TK_EXEC,
1032 r: ByteRange { from, to },
1033 }
1034 }
1035
1036 fn rt_str(from: usize, to: usize) -> RangeToken<'static> {
1037 RangeToken {
1038 tag: "string",
1039 r: ByteRange { from, to },
1040 }
1041 }
1042
1043 #[test_case(
1045 Held::One(rt_str(0, 5)),
1046 None,
1047 rt_dot(0, 5),
1048 &[sr(10, 15)],
1049 Held::One(rt_dot(0, 5));
1050 "held one range matches held"
1051 )]
1052 #[test_case(
1053 Held::One(rt_str(0, 5)),
1054 None,
1055 rt_dot(0, 3),
1056 &[sr(10, 15)],
1057 Held::Two(rt_dot(0, 3), rt_str(3, 5));
1058 "held one range start to within held"
1059 )]
1060 #[test_case(
1061 Held::One(rt_str(0, 5)),
1062 Some(rt_def(5, 10)),
1063 rt_dot(0, 7),
1064 &[sr(10, 15), sr(20, 30)],
1065 Held::Two(rt_dot(0, 7), rt_def(7, 10));
1066 "held one range start to past held but before next token"
1067 )]
1068 #[test_case(
1069 Held::One(rt_str(0, 5)),
1070 Some(rt_def(5, 10)),
1071 rt_dot(0, 13),
1072 &[sr(10, 15), sr(20, 30)],
1073 Held::Two(rt_dot(0, 13), rt_str(13, 15));
1074 "held one range start to into next token"
1075 )]
1076 #[test_case(
1077 Held::One(rt_str(0, 5)),
1078 Some(rt_def(5, 10)),
1079 rt_dot(0, 16),
1080 &[sr(10, 15), sr(20, 30)],
1081 Held::Two(rt_dot(0, 16), rt_def(16, 20));
1082 "held one range start to past next token"
1083 )]
1084 #[test_case(
1086 Held::One(rt_str(0, 5)),
1087 None,
1088 rt_dot(3, 5),
1089 &[sr(10, 15)],
1090 Held::Two(rt_str(0, 3), rt_dot(3, 5));
1091 "held one range from within to end of held"
1092 )]
1093 #[test_case(
1094 Held::One(rt_str(0, 5)),
1095 None,
1096 rt_dot(2, 4),
1097 &[sr(10, 15)],
1098 Held::Three(rt_str(0, 2), rt_dot(2, 4), rt_str(4, 5));
1099 "held one range with to within held"
1100 )]
1101 #[test_case(
1102 Held::One(rt_str(0, 5)),
1103 Some(rt_def(5, 10)),
1104 rt_dot(3, 7),
1105 &[sr(10, 15), sr(20, 30)],
1106 Held::Three(rt_str(0, 3), rt_dot(3, 7), rt_def(7, 10));
1107 "held one range within to past held but before next token"
1108 )]
1109 #[test_case(
1110 Held::One(rt_str(0, 5)),
1111 Some(rt_def(5, 10)),
1112 rt_dot(3, 13),
1113 &[sr(10, 15), sr(20, 30)],
1114 Held::Three(rt_str(0, 3), rt_dot(3, 13), rt_str(13, 15));
1115 "held one range within to into next token"
1116 )]
1117 #[test_case(
1118 Held::One(rt_str(0, 5)),
1119 Some(rt_def(5, 10)),
1120 rt_dot(3, 16),
1121 &[sr(10, 15), sr(20, 30)],
1122 Held::Three(rt_str(0, 3), rt_dot(3, 16), rt_def(16, 20));
1123 "held one range within to past next token"
1124 )]
1125 #[test_case(
1127 Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1128 None,
1129 rt_exe(0, 5),
1130 &[sr(10, 15)],
1131 Held::One(rt_exe(0, 5));
1132 "held two range matches all held"
1133 )]
1134 #[test_case(
1135 Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1136 None,
1137 rt_exe(2, 5),
1138 &[sr(10, 15)],
1139 Held::Two(rt_str(0, 2), rt_exe(2, 5));
1140 "held two range from within first to end of held"
1141 )]
1142 #[test_case(
1143 Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1144 None,
1145 rt_exe(4, 5),
1146 &[sr(10, 15)],
1147 Held::Three(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 5));
1148 "held two range from within second to end of held"
1149 )]
1150 #[test_case(
1151 Held::Two(rt_str(0, 3), rt_dot(3, 5)),
1152 Some(rt_def(5, 10)),
1153 rt_exe(4, 8),
1154 &[sr(10, 15)],
1155 Held::Four(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 8), rt_def(8, 10));
1156 "held two range from within second past end of held"
1157 )]
1158 #[test_case(
1160 Held::Three(rt_str(0, 3), rt_dot(3, 5), rt_str(5, 8)),
1161 None,
1162 rt_exe(0, 8),
1163 &[sr(10, 15)],
1164 Held::One(rt_exe(0, 8));
1165 "held three range matches all held"
1166 )]
1167 #[test_case(
1168 Held::Three(rt_str(0, 3), rt_dot(3, 5), rt_str(5, 8)),
1169 None,
1170 rt_exe(2, 8),
1171 &[sr(10, 15)],
1172 Held::Two(rt_str(0, 2), rt_exe(2, 8));
1173 "held three range from within first to end of held"
1174 )]
1175 #[test_case(
1176 Held::Three(rt_str(0, 3), rt_dot(3, 5), rt_str(5, 8)),
1177 None,
1178 rt_exe(4, 8),
1179 &[sr(10, 15)],
1180 Held::Three(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 8));
1181 "held three range from within second to end of held"
1182 )]
1183 #[test_case(
1184 Held::Three(rt_str(0, 3), rt_dot(3, 6), rt_str(6, 9)),
1185 None,
1186 rt_exe(4, 5),
1187 &[sr(10, 15)],
1188 Held::Five(rt_str(0, 3), rt_dot(3, 4), rt_exe(4, 5), rt_dot(5, 6), rt_str(6, 9));
1189 "held three range from within second"
1190 )]
1191 #[test]
1192 fn update_held(
1193 initial: Held<'static>,
1194 held: Option<RangeToken<'static>>,
1195 r: RangeToken<'static>,
1196 ranges: &[SyntaxRange],
1197 expected: Held<'static>,
1198 ) {
1199 let mut it = TokenIter {
1200 start_byte: 0,
1201 end_byte: 42,
1202 names: &["string"],
1203 ranges: ranges.iter().peekable(),
1204 held,
1205 dot_held: None,
1206 dot_range: None,
1207 load_exec_range: None,
1208 };
1209
1210 let held = it.update_held(initial, r);
1211
1212 assert_eq!(held, expected);
1213 }
1214
1215 fn rt(tag: &str, from: usize, to: usize) -> RangeToken<'_> {
1216 RangeToken {
1217 tag,
1218 r: ByteRange { from, to },
1219 }
1220 }
1221
1222 #[test]
1223 fn char_delete_correctly_update_state() {
1224 let query = r#"
1226"fn" @keyword
1227
1228[ "(" ")" "{" "}" ] @punctuation"#;
1229
1230 let s = "fn main() {}";
1231 let mut b = Buffer::new_unnamed(0, s);
1232 let gb = &b.txt;
1233 b.ts_state = Some(
1234 TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb)
1235 .unwrap(),
1236 );
1237
1238 assert_eq!(b.str_contents(), "fn main() {}\n");
1239 assert_eq!(
1240 b.ts_state.as_ref().unwrap().t.range_tokens(),
1241 vec![
1242 rt("keyword", 0, 2), rt("punctuation", 7, 8), rt("punctuation", 8, 9), rt("punctuation", 10, 11), rt("punctuation", 11, 12), ]
1248 );
1249
1250 b.dot = Dot::Cur { c: Cur { idx: 9 } };
1251 b.handle_action(Action::Delete, Source::Fsys);
1252 b.ts_state
1253 .as_mut()
1254 .unwrap()
1255 .update(&b.txt, 0, usize::MAX - 1);
1256 let ranges = b.ts_state.as_ref().unwrap().t.range_tokens();
1257
1258 assert_eq!(b.str_contents(), "fn main(){}\n");
1259 assert_eq!(ranges.len(), 5);
1260
1261 assert_eq!(ranges[3], rt("punctuation", 9, 10), "opening curly");
1263 assert_eq!(ranges[4], rt("punctuation", 10, 11), "closing curly");
1264 }
1265
1266 #[test]
1267 fn overlapping_tokens_prefer_previous_matches() {
1268 let query = r#"
1271(identifier) @variable
1272
1273(import_statement
1274 name: (dotted_name
1275 (identifier) @module))
1276
1277(import_statement
1278 name: (aliased_import
1279 name: (dotted_name
1280 (identifier) @module)
1281 alias: (identifier) @module))
1282
1283(import_from_statement
1284 module_name: (dotted_name
1285 (identifier) @module))"#;
1286
1287 let s = "import builtins as _builtins";
1288 let b = Buffer::new_unnamed(0, s);
1289 let gb = &b.txt;
1290 let ts = TsState::try_new_from_language(
1291 "python",
1292 tree_sitter_python::LANGUAGE.into(),
1293 query,
1294 gb,
1295 )
1296 .unwrap();
1297
1298 assert_eq!(
1299 ts.t.range_tokens(),
1300 vec![
1301 rt("module", 7, 15), rt("module", 19, 28) ]
1304 );
1305 }
1306
1307 #[test]
1308 fn built_in_predicates_work() {
1309 let query = r#"
1310(identifier) @variable
1311
1312; Assume all-caps names are constants
1313((identifier) @constant
1314 (#match? @constant "^[A-Z][A-Z%d_]*$"))
1315
1316((identifier) @constant.builtin
1317 (#any-of? @constant.builtin "Some" "None" "Ok" "Err"))
1318
1319[ "(" ")" "{" "}" ] @punctuation"#;
1320
1321 let s = "Ok(Some(42)) foo BAR";
1322 let b = Buffer::new_unnamed(0, s);
1323 let gb = &b.txt;
1324 let ts =
1325 TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb)
1326 .unwrap();
1327
1328 assert_eq!(
1329 ts.t.range_tokens(),
1330 vec![
1331 rt("constant.builtin", 0, 2), rt("punctuation", 2, 3), rt("constant.builtin", 3, 7), rt("punctuation", 7, 8), rt("punctuation", 10, 11), rt("punctuation", 11, 12), rt("variable", 13, 16), rt("constant", 17, 20), ]
1340 );
1341 }
1342}