1use editor_core::delta::TextDelta;
2use editor_core::processing::{DocumentProcessor, ProcessingEdit};
3use editor_core::{EditorStateManager, FoldRegion, Interval, LineIndex, StyleId, StyleLayerId};
4use std::collections::BTreeMap;
5use streaming_iterator::StreamingIterator;
6use tree_sitter::{InputEdit, Parser, Point, Query, QueryCursor, Tree};
7
8#[derive(Debug)]
10pub enum TreeSitterError {
11 Wasm(String),
13 Io(String),
15 Language(String),
17 Query(String),
19 DeltaMismatch,
21}
22
23impl std::fmt::Display for TreeSitterError {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 match self {
26 Self::Wasm(msg) => write!(f, "tree-sitter wasm error: {msg}"),
27 Self::Io(msg) => write!(f, "tree-sitter io error: {msg}"),
28 Self::Language(msg) => write!(f, "tree-sitter language error: {msg}"),
29 Self::Query(msg) => write!(f, "tree-sitter query error: {msg}"),
30 Self::DeltaMismatch => write!(f, "tree-sitter delta mismatch"),
31 }
32 }
33}
34
35impl std::error::Error for TreeSitterError {}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum TreeSitterUpdateMode {
40 Initial,
42 Incremental,
44 FullReparse,
46 Skipped,
48}
49
50#[derive(Debug, Clone)]
52pub enum TreeSitterLanguage {
53 Native(tree_sitter::Language),
55 Wasm {
57 language_id: String,
59 wasm_bytes: Vec<u8>,
61 },
62}
63
64impl TreeSitterLanguage {
65 pub fn native(language: tree_sitter::Language) -> Self {
67 Self::Native(language)
68 }
69
70 pub fn wasm(language_id: String, wasm_bytes: Vec<u8>) -> Self {
72 Self::Wasm {
73 language_id,
74 wasm_bytes,
75 }
76 }
77}
78
79#[derive(Debug, Clone)]
81pub struct TreeSitterProcessorConfig {
82 pub language: TreeSitterLanguage,
84 pub highlights_query: String,
86 pub folds_query: Option<String>,
88 pub capture_styles: BTreeMap<String, StyleId>,
90 pub style_layer: StyleLayerId,
92 pub preserve_collapsed_folds: bool,
94}
95
96impl TreeSitterProcessorConfig {
97 pub fn new(language: TreeSitterLanguage, highlights_query: impl Into<String>) -> Self {
103 Self {
104 language,
105 highlights_query: highlights_query.into(),
106 folds_query: None,
107 capture_styles: BTreeMap::new(),
108 style_layer: StyleLayerId::TREE_SITTER,
109 preserve_collapsed_folds: true,
110 }
111 }
112
113 pub fn with_folds_query(mut self, folds_query: impl Into<String>) -> Self {
115 self.folds_query = Some(folds_query.into());
116 self
117 }
118
119 pub fn with_simple_capture_styles<const N: usize>(
121 mut self,
122 styles: [(&'static str, StyleId); N],
123 ) -> Self {
124 for (name, style_id) in styles {
125 self.capture_styles.insert(name.to_string(), style_id);
126 }
127 self
128 }
129
130 pub fn set_preserve_collapsed_folds(&mut self, preserve: bool) {
132 self.preserve_collapsed_folds = preserve;
133 }
134
135 pub fn highlights_capture_names(&self) -> Result<Vec<String>, TreeSitterError> {
140 let query = match &self.language {
141 TreeSitterLanguage::Native(language) => Query::new(language, &self.highlights_query)
142 .map_err(|e| TreeSitterError::Query(e.to_string()))?,
143 TreeSitterLanguage::Wasm {
144 language_id,
145 wasm_bytes,
146 } => {
147 let engine = tree_sitter::wasmtime::Engine::default();
148 let mut store = tree_sitter::WasmStore::new(&engine)
149 .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
150 let language = store
151 .load_language(language_id, wasm_bytes)
152 .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
153 Query::new(&language, &self.highlights_query)
154 .map_err(|e| TreeSitterError::Query(e.to_string()))?
155 }
156 };
157
158 Ok(query
159 .capture_names()
160 .iter()
161 .map(|name| (*name).to_string())
162 .collect())
163 }
164}
165
166pub struct TreeSitterProcessor {
171 config: TreeSitterProcessorConfig,
172 parser: Parser,
173 highlight_query: Query,
174 highlight_capture_styles: Vec<Option<StyleId>>,
175 fold_query: Option<Query>,
176 tree: Option<Tree>,
177 needs_parse: bool,
178 text: String,
179 line_index: LineIndex,
180 last_synced_version: Option<u64>,
181 last_processed_version: Option<u64>,
182 last_update_mode: TreeSitterUpdateMode,
183 #[allow(dead_code)]
185 wasm_engine: Option<tree_sitter::wasmtime::Engine>,
186}
187
188impl TreeSitterProcessor {
189 pub fn new(config: TreeSitterProcessorConfig) -> Result<Self, TreeSitterError> {
191 let mut parser = Parser::new();
192 let (language, wasm_engine) = match &config.language {
193 TreeSitterLanguage::Native(language) => {
194 parser
195 .set_language(language)
196 .map_err(|e| TreeSitterError::Language(e.to_string()))?;
197 (language.clone(), None)
198 }
199 TreeSitterLanguage::Wasm {
200 language_id,
201 wasm_bytes,
202 } => {
203 let engine = tree_sitter::wasmtime::Engine::default();
204
205 let store = tree_sitter::WasmStore::new(&engine)
207 .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
208 parser
209 .set_wasm_store(store)
210 .map_err(|e| TreeSitterError::Language(e.to_string()))?;
211
212 let mut store = tree_sitter::WasmStore::new(&engine)
214 .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
215 let language = store
216 .load_language(language_id, wasm_bytes)
217 .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
218
219 parser
220 .set_language(&language)
221 .map_err(|e| TreeSitterError::Language(e.to_string()))?;
222
223 (language, Some(engine))
224 }
225 };
226
227 let highlight_query = Query::new(&language, &config.highlights_query)
228 .map_err(|e| TreeSitterError::Query(e.to_string()))?;
229 let highlight_capture_styles = highlight_query
230 .capture_names()
231 .iter()
232 .map(|name| config.capture_styles.get(*name).copied())
233 .collect::<Vec<_>>();
234
235 let fold_query = match config.folds_query.as_deref() {
236 Some(q) if !q.trim().is_empty() => {
237 Some(Query::new(&language, q).map_err(|e| TreeSitterError::Query(e.to_string()))?)
238 }
239 _ => None,
240 };
241
242 Ok(Self {
243 config,
244 parser,
245 highlight_query,
246 highlight_capture_styles,
247 fold_query,
248 tree: None,
249 needs_parse: false,
250 text: String::new(),
251 line_index: LineIndex::new(),
252 last_synced_version: None,
253 last_processed_version: None,
254 last_update_mode: TreeSitterUpdateMode::FullReparse,
255 wasm_engine,
256 })
257 }
258
259 pub fn last_update_mode(&self) -> TreeSitterUpdateMode {
261 self.last_update_mode
262 }
263
264 pub fn expand_selection_syntax(&self, start: usize, end: usize) -> Option<(usize, usize)> {
274 let tree = self.tree.as_ref()?;
275 let root = tree.root_node();
276
277 let (sel_start, sel_end) = if start <= end {
278 (start, end)
279 } else {
280 (end, start)
281 };
282 let start_byte = self.line_index.char_offset_to_byte_offset(sel_start);
283 let end_byte = self.line_index.char_offset_to_byte_offset(sel_end);
284
285 let mut node = root.descendant_for_byte_range(start_byte, end_byte)?;
286
287 loop {
288 let node_start = self
289 .line_index
290 .byte_offset_to_char_offset(node.start_byte());
291 let node_end = self.line_index.byte_offset_to_char_offset(node.end_byte());
292
293 if node_start == sel_start && node_end == sel_end {
295 if let Some(parent) = node.parent() {
296 node = parent;
297 continue;
298 }
299 return None;
300 }
301
302 return Some((node_start, node_end));
303 }
304 }
305
306 fn sync_from_text_full(&mut self, text: &str) {
307 self.text.clear();
308 self.text.push_str(text);
309 self.line_index = LineIndex::from_text(&self.text);
310 }
311
312 fn point_for_char_offset(&self, char_offset: usize) -> Point {
313 let (row, col) = self.line_index.char_offset_to_line_byte_column(char_offset);
314 Point { row, column: col }
315 }
316
317 fn advance_point(mut point: Point, text: &str) -> Point {
318 let mut parts = text.split('\n');
319 let Some(first) = parts.next() else {
320 return point;
321 };
322
323 point.column = point.column.saturating_add(first.len());
324 for part in parts {
325 point.row = point.row.saturating_add(1);
326 point.column = part.len();
327 }
328
329 point
330 }
331
332 fn apply_text_delta_incremental(&mut self, delta: &TextDelta) -> Result<(), TreeSitterError> {
333 if self.line_index.char_count() != delta.before_char_count {
334 return Err(TreeSitterError::DeltaMismatch);
335 }
336 if self.tree.is_none() {
337 return Err(TreeSitterError::DeltaMismatch);
338 }
339
340 for edit in &delta.edits {
341 let start_char = edit.start;
342 let deleted_chars = edit.deleted_text.chars().count();
343
344 let start_byte = self.line_index.char_offset_to_byte_offset(start_char);
345 let old_end_byte = start_byte.saturating_add(edit.deleted_text.len());
346 let new_end_byte = start_byte.saturating_add(edit.inserted_text.len());
347
348 let Some(old_slice) = self.text.get(start_byte..old_end_byte) else {
349 return Err(TreeSitterError::DeltaMismatch);
350 };
351 if old_slice != edit.deleted_text {
352 return Err(TreeSitterError::DeltaMismatch);
353 }
354
355 let start_position = self.point_for_char_offset(start_char);
356 let old_end_position = Self::advance_point(start_position, &edit.deleted_text);
357 let new_end_position = Self::advance_point(start_position, &edit.inserted_text);
358
359 if let Some(tree) = self.tree.as_mut() {
360 tree.edit(&InputEdit {
361 start_byte,
362 old_end_byte,
363 new_end_byte,
364 start_position,
365 old_end_position,
366 new_end_position,
367 });
368 }
369
370 self.text
371 .replace_range(start_byte..old_end_byte, &edit.inserted_text);
372 self.line_index.delete(start_char, deleted_chars);
373 self.line_index.insert(start_char, &edit.inserted_text);
374 }
375
376 if self.line_index.char_count() != delta.after_char_count {
377 return Err(TreeSitterError::DeltaMismatch);
378 }
379
380 Ok(())
381 }
382
383 fn parse(&mut self) -> Option<Tree> {
384 self.parser.parse(&self.text, self.tree.as_ref())
385 }
386
387 fn collect_highlight_intervals_in_byte_range(
388 &self,
389 tree: &Tree,
390 byte_range: Option<(usize, usize)>,
391 ) -> Vec<Interval> {
392 let mut cursor = QueryCursor::new();
393 if let Some((start, end)) = byte_range {
394 let start = start.min(self.text.len());
395 let end = end.min(self.text.len());
396 if end > start {
397 cursor.set_byte_range(start..end);
398 }
399 }
400 let root = tree.root_node();
401 let mut out = Vec::<Interval>::new();
402
403 let mut matches = cursor.matches(&self.highlight_query, root, self.text.as_bytes());
404 while let Some(m) = matches.next() {
405 for capture in m.captures {
406 let idx = capture.index as usize;
407 let Some(style_id) = self.highlight_capture_styles.get(idx).and_then(|x| *x) else {
408 continue;
409 };
410
411 let node = capture.node;
412 let start_byte = node.start_byte();
413 let end_byte = node.end_byte();
414 if end_byte <= start_byte {
415 continue;
416 }
417
418 let start = self.line_index.byte_offset_to_char_offset(start_byte);
419 let end = self.line_index.byte_offset_to_char_offset(end_byte);
420 if end <= start {
421 continue;
422 }
423
424 out.push(Interval::new(start, end, style_id));
425 }
426 }
427
428 out.sort_by_key(|i| (i.start, i.end, i.style_id));
429 out.dedup_by(|a, b| a.start == b.start && a.end == b.end && a.style_id == b.style_id);
430 out
431 }
432
433 fn collect_fold_regions_in_byte_range(
434 &self,
435 tree: &Tree,
436 byte_range: Option<(usize, usize)>,
437 ) -> Vec<FoldRegion> {
438 let Some(query) = self.fold_query.as_ref() else {
439 return Vec::new();
440 };
441
442 let mut cursor = QueryCursor::new();
443 if let Some((start, end)) = byte_range {
444 let start = start.min(self.text.len());
445 let end = end.min(self.text.len());
446 if end > start {
447 cursor.set_byte_range(start..end);
448 }
449 }
450 let root = tree.root_node();
451 let mut regions = Vec::<FoldRegion>::new();
452
453 let mut matches = cursor.matches(query, root, self.text.as_bytes());
454 while let Some(m) = matches.next() {
455 for capture in m.captures {
456 let node = capture.node;
457 let start_line = node.start_position().row;
458 let end_line = node.end_position().row;
459 if end_line > start_line {
460 regions.push(FoldRegion::new(start_line, end_line));
461 }
462 }
463 }
464
465 regions.sort_by_key(|r| (r.start_line, r.end_line));
466 regions.dedup_by(|a, b| a.start_line == b.start_line && a.end_line == b.end_line);
467 regions
468 }
469}
470
471impl DocumentProcessor for TreeSitterProcessor {
472 type Error = TreeSitterError;
473
474 fn process(&mut self, state: &EditorStateManager) -> Result<Vec<ProcessingEdit>, Self::Error> {
475 let version = state.version();
476 if self.last_processed_version == Some(version) {
477 self.last_update_mode = TreeSitterUpdateMode::Skipped;
478 return Ok(Vec::new());
479 }
480
481 if self.tree.is_none() {
482 let full = state.editor().get_text();
484 return self.process_text(version, None, Some(&full));
485 }
486
487 if let Some(delta) = state.last_text_delta() {
488 match self.process_text(version, Some(delta), None) {
489 Ok(edits) => Ok(edits),
490 Err(TreeSitterError::DeltaMismatch) => {
491 let full = state.editor().get_text();
493 self.process_text(version, Some(delta), Some(&full))
494 }
495 Err(e) => Err(e),
496 }
497 } else {
498 let full = state.editor().get_text();
500 self.process_text(version, None, Some(&full))
501 }
502 }
503}
504
505impl TreeSitterProcessor {
506 pub fn sync_to(
516 &mut self,
517 version: u64,
518 delta: Option<&TextDelta>,
519 full_text: Option<&str>,
520 ) -> Result<TreeSitterUpdateMode, TreeSitterError> {
521 if self.last_synced_version == Some(version) {
522 self.last_update_mode = TreeSitterUpdateMode::Skipped;
523 return Ok(TreeSitterUpdateMode::Skipped);
524 }
525
526 let update_mode = if self.tree.is_none() {
527 let Some(text) = full_text else {
528 return Err(TreeSitterError::DeltaMismatch);
529 };
530 self.sync_from_text_full(text);
531 self.tree = self.parse();
532 self.needs_parse = false;
533 TreeSitterUpdateMode::Initial
534 } else if let Some(delta) = delta {
535 match self.apply_text_delta_incremental(delta) {
536 Ok(()) => {
537 self.needs_parse = true;
540 TreeSitterUpdateMode::Incremental
541 }
542 Err(_) => {
543 let Some(text) = full_text else {
544 return Err(TreeSitterError::DeltaMismatch);
545 };
546 self.sync_from_text_full(text);
547 self.tree = self.parser.parse(&self.text, None);
548 self.needs_parse = false;
549 TreeSitterUpdateMode::FullReparse
550 }
551 }
552 } else {
553 let Some(text) = full_text else {
554 return Err(TreeSitterError::DeltaMismatch);
555 };
556 self.sync_from_text_full(text);
557 self.tree = self.parser.parse(&self.text, None);
558 self.needs_parse = false;
559 TreeSitterUpdateMode::FullReparse
560 };
561
562 self.last_synced_version = Some(version);
563 self.last_update_mode = update_mode;
564 Ok(update_mode)
565 }
566
567 pub fn compute_processing_edits(
577 &mut self,
578 char_range: Option<(usize, usize)>,
579 ) -> Result<Vec<ProcessingEdit>, TreeSitterError> {
580 let Some(version) = self.last_synced_version else {
581 return Ok(Vec::new());
582 };
583 if self.last_processed_version == Some(version) {
584 return Ok(Vec::new());
585 }
586
587 if self.needs_parse {
588 self.tree = self.parse();
589 self.needs_parse = false;
590 }
591
592 let Some(tree) = self.tree.as_ref() else {
593 self.last_processed_version = Some(version);
594 return Ok(Vec::new());
595 };
596
597 let byte_range = char_range.and_then(|(start_char, end_char)| {
598 let start = self.line_index.char_offset_to_byte_offset(start_char);
599 let end = self.line_index.char_offset_to_byte_offset(end_char);
600 if end > start {
601 Some((start, end))
602 } else {
603 None
604 }
605 });
606
607 let intervals = self.collect_highlight_intervals_in_byte_range(tree, byte_range);
608 let fold_regions = self.collect_fold_regions_in_byte_range(tree, byte_range);
609
610 let mut edits = vec![ProcessingEdit::ReplaceStyleLayer {
611 layer: self.config.style_layer,
612 intervals,
613 }];
614
615 if self.fold_query.is_some() {
616 edits.push(ProcessingEdit::ReplaceFoldingRegions {
617 regions: fold_regions,
618 preserve_collapsed: self.config.preserve_collapsed_folds,
619 });
620 }
621
622 self.last_processed_version = Some(version);
623 Ok(edits)
624 }
625
626 pub fn process_text(
639 &mut self,
640 version: u64,
641 delta: Option<&TextDelta>,
642 full_text: Option<&str>,
643 ) -> Result<Vec<ProcessingEdit>, TreeSitterError> {
644 let _ = self.sync_to(version, delta, full_text)?;
646 self.compute_processing_edits(None)
647 }
648}