1use editor_core::delta::TextDelta;
2use editor_core::intervals::{FoldRegion, Interval, StyleId, StyleLayerId};
3use editor_core::processing::{DocumentProcessor, ProcessingEdit};
4use editor_core::{EditorStateManager, LineIndex};
5use std::collections::BTreeMap;
6use streaming_iterator::StreamingIterator;
7use tree_sitter::{InputEdit, Parser, Point, Query, QueryCursor, Tree};
8
9#[derive(Debug)]
11pub enum TreeSitterError {
12 Language(String),
14 Query(String),
16 DeltaMismatch,
18}
19
20impl std::fmt::Display for TreeSitterError {
21 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
22 match self {
23 Self::Language(msg) => write!(f, "tree-sitter language error: {msg}"),
24 Self::Query(msg) => write!(f, "tree-sitter query error: {msg}"),
25 Self::DeltaMismatch => write!(f, "tree-sitter delta mismatch"),
26 }
27 }
28}
29
30impl std::error::Error for TreeSitterError {}
31
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum TreeSitterUpdateMode {
35 Initial,
37 Incremental,
39 FullReparse,
41 Skipped,
43}
44
45#[derive(Debug, Clone)]
47pub struct TreeSitterProcessorConfig {
48 pub language: tree_sitter::Language,
50 pub highlights_query: String,
52 pub folds_query: Option<String>,
54 pub capture_styles: BTreeMap<String, StyleId>,
56 pub style_layer: StyleLayerId,
58 pub preserve_collapsed_folds: bool,
60}
61
62impl TreeSitterProcessorConfig {
63 pub fn new(language: tree_sitter::Language, highlights_query: impl Into<String>) -> Self {
69 Self {
70 language,
71 highlights_query: highlights_query.into(),
72 folds_query: None,
73 capture_styles: BTreeMap::new(),
74 style_layer: StyleLayerId::TREE_SITTER,
75 preserve_collapsed_folds: true,
76 }
77 }
78
79 pub fn with_folds_query(mut self, folds_query: impl Into<String>) -> Self {
81 self.folds_query = Some(folds_query.into());
82 self
83 }
84
85 pub fn with_default_rust_folds(self) -> Self {
87 self.with_folds_query(
88 r#"
89 (function_item) @fold
90 (impl_item) @fold
91 (struct_item) @fold
92 (enum_item) @fold
93 (mod_item) @fold
94 (block) @fold
95 "#,
96 )
97 }
98
99 pub fn with_simple_capture_styles<const N: usize>(
101 mut self,
102 styles: [(&'static str, StyleId); N],
103 ) -> Self {
104 for (name, style_id) in styles {
105 self.capture_styles.insert(name.to_string(), style_id);
106 }
107 self
108 }
109
110 pub fn set_preserve_collapsed_folds(&mut self, preserve: bool) {
112 self.preserve_collapsed_folds = preserve;
113 }
114}
115
116pub struct TreeSitterProcessor {
121 config: TreeSitterProcessorConfig,
122 parser: Parser,
123 highlight_query: Query,
124 highlight_capture_styles: Vec<Option<StyleId>>,
125 fold_query: Option<Query>,
126 tree: Option<Tree>,
127 text: String,
128 line_index: LineIndex,
129 last_processed_version: Option<u64>,
130 last_update_mode: TreeSitterUpdateMode,
131}
132
133impl TreeSitterProcessor {
134 pub fn new(config: TreeSitterProcessorConfig) -> Result<Self, TreeSitterError> {
136 let mut parser = Parser::new();
137 parser
138 .set_language(&config.language)
139 .map_err(|e| TreeSitterError::Language(e.to_string()))?;
140
141 let highlight_query = Query::new(&config.language, &config.highlights_query)
142 .map_err(|e| TreeSitterError::Query(e.to_string()))?;
143 let highlight_capture_styles = highlight_query
144 .capture_names()
145 .iter()
146 .map(|name| config.capture_styles.get(*name).copied())
147 .collect::<Vec<_>>();
148
149 let fold_query = match config.folds_query.as_deref() {
150 Some(q) if !q.trim().is_empty() => Some(
151 Query::new(&config.language, q)
152 .map_err(|e| TreeSitterError::Query(e.to_string()))?,
153 ),
154 _ => None,
155 };
156
157 Ok(Self {
158 config,
159 parser,
160 highlight_query,
161 highlight_capture_styles,
162 fold_query,
163 tree: None,
164 text: String::new(),
165 line_index: LineIndex::new(),
166 last_processed_version: None,
167 last_update_mode: TreeSitterUpdateMode::FullReparse,
168 })
169 }
170
171 pub fn last_update_mode(&self) -> TreeSitterUpdateMode {
173 self.last_update_mode
174 }
175
176 fn sync_from_state_full(&mut self, state: &EditorStateManager) {
177 self.text = state.editor().get_text();
178 self.line_index = LineIndex::from_text(&self.text);
179 }
180
181 fn point_for_char_offset(&self, char_offset: usize) -> Point {
182 let (row, col) = self.line_index.char_offset_to_line_byte_column(char_offset);
183 Point { row, column: col }
184 }
185
186 fn advance_point(mut point: Point, text: &str) -> Point {
187 let mut parts = text.split('\n');
188 let Some(first) = parts.next() else {
189 return point;
190 };
191
192 point.column = point.column.saturating_add(first.len());
193 for part in parts {
194 point.row = point.row.saturating_add(1);
195 point.column = part.len();
196 }
197
198 point
199 }
200
201 fn apply_text_delta_incremental(&mut self, delta: &TextDelta) -> Result<(), TreeSitterError> {
202 if self.line_index.char_count() != delta.before_char_count {
203 return Err(TreeSitterError::DeltaMismatch);
204 }
205 if self.tree.is_none() {
206 return Err(TreeSitterError::DeltaMismatch);
207 }
208
209 for edit in &delta.edits {
210 let start_char = edit.start;
211 let deleted_chars = edit.deleted_text.chars().count();
212
213 let start_byte = self.line_index.char_offset_to_byte_offset(start_char);
214 let old_end_byte = start_byte.saturating_add(edit.deleted_text.len());
215 let new_end_byte = start_byte.saturating_add(edit.inserted_text.len());
216
217 let Some(old_slice) = self.text.get(start_byte..old_end_byte) else {
218 return Err(TreeSitterError::DeltaMismatch);
219 };
220 if old_slice != edit.deleted_text {
221 return Err(TreeSitterError::DeltaMismatch);
222 }
223
224 let start_position = self.point_for_char_offset(start_char);
225 let old_end_position = Self::advance_point(start_position, &edit.deleted_text);
226 let new_end_position = Self::advance_point(start_position, &edit.inserted_text);
227
228 if let Some(tree) = self.tree.as_mut() {
229 tree.edit(&InputEdit {
230 start_byte,
231 old_end_byte,
232 new_end_byte,
233 start_position,
234 old_end_position,
235 new_end_position,
236 });
237 }
238
239 self.text
240 .replace_range(start_byte..old_end_byte, &edit.inserted_text);
241 self.line_index.delete(start_char, deleted_chars);
242 self.line_index.insert(start_char, &edit.inserted_text);
243 }
244
245 if self.line_index.char_count() != delta.after_char_count {
246 return Err(TreeSitterError::DeltaMismatch);
247 }
248
249 Ok(())
250 }
251
252 fn parse(&mut self) -> Option<Tree> {
253 self.parser.parse(&self.text, self.tree.as_ref())
254 }
255
256 fn collect_highlight_intervals(&self, tree: &Tree) -> Vec<Interval> {
257 let mut cursor = QueryCursor::new();
258 let root = tree.root_node();
259 let mut out = Vec::<Interval>::new();
260
261 let mut matches = cursor.matches(&self.highlight_query, root, self.text.as_bytes());
262 while let Some(m) = matches.next() {
263 for capture in m.captures {
264 let idx = capture.index as usize;
265 let Some(style_id) = self.highlight_capture_styles.get(idx).and_then(|x| *x) else {
266 continue;
267 };
268
269 let node = capture.node;
270 let start_byte = node.start_byte();
271 let end_byte = node.end_byte();
272 if end_byte <= start_byte {
273 continue;
274 }
275
276 let start = self.line_index.byte_offset_to_char_offset(start_byte);
277 let end = self.line_index.byte_offset_to_char_offset(end_byte);
278 if end <= start {
279 continue;
280 }
281
282 out.push(Interval::new(start, end, style_id));
283 }
284 }
285
286 out.sort_by_key(|i| (i.start, i.end, i.style_id));
287 out.dedup_by(|a, b| a.start == b.start && a.end == b.end && a.style_id == b.style_id);
288 out
289 }
290
291 fn collect_fold_regions(&self, tree: &Tree) -> Vec<FoldRegion> {
292 let Some(query) = self.fold_query.as_ref() else {
293 return Vec::new();
294 };
295
296 let mut cursor = QueryCursor::new();
297 let root = tree.root_node();
298 let mut regions = Vec::<FoldRegion>::new();
299
300 let mut matches = cursor.matches(query, root, self.text.as_bytes());
301 while let Some(m) = matches.next() {
302 for capture in m.captures {
303 let node = capture.node;
304 let start_line = node.start_position().row;
305 let end_line = node.end_position().row;
306 if end_line > start_line {
307 regions.push(FoldRegion::new(start_line, end_line));
308 }
309 }
310 }
311
312 regions.sort_by_key(|r| (r.start_line, r.end_line));
313 regions.dedup_by(|a, b| a.start_line == b.start_line && a.end_line == b.end_line);
314 regions
315 }
316}
317
318impl DocumentProcessor for TreeSitterProcessor {
319 type Error = TreeSitterError;
320
321 fn process(&mut self, state: &EditorStateManager) -> Result<Vec<ProcessingEdit>, Self::Error> {
322 let version = state.version();
323 if self.last_processed_version == Some(version) {
324 self.last_update_mode = TreeSitterUpdateMode::Skipped;
325 return Ok(Vec::new());
326 }
327
328 let update_mode = if self.tree.is_none() {
329 self.sync_from_state_full(state);
330 self.tree = self.parse();
331 TreeSitterUpdateMode::Initial
332 } else if let Some(delta) = state.last_text_delta() {
333 match self.apply_text_delta_incremental(delta) {
334 Ok(()) => {
335 self.tree = self.parse();
336 TreeSitterUpdateMode::Incremental
337 }
338 Err(_) => {
339 self.sync_from_state_full(state);
340 self.tree = self.parser.parse(&self.text, None);
341 TreeSitterUpdateMode::FullReparse
342 }
343 }
344 } else {
345 self.sync_from_state_full(state);
346 self.tree = self.parser.parse(&self.text, None);
347 TreeSitterUpdateMode::FullReparse
348 };
349
350 let Some(tree) = self.tree.as_ref() else {
351 self.last_processed_version = Some(version);
352 self.last_update_mode = update_mode;
353 return Ok(Vec::new());
354 };
355
356 let intervals = self.collect_highlight_intervals(tree);
357 let fold_regions = self.collect_fold_regions(tree);
358
359 let mut edits = vec![ProcessingEdit::ReplaceStyleLayer {
360 layer: self.config.style_layer,
361 intervals,
362 }];
363
364 if self.fold_query.is_some() {
365 edits.push(ProcessingEdit::ReplaceFoldingRegions {
366 regions: fold_regions,
367 preserve_collapsed: self.config.preserve_collapsed_folds,
368 });
369 }
370
371 self.last_processed_version = Some(version);
372 self.last_update_mode = update_mode;
373 Ok(edits)
374 }
375}