code_splitter/
splitter.rs

1use crate::chunk::Chunk;
2use crate::error::Result;
3use crate::sizer::Sizer;
4
5use std::str;
6use tree_sitter::{Language, Node, Parser};
7
8/// Default maximum size of a chunk.
9const DEFAULT_MAX_SIZE: usize = 512;
10
11/// A struct for splitting code into chunks.
12pub struct Splitter<T: Sizer> {
13    /// Language of the code.
14    language: Language,
15    /// Sizer for counting the size of code chunks.
16    sizer: T,
17    /// Maximum size of a code chunk.
18    max_size: usize,
19}
20
21impl<T> Splitter<T>
22where
23    T: Sizer,
24{
25    /// Create a new `Splitter` that counts the size of code chunks with the given sizer.
26    ///
27    /// # Example: split by characters
28    /// ```
29    /// use code_splitter::{CharCounter, Splitter};
30    ///
31    /// let lang = tree_sitter_md::language();
32    /// let splitter = Splitter::new(lang, CharCounter).unwrap();
33    /// let chunks = splitter.split(b"hello, world!").unwrap();
34    /// ```
35    ///
36    /// # Example: split by words
37    /// ```
38    /// use code_splitter::{Splitter, WordCounter};
39    ///
40    /// let lang = tree_sitter_md::language();
41    /// let splitter = Splitter::new(lang, WordCounter).unwrap();
42    /// let chunks = splitter.split(b"hello, world!").unwrap();
43    /// ```
44    ///
45    /// # Example: split by tokens with huggingface tokenizer
46    /// ```
47    /// # #[cfg(feature = "tokenizers")]
48    /// # {
49    /// use code_splitter::Splitter;
50    /// use tokenizers::Tokenizer;
51    ///
52    /// let lang = tree_sitter_md::language();
53    /// let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
54    /// let splitter = Splitter::new(lang, tokenizer).unwrap();
55    /// let chunks = splitter.split(b"hello, world!").unwrap();
56    /// # }
57    /// ```
58    ///
59    /// # Example: split by tokens with tiktoken core BPE
60    /// ```
61    /// # #[cfg(feature = "tiktoken-rs")]
62    /// # {
63    /// use code_splitter::Splitter;
64    /// use tiktoken_rs::cl100k_base;
65    ///
66    /// let lang = tree_sitter_md::language();
67    /// let bpe = cl100k_base().unwrap();
68    /// let splitter = Splitter::new(lang, bpe).unwrap();
69    /// let chunks = splitter.split(b"hello, world!").unwrap();
70    /// # }
71    /// ```
72    pub fn new(language: Language, sizer: T) -> Result<Self> {
73        // Ensure tree-sitter-<language> crate can be loaded
74        Parser::new().set_language(&language)?;
75
76        Ok(Self {
77            language,
78            sizer,
79            max_size: DEFAULT_MAX_SIZE,
80        })
81    }
82
83    /// Set the maximum size of a chunk. The default is 512.
84    ///
85    /// # Example: set the maximum size to 256
86    /// ```
87    /// use code_splitter::{CharCounter, Splitter};
88    ///
89    /// let lang = tree_sitter_md::language();
90    /// let splitter = Splitter::new(lang, CharCounter)
91    ///   .unwrap()
92    ///   .with_max_size(256);
93    /// let chunks = splitter.split(b"hello, world!").unwrap();
94    /// ```
95    pub fn with_max_size(mut self, max_size: usize) -> Self {
96        self.max_size = max_size;
97        self
98    }
99
100    /// Split the code into chunks with no larger than `max_size`.
101    pub fn split(&self, code: &[u8]) -> Result<Vec<Chunk>> {
102        if code.is_empty() {
103            return Ok(vec![]);
104        }
105
106        let mut parser = Parser::new();
107        parser
108            .set_language(&self.language)
109            .expect("Error loading tree-sitter language");
110        let tree = parser.parse(code, None).ok_or("Error parsing code")?;
111        let root_node = tree.root_node();
112
113        let chunks = self.split_node(&root_node, 0, code)?;
114
115        Ok(chunks)
116    }
117
118    fn split_node(&self, node: &Node, depth: usize, code: &[u8]) -> Result<Vec<Chunk>> {
119        let text = node.utf8_text(code)?;
120        let chunk_size = self.sizer.size(text)?;
121
122        if chunk_size == 0 {
123            return Ok(vec![]);
124        }
125
126        if chunk_size <= self.max_size {
127            return Ok(vec![Chunk {
128                subtree: format!("{}: {}", format_node(node, depth), chunk_size),
129                range: node.range(),
130                size: chunk_size,
131            }]);
132        }
133
134        let chunks = node
135            // Traverse the children in depth-first order
136            .children(&mut node.walk())
137            .map(|child| self.split_node(&child, depth + 1, code))
138            .collect::<Result<Vec<_>>>()?
139            .into_iter()
140            // Join the tail and head of neighboring chunks if possible
141            .try_fold(Vec::new(), |mut acc, mut next| -> Result<Vec<Chunk>> {
142                if let Some(tail) = acc.pop() {
143                    if let Some(head) = next.first_mut() {
144                        let joined_size = self.joined_size(&tail, head, code)?;
145                        if joined_size <= self.max_size {
146                            // Concatenate the tail and head names
147                            head.subtree = format!("{}\n{}", tail.subtree, head.subtree);
148                            head.range.start_byte = tail.range.start_byte;
149                            head.range.start_point = tail.range.start_point;
150                            head.size = joined_size;
151                        } else {
152                            acc.push(tail);
153                        }
154                    } else {
155                        // Push the tail back if next is empty
156                        acc.push(tail);
157                    }
158                }
159                acc.append(&mut next);
160                Ok(acc)
161            })?;
162
163        Ok(chunks)
164    }
165
166    fn joined_size(&self, chunk: &Chunk, next: &Chunk, code: &[u8]) -> Result<usize> {
167        let joined_bytes = &code[chunk.range.start_byte..next.range.end_byte];
168        let joined_text = str::from_utf8(joined_bytes)?;
169        self.sizer.size(joined_text)
170    }
171}
172
173fn format_node(node: &Node, depth: usize) -> String {
174    format!(
175        "{indent}{branch} {kind:<32} [{start}..{end}]",
176        indent = "│  ".repeat(depth.saturating_sub(1)),
177        branch = if depth > 0 { "├─" } else { "" },
178        kind = node.kind(),
179        start = node.start_position().row,
180        end = node.end_position().row
181    )
182}