1use crate::{
6 buffer::{GapBuffer, SliceIter},
7 dot::Range,
8 syntax::{ByteRange, LineIter, SyntaxRange},
9};
10use libloading::{Library, Symbol};
11use std::{
12 cmp::{max, min},
13 collections::HashSet,
14 fmt, fs,
15 iter::repeat_n,
16 ops::{Deref, DerefMut},
17 path::Path,
18};
19use tracing::{error, info};
20use tree_sitter::{self as ts, StreamingIterator, ffi::TSLanguage};
21
22pub const SUPPORTED_PREDICATES: [&str; 0] = [];
23
24impl From<ts::Range> for ByteRange {
25 fn from(r: ts::Range) -> Self {
26 Self {
27 from: r.start_byte,
28 to: r.end_byte,
29 }
30 }
31}
32
33#[derive(Debug)]
35pub struct TsState {
36 tree: ts::Tree,
37 p: Parser,
38 t: Tokenizer,
39}
40
41impl TsState {
42 pub fn try_new(
43 lang: &str,
44 so_dir: &str,
45 query_dir: &str,
46 gb: &GapBuffer,
47 ) -> Result<Self, String> {
48 let query_path = Path::new(query_dir).join(lang).join("highlights.scm");
49 let query = match fs::read_to_string(query_path) {
50 Ok(s) => s,
51 Err(e) => return Err(format!("unable to read tree-sitter query file: {e}")),
52 };
53
54 let p = Parser::try_new(so_dir, lang)?;
55
56 Self::try_new_explicit(p, &query, gb)
57 }
58
59 #[cfg(test)]
60 pub(crate) fn try_new_from_language(
61 lang_name: &str,
62 lang: ts::Language,
63 query: &str,
64 gb: &GapBuffer,
65 ) -> Result<Self, String> {
66 let p = Parser::try_new_from_language(lang_name, lang)?;
67
68 Self::try_new_explicit(p, query, gb)
69 }
70
71 fn try_new_explicit(mut p: Parser, query: &str, gb: &GapBuffer) -> Result<Self, String> {
72 let tree = p.parse_with_options(
73 &mut |byte_offset, _| gb.maximal_slice_from_offset(byte_offset),
74 None,
75 None,
76 );
77
78 match tree {
79 Some(tree) => {
80 let t = p.new_tokenizer(query)?;
81 info!("TS loaded for {}", p.lang_name);
82
83 Ok(Self { p, t, tree })
84 }
85 None => Err("failed to parse file".to_owned()),
86 }
87 }
88
89 pub(super) fn apply_prepared_edit(
96 &mut self,
97 start_byte: usize,
98 old_end_byte: usize,
99 new_end_byte: usize,
100 gb: &GapBuffer,
101 ) {
102 self.tree.edit(&ts::InputEdit {
103 start_byte,
104 old_end_byte,
105 new_end_byte,
106 start_position: ts::Point::new(0, 0),
108 old_end_position: ts::Point::new(0, 0),
109 new_end_position: ts::Point::new(0, 0),
110 });
111
112 let new_tree = self.p.parse_with_options(
113 &mut |byte_offset, _| gb.maximal_slice_from_offset(byte_offset),
114 Some(&self.tree),
115 None,
116 );
117
118 if let Some(tree) = new_tree {
119 self.tree = tree;
122 }
123
124 self.t.clear();
125 }
126
127 pub(super) fn prepare_insert_char(
128 &self,
129 ch_idx: usize,
130 ch: char,
131 gb: &GapBuffer,
132 ) -> (usize, usize, usize) {
133 let start_byte = gb.char_to_byte(ch_idx);
134
135 (start_byte, start_byte, start_byte + ch.len_utf8())
136 }
137
138 pub(super) fn prepare_insert_string(
139 &self,
140 ch_idx: usize,
141 s: &str,
142 gb: &GapBuffer,
143 ) -> (usize, usize, usize) {
144 let start_byte = gb.char_to_byte(ch_idx);
145
146 (start_byte, start_byte, start_byte + s.len())
147 }
148
149 pub(super) fn prepare_delete_char(
150 &self,
151 ch_idx: usize,
152 gb: &GapBuffer,
153 ) -> (usize, usize, usize) {
154 let (start_byte, old_end_byte) = gb.char_range_to_byte_range(ch_idx, ch_idx + 1);
155
156 (start_byte, old_end_byte, start_byte)
157 }
158
159 pub(super) fn prepare_delete_range(
160 &self,
161 ch_from: usize,
162 ch_to: usize,
163 gb: &GapBuffer,
164 ) -> (usize, usize, usize) {
165 let (start_byte, old_end_byte) = gb.char_range_to_byte_range(ch_from, ch_to);
166
167 (start_byte, old_end_byte, start_byte)
168 }
169
170 pub fn update(&mut self, gb: &GapBuffer, from_row: usize, n_rows: usize) {
171 let raw_from = gb.line_to_byte(from_row);
172 let raw_to = if from_row + n_rows + 1 < gb.len_lines() {
173 gb.line_to_byte(from_row + n_rows + 1)
174 } else {
175 gb.len()
176 };
177
178 if let Some((a, b)) = self.t.missing_region(raw_from, raw_to) {
179 const PADDING: usize = 512;
183 let byte_from = if b < raw_to {
184 a.saturating_sub(PADDING)
185 } else {
186 a
187 };
188 let byte_to = if a > raw_from {
189 min(b + PADDING, gb.len())
190 } else {
191 b
192 };
193
194 self.t.update(self.tree.root_node(), gb, byte_from, byte_to);
195 }
196 }
197
198 #[inline]
199 pub fn iter_tokenized_lines_from<'a>(
200 &'a self,
201 line: usize,
202 gb: &'a GapBuffer,
203 dot_range: Range,
204 load_exec_range: Option<(bool, Range)>,
205 ) -> LineIter<'a> {
206 self.t
207 .iter_tokenized_lines_from(line, gb, dot_range, load_exec_range)
208 }
209
210 pub fn pretty_print_tree(&self) -> String {
211 let sexp = self.tree.root_node().to_sexp();
212 let mut buf = String::with_capacity(sexp.len()); let mut has_field = false;
214 let mut indent = 0;
215
216 for s in sexp.split([' ', ')']) {
217 if s.is_empty() {
218 indent -= 1;
219 buf.push(')');
220 } else if s.starts_with('(') {
221 if has_field {
222 has_field = false;
223 } else {
224 if indent > 0 {
225 buf.push('\n');
226 buf.extend(repeat_n(' ', indent * 2));
227 }
228 indent += 1;
229 }
230
231 buf.push_str(s); } else if s.ends_with(':') {
233 buf.push('\n');
234 buf.extend(repeat_n(' ', indent * 2));
235 buf.push_str(s); buf.push(' ');
237 has_field = true;
238 indent += 1;
239 }
240 }
241
242 buf
243 }
244}
245
246impl<'a> ts::TextProvider<&'a [u8]> for &'a GapBuffer {
248 type I = SliceIter<'a>;
249
250 fn text(&mut self, node: ts::Node<'_>) -> Self::I {
251 let ts::Range {
252 start_byte,
253 end_byte,
254 ..
255 } = node.range();
256
257 self.slice_from_byte_offsets(start_byte, end_byte)
258 .slice_iter()
259 }
260}
261
262pub struct Parser {
264 lang_name: String,
265 inner: ts::Parser,
266 lang: ts::Language,
267 _lib: Option<Library>,
270}
271
272impl Deref for Parser {
273 type Target = ts::Parser;
274
275 fn deref(&self) -> &Self::Target {
276 &self.inner
277 }
278}
279
280impl DerefMut for Parser {
281 fn deref_mut(&mut self) -> &mut Self::Target {
282 &mut self.inner
283 }
284}
285
286impl fmt::Debug for Parser {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 write!(f, "Parser({})", self.lang_name)
289 }
290}
291
292impl Parser {
293 pub fn try_new<P: AsRef<Path>>(so_dir: P, lang_name: &str) -> Result<Self, String> {
296 let p = so_dir.as_ref().join(format!("{lang_name}.so"));
297 let lang_fn = format!("tree_sitter_{lang_name}");
298
299 unsafe {
302 let lib = Library::new(p).map_err(|e| e.to_string())?;
303 let func: Symbol<'_, unsafe extern "C" fn() -> *const TSLanguage> =
304 lib.get(lang_fn.as_bytes()).map_err(|e| e.to_string())?;
305
306 let lang = ts::Language::from_raw(func());
307 if lang.abi_version() < ts::MIN_COMPATIBLE_LANGUAGE_VERSION {
308 return Err(format!(
309 "incompatible .so tree-sitter parser version: {} < {}",
310 lang.abi_version(),
311 ts::MIN_COMPATIBLE_LANGUAGE_VERSION
312 ));
313 }
314
315 let mut inner = ts::Parser::new();
316 inner.set_language(&lang).map_err(|e| e.to_string())?;
317
318 Ok(Self {
319 lang_name: lang_name.to_owned(),
320 inner,
321 lang,
322 _lib: Some(lib),
323 })
324 }
325 }
326
327 #[cfg(test)]
329 fn try_new_from_language(lang_name: &str, lang: ts::Language) -> Result<Self, String> {
330 let mut inner = ts::Parser::new();
331 inner.set_language(&lang).map_err(|e| e.to_string())?;
332
333 Ok(Self {
334 lang_name: lang_name.to_owned(),
335 inner,
336 lang,
337 _lib: None,
338 })
339 }
340
341 pub fn new_tokenizer(&self, query: &str) -> Result<Tokenizer, String> {
342 let q = ts::Query::new(&self.lang, query).map_err(|e| format!("{e:?}"))?;
343 let cur = ts::QueryCursor::new();
344
345 let mut unsupported_predicates = HashSet::new();
350 for i in 0..q.pattern_count() {
351 for p in q.general_predicates(i) {
352 if !SUPPORTED_PREDICATES.contains(&p.operator.as_ref()) {
353 unsupported_predicates.insert(p.operator.clone());
354 }
355 }
356 }
357
358 if !unsupported_predicates.is_empty() {
359 error!("Unsupported custom tree-sitter predicates found: {unsupported_predicates:?}");
360 info!("Supported custom tree-sitter predicates: {SUPPORTED_PREDICATES:?}");
361 info!("Please modify the highlights.scm file to remove the unsupported predicates");
362
363 return Err(format!(
364 "{} highlights query contained unsupported custom predicates",
365 self.lang_name
366 ));
367 }
368
369 let names = q.capture_names().iter().map(|s| s.to_string()).collect();
370
371 Ok(Tokenizer {
372 q,
373 cur,
374 names,
375 ranges: Vec::new(),
376 tokenized_regions: Vec::new(),
377 })
378 }
379}
380
381pub struct Tokenizer {
382 q: ts::Query,
384 cur: ts::QueryCursor,
385 names: Vec<String>,
386 ranges: Vec<SyntaxRange>,
388 tokenized_regions: Vec<ByteRange>,
390}
391
392impl fmt::Debug for Tokenizer {
393 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
394 write!(f, "Tokenizer")
395 }
396}
397
398#[inline]
399fn mark_region(regions: &mut Vec<ByteRange>, from: usize, to: usize) {
400 regions.push(ByteRange { from, to });
401 if regions.len() == 1 {
402 return;
403 }
404
405 regions.sort_unstable();
406
407 let mut idx = 0;
408 for i in 1..regions.len() {
409 if regions[idx].to >= regions[i].from {
410 regions[idx].to = max(regions[idx].to, regions[i].to);
412 } else {
413 idx += 1;
415 regions.swap(idx, i);
416 }
417 }
418
419 regions.truncate(idx + 1);
422}
423
424#[inline]
427fn missing_region(regions: &[ByteRange], from: usize, to: usize) -> Option<(usize, usize)> {
428 let mut it = regions.iter();
429 while let Some(r) = it.next() {
430 if to < r.from {
431 break;
433 } else if from < r.from {
434 let end = if r.to > to { r.from } else { to };
436 return Some((from, end));
437 } else if r.contains(from, to) {
438 return None;
440 } else if from < r.to && to > r.to {
441 let end = match it.next() {
443 Some(r) if r.from < to => r.from,
444 _ => to,
445 };
446 return Some((r.to, end));
447 }
448 }
449
450 Some((from, to))
451}
452
453impl Tokenizer {
454 fn clear(&mut self) {
455 self.ranges.clear();
456 self.tokenized_regions.clear();
457 }
458
459 fn missing_region(&self, from: usize, to: usize) -> Option<(usize, usize)> {
460 missing_region(&self.tokenized_regions, from, to)
461 }
462
463 fn mark_region(&mut self, from: usize, to: usize) {
464 mark_region(&mut self.tokenized_regions, from, to);
465 }
466
467 pub fn update(&mut self, root: ts::Node<'_>, gb: &GapBuffer, from: usize, to: usize) {
468 self.cur.set_byte_range(from..to);
469
470 let mut it = self.cur.captures(&self.q, root, gb);
472
473 while let Some((m, idx)) = it.next() {
474 let cap = m.captures[*idx];
475 let r = ByteRange::from(cap.node.range());
476 if let Some(prev) = self.ranges.last_mut() {
477 if r == prev.r {
478 prev.cap_idx = Some(cap.index as usize);
481 continue;
482 } else if r.from < prev.r.to && prev.r.from < r.to {
483 continue;
484 }
485 }
486 self.ranges.push(SyntaxRange {
487 r,
488 cap_idx: Some(cap.index as usize),
489 });
490 }
491
492 self.ranges.sort_unstable();
493 self.ranges.dedup();
494 self.mark_region(from, to);
495 }
496
497 #[inline]
498 pub fn iter_tokenized_lines_from<'a>(
499 &'a self,
500 line: usize,
501 gb: &'a GapBuffer,
502 dot_range: Range,
503 load_exec_range: Option<(bool, Range)>,
504 ) -> LineIter<'a> {
505 LineIter::new(
506 line,
507 gb,
508 dot_range,
509 load_exec_range,
510 &self.names,
511 &self.ranges,
512 )
513 }
514
515 #[cfg(test)]
516 fn range_tokens(&self) -> Vec<crate::syntax::RangeToken<'_>> {
517 use crate::syntax::{RangeToken, TK_DEFAULT};
518
519 let names = self.q.capture_names();
520
521 self.ranges
522 .iter()
523 .map(|sr| RangeToken {
524 tag: sr.cap_idx.map(|i| names[i]).unwrap_or(TK_DEFAULT),
525 r: sr.r,
526 })
527 .collect()
528 }
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534 use crate::{
535 buffer::Buffer,
536 dot::{Cur, Dot},
537 editor::Action,
538 syntax::{RangeToken, SyntaxState, SyntaxStateInner},
539 };
540 use ad_event::Source;
541 use simple_test_case::test_case;
542
543 fn rt(tag: &str, from: usize, to: usize) -> RangeToken<'_> {
544 RangeToken {
545 tag,
546 r: ByteRange { from, to },
547 }
548 }
549
550 #[test]
551 fn char_delete_correctly_update_state() {
552 let query = r#"
554"fn" @keyword
555
556[ "(" ")" "{" "}" ] @punctuation"#;
557
558 let s = "fn main() {}";
559 let mut b = Buffer::new_unnamed(0, s, Default::default());
560 let gb = &b.txt;
561 let mut ts =
562 TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb)
563 .unwrap();
564 ts.update(gb, 0, gb.len());
565 b.syntax_state = Some(SyntaxState::ts(ts));
566
567 assert_eq!(b.str_contents(), "fn main() {}");
568
569 let ranges = match b.syntax_state.as_ref() {
570 Some(SyntaxState {
571 inner: SyntaxStateInner::Ts(ts),
572 ..
573 }) => ts.t.range_tokens(),
574 _ => panic!("no ts state"),
575 };
576 assert_eq!(
577 ranges,
578 vec![
579 rt("keyword", 0, 2), rt("punctuation", 7, 8), rt("punctuation", 8, 9), rt("punctuation", 10, 11), rt("punctuation", 11, 12), ]
585 );
586
587 b.dot = Dot::Cur { c: Cur { idx: 9 } };
588 b.handle_action(Action::Delete, Source::Fsys);
589 b.syntax_state
590 .as_mut()
591 .unwrap()
592 .update(&b.txt, 0, usize::MAX - 1);
593 let ranges = match b.syntax_state.as_ref() {
594 Some(SyntaxState {
595 inner: SyntaxStateInner::Ts(ts),
596 ..
597 }) => ts.t.range_tokens(),
598 _ => panic!("no ts state"),
599 };
600
601 assert_eq!(b.str_contents(), "fn main(){}");
602 assert_eq!(ranges.len(), 5);
603
604 assert_eq!(ranges[3], rt("punctuation", 9, 10), "opening curly");
606 assert_eq!(ranges[4], rt("punctuation", 10, 11), "closing curly");
607 }
608
609 #[test]
610 fn overlapping_tokens_prefer_previous_matches() {
611 let query = r#"
614(identifier) @variable
615
616(import_statement
617 name: (dotted_name
618 (identifier) @module))
619
620(import_statement
621 name: (aliased_import
622 name: (dotted_name
623 (identifier) @module)
624 alias: (identifier) @module))
625
626(import_from_statement
627 module_name: (dotted_name
628 (identifier) @module))"#;
629
630 let s = "import builtins as _builtins";
631 let b = Buffer::new_unnamed(0, s, Default::default());
632 let gb = &b.txt;
633 let mut ts = TsState::try_new_from_language(
634 "python",
635 tree_sitter_python::LANGUAGE.into(),
636 query,
637 gb,
638 )
639 .unwrap();
640 ts.update(gb, 0, gb.len());
641
642 assert_eq!(
643 ts.t.range_tokens(),
644 vec![
645 rt("module", 7, 15), rt("module", 19, 28) ]
648 );
649 }
650
651 #[test]
652 fn built_in_predicates_work() {
653 let query = r#"
654(identifier) @variable
655
656; Assume all-caps names are constants
657((identifier) @constant
658 (#match? @constant "^[A-Z][A-Z%d_]*$"))
659
660((identifier) @constant.builtin
661 (#any-of? @constant.builtin "Some" "None" "Ok" "Err"))
662
663[ "(" ")" "{" "}" ] @punctuation"#;
664
665 let s = "Ok(Some(42)) foo BAR";
666 let b = Buffer::new_unnamed(0, s, Default::default());
667 let gb = &b.txt;
668 let mut ts =
669 TsState::try_new_from_language("rust", tree_sitter_rust::LANGUAGE.into(), query, gb)
670 .unwrap();
671 ts.update(gb, 0, gb.len());
672
673 assert_eq!(
674 ts.t.range_tokens(),
675 vec![
676 rt("constant.builtin", 0, 2), rt("punctuation", 2, 3), rt("constant.builtin", 3, 7), rt("punctuation", 7, 8), rt("punctuation", 10, 11), rt("punctuation", 11, 12), rt("variable", 13, 16), rt("constant", 17, 20), ]
685 );
686 }
687
688 fn br(from: usize, to: usize) -> ByteRange {
689 ByteRange { from, to }
690 }
691
692 #[test_case(vec![], 0, 5, vec![br(0, 5)]; "no initial regions")]
693 #[test_case(vec![br(0, 5)], 0, 5, vec![br(0, 5)]; "existing region idempotent")]
694 #[test_case(vec![br(9, 15)], 0, 5, vec![br(0, 5), br(9, 15)]; "disjoint regions")]
695 #[test_case(vec![br(0, 5)], 3, 5, vec![br(0, 5)]; "existing region contains new")]
696 #[test_case(vec![br(0, 5)], 3, 9, vec![br(0, 9)]; "existing region extending past current end")]
697 #[test_case(vec![br(3, 5)], 0, 3, vec![br(0, 5)]; "existing region extending before current start")]
698 #[test_case(vec![br(3, 5)], 0, 9, vec![br(0, 9)]; "existing region contained within new")]
699 #[test_case(vec![br(0, 5), br(7, 15)], 4, 9, vec![br(0, 15)]; "new region joins multiple existing")]
700 #[test]
701 fn mark_region_works(
702 mut regions: Vec<ByteRange>,
703 from: usize,
704 to: usize,
705 expected: Vec<ByteRange>,
706 ) {
707 mark_region(&mut regions, from, to);
708 assert_eq!(regions, expected);
709 }
710
711 #[test_case(vec![br(0, 100)], 5, 20, None; "contained")]
712 #[test_case(vec![br(0, 1366)], 89, 1385, Some((1366, 1385)); "scroll down")]
713 #[test_case(vec![br(100, 1366)], 0, 255, Some((0, 100)); "scroll up")]
714 #[test_case(vec![br(100, 1366)], 0, 80, Some((0, 80)); "before")]
715 #[test_case(vec![br(100, 1366)], 1400, 1500, Some((1400, 1500)); "after")]
716 #[test_case(vec![br(0, 100), br(200, 300)], 150, 180, Some((150, 180)); "in between regions")]
717 #[test_case(vec![br(0, 100), br(200, 300)], 50, 180, Some((100, 180)); "from one range into gap")]
718 #[test_case(vec![br(0, 100), br(200, 300)], 150, 280, Some((150, 200)); "from gap into region")]
719 #[test_case(vec![br(0, 100), br(200, 300)], 50, 280, Some((100, 200)); "from one region into another")]
720 #[test_case(vec![br(50, 100), br(200, 300)], 0, 150, Some((0, 150)); "around an existing region")]
721 #[test]
722 fn missing_region_works(
723 regions: Vec<ByteRange>,
724 from: usize,
725 to: usize,
726 expected: Option<(usize, usize)>,
727 ) {
728 let res = missing_region(®ions, from, to);
729 assert_eq!(res, expected);
730 }
731}