Skip to main content

editor_core_treesitter/
processor.rs

1use editor_core::delta::TextDelta;
2use editor_core::intervals::{FoldRegion, Interval, StyleId, StyleLayerId};
3use editor_core::processing::{DocumentProcessor, ProcessingEdit};
4use editor_core::{EditorStateManager, LineIndex};
5use std::collections::BTreeMap;
6use streaming_iterator::StreamingIterator;
7use tree_sitter::{InputEdit, Parser, Point, Query, QueryCursor, Tree};
8
9/// Errors produced by [`TreeSitterProcessor`].
10#[derive(Debug)]
11pub enum TreeSitterError {
12    /// Setting the Tree-sitter language failed.
13    Language(String),
14    /// Compiling a Tree-sitter query failed.
15    Query(String),
16    /// Internal text synchronization failed (the delta did not match the expected text).
17    DeltaMismatch,
18}
19
20impl std::fmt::Display for TreeSitterError {
21    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22        match self {
23            Self::Language(msg) => write!(f, "tree-sitter language error: {msg}"),
24            Self::Query(msg) => write!(f, "tree-sitter query error: {msg}"),
25            Self::DeltaMismatch => write!(f, "tree-sitter delta mismatch"),
26        }
27    }
28}
29
30impl std::error::Error for TreeSitterError {}
31
32/// How the processor updated its parse tree for the last `process()` call.
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TreeSitterUpdateMode {
35    /// First parse for this processor instance.
36    Initial,
37    /// Updated by applying `TextDelta` edits and re-parsing incrementally.
38    Incremental,
39    /// Fell back to re-syncing from full text and re-parsing.
40    FullReparse,
41    /// No work was performed (the processor already handled this editor version).
42    Skipped,
43}
44
45/// Configuration for [`TreeSitterProcessor`].
46#[derive(Debug, Clone)]
47pub struct TreeSitterProcessorConfig {
48    /// Tree-sitter language.
49    pub language: tree_sitter::Language,
50    /// Syntax highlighting query (`.scm`).
51    pub highlights_query: String,
52    /// Optional folding query (`.scm`). Each capture becomes a fold candidate.
53    pub folds_query: Option<String>,
54    /// Mapping from capture name (e.g. `"comment"`) to an `editor-core` `StyleId`.
55    pub capture_styles: BTreeMap<String, StyleId>,
56    /// Target style layer id to replace.
57    pub style_layer: StyleLayerId,
58    /// Whether to preserve the collapsed state for existing fold regions on replacement.
59    pub preserve_collapsed_folds: bool,
60}
61
62impl TreeSitterProcessorConfig {
63    /// Create a config with a language + highlights query.
64    ///
65    /// By default:
66    /// - `style_layer` is [`StyleLayerId::TREE_SITTER`]
67    /// - `preserve_collapsed_folds` is `true`
68    pub fn new(language: tree_sitter::Language, highlights_query: impl Into<String>) -> Self {
69        Self {
70            language,
71            highlights_query: highlights_query.into(),
72            folds_query: None,
73            capture_styles: BTreeMap::new(),
74            style_layer: StyleLayerId::TREE_SITTER,
75            preserve_collapsed_folds: true,
76        }
77    }
78
79    /// Set a folding query.
80    pub fn with_folds_query(mut self, folds_query: impl Into<String>) -> Self {
81        self.folds_query = Some(folds_query.into());
82        self
83    }
84
85    /// A small fold query that works well for Rust-like curly-brace languages.
86    pub fn with_default_rust_folds(self) -> Self {
87        self.with_folds_query(
88            r#"
89            (function_item) @fold
90            (impl_item) @fold
91            (struct_item) @fold
92            (enum_item) @fold
93            (mod_item) @fold
94            (block) @fold
95            "#,
96        )
97    }
98
99    /// Add a set of capture name → style id mappings.
100    pub fn with_simple_capture_styles<const N: usize>(
101        mut self,
102        styles: [(&'static str, StyleId); N],
103    ) -> Self {
104        for (name, style_id) in styles {
105            self.capture_styles.insert(name.to_string(), style_id);
106        }
107        self
108    }
109
110    /// Control whether fold replacement preserves collapsed state.
111    pub fn set_preserve_collapsed_folds(&mut self, preserve: bool) {
112        self.preserve_collapsed_folds = preserve;
113    }
114}
115
116/// An incremental Tree-sitter based document processor.
117///
118/// This processor tracks a parse tree and updates it based on `TextDelta` edits when available.
119/// It then produces highlighting and folding edits in `editor-core`'s derived-state format.
120pub struct TreeSitterProcessor {
121    config: TreeSitterProcessorConfig,
122    parser: Parser,
123    highlight_query: Query,
124    highlight_capture_styles: Vec<Option<StyleId>>,
125    fold_query: Option<Query>,
126    tree: Option<Tree>,
127    text: String,
128    line_index: LineIndex,
129    last_processed_version: Option<u64>,
130    last_update_mode: TreeSitterUpdateMode,
131}
132
133impl TreeSitterProcessor {
134    /// Create a new processor from the given config.
135    pub fn new(config: TreeSitterProcessorConfig) -> Result<Self, TreeSitterError> {
136        let mut parser = Parser::new();
137        parser
138            .set_language(&config.language)
139            .map_err(|e| TreeSitterError::Language(e.to_string()))?;
140
141        let highlight_query = Query::new(&config.language, &config.highlights_query)
142            .map_err(|e| TreeSitterError::Query(e.to_string()))?;
143        let highlight_capture_styles = highlight_query
144            .capture_names()
145            .iter()
146            .map(|name| config.capture_styles.get(*name).copied())
147            .collect::<Vec<_>>();
148
149        let fold_query = match config.folds_query.as_deref() {
150            Some(q) if !q.trim().is_empty() => Some(
151                Query::new(&config.language, q)
152                    .map_err(|e| TreeSitterError::Query(e.to_string()))?,
153            ),
154            _ => None,
155        };
156
157        Ok(Self {
158            config,
159            parser,
160            highlight_query,
161            highlight_capture_styles,
162            fold_query,
163            tree: None,
164            text: String::new(),
165            line_index: LineIndex::new(),
166            last_processed_version: None,
167            last_update_mode: TreeSitterUpdateMode::FullReparse,
168        })
169    }
170
171    /// Get the last update mode (useful for tests and instrumentation).
172    pub fn last_update_mode(&self) -> TreeSitterUpdateMode {
173        self.last_update_mode
174    }
175
176    fn sync_from_state_full(&mut self, state: &EditorStateManager) {
177        self.text = state.editor().get_text();
178        self.line_index = LineIndex::from_text(&self.text);
179    }
180
181    fn point_for_char_offset(&self, char_offset: usize) -> Point {
182        let (row, col) = self.line_index.char_offset_to_line_byte_column(char_offset);
183        Point { row, column: col }
184    }
185
186    fn advance_point(mut point: Point, text: &str) -> Point {
187        let mut parts = text.split('\n');
188        let Some(first) = parts.next() else {
189            return point;
190        };
191
192        point.column = point.column.saturating_add(first.len());
193        for part in parts {
194            point.row = point.row.saturating_add(1);
195            point.column = part.len();
196        }
197
198        point
199    }
200
201    fn apply_text_delta_incremental(&mut self, delta: &TextDelta) -> Result<(), TreeSitterError> {
202        if self.line_index.char_count() != delta.before_char_count {
203            return Err(TreeSitterError::DeltaMismatch);
204        }
205        if self.tree.is_none() {
206            return Err(TreeSitterError::DeltaMismatch);
207        }
208
209        for edit in &delta.edits {
210            let start_char = edit.start;
211            let deleted_chars = edit.deleted_text.chars().count();
212
213            let start_byte = self.line_index.char_offset_to_byte_offset(start_char);
214            let old_end_byte = start_byte.saturating_add(edit.deleted_text.len());
215            let new_end_byte = start_byte.saturating_add(edit.inserted_text.len());
216
217            let Some(old_slice) = self.text.get(start_byte..old_end_byte) else {
218                return Err(TreeSitterError::DeltaMismatch);
219            };
220            if old_slice != edit.deleted_text {
221                return Err(TreeSitterError::DeltaMismatch);
222            }
223
224            let start_position = self.point_for_char_offset(start_char);
225            let old_end_position = Self::advance_point(start_position, &edit.deleted_text);
226            let new_end_position = Self::advance_point(start_position, &edit.inserted_text);
227
228            if let Some(tree) = self.tree.as_mut() {
229                tree.edit(&InputEdit {
230                    start_byte,
231                    old_end_byte,
232                    new_end_byte,
233                    start_position,
234                    old_end_position,
235                    new_end_position,
236                });
237            }
238
239            self.text
240                .replace_range(start_byte..old_end_byte, &edit.inserted_text);
241            self.line_index.delete(start_char, deleted_chars);
242            self.line_index.insert(start_char, &edit.inserted_text);
243        }
244
245        if self.line_index.char_count() != delta.after_char_count {
246            return Err(TreeSitterError::DeltaMismatch);
247        }
248
249        Ok(())
250    }
251
252    fn parse(&mut self) -> Option<Tree> {
253        self.parser.parse(&self.text, self.tree.as_ref())
254    }
255
256    fn collect_highlight_intervals(&self, tree: &Tree) -> Vec<Interval> {
257        let mut cursor = QueryCursor::new();
258        let root = tree.root_node();
259        let mut out = Vec::<Interval>::new();
260
261        let mut matches = cursor.matches(&self.highlight_query, root, self.text.as_bytes());
262        while let Some(m) = matches.next() {
263            for capture in m.captures {
264                let idx = capture.index as usize;
265                let Some(style_id) = self.highlight_capture_styles.get(idx).and_then(|x| *x) else {
266                    continue;
267                };
268
269                let node = capture.node;
270                let start_byte = node.start_byte();
271                let end_byte = node.end_byte();
272                if end_byte <= start_byte {
273                    continue;
274                }
275
276                let start = self.line_index.byte_offset_to_char_offset(start_byte);
277                let end = self.line_index.byte_offset_to_char_offset(end_byte);
278                if end <= start {
279                    continue;
280                }
281
282                out.push(Interval::new(start, end, style_id));
283            }
284        }
285
286        out.sort_by_key(|i| (i.start, i.end, i.style_id));
287        out.dedup_by(|a, b| a.start == b.start && a.end == b.end && a.style_id == b.style_id);
288        out
289    }
290
291    fn collect_fold_regions(&self, tree: &Tree) -> Vec<FoldRegion> {
292        let Some(query) = self.fold_query.as_ref() else {
293            return Vec::new();
294        };
295
296        let mut cursor = QueryCursor::new();
297        let root = tree.root_node();
298        let mut regions = Vec::<FoldRegion>::new();
299
300        let mut matches = cursor.matches(query, root, self.text.as_bytes());
301        while let Some(m) = matches.next() {
302            for capture in m.captures {
303                let node = capture.node;
304                let start_line = node.start_position().row;
305                let end_line = node.end_position().row;
306                if end_line > start_line {
307                    regions.push(FoldRegion::new(start_line, end_line));
308                }
309            }
310        }
311
312        regions.sort_by_key(|r| (r.start_line, r.end_line));
313        regions.dedup_by(|a, b| a.start_line == b.start_line && a.end_line == b.end_line);
314        regions
315    }
316}
317
318impl DocumentProcessor for TreeSitterProcessor {
319    type Error = TreeSitterError;
320
321    fn process(&mut self, state: &EditorStateManager) -> Result<Vec<ProcessingEdit>, Self::Error> {
322        let version = state.version();
323        if self.last_processed_version == Some(version) {
324            self.last_update_mode = TreeSitterUpdateMode::Skipped;
325            return Ok(Vec::new());
326        }
327
328        let update_mode = if self.tree.is_none() {
329            self.sync_from_state_full(state);
330            self.tree = self.parse();
331            TreeSitterUpdateMode::Initial
332        } else if let Some(delta) = state.last_text_delta() {
333            match self.apply_text_delta_incremental(delta) {
334                Ok(()) => {
335                    self.tree = self.parse();
336                    TreeSitterUpdateMode::Incremental
337                }
338                Err(_) => {
339                    self.sync_from_state_full(state);
340                    self.tree = self.parser.parse(&self.text, None);
341                    TreeSitterUpdateMode::FullReparse
342                }
343            }
344        } else {
345            self.sync_from_state_full(state);
346            self.tree = self.parser.parse(&self.text, None);
347            TreeSitterUpdateMode::FullReparse
348        };
349
350        let Some(tree) = self.tree.as_ref() else {
351            self.last_processed_version = Some(version);
352            self.last_update_mode = update_mode;
353            return Ok(Vec::new());
354        };
355
356        let intervals = self.collect_highlight_intervals(tree);
357        let fold_regions = self.collect_fold_regions(tree);
358
359        let mut edits = vec![ProcessingEdit::ReplaceStyleLayer {
360            layer: self.config.style_layer,
361            intervals,
362        }];
363
364        if self.fold_query.is_some() {
365            edits.push(ProcessingEdit::ReplaceFoldingRegions {
366                regions: fold_regions,
367                preserve_collapsed: self.config.preserve_collapsed_folds,
368            });
369        }
370
371        self.last_processed_version = Some(version);
372        self.last_update_mode = update_mode;
373        Ok(edits)
374    }
375}