libdiffsitter/input_processing.rs
1//! Utilities for processing the ASTs provided by `tree_sitter`
2//!
3//! These methods handle preprocessing the input data so it can be fed into the diff engines to
4//! compute diff data.
5
6use logging_timer::time;
7use serde::{Deserialize, Serialize};
8use std::borrow::Cow;
9use std::collections::HashSet;
10use std::hash::{Hash, Hasher};
11use std::ops::{Deref, DerefMut};
12use std::{cell::RefCell, ops::Index, path::PathBuf};
13use tree_sitter::Node as TSNode;
14use tree_sitter::Point;
15use tree_sitter::Tree as TSTree;
16use unicode_segmentation as us;
17
18#[cfg(test)]
19use mockall::{automock, predicate::str};
20
21/// A wrapper trait that exists so we can mock TS nodes.
22#[cfg_attr(test, automock)]
23trait TSNodeTrait {
24 /// Return the kind string that corresponds to a node.
25 fn kind(&self) -> &str;
26}
27
28/// The configuration options for processing tree-sitter output.
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(rename_all = "kebab-case", default)]
31pub struct TreeSitterProcessor {
32 /// Whether we should split the nodes graphemes.
33 ///
34 /// If this is disabled, then the direct tree-sitter nodes will be used and diffs will be less
35 /// granular. This has the advantage of being faster and using less memory.
36 pub split_graphemes: bool,
37
38 /// The kinds of nodes to exclude from processing. This takes precedence over `include_kinds`.
39 ///
40 /// This is a set of strings that correspond to the tree sitter node types.
41 pub exclude_kinds: Option<HashSet<String>>,
42
43 /// The kinds of nodes to explicitly include when processing. The nodes specified here will be overridden by the
44 /// nodes specified in `exclude_kinds`.
45 ///
46 /// This is a set of strings that correspond to the tree sitter node types.
47 pub include_kinds: Option<HashSet<String>>,
48
49 /// Whether to strip whitespace when processing node text.
50 ///
51 /// Whitespace includes whitespace characters and newlines. This can provide much more accurate
52 /// diffs that do not account for line breaks. This is useful especially for more text heavy
53 /// documents like markdown files.
54 pub strip_whitespace: bool,
55}
56
57// TODO: if we want to do any string transformations we need to store Cow strings.
58// Most strings won't be modified so it's fine to use a pointer. For the few we do
59// modify we'll need to store the direct string.
60// We should add some abstractions to do input processing.
61
62impl Default for TreeSitterProcessor {
63 fn default() -> Self {
64 Self {
65 split_graphemes: true,
66 exclude_kinds: None,
67 include_kinds: None,
68 strip_whitespace: true,
69 }
70 }
71}
72
73#[derive(Debug)]
74struct TSNodeWrapper<'a>(TSNode<'a>);
75
76impl<'a> TSNodeTrait for TSNodeWrapper<'a> {
77 fn kind(&self) -> &str {
78 self.0.kind()
79 }
80}
81
82impl TreeSitterProcessor {
83 #[time("info", "ast::{}")]
84 pub fn process<'a>(&self, tree: &'a TSTree, text: &'a str) -> Vec<Entry<'a>> {
85 let ast_vector = from_ts_tree(tree, text);
86 let iter = ast_vector
87 .leaves
88 .iter()
89 .filter(|leaf| self.should_include_node(&TSNodeWrapper(leaf.reference)));
90 // Splitting on graphemes generates a vector of entries instead of a direct mapping, which
91 // is why we have the branching here
92 if self.split_graphemes {
93 iter.flat_map(|leaf| leaf.split_on_graphemes(self.strip_whitespace))
94 .collect()
95 } else {
96 iter.map(|&x| self.process_leaf(x)).collect()
97 }
98 }
99
100 /// Process a vector leaf and turn it into an [Entry].
101 ///
102 /// This applies input processing according to the user provided options.
103 fn process_leaf<'a>(&self, leaf: VectorLeaf<'a>) -> Entry<'a> {
104 let new_text = if self.strip_whitespace {
105 // This includes newlines
106 Cow::from(leaf.text.trim())
107 } else {
108 Cow::from(leaf.text)
109 };
110
111 Entry {
112 reference: leaf.reference,
113 text: new_text,
114 start_position: leaf.reference.start_position(),
115 end_position: leaf.reference.start_position(),
116 kind_id: leaf.reference.kind_id(),
117 }
118 }
119
120 /// A helper method to determine whether a node type should be filtered out based on the user's filtering
121 /// preferences.
122 ///
123 /// This method will first check if the node has been specified for exclusion, which takes precedence. Then it will
124 /// check if the node kind is explicitly included. If either the exclusion or inclusion sets aren't specified,
125 /// then the filter will not be applied.
126 fn should_include_node(&self, node: &dyn TSNodeTrait) -> bool {
127 let should_exclude = self
128 .exclude_kinds
129 .as_ref()
130 .is_some_and(|x| x.contains(node.kind()))
131 || self
132 .include_kinds
133 .as_ref()
134 .is_some_and(|x| !x.contains(node.kind()));
135 !should_exclude
136 }
137}
138
139/// Create a `DiffVector` from a `tree_sitter` tree
140///
141/// This method calls a helper function that does an in-order traversal of the tree and adds
142/// leaf nodes to a vector
143#[time("info", "ast::{}")]
144fn from_ts_tree<'a>(tree: &'a TSTree, text: &'a str) -> Vector<'a> {
145 let leaves = RefCell::new(Vec::new());
146 build(&leaves, tree.root_node(), text);
147 Vector {
148 leaves: leaves.into_inner(),
149 source_text: text,
150 }
151}
152
153/// The leaves of an AST vector
154///
155/// This is used as an intermediate struct for flattening the tree structure.
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub struct VectorLeaf<'a> {
158 pub reference: TSNode<'a>,
159 pub text: &'a str,
160}
161
162/// A proxy for (Point)[`tree_sitter::Point`] for [serde].
163///
164/// This is a copy of an external struct that we use with serde so we can create json objects with
165/// serde.
166#[derive(Serialize, Deserialize)]
167#[serde(remote = "Point")]
168struct PointWrapper {
169 pub row: usize,
170 pub column: usize,
171}
172
173/// A mapping between a tree-sitter node and the text it corresponds to
174///
175/// This is also all of the metadata the diff rendering interface has access to, and also defines
176/// the data that will be output by the JSON serializer.
177#[derive(Debug, Clone, Serialize)]
178pub struct Entry<'node> {
179 /// The node an entry in the diff vector refers to
180 ///
181 /// We keep a reference to the leaf node so that we can easily grab the text and other metadata
182 /// surrounding the syntax
183 #[serde(skip_serializing)]
184 pub reference: TSNode<'node>,
185
186 /// A reference to the text the node refers to
187 ///
188 /// This is different from the `source_text` that the [AstVector] refers to, as the
189 /// entry only holds a reference to the specific range of text that the node covers.
190 ///
191 /// We use a [Cow] here instead of a direct string reference because we might have to rewrite
192 /// the text based on input processing settings, but if we don't have to there's no need to
193 /// allocate an extra string.
194 pub text: Cow<'node, str>,
195
196 /// The entry's start position in the document.
197 #[serde(with = "PointWrapper")]
198 pub start_position: Point,
199
200 /// The entry's end position in the document.
201 #[serde(with = "PointWrapper")]
202 pub end_position: Point,
203
204 /// The cached kind_id from the TSNode reference.
205 ///
206 /// Caching it here saves some time because it is queried repeatedly later. If we don't store
207 /// it inline then we have to cross the FFI boundary which incurs some overhead.
208 // PERF: Use cross language LTO to see if LLVM can optimize across the FFI boundary.
209 pub kind_id: u16,
210}
211
212impl<'a> VectorLeaf<'a> {
213 /// Split an entry into a vector of entries per grapheme.
214 ///
215 /// Each grapheme will get its own [Entry] struct. This method will resolve the
216 /// indices/positioning of each grapheme from the `self.text` field.
217 ///
218 /// This effectively maps out the byte position for each node from the unicode text, accounting
219 /// for both newlines and grapheme splits.
220 fn split_on_graphemes(self, strip_whitespace: bool) -> Vec<Entry<'a>> {
221 let mut entries: Vec<Entry<'a>> = Vec::new();
222
223 // We have to split lines because newline characters might be in the text for a tree sitter
224 // node. We try to split up each unicode grapheme and assign them a location in the text
225 // with a row and column, so we need to make sure that we are properly resetting the column
226 // offset for and offsetting the row for each new line in a tree sitter node's text.
227 let lines = self.text.lines();
228
229 for (line_offset, line) in lines.enumerate() {
230 let indices: Vec<(usize, &str)> =
231 us::UnicodeSegmentation::grapheme_indices(line, true).collect();
232 entries.reserve(entries.len() + indices.len());
233
234 for (idx, grapheme) in indices {
235 // Every grapheme has to be at least one byte
236 debug_assert!(!grapheme.is_empty());
237
238 if strip_whitespace && grapheme.chars().all(char::is_whitespace) {
239 continue;
240 }
241
242 // We simply offset from the start position of the node if we are on the first
243 // line, which implies no newline offset needs to be applied. If the line_offset is
244 // more than 0, we know we've hit a newline so the starting position for the column
245 // is 0, shifted over for the grapheme index.
246 let start_column = if line_offset == 0 {
247 self.reference.start_position().column + idx
248 } else {
249 idx
250 };
251 let row = self.reference.start_position().row + line_offset;
252 let new_start_pos = Point {
253 row,
254 column: start_column,
255 };
256 let new_end_pos = Point {
257 row,
258 column: new_start_pos.column + grapheme.len(),
259 };
260 debug_assert!(new_start_pos.row <= new_end_pos.row);
261 let entry = Entry {
262 reference: self.reference,
263 text: Cow::from(&line[idx..idx + grapheme.len()]),
264 start_position: new_start_pos,
265 end_position: new_end_pos,
266 kind_id: self.reference.kind_id(),
267 };
268 // We add the debug assert config here because there's no need to even get a
269 // reference to the last element if we're not in debug mode.
270 #[cfg(debug_assertions)]
271 if let Some(last_entry) = entries.last() {
272 // Our invariant is that one of the following must hold true:
273 // 1. The last entry ended on a previous line (now we don't need to check the
274 // column offset).
275 // 2. The last entry is on the same line, so the column offset for the entry we
276 // are about to insert must be greater than or equal to the end column of
277 // the last entry. It's valid for them to be equal because the end position
278 // is not inclusive.
279 debug_assert!(
280 last_entry.end_position().row < entry.start_position().row
281 || (last_entry.end_position.row == entry.start_position().row
282 && last_entry.end_position.column <= entry.start_position().column)
283 );
284 }
285 entries.push(entry);
286 }
287 }
288 entries
289 }
290}
291
292impl<'a> From<VectorLeaf<'a>> for Entry<'a> {
293 fn from(leaf: VectorLeaf<'a>) -> Self {
294 Self {
295 reference: leaf.reference,
296 text: Cow::from(leaf.text),
297 start_position: leaf.reference.start_position(),
298 end_position: leaf.reference.start_position(),
299 kind_id: leaf.reference.kind_id(),
300 }
301 }
302}
303
304impl<'a> Entry<'a> {
305 /// Get the start position of an entry
306 #[must_use]
307 pub fn start_position(&self) -> Point {
308 self.start_position
309 }
310
311 /// Get the end position of an entry
312 #[must_use]
313 pub fn end_position(&self) -> Point {
314 self.end_position
315 }
316}
317
318impl<'a> From<&'a Vector<'a>> for Vec<Entry<'a>> {
319 fn from(ast_vector: &'a Vector<'a>) -> Self {
320 ast_vector
321 .leaves
322 .iter()
323 .flat_map(|entry| entry.split_on_graphemes(true))
324 .collect()
325 }
326}
327
328/// A vector that allows for linear traversal through the leafs of an AST.
329///
330/// This representation of the tree leaves is much more convenient for things like dynamic
331/// programming, and provides useful for formatting.
332#[derive(Debug)]
333pub struct Vector<'a> {
334 /// The leaves of the AST, build with an in-order traversal
335 pub leaves: Vec<VectorLeaf<'a>>,
336
337 /// The full source text that the AST refers to
338 pub source_text: &'a str,
339}
340
341impl<'a> Eq for Entry<'a> {}
342
343/// A wrapper struct for AST vector data that owns the data that the AST vector references
344///
345/// Ideally we would just have the AST vector own the actual string and tree, but it makes things
346/// extremely messy with the borrow checker, so we have this wrapper struct that holds the owned
347/// data that the vector references. This gets tricky because the tree sitter library uses FFI so
348/// the lifetime references get even more mangled.
349#[derive(Debug)]
350pub struct VectorData {
351 /// The text in the file
352 pub text: String,
353
354 /// The tree that was parsed using the text
355 pub tree: TSTree,
356
357 /// The file path that the text corresponds to
358 pub path: PathBuf,
359}
360
361impl<'a> Vector<'a> {
362 /// Create a `DiffVector` from a `tree_sitter` tree
363 ///
364 /// This method calls a helper function that does an in-order traversal of the tree and adds
365 /// leaf nodes to a vector
366 #[time("info", "ast::{}")]
367 pub fn from_ts_tree(tree: &'a TSTree, text: &'a str) -> Self {
368 let leaves = RefCell::new(Vec::new());
369 build(&leaves, tree.root_node(), text);
370 Vector {
371 leaves: leaves.into_inner(),
372 source_text: text,
373 }
374 }
375
376 /// Return the number of nodes in the diff vector
377 #[must_use]
378 pub fn len(&self) -> usize {
379 self.leaves.len()
380 }
381
382 /// Return whether there are any leaves in the diff vector.
383 #[must_use]
384 pub fn is_empty(&self) -> bool {
385 self.leaves.is_empty()
386 }
387}
388
389impl<'a> Index<usize> for Vector<'a> {
390 type Output = VectorLeaf<'a>;
391
392 fn index(&self, index: usize) -> &Self::Output {
393 &self.leaves[index]
394 }
395}
396
397impl<'a> Hash for VectorLeaf<'a> {
398 fn hash<H: Hasher>(&self, state: &mut H) {
399 self.reference.kind_id().hash(state);
400 self.text.hash(state);
401 }
402}
403
404impl<'a> PartialEq for Entry<'a> {
405 fn eq(&self, other: &Entry) -> bool {
406 self.kind_id == other.kind_id && self.text == other.text
407 }
408}
409
410impl<'a> PartialEq for Vector<'a> {
411 fn eq(&self, other: &Vector) -> bool {
412 if self.leaves.len() != other.leaves.len() {
413 return false;
414 }
415
416 for i in 0..self.leaves.len() {
417 let leaf = self.leaves[i];
418 let other_leaf = other.leaves[i];
419
420 if leaf != other_leaf {
421 return false;
422 }
423 }
424 true
425 }
426}
427
428/// Recursively build a vector from a given node
429///
430/// This is a helper function that simply walks the tree and collects leaves in an in-order manner.
431/// Every time it encounters a leaf node, it stores the metadata and reference to the node in an
432/// `Entry` struct.
433fn build<'a>(vector: &RefCell<Vec<VectorLeaf<'a>>>, node: tree_sitter::Node<'a>, text: &'a str) {
434 // If the node is a leaf, we can stop traversing
435 if node.child_count() == 0 {
436 // We only push an entry if the referenced text range isn't empty, since there's no point
437 // in having an empty text range. This also fixes a bug where the program would panic
438 // because it would attempt to access the 0th index in an empty text range.
439 if !node.byte_range().is_empty() {
440 let node_text: &'a str = &text[node.byte_range()];
441 // HACK: this is a workaround that was put in place to work around the Go parser which
442 // puts newlines into their own nodes, which later causes errors when trying to print
443 // these nodes. We just ignore those nodes.
444 if node_text
445 .replace("\r\n", "")
446 .replace(['\n', '\r'], "")
447 .is_empty()
448 {
449 return;
450 }
451
452 vector.borrow_mut().push(VectorLeaf {
453 reference: node,
454 text: node_text,
455 });
456 }
457 return;
458 }
459
460 let mut cursor = node.walk();
461
462 for child in node.children(&mut cursor) {
463 build(vector, child, text);
464 }
465}
466
467/// The different types of elements that can be in an edit script
468#[derive(Debug, Eq, PartialEq)]
469pub enum EditType<T> {
470 /// An element that was added in the edit script
471 Addition(T),
472
473 /// An element that was deleted in the edit script
474 Deletion(T),
475}
476
477impl<T> AsRef<T> for EditType<T> {
478 fn as_ref(&self) -> &T {
479 match self {
480 Self::Addition(x) | Self::Deletion(x) => x,
481 }
482 }
483}
484
485impl<T> Deref for EditType<T> {
486 type Target = T;
487
488 fn deref(&self) -> &Self::Target {
489 match self {
490 Self::Addition(x) | Self::Deletion(x) => x,
491 }
492 }
493}
494
495impl<T> DerefMut for EditType<T> {
496 fn deref_mut(&mut self) -> &mut Self::Target {
497 match self {
498 Self::Addition(x) | Self::Deletion(x) => x,
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::GrammarConfig;
507 use tree_sitter::Parser;
508
509 #[cfg(feature = "static-grammar-libs")]
510 use crate::parse::generate_language;
511
512 #[test]
513 fn test_should_filter_node() {
514 let exclude_kinds: HashSet<String> = HashSet::from(["comment".to_string()]);
515 let mut mock_node = MockTSNodeTrait::new();
516 mock_node.expect_kind().return_const("comment".to_owned());
517
518 // basic scenario - expect that the excluded kind is ignored
519 let processor = TreeSitterProcessor {
520 split_graphemes: false,
521 exclude_kinds: Some(exclude_kinds.clone()),
522 include_kinds: None,
523 ..Default::default()
524 };
525 assert!(!processor.should_include_node(&mock_node));
526
527 // expect that it's still excluded if the included list also has an element that was excluded
528 let processor = TreeSitterProcessor {
529 split_graphemes: false,
530 exclude_kinds: Some(exclude_kinds.clone()),
531 include_kinds: Some(exclude_kinds),
532 ..Default::default()
533 };
534 assert!(!processor.should_include_node(&mock_node));
535
536 // Don't exclude anything, but only include types that our node is not
537 let include_kinds: HashSet<String> = HashSet::from([
538 "some_other_type".to_string(),
539 "yet another type".to_string(),
540 ]);
541 let processor = TreeSitterProcessor {
542 split_graphemes: false,
543 exclude_kinds: None,
544 include_kinds: Some(include_kinds),
545 ..Default::default()
546 };
547 assert!(!processor.should_include_node(&mock_node));
548
549 // include our node type
550 let include_kinds: HashSet<String> = HashSet::from(["comment".to_string()]);
551 let processor = TreeSitterProcessor {
552 split_graphemes: false,
553 exclude_kinds: None,
554 include_kinds: Some(include_kinds),
555 ..Default::default()
556 };
557 assert!(processor.should_include_node(&mock_node));
558
559 // don't filter anything
560 let processor = TreeSitterProcessor {
561 split_graphemes: false,
562 exclude_kinds: None,
563 include_kinds: None,
564 ..Default::default()
565 };
566 assert!(processor.should_include_node(&mock_node));
567 }
568
569 // NOTE: this has to be gated behind the 'static-grammar-libs' cargo feature, otherwise the
570 // crate won't be built with the grammars bundled into the binary which means this won't be
571 // able to load the markdown parser. It's possible that the markdown dynamic library is
572 // available even if we don't compile the grammars statically but there's no guarantees of
573 // which grammars are available dynamically, and we don't enforce that certain grammars have to
574 // be available.
575 #[cfg(feature = "static-grammar-libs")]
576 #[test]
577 fn test_strip_whitespace() {
578 let md_parser = generate_language("python", &GrammarConfig::default()).unwrap();
579 let mut parser = Parser::new();
580 parser.set_language(&md_parser).unwrap();
581 let text_a = "'''# A heading\nThis has no diff.'''";
582 let text_b = "'''# A heading\nThis\nhas\r\nno diff.'''";
583 let tree_a = parser.parse(text_a, None).unwrap();
584 let tree_b = parser.parse(text_b, None).unwrap();
585 {
586 let processor = TreeSitterProcessor {
587 strip_whitespace: true,
588 ..Default::default()
589 };
590 let entries_a = processor.process(&tree_a, text_a);
591 let entries_b = processor.process(&tree_b, text_b);
592 assert_eq!(entries_a, entries_b);
593 }
594 {
595 let processor = TreeSitterProcessor {
596 strip_whitespace: false,
597 ..Default::default()
598 };
599 let entries_a = processor.process(&tree_a, text_a);
600 let entries_b = processor.process(&tree_b, text_b);
601 assert_ne!(entries_a, entries_b);
602 }
603 }
604}