1use std::{collections::HashMap, ops::Range, path::Path, sync::LazyLock};
9
10use gapbuf::GapBuffer;
11use parking_lot::Mutex;
12use streaming_iterator::StreamingIterator;
13use tree_sitter::{
14 InputEdit, Language, Node, Parser, Point as TSPoint, Query, QueryCursor, TextProvider, Tree,
15};
16
17use super::{Change, Key, Point, Reader, Tag, Text};
18use crate::{
19 cfg::PrintCfg,
20 form::{self, FormId},
21 text::Matcheable,
22};
23
24pub struct TsParser {
25 parser: Parser,
26 queries: [Query; 2],
27 tree: Tree,
28 forms: &'static [(FormId, Key, Key)],
29 name: &'static str,
30 keys: Range<Key>,
31 old_tree: Option<Tree>,
32}
33
34impl TsParser {
35 pub fn new(text: &mut Text, path: impl AsRef<Path>) -> Option<Self> {
36 let (name, lang, queries) = lang_from_path(path)?;
37
38 let mut parser = Parser::new();
39 parser.set_language(lang).unwrap();
40 let queries = queries.map(|q| Query::new(lang, q).unwrap());
41
42 let tree = parser
43 .parse_with_options(&mut buf_parse(text), None, None)
44 .unwrap();
45 let mut cursor = QueryCursor::new();
46 let buf = TsBuf(&text.buf);
47
48 let (keys, forms) = forms_from_query(name, &queries[0]);
49
50 let mut captures = cursor.captures(&queries[0], tree.root_node(), buf);
51
52 while let Some((captures, _)) = captures.next() {
53 for cap in captures.captures.iter() {
54 let range = cap.node.range();
55 let (start, end) = (range.start_byte, range.end_byte);
56 let (form, start_key, end_key) = forms[cap.index as usize];
57 if start != end {
58 text.tags.insert(start, Tag::PushForm(form), start_key);
59 text.tags.insert(end, Tag::PopForm(form), end_key);
60 }
61 }
62 }
63
64 Some(TsParser {
65 parser,
66 queries,
67 tree,
68 forms,
69 name,
70 keys,
71 old_tree: None,
72 })
73 }
74
75 pub fn lang(&self) -> &'static str {
76 self.name
77 }
78
79 pub fn indent_on(&self, text: &mut Text, p: Point, cfg: PrintCfg) -> Option<usize> {
82 let query = &self.queries[1];
83 if query.pattern_count() == 0 {
84 return None;
85 }
86 let tab = cfg.tab_stops.size() as i32;
87 let (start, _) = text.points_of_line(p.line());
88 let indented_start = text
89 .chars_fwd(start)
90 .take_while(|(p, _)| p.line() == start.line())
91 .find_map(|(p, c)| (!c.is_whitespace()).then_some(p));
92 let root = self.tree.root_node();
94
95 type Captures<'a> = HashMap<&'a str, HashMap<usize, HashMap<&'a str, Option<&'a str>>>>;
96 let mut caps: Captures = HashMap::new();
97 let q = {
98 let mut cursor = QueryCursor::new();
99 let buf = TsBuf(&text.buf);
100 cursor.matches(query, root, buf).for_each(|qm| {
101 for cap in qm.captures.iter() {
102 let cap_end = query.capture_names()[cap.index as usize]
103 .strip_prefix("indent.")
104 .unwrap();
105 let nodes = if let Some(nodes) = caps.get_mut(cap_end) {
106 nodes
107 } else {
108 caps.insert(cap_end, HashMap::new());
109 caps.get_mut(cap_end).unwrap()
110 };
111 let props = query.property_settings(qm.pattern_index).iter();
112 nodes.insert(
113 cap.node.id(),
114 props
115 .map(|p| {
116 let key = p.key.strip_prefix("indent.").unwrap();
117 (key, p.value.as_deref())
118 })
119 .collect(),
120 );
121 }
122 });
123 |caps: &Captures, node: Node, queries: &[&str]| {
124 caps.get(queries[0])
125 .and_then(|nodes| nodes.get(&node.id()))
126 .is_some_and(|props| {
127 let key = queries.get(1);
128 key.is_none_or(|key| props.iter().any(|(k, _)| k == key))
129 })
130 }
131 };
132
133 let mut opt_node = if let Some(indented_start) = indented_start {
134 Some(descendant_in(root, indented_start.byte()))
135 } else {
137 let Some((prev_l, line)) = text
139 .lines_in((Point::default(), start))
140 .rev()
141 .find(|(_, line)| !(line.matches(r"^\s*$", ..).unwrap()))
142 else {
143 return Some(0);
145 };
146 let trail = line.chars().rev().take_while(|c| c.is_whitespace()).count();
147 let (prev_start, prev_end) = text.points_of_line(prev_l);
148 let mut node = descendant_in(root, prev_end.byte() - (trail + 1));
149 if node.kind().contains("comment") {
150 let first_node = descendant_in(root, prev_start.byte());
154 if first_node.id() != node.id() {
155 node = descendant_in(root, node.start_byte() - 1)
156 }
157 }
158
159 Some(if q(&caps, node, &["end"]) {
160 descendant_in(root, start.byte())
161 } else {
162 node
163 })
164 };
165
166 if q(&caps, opt_node.unwrap(), &["zero"]) {
167 return Some(0);
168 }
169
170 let mut indent = 0;
171 let mut processed_lines = Vec::new();
172 while let Some(node) = opt_node {
173 if !q(&caps, node, &["begin"])
176 && node.start_position().row < p.line()
177 && p.line() <= node.end_position().row
178 {
179 if !q(&caps, node, &["align"]) && q(&caps, node, &["auto"]) {
180 return None;
181 } else if q(&caps, node, &["ignore"]) {
182 return Some(0);
183 }
184 }
185
186 let s_line = node.range().start_point.row;
187 let e_line = node.range().end_point.row;
188 let should_process = !processed_lines.contains(&s_line);
189
190 let mut is_processed = false;
191
192 if should_process
193 && ((s_line == p.line() && q(&caps, node, &["branch"]))
194 || (s_line != p.line() && q(&caps, node, &["dedent"])))
195 {
196 indent -= tab;
197 is_processed = true;
198 }
199
200 let is_in_err = should_process && node.parent().is_some_and(|p| p.is_error());
201 if should_process
204 && q(&caps, node, &["begin"])
205 && (s_line != e_line || is_in_err || q(&caps, node, &["begin", "immediate"]))
206 && (s_line != p.line() || q(&caps, node, &["begin", "start_at_same_line"]))
207 {
208 is_processed = true;
209 indent += tab;
210 }
211
212 if is_in_err && !q(&caps, node, &["align"]) {
213 let mut cursor = node.walk();
214 for child in node.children(&mut cursor) {
215 if q(&caps, child, &["align"]) {
216 let props = caps["align"][&child.id()].clone();
217 caps.get_mut("align").unwrap().insert(node.id(), props);
218 }
219 }
220 }
221
222 type FoundDelim<'a> = (Option<Node<'a>>, bool);
223 fn find_delim<'a>(text: &mut Text, node: Node<'a>, delim: &str) -> FoundDelim<'a> {
224 let mut c = node.walk();
225 let child = node.children(&mut c).find(|child| child.kind() == delim);
226 let ret = child.map(|child| {
227 let (_, end) = text.points_of_line(child.range().start_point.row);
228 let range = child.range().start_byte..end.byte();
229 text.make_contiguous_in(range.clone());
230 let line = unsafe { text.continuous_in_unchecked(range) };
231 let is_last_in_line = line.split_whitespace().any(|w| w != delim);
232 (child, is_last_in_line)
233 });
234 let (child, is_last_in_line) = ret.unzip();
235 (child, is_last_in_line.unwrap_or(false))
236 }
237
238 if should_process
239 && q(&caps, node, &["align"])
240 && (s_line != e_line || is_in_err)
241 && s_line != p.line()
242 {
243 let props = &caps["align"][&node.id()];
244 let (o_delim_node, o_is_last_in_line) = props
245 .get(&"open_delimiter")
246 .and_then(|delim| delim.map(|d| find_delim(text, node, d)))
247 .unwrap_or((Some(node), false));
248 let (c_delim_node, c_is_last_in_line) = props
249 .get(&"close_delimiter")
250 .and_then(|delim| delim.map(|d| find_delim(text, node, d)))
251 .unwrap_or((Some(node), false));
252
253 if let Some(o_delim_node) = o_delim_node {
254 let o_s_line = o_delim_node.start_position().row;
255 let o_s_col = o_delim_node.start_position().row;
256 let c_s_line = c_delim_node.map(|n| n.start_position().row);
257
258 let indent_is_absolute = if o_is_last_in_line && should_process {
261 indent += tab;
262 if c_is_last_in_line && c_s_line.is_some_and(|l| l < p.line()) {
265 indent = (indent - 1).max(0);
266 }
267 false
268 } else if c_is_last_in_line
270 && let Some(c_s_line) = c_s_line
271 && (o_s_line != c_s_line && c_s_line < p.line())
274 {
275 indent = (indent - 1).max(0);
276 false
277 } else {
278 let inc = props.get("increment").cloned().flatten();
279 indent = o_s_col as i32 + inc.map(str::parse::<i32>).unwrap().unwrap();
280 true
281 };
282
283 let avoid_last_matching_next = c_s_line
287 .is_some_and(|c_s_line| c_s_line != o_s_line && c_s_line == p.line())
288 && props.contains_key("avoid_last_matching_next");
289 if avoid_last_matching_next {
290 indent += tab;
291 }
292 is_processed = true;
293 if indent_is_absolute {
294 return Some(indent as usize);
295 }
296 }
297 }
298
299 if should_process && is_processed {
300 processed_lines.push(s_line);
301 }
302 opt_node = node.parent();
303 }
304
305 Some(indent as usize)
306 }
307
308 pub fn parse_with<T: AsRef<[u8]>, F: FnMut(usize, tree_sitter::Point) -> T>(
309 &mut self,
310 callback: &mut F,
311 old_tree: Option<&Tree>,
312 ) -> Option<Tree> {
313 self.parser.parse_with_options(callback, old_tree, None)
314 }
315}
316
317fn descendant_in(node: Node, byte: usize) -> Node {
318 node.descendant_for_byte_range(byte, byte + 1).unwrap()
319}
320
321impl Reader for TsParser {
322 fn apply_changes(&mut self, text: &Text, changes: &[Change<&str>]) {
323 for change in changes {
324 let start = change.start();
325 let added = change.added_end();
326 let taken = change.taken_end();
327
328 let ts_start = ts_point(start, text);
329 let ts_taken_end = ts_point_from(taken, (ts_start.column, start), text);
330 let ts_added_end = ts_point_from(added, (ts_start.column, start), text);
331 self.tree.edit(&InputEdit {
332 start_byte: start.byte(),
333 old_end_byte: taken.byte(),
334 new_end_byte: added.byte(),
335 start_position: ts_start,
336 old_end_position: ts_taken_end,
337 new_end_position: ts_added_end,
338 });
339 }
340
341 let tree = self
342 .parser
343 .parse_with_options(&mut buf_parse(text), Some(&self.tree), None)
344 .unwrap();
345 self.old_tree = Some(std::mem::replace(&mut self.tree, tree));
346 }
347
348 fn ranges_to_update(&mut self, text: &Text, changes: &[Change<&str>]) -> Vec<Range<usize>> {
349 let old_tree = self.old_tree.as_ref().unwrap();
350 let mut cursor = QueryCursor::new();
351 let buf = TsBuf(&text.buf);
352
353 let mut to_update = Vec::new();
354
355 for change in changes {
356 let start = change.start();
357 let added = change.added_end();
358 let start = text.point_at_line(start.line()).byte();
359 let end = text.point_at_line(added.line() + 1).byte();
360
361 let ts_start = start.saturating_sub(1);
364 let ts_end = (end + 1).min(text.len().byte());
365 cursor.set_byte_range(ts_start..ts_end);
366
367 let mut this_to_update = Vec::new();
368
369 let mut query_matches = cursor.captures(&self.queries[0], old_tree.root_node(), buf);
375 while let Some((query_match, _)) = query_matches.next() {
376 for cap in query_match.captures.iter() {
377 let range = cap.node.range();
378 if range.start_point.row != range.end_point.row {
379 this_to_update.push((range, cap.index));
380 }
381 }
382 }
383
384 let mut query_matches = cursor.captures(&self.queries[0], self.tree.root_node(), buf);
385 'ml_range_edited: while let Some((query_match, _)) = query_matches.next() {
386 for cap in query_match.captures.iter() {
387 let range = cap.node.range();
388 let entry = (range, cap.index);
389
390 if range.start_point.row != range.end_point.row
396 && this_to_update
397 .extract_if(.., |r| *r == entry)
398 .next()
399 .is_none()
400 {
401 this_to_update.push(entry);
402 break 'ml_range_edited;
403 }
404 }
405 }
406 if this_to_update.is_empty() {
407 super::merge_range_in(&mut to_update, start..end)
408 } else {
409 super::merge_range_in(&mut to_update, start..text.len().byte());
413 break;
414 }
415 }
416
417 to_update
418 }
419
420 fn update_range(&mut self, text: &mut Text, range: Range<usize>) {
421 let buf = TsBuf(&text.buf);
422
423 text.tags.remove_from(range.clone(), self.keys.clone());
424
425 let start = range.start.saturating_sub(1);
426 let end = (range.end + 1).min(text.len().byte());
427
428 let mut cursor = QueryCursor::new();
429
430 cursor.set_byte_range(start..end);
431 let mut query_matches = cursor.captures(&self.queries[0], self.tree.root_node(), buf);
432 while let Some((query_match, _)) = query_matches.next() {
433 for cap in query_match.captures.iter() {
434 let bytes = cap.node.byte_range();
435 let (form, start_key, end_key) = self.forms[cap.index as usize];
436 if bytes.start != bytes.end {
437 text.tags
438 .insert(bytes.start, Tag::PushForm(form), start_key);
439 text.tags.insert(bytes.end, Tag::PopForm(form), end_key);
440 }
441 }
442 }
443 }
444}
445
446fn buf_parse<'a>(text: &'a Text) -> impl FnMut(usize, TSPoint) -> &'a [u8] {
447 let [s0, s1] = text.strs();
448 |byte, _point| {
449 if byte < s0.len() {
450 &s0.as_bytes()[byte..]
451 } else {
452 &s1.as_bytes()[byte - s0.len()..]
453 }
454 }
455}
456
457fn ts_point(point: Point, text: &Text) -> TSPoint {
458 let strs = text.strs_in(..point.byte());
459 let iter = strs.into_iter().flat_map(str::chars).rev();
460 let col = iter.take_while(|&b| b != '\n').count();
461
462 TSPoint::new(point.line(), col)
463}
464
465fn ts_point_from(to: Point, from: (usize, Point), text: &Text) -> TSPoint {
466 let (col, from) = from;
467 let strs = text.strs_in((from, to));
468 let iter = strs.into_iter().flat_map(str::chars).rev();
469
470 let col = if to.line() == from.line() {
471 col + iter.count()
472 } else {
473 iter.take_while(|&b| b != '\n').count()
474 };
475
476 TSPoint::new(to.line(), col)
477}
478
479fn forms_from_query(
480 lang: &'static str,
481 query: &Query,
482) -> (Range<Key>, &'static [(FormId, Key, Key)]) {
483 static LISTS: Mutex<Vec<(&str, (Range<Key>, &[(FormId, Key, Key)]))>> = Mutex::new(Vec::new());
484 let mut lists = LISTS.lock();
485
486 if let Some((_, (keys, forms))) = lists.iter().find(|(l, _)| *l == lang) {
487 (keys.clone(), forms)
488 } else {
489 let mut forms = Vec::new();
490 let capture_names = query.capture_names();
491 let keys = Key::new_many(capture_names.len() * 2);
492
493 let mut iter_keys = keys.clone();
494 for name in capture_names {
495 let start = iter_keys.next().unwrap();
496 let end = iter_keys.next().unwrap();
497 forms.push(if name.contains('.') {
498 let (refed, _) = name.rsplit_once('.').unwrap();
499 (form::set_weak(name, refed), start, end)
500 } else {
501 (form::set_weak(name, "Default"), start, end)
502 });
503 }
504
505 lists.push((lang, (keys, forms.leak())));
506 let (_, (keys, forms)) = lists.last().unwrap();
507 (keys.clone(), forms)
508 }
509}
510
511impl<'a> TextProvider<&'a [u8]> for TsBuf<'a> {
512 type I = std::array::IntoIter<&'a [u8], 2>;
513
514 fn text(&mut self, node: tree_sitter::Node) -> Self::I {
515 let range = node.range();
516 let (s0, s1) = self.0.range(range.start_byte..range.end_byte).as_slices();
517
518 [s0, s1].into_iter()
519 }
520}
521
522#[derive(Clone, Copy)]
523struct TsBuf<'a>(&'a GapBuffer<u8>);
524
525#[allow(unused)]
526#[cfg(debug_assertions)]
527fn log_node(node: tree_sitter::Node) {
528 use std::fmt::Write;
529
530 let mut cursor = node.walk();
531 let mut node = Some(cursor.node());
532 let mut log = String::new();
533 while let Some(no) = node {
534 let indent = " ".repeat(cursor.depth() as usize);
535 writeln!(log, "{indent}{no:?}").unwrap();
536 let mut next_exists = cursor.goto_first_child() || cursor.goto_next_sibling();
537 while !next_exists && cursor.goto_parent() {
538 next_exists = cursor.goto_next_sibling();
539 }
540 node = next_exists.then_some(cursor.node());
541 }
542
543 crate::log_file!("{log}");
544}
545
546fn lang_from_path(
547 path: impl AsRef<Path>,
548) -> Option<(&'static str, &'static Language, [&'static str; 2])> {
549 type Lang<'a> = ((&'a str, &'a str, &'a str), &'a Language, [&'a str; 2]);
550 static LANGUAGES: LazyLock<Mutex<Vec<Lang>>> = LazyLock::new(|| {
551 macro l($lang:ident) {
552 Box::leak(Box::new($lang::LANGUAGE.into()))
553 }
554 macro lf($lang:ident) {
555 Box::leak(Box::new($lang::language()))
556 }
557 macro h($lang:ident) {{
558 let hi = include_str!(concat!(
559 "../../../ts-queries/",
560 stringify!($lang),
561 "/highlights.scm"
562 ));
563 [hi, ""]
564 }}
565 macro i($lang:ident) {
566 include_str!(concat!(
567 "../../../ts-queries/",
568 stringify!($lang),
569 "/indents.scm"
570 ))
571 }
572 macro h_i($lang:ident) {
573 [h!($lang)[0], i!($lang)]
574 }
575
576 let lang_ocaml = Box::leak(Box::new(ts_ocaml::LANGUAGE_OCAML.into()));
577 let lang_php = Box::leak(Box::new(ts_php::LANGUAGE_PHP_ONLY.into()));
578 let lang_ts = Box::leak(Box::new(ts_ts::LANGUAGE_TYPESCRIPT.into()));
579 let lang_xml = Box::leak(Box::new(ts_xml::LANGUAGE_XML.into()));
580
581 Mutex::new(vec![
582 (("asm", "Assembly", "assembly"), l!(ts_asm), h!(asm)),
606 (("c", "C", "c"), l!(ts_c), h_i!(c)),
607 (("cc", "C++", "cpp"), l!(ts_cpp), h_i!(cpp)),
608 (("cpp", "C++", "cpp"), l!(ts_cpp), h_i!(cpp)),
609 (("cs", "C#", "csharp"), l!(ts_c_sharp), h!(c_sharp)),
610 (("css", "CSS", "css"), l!(ts_css), h_i!(css)),
611 (("cxx", "C++", "cpp"), l!(ts_cpp), h_i!(cpp)),
612 (("dart", "Dart", "dart"), lf!(ts_dart), h_i!(dart)),
613 (("erl", "Erlang", "erlang"), l!(ts_erlang), h!(erlang)),
614 (("ex", "Elixir", "elixir"), l!(ts_elixir), h_i!(elixir)),
615 (("exs", "Elixir", "elixir"), l!(ts_elixir), h_i!(elixir)),
616 (("for", "Fortran", "fortran"), l!(ts_fortran), h_i!(fortran)),
617 (("fpp", "Fortran", "fortran"), l!(ts_fortran), h_i!(fortran)),
618 (("gleam", "Gleam", "gleam"), l!(ts_gleam), h_i!(gleam)),
619 (("go", "Go", "go"), l!(ts_go), h_i!(go)),
620 (("groovy", "Groovy", "groovy"), l!(ts_groovy), h_i!(groovy)),
621 (("gvy", "Groovy", "groovy"), l!(ts_groovy), h_i!(groovy)),
622 (("h", "C", "c"), l!(ts_c), h_i!(c)),
623 (("hpp", "C++", "cpp"), l!(ts_cpp), h_i!(cpp)),
624 (("hrl", "Erlang", "erlang"), l!(ts_erlang), h!(erlang)),
625 (("hs", "Haskell", "haskell"), l!(ts_haskell), h!(haskell)),
626 (("hsc", "Haskell", "haskell"), l!(ts_haskell), h!(haskell)),
627 (("htm", "HTML", "html"), l!(ts_html), h_i!(html)),
628 (("html", "HTML", "html"), l!(ts_html), h_i!(html)),
629 (("hxx", "C++", "cpp"), l!(ts_cpp), h_i!(cpp)),
630 (("java", "Java", "java"), l!(ts_java), h_i!(java)),
631 (("jl", "Julia", "julia"), l!(ts_julia), h_i!(julia)),
632 (("js", "JavaScript", "javascript"), l!(ts_js), h_i!(js)),
633 (("json", "JSON", "json"), l!(ts_json), h_i!(json)),
634 (("jsonc", "JSON", "jsonc"), l!(ts_json), h_i!(json)),
635 (("lua", "Lua", "lua"), l!(ts_lua), h_i!(lua)),
636 (("m", "Objective-C", "objc"), l!(ts_objc), h_i!(objc)),
637 (("md", "Markdown", "markdown"), l!(ts_md), h_i!(markdown)),
638 (("ml", "OCaml", "ocaml"), lang_ocaml, h_i!(ocaml)),
639 (("nix", "Nix", "nix"), l!(ts_nix), h_i!(nix)),
640 (("php", "PHP", "php"), lang_php, h_i!(php)),
641 (("py", "Python", "python"), l!(ts_python), h_i!(python)),
642 (("pyc", "Python", "python"), l!(ts_python), h_i!(python)),
643 (("pyo", "Python", "python"), l!(ts_python), h_i!(python)),
644 (("r", "R", "r"), l!(ts_r), h_i!(r)),
645 (("rb", "Ruby", "ruby"), l!(ts_ruby), h_i!(ruby)),
646 (("rs", "Rust", "rust"), l!(ts_rust), h_i!(rust)),
647 (("sc", "Scala", "scala"), l!(ts_scala), h!(scala)),
648 (("scala", "Scala", "scala"), l!(ts_scala), h!(scala)),
649 (("scss", "SCSS", "scss"), lf!(ts_scss), h_i!(scss)),
650 (("sh", "Shell", "shell"), l!(ts_bash), h!(bash)),
651 (("sql", "SQL", "sql"), l!(ts_sequel), h_i!(sql)),
652 (("swift", "Swift", "swift"), l!(ts_swift), h_i!(swift)),
653 (("ts", "TypeScript", "typescript"), lang_ts, h!(ts)),
654 (("vim", "Viml", "viml"), lf!(ts_vim), h!(vim)),
655 (("xml", "XML", "xml"), lang_xml, h_i!(xml)),
656 (("xrl", "Erlang", "erlang"), l!(ts_erlang), h!(erlang)),
657 (("yaml", "YAML", "yaml"), l!(ts_yaml), h_i!(yaml)),
658 (("yml", "YAML", "yaml"), l!(ts_yaml), h_i!(yaml)),
659 (("yrl", "Erlang", "erlang"), l!(ts_erlang), h!(erlang)),
660 (("zig", "Zig", "zig"), l!(ts_zig), h_i!(zig)),
661 ])
662 });
663
664 let ext = path.as_ref().extension()?.to_str()?;
665 let langs = LANGUAGES.lock();
666 langs
667 .binary_search_by(|((rhs, ..), ..)| rhs.cmp(&ext))
668 .ok()
669 .map(|i| {
670 let ((_, name, _), lang, hl) = langs.get(i).unwrap();
671 (*name, *lang, *hl)
672 })
673}