swiftide_integrations/treesitter/
splitter.rs

1use anyhow::{Context as _, Result};
2use std::ops::Range;
3use tree_sitter::{Node, Parser};
4
5use derive_builder::Builder;
6
7use super::supported_languages::SupportedLanguages;
8
9// TODO: Instead of counting bytes, count tokens with titktoken
10const DEFAULT_MAX_BYTES: usize = 1500;
11
12#[derive(Debug, Builder, Clone)]
13/// Splits code files into meaningful chunks
14///
15/// Supports splitting code files into chunks based on a maximum size or a range of bytes.
16#[builder(setter(into), build_fn(error = "anyhow::Error"))]
17pub struct CodeSplitter {
18    /// Maximum size of a chunk in bytes or a range of bytes
19    #[builder(default, setter(into))]
20    chunk_size: ChunkSize,
21    #[builder(setter(custom))]
22    language: SupportedLanguages,
23}
24
25impl CodeSplitterBuilder {
26    /// Attempts to set the language for the `CodeSplitter`.
27    ///
28    /// # Arguments
29    ///
30    /// * `language` - A value that can be converted into `SupportedLanguages`.
31    ///
32    /// # Returns
33    ///
34    /// * `Result<Self>` - The builder instance with the language set, or an error if the language
35    ///   is not supported.
36    ///
37    /// # Errors
38    ///
39    /// Errors if language is not supported
40    pub fn try_language(mut self, language: impl TryInto<SupportedLanguages>) -> Result<Self> {
41        self.language = Some(
42            language
43                .try_into()
44                .ok()
45                .context("Treesitter language not supported")?,
46        );
47        Ok(self)
48    }
49}
50
51#[derive(Debug, Clone)]
52/// Represents the size of a chunk, either as a fixed number of bytes or a range of bytes.
53pub enum ChunkSize {
54    Bytes(usize),
55    Range(Range<usize>),
56}
57
58impl From<usize> for ChunkSize {
59    /// Converts a `usize` into a `ChunkSize::Bytes` variant.
60    fn from(size: usize) -> Self {
61        ChunkSize::Bytes(size)
62    }
63}
64
65impl From<Range<usize>> for ChunkSize {
66    /// Converts a `Range<usize>` into a `ChunkSize::Range` variant.
67    fn from(range: Range<usize>) -> Self {
68        ChunkSize::Range(range)
69    }
70}
71
72impl Default for ChunkSize {
73    /// Provides a default value for `ChunkSize`, which is `ChunkSize::Bytes(DEFAULT_MAX_BYTES)`.
74    fn default() -> Self {
75        ChunkSize::Bytes(DEFAULT_MAX_BYTES)
76    }
77}
78
79impl CodeSplitter {
80    /// Creates a new `CodeSplitter` with the specified language and default chunk size.
81    ///
82    /// # Arguments
83    ///
84    /// * `language` - The programming language for which the code will be split.
85    ///
86    /// # Returns
87    ///
88    /// * `Self` - A new instance of `CodeSplitter`.
89    pub fn new(language: SupportedLanguages) -> Self {
90        Self {
91            chunk_size: ChunkSize::default(),
92            language,
93        }
94    }
95
96    /// Creates a new builder for `CodeSplitter`.
97    ///
98    /// # Returns
99    ///
100    /// * `CodeSplitterBuilder` - A new builder instance for `CodeSplitter`.
101    pub fn builder() -> CodeSplitterBuilder {
102        CodeSplitterBuilder::default()
103    }
104
105    /// Recursively chunks a syntax node into smaller pieces based on the chunk size.
106    ///
107    /// # Arguments
108    ///
109    /// * `node` - The syntax node to be chunked.
110    /// * `source` - The source code as a string.
111    /// * `last_end` - The end byte of the last chunk.
112    ///
113    /// # Returns
114    ///
115    /// * `Vec<String>` - A vector of code chunks as strings.
116    fn chunk_node(
117        &self,
118        node: Node,
119        source: &str,
120        mut last_end: usize,
121        current_chunk: Option<String>,
122    ) -> Vec<String> {
123        let mut new_chunks: Vec<String> = Vec::new();
124        let mut current_chunk = current_chunk.unwrap_or_default();
125
126        for child in node.children(&mut node.walk()) {
127            debug_assert!(
128                current_chunk.len() <= self.max_bytes(),
129                "Chunk too big: {} > {}",
130                current_chunk.len(),
131                self.max_bytes()
132            );
133
134            // if the next child will make the chunk too big then there are two options:
135            // 1. if the next child is too big to fit in a whole chunk, then recursively chunk it
136            //    one level down
137            // 2. if the next child is small enough to fit in a chunk, then add the current chunk to
138            //    the list and start a new chunk
139
140            let next_child_size = child.end_byte() - last_end;
141            if current_chunk.len() + next_child_size >= self.max_bytes() {
142                if next_child_size > self.max_bytes() {
143                    let mut sub_chunks =
144                        self.chunk_node(child, source, last_end, Some(current_chunk));
145                    current_chunk = sub_chunks.pop().unwrap_or_default();
146                    new_chunks.extend(sub_chunks);
147                } else {
148                    // NOTE: if the current chunk was smaller than then the min_bytes, then it is
149                    // discarded here
150                    if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() {
151                        new_chunks.push(current_chunk);
152                    }
153                    current_chunk = source[last_end..child.end_byte()].to_string();
154                }
155            } else {
156                current_chunk += &source[last_end..child.end_byte()];
157            }
158
159            last_end = child.end_byte();
160        }
161
162        if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() {
163            new_chunks.push(current_chunk);
164        }
165
166        new_chunks
167    }
168
169    /// Splits the given code into chunks based on the chunk size.
170    ///
171    /// # Arguments
172    ///
173    /// * `code` - The source code to be split.
174    ///
175    /// # Returns
176    ///
177    /// * `Result<Vec<String>>` - A result containing a vector of code chunks as strings, or an
178    ///   error if the code could not be parsed.
179    ///
180    /// # Errors
181    ///
182    /// Returns an error if the node cannot be found or fails to parse
183    pub fn split(&self, code: &str) -> Result<Vec<String>> {
184        let mut parser = Parser::new();
185        parser.set_language(&self.language.into())?;
186        let tree = parser.parse(code, None).context("No nodes found")?;
187        let root_node = tree.root_node();
188
189        if root_node.has_error() {
190            tracing::error!("Syntax error parsing code: {:?}", code);
191            return Ok(vec![code.to_string()]);
192        }
193
194        Ok(self.chunk_node(root_node, code, 0, None))
195    }
196
197    /// Returns the maximum number of bytes allowed in a chunk.
198    ///
199    /// # Returns
200    ///
201    /// * `usize` - The maximum number of bytes in a chunk.
202    fn max_bytes(&self) -> usize {
203        match &self.chunk_size {
204            ChunkSize::Bytes(size) => *size,
205            ChunkSize::Range(range) => range.end,
206        }
207    }
208
209    /// Returns the minimum number of bytes allowed in a chunk.
210    ///
211    /// # Returns
212    ///
213    /// * `usize` - The minimum number of bytes in a chunk.
214    fn min_bytes(&self) -> usize {
215        if let ChunkSize::Range(range) = &self.chunk_size {
216            range.start
217        } else {
218            0
219        }
220    }
221}
222
223#[cfg(test)]
224mod test {
225    use super::*;
226    use indoc::indoc;
227
228    #[test]
229    fn test_split_single_chunk() {
230        let code = "fn hello_world() {}";
231
232        let splitter = CodeSplitter::new(SupportedLanguages::Rust);
233
234        let chunks = splitter.split(code);
235
236        assert_eq!(chunks.unwrap(), vec!["fn hello_world() {}"]);
237    }
238
239    #[test]
240    fn test_chunk_lines() {
241        let splitter = CodeSplitter::new(SupportedLanguages::Rust);
242
243        let text = indoc! {r#"
244            fn main() {
245                println!("Hello");
246                println!("World");
247                println!("!");
248            }
249        "#};
250
251        let chunks = splitter.split(text).unwrap();
252
253        dbg!(&chunks);
254        assert_eq!(chunks.len(), 1);
255        assert_eq!(
256            chunks[0],
257            "fn main() {\n    println!(\"Hello\");\n    println!(\"World\");\n    println!(\"!\");\n}"
258        );
259    }
260
261    #[test]
262    fn test_max_bytes_limit() {
263        let splitter = CodeSplitter::builder()
264            .try_language(SupportedLanguages::Rust)
265            .unwrap()
266            .chunk_size(50)
267            .build()
268            .unwrap();
269
270        let text = indoc! {r#"
271            fn main() {
272                println!("Hello, World!");
273                println!("Goodbye, World!");
274            }
275        "#};
276        let chunks = splitter.split(text).unwrap();
277
278        assert!(chunks.iter().all(|chunk| chunk.len() <= 50));
279        assert!(
280            chunks
281                .windows(2)
282                .all(|pair| pair.iter().map(String::len).sum::<usize>() >= 50)
283        );
284
285        assert_eq!(
286            chunks,
287            vec![
288                "fn main() {\n    println!(\"Hello, World!\");",
289                "\n    println!(\"Goodbye, World!\");\n}",
290            ]
291        );
292    }
293
294    #[test]
295    fn test_empty_text() {
296        let splitter = CodeSplitter::builder()
297            .try_language(SupportedLanguages::Rust)
298            .unwrap()
299            .chunk_size(50)
300            .build()
301            .unwrap();
302
303        let text = "";
304        let chunks = splitter.split(text).unwrap();
305
306        dbg!(&chunks);
307        assert_eq!(chunks.len(), 0);
308    }
309
310    #[test]
311    fn test_range_max() {
312        let splitter = CodeSplitter::builder()
313            .try_language(SupportedLanguages::Rust)
314            .unwrap()
315            .chunk_size(0..50)
316            .build()
317            .unwrap();
318
319        let text = indoc! {r#"
320            fn main() {
321                println!("Hello, World!");
322                println!("Goodbye, World!");
323            }
324        "#};
325        let chunks = splitter.split(text).unwrap();
326        assert_eq!(
327            chunks,
328            vec![
329                "fn main() {\n    println!(\"Hello, World!\");",
330                "\n    println!(\"Goodbye, World!\");\n}",
331            ]
332        );
333    }
334
335    #[test]
336    fn test_range_min_and_max() {
337        let splitter = CodeSplitter::builder()
338            .try_language(SupportedLanguages::Rust)
339            .unwrap()
340            .chunk_size(20..50)
341            .build()
342            .unwrap();
343        let text = indoc! {r#"
344            fn main() {
345                println!("Hello, World!");
346                println!("Goodbye, World!");
347            }
348        "#};
349        let chunks = splitter.split(text).unwrap();
350
351        assert!(chunks.iter().all(|chunk| chunk.len() <= 50));
352        assert!(
353            chunks
354                .windows(2)
355                .all(|pair| pair.iter().map(String::len).sum::<usize>() > 50)
356        );
357        assert!(chunks.iter().all(|chunk| chunk.len() >= 20));
358
359        assert_eq!(
360            chunks,
361            vec![
362                "fn main() {\n    println!(\"Hello, World!\");",
363                "\n    println!(\"Goodbye, World!\");\n}"
364            ]
365        );
366    }
367
368    #[test]
369    fn test_on_self() {
370        // read the current file
371        let code = include_str!("splitter.rs");
372        // try chunking with varying ranges of bytes, give me ten with different min and max
373        let ranges = vec![
374            10..200,
375            50..100,
376            100..150,
377            150..200,
378            200..250,
379            250..300,
380            300..350,
381            350..400,
382            400..450,
383            450..500,
384        ];
385
386        for range in ranges {
387            let min = range.start;
388            let max = range.end;
389            let splitter = CodeSplitter::builder()
390                .try_language("rust")
391                .unwrap()
392                .chunk_size(range)
393                .build()
394                .unwrap();
395
396            assert_eq!(splitter.min_bytes(), min);
397            assert_eq!(splitter.max_bytes(), max);
398
399            let chunks = splitter.split(code).unwrap();
400
401            assert!(chunks.iter().all(|chunk| chunk.len() <= max));
402            let chunk_pairs_that_are_smaller_than_max = chunks
403                .windows(2)
404                .filter(|pair| pair.iter().map(String::len).sum::<usize>() < max);
405            assert!(
406                chunk_pairs_that_are_smaller_than_max.clone().count() == 0,
407                "max: {}, {} + {}, {:?}",
408                max,
409                chunk_pairs_that_are_smaller_than_max
410                    .clone()
411                    .next()
412                    .unwrap()[0]
413                    .len(),
414                chunk_pairs_that_are_smaller_than_max
415                    .clone()
416                    .next()
417                    .unwrap()[1]
418                    .len(),
419                chunk_pairs_that_are_smaller_than_max
420                    .collect::<Vec<_>>()
421                    .first()
422            );
423            assert!(chunks.iter().all(|chunk| chunk.len() >= min));
424
425            assert!(
426                chunks.iter().all(|chunk| chunk.len() >= min),
427                "{:?}",
428                chunks
429                    .iter()
430                    .filter(|chunk| chunk.len() < min)
431                    .collect::<Vec<_>>()
432            );
433            assert!(
434                chunks.iter().all(|chunk| chunk.len() <= max),
435                "max = {}, chunks = {:?}",
436                max,
437                chunks
438                    .iter()
439                    .filter(|chunk| chunk.len() > max)
440                    .collect::<Vec<_>>()
441            );
442        }
443
444        // assert there are no nodes smaller than 10
445    }
446}