libdiffsitter/
input_processing.rs

1//! Utilities for processing the ASTs provided by `tree_sitter`
2//!
3//! These methods handle preprocessing the input data so it can be fed into the diff engines to
4//! compute diff data.
5
6use logging_timer::time;
7use serde::{Deserialize, Serialize};
8use std::borrow::Cow;
9use std::collections::HashSet;
10use std::hash::{Hash, Hasher};
11use std::ops::{Deref, DerefMut};
12use std::{cell::RefCell, ops::Index, path::PathBuf};
13use tree_sitter::Node as TSNode;
14use tree_sitter::Point;
15use tree_sitter::Tree as TSTree;
16use unicode_segmentation as us;
17
18#[cfg(test)]
19use mockall::{automock, predicate::str};
20
21/// A wrapper trait that exists so we can mock TS nodes.
22#[cfg_attr(test, automock)]
23trait TSNodeTrait {
24    /// Return the kind string that corresponds to a node.
25    fn kind(&self) -> &str;
26}
27
28/// The configuration options for processing tree-sitter output.
29#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
30#[serde(rename_all = "kebab-case", default)]
31pub struct TreeSitterProcessor {
32    /// Whether we should split the nodes graphemes.
33    ///
34    /// If this is disabled, then the direct tree-sitter nodes will be used and diffs will be less
35    /// granular. This has the advantage of being faster and using less memory.
36    pub split_graphemes: bool,
37
38    /// The kinds of nodes to exclude from processing. This takes precedence over `include_kinds`.
39    ///
40    /// This is a set of strings that correspond to the tree sitter node types.
41    pub exclude_kinds: Option<HashSet<String>>,
42
43    /// The kinds of nodes to explicitly include when processing. The nodes specified here will be overridden by the
44    /// nodes specified in `exclude_kinds`.
45    ///
46    /// This is a set of strings that correspond to the tree sitter node types.
47    pub include_kinds: Option<HashSet<String>>,
48
49    /// Whether to strip whitespace when processing node text.
50    ///
51    /// Whitespace includes whitespace characters and newlines. This can provide much more accurate
52    /// diffs that do not account for line breaks. This is useful especially for more text heavy
53    /// documents like markdown files.
54    pub strip_whitespace: bool,
55}
56
57// TODO: if we want to do any string transformations we need to store Cow strings.
58// Most strings won't be modified so it's fine to use a pointer. For the few we do
59// modify we'll need to store the direct string.
60// We should add some abstractions to do input processing.
61
62impl Default for TreeSitterProcessor {
63    fn default() -> Self {
64        Self {
65            split_graphemes: true,
66            exclude_kinds: None,
67            include_kinds: None,
68            strip_whitespace: true,
69        }
70    }
71}
72
73#[derive(Debug)]
74struct TSNodeWrapper<'a>(TSNode<'a>);
75
76impl<'a> TSNodeTrait for TSNodeWrapper<'a> {
77    fn kind(&self) -> &str {
78        self.0.kind()
79    }
80}
81
82impl TreeSitterProcessor {
83    #[time("info", "ast::{}")]
84    pub fn process<'a>(&self, tree: &'a TSTree, text: &'a str) -> Vec<Entry<'a>> {
85        let ast_vector = from_ts_tree(tree, text);
86        let iter = ast_vector
87            .leaves
88            .iter()
89            .filter(|leaf| self.should_include_node(&TSNodeWrapper(leaf.reference)));
90        // Splitting on graphemes generates a vector of entries instead of a direct mapping, which
91        // is why we have the branching here
92        if self.split_graphemes {
93            iter.flat_map(|leaf| leaf.split_on_graphemes(self.strip_whitespace))
94                .collect()
95        } else {
96            iter.map(|&x| self.process_leaf(x)).collect()
97        }
98    }
99
100    /// Process a vector leaf and turn it into an [Entry].
101    ///
102    /// This applies input processing according to the user provided options.
103    fn process_leaf<'a>(&self, leaf: VectorLeaf<'a>) -> Entry<'a> {
104        let new_text = if self.strip_whitespace {
105            // This includes newlines
106            Cow::from(leaf.text.trim())
107        } else {
108            Cow::from(leaf.text)
109        };
110
111        Entry {
112            reference: leaf.reference,
113            text: new_text,
114            start_position: leaf.reference.start_position(),
115            end_position: leaf.reference.start_position(),
116            kind_id: leaf.reference.kind_id(),
117        }
118    }
119
120    /// A helper method to determine whether a node type should be filtered out based on the user's filtering
121    /// preferences.
122    ///
123    /// This method will first check if the node has been specified for exclusion, which takes precedence. Then it will
124    /// check if the node kind is explicitly included. If either the exclusion or inclusion sets aren't specified,
125    /// then the filter will not be applied.
126    fn should_include_node(&self, node: &dyn TSNodeTrait) -> bool {
127        let should_exclude = self
128            .exclude_kinds
129            .as_ref()
130            .is_some_and(|x| x.contains(node.kind()))
131            || self
132                .include_kinds
133                .as_ref()
134                .is_some_and(|x| !x.contains(node.kind()));
135        !should_exclude
136    }
137}
138
139/// Create a `DiffVector` from a `tree_sitter` tree
140///
141/// This method calls a helper function that does an in-order traversal of the tree and adds
142/// leaf nodes to a vector
143#[time("info", "ast::{}")]
144fn from_ts_tree<'a>(tree: &'a TSTree, text: &'a str) -> Vector<'a> {
145    let leaves = RefCell::new(Vec::new());
146    build(&leaves, tree.root_node(), text);
147    Vector {
148        leaves: leaves.into_inner(),
149        source_text: text,
150    }
151}
152
153/// The leaves of an AST vector
154///
155/// This is used as an intermediate struct for flattening the tree structure.
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub struct VectorLeaf<'a> {
158    pub reference: TSNode<'a>,
159    pub text: &'a str,
160}
161
162/// A proxy for (Point)[`tree_sitter::Point`] for [serde].
163///
164/// This is a copy of an external struct that we use with serde so we can create json objects with
165/// serde.
166#[derive(Serialize, Deserialize)]
167#[serde(remote = "Point")]
168struct PointWrapper {
169    pub row: usize,
170    pub column: usize,
171}
172
173/// A mapping between a tree-sitter node and the text it corresponds to
174///
175/// This is also all of the metadata the diff rendering interface has access to, and also defines
176/// the data that will be output by the JSON serializer.
177#[derive(Debug, Clone, Serialize)]
178pub struct Entry<'node> {
179    /// The node an entry in the diff vector refers to
180    ///
181    /// We keep a reference to the leaf node so that we can easily grab the text and other metadata
182    /// surrounding the syntax
183    #[serde(skip_serializing)]
184    pub reference: TSNode<'node>,
185
186    /// A reference to the text the node refers to
187    ///
188    /// This is different from the `source_text` that the [AstVector] refers to, as the
189    /// entry only holds a reference to the specific range of text that the node covers.
190    ///
191    /// We use a [Cow] here instead of a direct string reference because we might have to rewrite
192    /// the text based on input processing settings, but if we don't have to there's no need to
193    /// allocate an extra string.
194    pub text: Cow<'node, str>,
195
196    /// The entry's start position in the document.
197    #[serde(with = "PointWrapper")]
198    pub start_position: Point,
199
200    /// The entry's end position in the document.
201    #[serde(with = "PointWrapper")]
202    pub end_position: Point,
203
204    /// The cached kind_id from the TSNode reference.
205    ///
206    /// Caching it here saves some time because it is queried repeatedly later. If we don't store
207    /// it inline then we have to cross the FFI boundary which incurs some overhead.
208    // PERF: Use cross language LTO to see if LLVM can optimize across the FFI boundary.
209    pub kind_id: u16,
210}
211
212impl<'a> VectorLeaf<'a> {
213    /// Split an entry into a vector of entries per grapheme.
214    ///
215    /// Each grapheme will get its own [Entry] struct. This method will resolve the
216    /// indices/positioning of each grapheme from the `self.text` field.
217    ///
218    /// This effectively maps out the byte position for each node from the unicode text, accounting
219    /// for both newlines and grapheme splits.
220    fn split_on_graphemes(self, strip_whitespace: bool) -> Vec<Entry<'a>> {
221        let mut entries: Vec<Entry<'a>> = Vec::new();
222
223        // We have to split lines because newline characters might be in the text for a tree sitter
224        // node. We try to split up each unicode grapheme and assign them a location in the text
225        // with a row and column, so we need to make sure that we are properly resetting the column
226        // offset for and offsetting the row for each new line in a tree sitter node's text.
227        let lines = self.text.lines();
228
229        for (line_offset, line) in lines.enumerate() {
230            let indices: Vec<(usize, &str)> =
231                us::UnicodeSegmentation::grapheme_indices(line, true).collect();
232            entries.reserve(entries.len() + indices.len());
233
234            for (idx, grapheme) in indices {
235                // Every grapheme has to be at least one byte
236                debug_assert!(!grapheme.is_empty());
237
238                if strip_whitespace && grapheme.chars().all(char::is_whitespace) {
239                    continue;
240                }
241
242                // We simply offset from the start position of the node if we are on the first
243                // line, which implies no newline offset needs to be applied. If the line_offset is
244                // more than 0, we know we've hit a newline so the starting position for the column
245                // is 0, shifted over for the grapheme index.
246                let start_column = if line_offset == 0 {
247                    self.reference.start_position().column + idx
248                } else {
249                    idx
250                };
251                let row = self.reference.start_position().row + line_offset;
252                let new_start_pos = Point {
253                    row,
254                    column: start_column,
255                };
256                let new_end_pos = Point {
257                    row,
258                    column: new_start_pos.column + grapheme.len(),
259                };
260                debug_assert!(new_start_pos.row <= new_end_pos.row);
261                let entry = Entry {
262                    reference: self.reference,
263                    text: Cow::from(&line[idx..idx + grapheme.len()]),
264                    start_position: new_start_pos,
265                    end_position: new_end_pos,
266                    kind_id: self.reference.kind_id(),
267                };
268                // We add the debug assert config here because there's no need to even get a
269                // reference to the last element if we're not in debug mode.
270                #[cfg(debug_assertions)]
271                if let Some(last_entry) = entries.last() {
272                    // Our invariant is that one of the following must hold true:
273                    // 1. The last entry ended on a previous line (now we don't need to check the
274                    //    column offset).
275                    // 2. The last entry is on the same line, so the column offset for the entry we
276                    //    are about to insert must be greater than or equal to the end column of
277                    //    the last entry. It's valid for them to be equal because the end position
278                    //    is not inclusive.
279                    debug_assert!(
280                        last_entry.end_position().row < entry.start_position().row
281                            || (last_entry.end_position.row == entry.start_position().row
282                                && last_entry.end_position.column <= entry.start_position().column)
283                    );
284                }
285                entries.push(entry);
286            }
287        }
288        entries
289    }
290}
291
292impl<'a> From<VectorLeaf<'a>> for Entry<'a> {
293    fn from(leaf: VectorLeaf<'a>) -> Self {
294        Self {
295            reference: leaf.reference,
296            text: Cow::from(leaf.text),
297            start_position: leaf.reference.start_position(),
298            end_position: leaf.reference.start_position(),
299            kind_id: leaf.reference.kind_id(),
300        }
301    }
302}
303
304impl<'a> Entry<'a> {
305    /// Get the start position of an entry
306    #[must_use]
307    pub fn start_position(&self) -> Point {
308        self.start_position
309    }
310
311    /// Get the end position of an entry
312    #[must_use]
313    pub fn end_position(&self) -> Point {
314        self.end_position
315    }
316}
317
318impl<'a> From<&'a Vector<'a>> for Vec<Entry<'a>> {
319    fn from(ast_vector: &'a Vector<'a>) -> Self {
320        ast_vector
321            .leaves
322            .iter()
323            .flat_map(|entry| entry.split_on_graphemes(true))
324            .collect()
325    }
326}
327
328/// A vector that allows for linear traversal through the leafs of an AST.
329///
330/// This representation of the tree leaves is much more convenient for things like dynamic
331/// programming, and provides useful for formatting.
332#[derive(Debug)]
333pub struct Vector<'a> {
334    /// The leaves of the AST, build with an in-order traversal
335    pub leaves: Vec<VectorLeaf<'a>>,
336
337    /// The full source text that the AST refers to
338    pub source_text: &'a str,
339}
340
341impl<'a> Eq for Entry<'a> {}
342
343/// A wrapper struct for AST vector data that owns the data that the AST vector references
344///
345/// Ideally we would just have the AST vector own the actual string and tree, but it makes things
346/// extremely messy with the borrow checker, so we have this wrapper struct that holds the owned
347/// data that the vector references. This gets tricky because the tree sitter library uses FFI so
348/// the lifetime references get even more mangled.
349#[derive(Debug)]
350pub struct VectorData {
351    /// The text in the file
352    pub text: String,
353
354    /// The tree that was parsed using the text
355    pub tree: TSTree,
356
357    /// The file path that the text corresponds to
358    pub path: PathBuf,
359}
360
361impl<'a> Vector<'a> {
362    /// Create a `DiffVector` from a `tree_sitter` tree
363    ///
364    /// This method calls a helper function that does an in-order traversal of the tree and adds
365    /// leaf nodes to a vector
366    #[time("info", "ast::{}")]
367    pub fn from_ts_tree(tree: &'a TSTree, text: &'a str) -> Self {
368        let leaves = RefCell::new(Vec::new());
369        build(&leaves, tree.root_node(), text);
370        Vector {
371            leaves: leaves.into_inner(),
372            source_text: text,
373        }
374    }
375
376    /// Return the number of nodes in the diff vector
377    #[must_use]
378    pub fn len(&self) -> usize {
379        self.leaves.len()
380    }
381
382    /// Return whether there are any leaves in the diff vector.
383    #[must_use]
384    pub fn is_empty(&self) -> bool {
385        self.leaves.is_empty()
386    }
387}
388
389impl<'a> Index<usize> for Vector<'a> {
390    type Output = VectorLeaf<'a>;
391
392    fn index(&self, index: usize) -> &Self::Output {
393        &self.leaves[index]
394    }
395}
396
397impl<'a> Hash for VectorLeaf<'a> {
398    fn hash<H: Hasher>(&self, state: &mut H) {
399        self.reference.kind_id().hash(state);
400        self.text.hash(state);
401    }
402}
403
404impl<'a> PartialEq for Entry<'a> {
405    fn eq(&self, other: &Entry) -> bool {
406        self.kind_id == other.kind_id && self.text == other.text
407    }
408}
409
410impl<'a> PartialEq for Vector<'a> {
411    fn eq(&self, other: &Vector) -> bool {
412        if self.leaves.len() != other.leaves.len() {
413            return false;
414        }
415
416        for i in 0..self.leaves.len() {
417            let leaf = self.leaves[i];
418            let other_leaf = other.leaves[i];
419
420            if leaf != other_leaf {
421                return false;
422            }
423        }
424        true
425    }
426}
427
428/// Recursively build a vector from a given node
429///
430/// This is a helper function that simply walks the tree and collects leaves in an in-order manner.
431/// Every time it encounters a leaf node, it stores the metadata and reference to the node in an
432/// `Entry` struct.
433fn build<'a>(vector: &RefCell<Vec<VectorLeaf<'a>>>, node: tree_sitter::Node<'a>, text: &'a str) {
434    // If the node is a leaf, we can stop traversing
435    if node.child_count() == 0 {
436        // We only push an entry if the referenced text range isn't empty, since there's no point
437        // in having an empty text range. This also fixes a bug where the program would panic
438        // because it would attempt to access the 0th index in an empty text range.
439        if !node.byte_range().is_empty() {
440            let node_text: &'a str = &text[node.byte_range()];
441            // HACK: this is a workaround that was put in place to work around the Go parser which
442            // puts newlines into their own nodes, which later causes errors when trying to print
443            // these nodes. We just ignore those nodes.
444            if node_text
445                .replace("\r\n", "")
446                .replace(['\n', '\r'], "")
447                .is_empty()
448            {
449                return;
450            }
451
452            vector.borrow_mut().push(VectorLeaf {
453                reference: node,
454                text: node_text,
455            });
456        }
457        return;
458    }
459
460    let mut cursor = node.walk();
461
462    for child in node.children(&mut cursor) {
463        build(vector, child, text);
464    }
465}
466
467/// The different types of elements that can be in an edit script
468#[derive(Debug, Eq, PartialEq)]
469pub enum EditType<T> {
470    /// An element that was added in the edit script
471    Addition(T),
472
473    /// An element that was deleted in the edit script
474    Deletion(T),
475}
476
477impl<T> AsRef<T> for EditType<T> {
478    fn as_ref(&self) -> &T {
479        match self {
480            Self::Addition(x) | Self::Deletion(x) => x,
481        }
482    }
483}
484
485impl<T> Deref for EditType<T> {
486    type Target = T;
487
488    fn deref(&self) -> &Self::Target {
489        match self {
490            Self::Addition(x) | Self::Deletion(x) => x,
491        }
492    }
493}
494
495impl<T> DerefMut for EditType<T> {
496    fn deref_mut(&mut self) -> &mut Self::Target {
497        match self {
498            Self::Addition(x) | Self::Deletion(x) => x,
499        }
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use crate::GrammarConfig;
507    use tree_sitter::Parser;
508
509    #[cfg(feature = "static-grammar-libs")]
510    use crate::parse::generate_language;
511
512    #[test]
513    fn test_should_filter_node() {
514        let exclude_kinds: HashSet<String> = HashSet::from(["comment".to_string()]);
515        let mut mock_node = MockTSNodeTrait::new();
516        mock_node.expect_kind().return_const("comment".to_owned());
517
518        // basic scenario - expect that the excluded kind is ignored
519        let processor = TreeSitterProcessor {
520            split_graphemes: false,
521            exclude_kinds: Some(exclude_kinds.clone()),
522            include_kinds: None,
523            ..Default::default()
524        };
525        assert!(!processor.should_include_node(&mock_node));
526
527        // expect that it's still excluded if the included list also has an element that was excluded
528        let processor = TreeSitterProcessor {
529            split_graphemes: false,
530            exclude_kinds: Some(exclude_kinds.clone()),
531            include_kinds: Some(exclude_kinds),
532            ..Default::default()
533        };
534        assert!(!processor.should_include_node(&mock_node));
535
536        // Don't exclude anything, but only include types that our node is not
537        let include_kinds: HashSet<String> = HashSet::from([
538            "some_other_type".to_string(),
539            "yet another type".to_string(),
540        ]);
541        let processor = TreeSitterProcessor {
542            split_graphemes: false,
543            exclude_kinds: None,
544            include_kinds: Some(include_kinds),
545            ..Default::default()
546        };
547        assert!(!processor.should_include_node(&mock_node));
548
549        // include our node type
550        let include_kinds: HashSet<String> = HashSet::from(["comment".to_string()]);
551        let processor = TreeSitterProcessor {
552            split_graphemes: false,
553            exclude_kinds: None,
554            include_kinds: Some(include_kinds),
555            ..Default::default()
556        };
557        assert!(processor.should_include_node(&mock_node));
558
559        // don't filter anything
560        let processor = TreeSitterProcessor {
561            split_graphemes: false,
562            exclude_kinds: None,
563            include_kinds: None,
564            ..Default::default()
565        };
566        assert!(processor.should_include_node(&mock_node));
567    }
568
569    // NOTE: this has to be gated behind the 'static-grammar-libs' cargo feature, otherwise the
570    // crate won't be built with the grammars bundled into the binary which means this won't be
571    // able to load the markdown parser. It's possible that the markdown dynamic library is
572    // available even if we don't compile the grammars statically but there's no guarantees of
573    // which grammars are available dynamically, and we don't enforce that certain grammars have to
574    // be available.
575    #[cfg(feature = "static-grammar-libs")]
576    #[test]
577    fn test_strip_whitespace() {
578        let md_parser = generate_language("python", &GrammarConfig::default()).unwrap();
579        let mut parser = Parser::new();
580        parser.set_language(&md_parser).unwrap();
581        let text_a = "'''# A heading\nThis has no diff.'''";
582        let text_b = "'''# A heading\nThis\nhas\r\nno diff.'''";
583        let tree_a = parser.parse(text_a, None).unwrap();
584        let tree_b = parser.parse(text_b, None).unwrap();
585        {
586            let processor = TreeSitterProcessor {
587                strip_whitespace: true,
588                ..Default::default()
589            };
590            let entries_a = processor.process(&tree_a, text_a);
591            let entries_b = processor.process(&tree_b, text_b);
592            assert_eq!(entries_a, entries_b);
593        }
594        {
595            let processor = TreeSitterProcessor {
596                strip_whitespace: false,
597                ..Default::default()
598            };
599            let entries_a = processor.process(&tree_a, text_a);
600            let entries_b = processor.process(&tree_b, text_b);
601            assert_ne!(entries_a, entries_b);
602        }
603    }
604}