text_splitter/splitter/
code.rs

1use std::{cmp::Ordering, ops::Range};
2
3use thiserror::Error;
4use tree_sitter::{Language, LanguageError, Parser, TreeCursor, MIN_COMPATIBLE_LANGUAGE_VERSION};
5
6use crate::{
7    splitter::{SemanticLevel, Splitter},
8    trim::Trim,
9    ChunkConfig, ChunkSizer,
10};
11
12/// Indicates there was an error with creating a `CodeSplitter`.
13/// The `Display` implementation will provide a human-readable error message to
14/// help debug the issue that caused the error.
15#[derive(Error, Debug)]
16#[error(transparent)]
17#[allow(clippy::module_name_repetitions)]
18pub struct CodeSplitterError(#[from] CodeSplitterErrorRepr);
19
20/// Private error and free to change across minor version of the crate.
21#[derive(Error, Debug)]
22enum CodeSplitterErrorRepr {
23    #[error(
24        "Language version {0:?} is too old. Expected at least version {min_version}",
25        min_version=MIN_COMPATIBLE_LANGUAGE_VERSION,
26    )]
27    LanguageError(LanguageError),
28}
29
30/// Source code splitter. Recursively splits chunks into the largest
31/// semantic units that fit within the chunk size. Also will attempt to merge
32/// neighboring chunks if they can fit within the given chunk size.
33#[derive(Debug)]
34#[allow(clippy::module_name_repetitions)]
35pub struct CodeSplitter<Sizer>
36where
37    Sizer: ChunkSizer,
38{
39    /// Method of determining chunk sizes.
40    chunk_config: ChunkConfig<Sizer>,
41    /// Language to use for parsing the code.
42    language: Language,
43}
44
45impl<Sizer> CodeSplitter<Sizer>
46where
47    Sizer: ChunkSizer,
48{
49    /// Creates a new [`CodeSplitter`].
50    ///
51    /// ```
52    /// use text_splitter::CodeSplitter;
53    ///
54    /// // By default, the chunk sizer is based on characters.
55    /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 512).expect("Invalid language");
56    /// ```
57    ///
58    /// # Errors
59    ///
60    /// Will return an error if the language version is too old to be compatible
61    /// with the current version of the tree-sitter crate.
62    pub fn new(
63        language: impl Into<Language>,
64        chunk_config: impl Into<ChunkConfig<Sizer>>,
65    ) -> Result<Self, CodeSplitterError> {
66        // Verify that this is a valid language so we can rely on that later.
67        let mut parser = Parser::new();
68        let language = language.into();
69        parser
70            .set_language(&language)
71            .map_err(CodeSplitterErrorRepr::LanguageError)?;
72        Ok(Self {
73            chunk_config: chunk_config.into(),
74            language,
75        })
76    }
77
78    /// Generate a list of chunks from a given text. Each chunk will be up to the `chunk_capacity`.
79    ///
80    /// ## Method
81    ///
82    /// To preserve as much semantic meaning within a chunk as possible, each chunk is composed of the largest semantic units that can fit in the next given chunk. For each splitter type, there is a defined set of semantic levels. Here is an example of the steps used:
83    //
84    // 1. Split the text by a increasing semantic levels.
85    // 2. Check the first item for each level and select the highest level whose first item still fits within the chunk size.
86    // 3. Merge as many of these neighboring sections of this level or above into a chunk to maximize chunk length.
87    //    Boundaries of higher semantic levels are always included when merging, so that the chunk doesn't inadvertantly cross semantic boundaries.
88    //
89    // The boundaries used to split the text if using the `chunks` method, in ascending order:
90    //
91    // 1. Characters
92    // 2. [Unicode Grapheme Cluster Boundaries](https://www.unicode.org/reports/tr29/#Grapheme_Cluster_Boundaries)
93    // 3. [Unicode Word Boundaries](https://www.unicode.org/reports/tr29/#Word_Boundaries)
94    // 4. [Unicode Sentence Boundaries](https://www.unicode.org/reports/tr29/#Sentence_Boundaries)
95    // 5. Ascending depth of the syntax tree. So function would have a higher level than a statement inside of the function, and so on.
96    //
97    // Splitting doesn't occur below the character level, otherwise you could get partial bytes of a char, which may not be a valid unicode str.
98    ///
99    /// ```
100    /// use text_splitter::CodeSplitter;
101    ///
102    /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
103    /// let text = "Some text\n\nfrom a\ndocument";
104    /// let chunks = splitter.chunks(text).collect::<Vec<_>>();
105    ///
106    /// assert_eq!(vec!["Some text", "from a", "document"], chunks);
107    /// ```
108    pub fn chunks<'splitter, 'text: 'splitter>(
109        &'splitter self,
110        text: &'text str,
111    ) -> impl Iterator<Item = &'text str> + 'splitter {
112        Splitter::<_>::chunks(self, text)
113    }
114
115    /// Returns an iterator over chunks of the text and their byte offsets.
116    /// Each chunk will be up to the `chunk_capacity`.
117    ///
118    /// See [`CodeSplitter::chunks`] for more information.
119    ///
120    /// ```
121    /// use text_splitter::CodeSplitter;
122    ///
123    /// let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 10).expect("Invalid language");
124    /// let text = "Some text\n\nfrom a\ndocument";
125    /// let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
126    ///
127    /// assert_eq!(vec![(0, "Some text"), (11, "from a"), (18, "document")], chunks);
128    pub fn chunk_indices<'splitter, 'text: 'splitter>(
129        &'splitter self,
130        text: &'text str,
131    ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
132        Splitter::<_>::chunk_indices(self, text)
133    }
134}
135
136impl<Sizer> Splitter<Sizer> for CodeSplitter<Sizer>
137where
138    Sizer: ChunkSizer,
139{
140    type Level = Depth;
141
142    const TRIM: Trim = Trim::PreserveIndentation;
143
144    fn chunk_config(&self) -> &ChunkConfig<Sizer> {
145        &self.chunk_config
146    }
147
148    fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
149        let mut parser = Parser::new();
150        parser
151            .set_language(&self.language)
152            // We verify at initialization that the language is valid, so this should be safe.
153            .expect("Error loading language");
154        // The only reason the tree would be None is:
155        // - No language was set (we do that)
156        // - There was a timeout or cancellation option set (we don't)
157        // - So it should be safe to unwrap here
158        let tree = parser.parse(text, None).expect("Error parsing source code");
159
160        CursorOffsets::new(tree.walk()).collect()
161    }
162}
163
164/// New type around a usize to capture the depth of a given code node.
165/// Custom type so that we can implement custom ordering, since we want to
166/// sort items of lower depth as higher priority.
167#[derive(Clone, Copy, Debug, Eq, PartialEq)]
168pub struct Depth(usize);
169
170impl PartialOrd for Depth {
171    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
172        Some(self.cmp(other))
173    }
174}
175
176impl Ord for Depth {
177    fn cmp(&self, other: &Self) -> Ordering {
178        other.0.cmp(&self.0)
179    }
180}
181
182/// New type around a tree-sitter cursor to allow for implementing an iterator.
183/// Each call to `next()` will return the next node in the tree in a depth-first
184/// order.
185struct CursorOffsets<'cursor> {
186    cursor: TreeCursor<'cursor>,
187}
188
189impl<'cursor> CursorOffsets<'cursor> {
190    fn new(cursor: TreeCursor<'cursor>) -> Self {
191        Self { cursor }
192    }
193}
194
195impl Iterator for CursorOffsets<'_> {
196    type Item = (Depth, Range<usize>);
197
198    fn next(&mut self) -> Option<Self::Item> {
199        // There are children (can call this initially because we don't want the root node)
200        if self.cursor.goto_first_child() {
201            return Some((
202                Depth(self.cursor.depth() as usize),
203                self.cursor.node().byte_range(),
204            ));
205        }
206
207        loop {
208            // There are sibling elements to grab
209            if self.cursor.goto_next_sibling() {
210                return Some((
211                    Depth(self.cursor.depth() as usize),
212                    self.cursor.node().byte_range(),
213                ));
214            // Start going back up the tree and check for next sibling on next iteration.
215            } else if self.cursor.goto_parent() {
216                continue;
217            }
218
219            // We have no more siblings or parents, so we're done.
220            return None;
221        }
222    }
223}
224
225impl SemanticLevel for Depth {}
226
227#[cfg(test)]
228mod tests {
229    use tree_sitter::{Node, Tree};
230
231    use super::*;
232
233    #[test]
234    fn rust_splitter() {
235        let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
236        let text = "fn main() {\n    let x = 5;\n}";
237        let chunks = splitter.chunks(text).collect::<Vec<_>>();
238
239        assert_eq!(chunks, vec!["fn main()", "{\n    let x = 5;", "}"]);
240    }
241
242    #[test]
243    fn rust_splitter_indices() {
244        let splitter = CodeSplitter::new(tree_sitter_rust::LANGUAGE, 16).unwrap();
245        let text = "fn main() {\n    let x = 5;\n}";
246        let chunks = splitter.chunk_indices(text).collect::<Vec<_>>();
247
248        assert_eq!(
249            chunks,
250            vec![(0, "fn main()"), (10, "{\n    let x = 5;"), (27, "}")]
251        );
252    }
253
254    #[test]
255    fn depth_partialord() {
256        assert_eq!(Depth(0).partial_cmp(&Depth(1)), Some(Ordering::Greater));
257        assert_eq!(Depth(1).partial_cmp(&Depth(2)), Some(Ordering::Greater));
258        assert_eq!(Depth(1).partial_cmp(&Depth(1)), Some(Ordering::Equal));
259        assert_eq!(Depth(2).partial_cmp(&Depth(1)), Some(Ordering::Less));
260    }
261
262    #[test]
263    fn depth_ord() {
264        assert_eq!(Depth(0).cmp(&Depth(1)), Ordering::Greater);
265        assert_eq!(Depth(1).cmp(&Depth(2)), Ordering::Greater);
266        assert_eq!(Depth(1).cmp(&Depth(1)), Ordering::Equal);
267        assert_eq!(Depth(2).cmp(&Depth(1)), Ordering::Less);
268    }
269
270    #[test]
271    fn depth_sorting() {
272        let mut depths = vec![Depth(0), Depth(1), Depth(2)];
273        depths.sort();
274        assert_eq!(depths, [Depth(2), Depth(1), Depth(0)]);
275    }
276
277    /// Checks that the optimized version of the code produces the same results as the naive version.
278    #[test]
279    fn optimized_code_offsets() {
280        let mut parser = Parser::new();
281        parser
282            .set_language(&tree_sitter_rust::LANGUAGE.into())
283            .expect("Error loading Rust grammar");
284        let source_code = "fn test() {
285    let x = 1;
286}";
287        let tree = parser
288            .parse(source_code, None)
289            .expect("Error parsing source code");
290
291        let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
292
293        assert_eq!(offsets, naive_offsets(&tree));
294    }
295
296    #[test]
297    fn multiple_top_siblings() {
298        let mut parser = Parser::new();
299        parser
300            .set_language(&tree_sitter_rust::LANGUAGE.into())
301            .expect("Error loading Rust grammar");
302        let source_code = "
303fn fn1() {}
304fn fn2() {}
305fn fn3() {}
306fn fn4() {}";
307        let tree = parser
308            .parse(source_code, None)
309            .expect("Error parsing source code");
310
311        let offsets = CursorOffsets::new(tree.walk()).collect::<Vec<_>>();
312
313        assert_eq!(offsets, naive_offsets(&tree));
314    }
315
316    fn naive_offsets(tree: &Tree) -> Vec<(Depth, Range<usize>)> {
317        let root_node = tree.root_node();
318        let mut offsets = vec![];
319        recursive_naive_offsets(&mut offsets, root_node, 0);
320        offsets
321    }
322
323    // Basic version to compare an optimized version against. According to the tree-sitter
324    // documentation, this is not efficient for large trees. But because it is the easiest
325    // to reason about it is a good check for correctness.
326    fn recursive_naive_offsets(
327        collection: &mut Vec<(Depth, Range<usize>)>,
328        node: Node<'_>,
329        depth: usize,
330    ) {
331        // We can skip the root node
332        if depth > 0 {
333            collection.push((Depth(depth), node.byte_range()));
334        }
335
336        for child in node.children(&mut node.walk()) {
337            recursive_naive_offsets(collection, child, depth + 1);
338        }
339    }
340}