Skip to main content

editor_core_treesitter/
processor.rs

1use editor_core::delta::TextDelta;
2use editor_core::processing::{DocumentProcessor, ProcessingEdit};
3use editor_core::{EditorStateManager, FoldRegion, Interval, LineIndex, StyleId, StyleLayerId};
4use std::collections::BTreeMap;
5use streaming_iterator::StreamingIterator;
6use tree_sitter::{InputEdit, Parser, Point, Query, QueryCursor, Tree};
7
8/// Errors produced by [`TreeSitterProcessor`].
9#[derive(Debug)]
10pub enum TreeSitterError {
11    /// Loading Tree-sitter WASM failed.
12    Wasm(String),
13    /// I/O failed (reading WASM or query files).
14    Io(String),
15    /// Setting the Tree-sitter language failed.
16    Language(String),
17    /// Compiling a Tree-sitter query failed.
18    Query(String),
19    /// Internal text synchronization failed (the delta did not match the expected text).
20    DeltaMismatch,
21}
22
23impl std::fmt::Display for TreeSitterError {
24    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25        match self {
26            Self::Wasm(msg) => write!(f, "tree-sitter wasm error: {msg}"),
27            Self::Io(msg) => write!(f, "tree-sitter io error: {msg}"),
28            Self::Language(msg) => write!(f, "tree-sitter language error: {msg}"),
29            Self::Query(msg) => write!(f, "tree-sitter query error: {msg}"),
30            Self::DeltaMismatch => write!(f, "tree-sitter delta mismatch"),
31        }
32    }
33}
34
35impl std::error::Error for TreeSitterError {}
36
37/// How the processor updated its parse tree for the last `process()` call.
38#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum TreeSitterUpdateMode {
40    /// First parse for this processor instance.
41    Initial,
42    /// Updated by applying `TextDelta` edits and re-parsing incrementally.
43    Incremental,
44    /// Fell back to re-syncing from full text and re-parsing.
45    FullReparse,
46    /// No work was performed (the processor already handled this editor version).
47    Skipped,
48}
49
50/// Language source for a Tree-sitter processor.
51#[derive(Debug, Clone)]
52pub enum TreeSitterLanguage {
53    /// A native (in-process) Tree-sitter grammar.
54    Native(tree_sitter::Language),
55    /// A WASM Tree-sitter grammar module to be loaded at runtime.
56    Wasm {
57        /// Stable Tree-sitter language id (e.g. `"rust"`).
58        language_id: String,
59        /// Raw WASM bytes (typically `language.wasm` on disk).
60        wasm_bytes: Vec<u8>,
61    },
62}
63
64impl TreeSitterLanguage {
65    /// Create a native language source.
66    pub fn native(language: tree_sitter::Language) -> Self {
67        Self::Native(language)
68    }
69
70    /// Create a WASM language source.
71    pub fn wasm(language_id: String, wasm_bytes: Vec<u8>) -> Self {
72        Self::Wasm {
73            language_id,
74            wasm_bytes,
75        }
76    }
77}
78
79/// Configuration for [`TreeSitterProcessor`].
80#[derive(Debug, Clone)]
81pub struct TreeSitterProcessorConfig {
82    /// Tree-sitter language.
83    pub language: TreeSitterLanguage,
84    /// Syntax highlighting query (`.scm`).
85    pub highlights_query: String,
86    /// Optional folding query (`.scm`). Each capture becomes a fold candidate.
87    pub folds_query: Option<String>,
88    /// Mapping from capture name (e.g. `"comment"`) to an `editor-core` `StyleId`.
89    pub capture_styles: BTreeMap<String, StyleId>,
90    /// Target style layer id to replace.
91    pub style_layer: StyleLayerId,
92    /// Whether to preserve the collapsed state for existing fold regions on replacement.
93    pub preserve_collapsed_folds: bool,
94}
95
96impl TreeSitterProcessorConfig {
97    /// Create a config with a language + highlights query.
98    ///
99    /// By default:
100    /// - `style_layer` is [`StyleLayerId::TREE_SITTER`]
101    /// - `preserve_collapsed_folds` is `true`
102    pub fn new(language: TreeSitterLanguage, highlights_query: impl Into<String>) -> Self {
103        Self {
104            language,
105            highlights_query: highlights_query.into(),
106            folds_query: None,
107            capture_styles: BTreeMap::new(),
108            style_layer: StyleLayerId::TREE_SITTER,
109            preserve_collapsed_folds: true,
110        }
111    }
112
113    /// Set a folding query.
114    pub fn with_folds_query(mut self, folds_query: impl Into<String>) -> Self {
115        self.folds_query = Some(folds_query.into());
116        self
117    }
118
119    /// Add a set of capture name → style id mappings.
120    pub fn with_simple_capture_styles<const N: usize>(
121        mut self,
122        styles: [(&'static str, StyleId); N],
123    ) -> Self {
124        for (name, style_id) in styles {
125            self.capture_styles.insert(name.to_string(), style_id);
126        }
127        self
128    }
129
130    /// Control whether fold replacement preserves collapsed state.
131    pub fn set_preserve_collapsed_folds(&mut self, preserve: bool) {
132        self.preserve_collapsed_folds = preserve;
133    }
134
135    /// Compile `highlights_query` and return its capture names in the query's declaration order.
136    ///
137    /// Hosts can use this to pre-allocate a stable capture → `StyleId` mapping before running the
138    /// processor.
139    pub fn highlights_capture_names(&self) -> Result<Vec<String>, TreeSitterError> {
140        let query = match &self.language {
141            TreeSitterLanguage::Native(language) => Query::new(language, &self.highlights_query)
142                .map_err(|e| TreeSitterError::Query(e.to_string()))?,
143            TreeSitterLanguage::Wasm {
144                language_id,
145                wasm_bytes,
146            } => {
147                let engine = tree_sitter::wasmtime::Engine::default();
148                let mut store = tree_sitter::WasmStore::new(&engine)
149                    .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
150                let language = store
151                    .load_language(language_id, wasm_bytes)
152                    .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
153                Query::new(&language, &self.highlights_query)
154                    .map_err(|e| TreeSitterError::Query(e.to_string()))?
155            }
156        };
157
158        Ok(query
159            .capture_names()
160            .iter()
161            .map(|name| (*name).to_string())
162            .collect())
163    }
164}
165
166/// An incremental Tree-sitter based document processor.
167///
168/// This processor tracks a parse tree and updates it based on `TextDelta` edits when available.
169/// It then produces highlighting and folding edits in `editor-core`'s derived-state format.
170pub struct TreeSitterProcessor {
171    config: TreeSitterProcessorConfig,
172    parser: Parser,
173    highlight_query: Query,
174    highlight_capture_styles: Vec<Option<StyleId>>,
175    fold_query: Option<Query>,
176    tree: Option<Tree>,
177    needs_parse: bool,
178    text: String,
179    line_index: LineIndex,
180    last_synced_version: Option<u64>,
181    last_processed_version: Option<u64>,
182    last_update_mode: TreeSitterUpdateMode,
183    // Keep the Wasmtime engine alive for the lifetime of the parser's Wasm store.
184    #[allow(dead_code)]
185    wasm_engine: Option<tree_sitter::wasmtime::Engine>,
186}
187
188impl TreeSitterProcessor {
189    /// Create a new processor from the given config.
190    pub fn new(config: TreeSitterProcessorConfig) -> Result<Self, TreeSitterError> {
191        let mut parser = Parser::new();
192        let (language, wasm_engine) = match &config.language {
193            TreeSitterLanguage::Native(language) => {
194                parser
195                    .set_language(language)
196                    .map_err(|e| TreeSitterError::Language(e.to_string()))?;
197                (language.clone(), None)
198            }
199            TreeSitterLanguage::Wasm {
200                language_id,
201                wasm_bytes,
202            } => {
203                let engine = tree_sitter::wasmtime::Engine::default();
204
205                // Parser store (must use the same engine as the language).
206                let store = tree_sitter::WasmStore::new(&engine)
207                    .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
208                parser
209                    .set_wasm_store(store)
210                    .map_err(|e| TreeSitterError::Language(e.to_string()))?;
211
212                // Load the language (can use a separate store, but must share the same engine).
213                let mut store = tree_sitter::WasmStore::new(&engine)
214                    .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
215                let language = store
216                    .load_language(language_id, wasm_bytes)
217                    .map_err(|e| TreeSitterError::Wasm(e.to_string()))?;
218
219                parser
220                    .set_language(&language)
221                    .map_err(|e| TreeSitterError::Language(e.to_string()))?;
222
223                (language, Some(engine))
224            }
225        };
226
227        let highlight_query = Query::new(&language, &config.highlights_query)
228            .map_err(|e| TreeSitterError::Query(e.to_string()))?;
229        let highlight_capture_styles = highlight_query
230            .capture_names()
231            .iter()
232            .map(|name| config.capture_styles.get(*name).copied())
233            .collect::<Vec<_>>();
234
235        let fold_query = match config.folds_query.as_deref() {
236            Some(q) if !q.trim().is_empty() => {
237                Some(Query::new(&language, q).map_err(|e| TreeSitterError::Query(e.to_string()))?)
238            }
239            _ => None,
240        };
241
242        Ok(Self {
243            config,
244            parser,
245            highlight_query,
246            highlight_capture_styles,
247            fold_query,
248            tree: None,
249            needs_parse: false,
250            text: String::new(),
251            line_index: LineIndex::new(),
252            last_synced_version: None,
253            last_processed_version: None,
254            last_update_mode: TreeSitterUpdateMode::FullReparse,
255            wasm_engine,
256        })
257    }
258
259    /// Get the last update mode (useful for tests and instrumentation).
260    pub fn last_update_mode(&self) -> TreeSitterUpdateMode {
261        self.last_update_mode
262    }
263
264    /// Expand a `(start, end)` selection to the next enclosing syntax node.
265    ///
266    /// Returns `None` if the processor has no parsed tree yet (call `process()`/`sync_to()` first),
267    /// or if the selection is already at the root node.
268    ///
269    /// Notes:
270    /// - Offsets are Unicode scalar indices (Rust `char` offsets), matching editor-core APIs.
271    /// - The returned range is best-effort and is based on Tree-sitter node byte ranges mapped
272    ///   through the processor's internal `LineIndex`.
273    pub fn expand_selection_syntax(&self, start: usize, end: usize) -> Option<(usize, usize)> {
274        let tree = self.tree.as_ref()?;
275        let root = tree.root_node();
276
277        let (sel_start, sel_end) = if start <= end {
278            (start, end)
279        } else {
280            (end, start)
281        };
282        let start_byte = self.line_index.char_offset_to_byte_offset(sel_start);
283        let end_byte = self.line_index.char_offset_to_byte_offset(sel_end);
284
285        let mut node = root.descendant_for_byte_range(start_byte, end_byte)?;
286
287        loop {
288            let node_start = self
289                .line_index
290                .byte_offset_to_char_offset(node.start_byte());
291            let node_end = self.line_index.byte_offset_to_char_offset(node.end_byte());
292
293            // If the selection already matches this node exactly, expand to its parent (if any).
294            if node_start == sel_start && node_end == sel_end {
295                if let Some(parent) = node.parent() {
296                    node = parent;
297                    continue;
298                }
299                return None;
300            }
301
302            return Some((node_start, node_end));
303        }
304    }
305
306    fn sync_from_text_full(&mut self, text: &str) {
307        self.text.clear();
308        self.text.push_str(text);
309        self.line_index = LineIndex::from_text(&self.text);
310    }
311
312    fn point_for_char_offset(&self, char_offset: usize) -> Point {
313        let (row, col) = self.line_index.char_offset_to_line_byte_column(char_offset);
314        Point { row, column: col }
315    }
316
317    fn advance_point(mut point: Point, text: &str) -> Point {
318        let mut parts = text.split('\n');
319        let Some(first) = parts.next() else {
320            return point;
321        };
322
323        point.column = point.column.saturating_add(first.len());
324        for part in parts {
325            point.row = point.row.saturating_add(1);
326            point.column = part.len();
327        }
328
329        point
330    }
331
332    fn apply_text_delta_incremental(&mut self, delta: &TextDelta) -> Result<(), TreeSitterError> {
333        if self.line_index.char_count() != delta.before_char_count {
334            return Err(TreeSitterError::DeltaMismatch);
335        }
336        if self.tree.is_none() {
337            return Err(TreeSitterError::DeltaMismatch);
338        }
339
340        for edit in &delta.edits {
341            let start_char = edit.start;
342            let deleted_chars = edit.deleted_text.chars().count();
343
344            let start_byte = self.line_index.char_offset_to_byte_offset(start_char);
345            let old_end_byte = start_byte.saturating_add(edit.deleted_text.len());
346            let new_end_byte = start_byte.saturating_add(edit.inserted_text.len());
347
348            let Some(old_slice) = self.text.get(start_byte..old_end_byte) else {
349                return Err(TreeSitterError::DeltaMismatch);
350            };
351            if old_slice != edit.deleted_text {
352                return Err(TreeSitterError::DeltaMismatch);
353            }
354
355            let start_position = self.point_for_char_offset(start_char);
356            let old_end_position = Self::advance_point(start_position, &edit.deleted_text);
357            let new_end_position = Self::advance_point(start_position, &edit.inserted_text);
358
359            if let Some(tree) = self.tree.as_mut() {
360                tree.edit(&InputEdit {
361                    start_byte,
362                    old_end_byte,
363                    new_end_byte,
364                    start_position,
365                    old_end_position,
366                    new_end_position,
367                });
368            }
369
370            self.text
371                .replace_range(start_byte..old_end_byte, &edit.inserted_text);
372            self.line_index.delete(start_char, deleted_chars);
373            self.line_index.insert(start_char, &edit.inserted_text);
374        }
375
376        if self.line_index.char_count() != delta.after_char_count {
377            return Err(TreeSitterError::DeltaMismatch);
378        }
379
380        Ok(())
381    }
382
383    fn parse(&mut self) -> Option<Tree> {
384        self.parser.parse(&self.text, self.tree.as_ref())
385    }
386
387    fn collect_highlight_intervals_in_byte_range(
388        &self,
389        tree: &Tree,
390        byte_range: Option<(usize, usize)>,
391    ) -> Vec<Interval> {
392        let mut cursor = QueryCursor::new();
393        if let Some((start, end)) = byte_range {
394            let start = start.min(self.text.len());
395            let end = end.min(self.text.len());
396            if end > start {
397                cursor.set_byte_range(start..end);
398            }
399        }
400        let root = tree.root_node();
401        let mut out = Vec::<Interval>::new();
402
403        let mut matches = cursor.matches(&self.highlight_query, root, self.text.as_bytes());
404        while let Some(m) = matches.next() {
405            for capture in m.captures {
406                let idx = capture.index as usize;
407                let Some(style_id) = self.highlight_capture_styles.get(idx).and_then(|x| *x) else {
408                    continue;
409                };
410
411                let node = capture.node;
412                let start_byte = node.start_byte();
413                let end_byte = node.end_byte();
414                if end_byte <= start_byte {
415                    continue;
416                }
417
418                let start = self.line_index.byte_offset_to_char_offset(start_byte);
419                let end = self.line_index.byte_offset_to_char_offset(end_byte);
420                if end <= start {
421                    continue;
422                }
423
424                out.push(Interval::new(start, end, style_id));
425            }
426        }
427
428        out.sort_by_key(|i| (i.start, i.end, i.style_id));
429        out.dedup_by(|a, b| a.start == b.start && a.end == b.end && a.style_id == b.style_id);
430        out
431    }
432
433    fn collect_fold_regions_in_byte_range(
434        &self,
435        tree: &Tree,
436        byte_range: Option<(usize, usize)>,
437    ) -> Vec<FoldRegion> {
438        let Some(query) = self.fold_query.as_ref() else {
439            return Vec::new();
440        };
441
442        let mut cursor = QueryCursor::new();
443        if let Some((start, end)) = byte_range {
444            let start = start.min(self.text.len());
445            let end = end.min(self.text.len());
446            if end > start {
447                cursor.set_byte_range(start..end);
448            }
449        }
450        let root = tree.root_node();
451        let mut regions = Vec::<FoldRegion>::new();
452
453        let mut matches = cursor.matches(query, root, self.text.as_bytes());
454        while let Some(m) = matches.next() {
455            for capture in m.captures {
456                let node = capture.node;
457                let start_line = node.start_position().row;
458                let end_line = node.end_position().row;
459                if end_line > start_line {
460                    regions.push(FoldRegion::new(start_line, end_line));
461                }
462            }
463        }
464
465        regions.sort_by_key(|r| (r.start_line, r.end_line));
466        regions.dedup_by(|a, b| a.start_line == b.start_line && a.end_line == b.end_line);
467        regions
468    }
469}
470
471impl DocumentProcessor for TreeSitterProcessor {
472    type Error = TreeSitterError;
473
474    fn process(&mut self, state: &EditorStateManager) -> Result<Vec<ProcessingEdit>, Self::Error> {
475        let version = state.version();
476        if self.last_processed_version == Some(version) {
477            self.last_update_mode = TreeSitterUpdateMode::Skipped;
478            return Ok(Vec::new());
479        }
480
481        if self.tree.is_none() {
482            // Initial parse always needs a full sync from the editor.
483            let full = state.editor().get_text();
484            return self.process_text(version, None, Some(&full));
485        }
486
487        if let Some(delta) = state.last_text_delta() {
488            match self.process_text(version, Some(delta), None) {
489                Ok(edits) => Ok(edits),
490                Err(TreeSitterError::DeltaMismatch) => {
491                    // Fall back to a full resync from the current editor text.
492                    let full = state.editor().get_text();
493                    self.process_text(version, Some(delta), Some(&full))
494                }
495                Err(e) => Err(e),
496            }
497        } else {
498            // No structured delta available; re-sync from the full current text.
499            let full = state.editor().get_text();
500            self.process_text(version, None, Some(&full))
501        }
502    }
503}
504
505impl TreeSitterProcessor {
506    /// Synchronize the processor's internal text/tree to the given `version`.
507    ///
508    /// This updates the parse tree (incrementally when possible), but does **not** run any
509    /// Tree-sitter queries. Call [`Self::compute_processing_edits`] afterwards to produce
510    /// `editor-core` derived-state edits (highlighting + folding).
511    ///
512    /// Notes:
513    /// - If `full_text` is `None` and a full resync is required, this returns
514    ///   [`TreeSitterError::DeltaMismatch`].
515    pub fn sync_to(
516        &mut self,
517        version: u64,
518        delta: Option<&TextDelta>,
519        full_text: Option<&str>,
520    ) -> Result<TreeSitterUpdateMode, TreeSitterError> {
521        if self.last_synced_version == Some(version) {
522            self.last_update_mode = TreeSitterUpdateMode::Skipped;
523            return Ok(TreeSitterUpdateMode::Skipped);
524        }
525
526        let update_mode = if self.tree.is_none() {
527            let Some(text) = full_text else {
528                return Err(TreeSitterError::DeltaMismatch);
529            };
530            self.sync_from_text_full(text);
531            self.tree = self.parse();
532            self.needs_parse = false;
533            TreeSitterUpdateMode::Initial
534        } else if let Some(delta) = delta {
535            match self.apply_text_delta_incremental(delta) {
536                Ok(()) => {
537                    // Defer parsing until we actually need to run queries. This allows callers
538                    // to coalesce bursts of edits (debounce) without re-parsing on every keystroke.
539                    self.needs_parse = true;
540                    TreeSitterUpdateMode::Incremental
541                }
542                Err(_) => {
543                    let Some(text) = full_text else {
544                        return Err(TreeSitterError::DeltaMismatch);
545                    };
546                    self.sync_from_text_full(text);
547                    self.tree = self.parser.parse(&self.text, None);
548                    self.needs_parse = false;
549                    TreeSitterUpdateMode::FullReparse
550                }
551            }
552        } else {
553            let Some(text) = full_text else {
554                return Err(TreeSitterError::DeltaMismatch);
555            };
556            self.sync_from_text_full(text);
557            self.tree = self.parser.parse(&self.text, None);
558            self.needs_parse = false;
559            TreeSitterUpdateMode::FullReparse
560        };
561
562        self.last_synced_version = Some(version);
563        self.last_update_mode = update_mode;
564        Ok(update_mode)
565    }
566
567    /// Compute highlighting/folding edits from the current synchronized parse tree.
568    ///
569    /// - `char_range` can be used to limit Tree-sitter query execution to a subset of the
570    ///   document (useful as a performance degradation mode for huge files).
571    ///
572    /// Notes:
573    /// - When `char_range` is specified, the returned style intervals and fold regions will be
574    ///   **partial** (only within the range). Consumers that replace whole layers/regions should
575    ///   treat this as a "best effort visible range" optimization.
576    pub fn compute_processing_edits(
577        &mut self,
578        char_range: Option<(usize, usize)>,
579    ) -> Result<Vec<ProcessingEdit>, TreeSitterError> {
580        let Some(version) = self.last_synced_version else {
581            return Ok(Vec::new());
582        };
583        if self.last_processed_version == Some(version) {
584            return Ok(Vec::new());
585        }
586
587        if self.needs_parse {
588            self.tree = self.parse();
589            self.needs_parse = false;
590        }
591
592        let Some(tree) = self.tree.as_ref() else {
593            self.last_processed_version = Some(version);
594            return Ok(Vec::new());
595        };
596
597        let byte_range = char_range.and_then(|(start_char, end_char)| {
598            let start = self.line_index.char_offset_to_byte_offset(start_char);
599            let end = self.line_index.char_offset_to_byte_offset(end_char);
600            if end > start {
601                Some((start, end))
602            } else {
603                None
604            }
605        });
606
607        let intervals = self.collect_highlight_intervals_in_byte_range(tree, byte_range);
608        let fold_regions = self.collect_fold_regions_in_byte_range(tree, byte_range);
609
610        let mut edits = vec![ProcessingEdit::ReplaceStyleLayer {
611            layer: self.config.style_layer,
612            intervals,
613        }];
614
615        if self.fold_query.is_some() {
616            edits.push(ProcessingEdit::ReplaceFoldingRegions {
617                regions: fold_regions,
618                preserve_collapsed: self.config.preserve_collapsed_folds,
619            });
620        }
621
622        self.last_processed_version = Some(version);
623        Ok(edits)
624    }
625
626    /// Process a document snapshot represented as:
627    /// - a monotonically increasing `version`,
628    /// - an optional `TextDelta` describing the change from the previous version,
629    /// - an optional `full_text` for (re-)synchronization.
630    ///
631    /// This method is useful for running Tree-sitter processing on a background thread where
632    /// a full `EditorStateManager` is not available. Callers can pass `full_text` only when
633    /// performing an initial parse or when a delta mismatch requires a full re-sync.
634    ///
635    /// Notes:
636    /// - If `full_text` is `None` and a full sync is required, this returns
637    ///   [`TreeSitterError::DeltaMismatch`].
638    pub fn process_text(
639        &mut self,
640        version: u64,
641        delta: Option<&TextDelta>,
642        full_text: Option<&str>,
643    ) -> Result<Vec<ProcessingEdit>, TreeSitterError> {
644        // Keep backwards-compatible semantics: `process_text` syncs + queries.
645        let _ = self.sync_to(version, delta, full_text)?;
646        self.compute_processing_edits(None)
647    }
648}