code_splitter/splitter.rs
1use crate::chunk::Chunk;
2use crate::error::Result;
3use crate::sizer::Sizer;
4
5use std::str;
6use tree_sitter::{Language, Node, Parser};
7
8/// Default maximum size of a chunk.
9const DEFAULT_MAX_SIZE: usize = 512;
10
11/// A struct for splitting code into chunks.
12pub struct Splitter<T: Sizer> {
13 /// Language of the code.
14 language: Language,
15 /// Sizer for counting the size of code chunks.
16 sizer: T,
17 /// Maximum size of a code chunk.
18 max_size: usize,
19}
20
21impl<T> Splitter<T>
22where
23 T: Sizer,
24{
25 /// Create a new `Splitter` that counts the size of code chunks with the given sizer.
26 ///
27 /// # Example: split by characters
28 /// ```
29 /// use code_splitter::{CharCounter, Splitter};
30 ///
31 /// let lang = tree_sitter_md::language();
32 /// let splitter = Splitter::new(lang, CharCounter).unwrap();
33 /// let chunks = splitter.split(b"hello, world!").unwrap();
34 /// ```
35 ///
36 /// # Example: split by words
37 /// ```
38 /// use code_splitter::{Splitter, WordCounter};
39 ///
40 /// let lang = tree_sitter_md::language();
41 /// let splitter = Splitter::new(lang, WordCounter).unwrap();
42 /// let chunks = splitter.split(b"hello, world!").unwrap();
43 /// ```
44 ///
45 /// # Example: split by tokens with huggingface tokenizer
46 /// ```
47 /// # #[cfg(feature = "tokenizers")]
48 /// # {
49 /// use code_splitter::Splitter;
50 /// use tokenizers::Tokenizer;
51 ///
52 /// let lang = tree_sitter_md::language();
53 /// let tokenizer = Tokenizer::from_pretrained("bert-base-cased", None).unwrap();
54 /// let splitter = Splitter::new(lang, tokenizer).unwrap();
55 /// let chunks = splitter.split(b"hello, world!").unwrap();
56 /// # }
57 /// ```
58 ///
59 /// # Example: split by tokens with tiktoken core BPE
60 /// ```
61 /// # #[cfg(feature = "tiktoken-rs")]
62 /// # {
63 /// use code_splitter::Splitter;
64 /// use tiktoken_rs::cl100k_base;
65 ///
66 /// let lang = tree_sitter_md::language();
67 /// let bpe = cl100k_base().unwrap();
68 /// let splitter = Splitter::new(lang, bpe).unwrap();
69 /// let chunks = splitter.split(b"hello, world!").unwrap();
70 /// # }
71 /// ```
72 pub fn new(language: Language, sizer: T) -> Result<Self> {
73 // Ensure tree-sitter-<language> crate can be loaded
74 Parser::new().set_language(&language)?;
75
76 Ok(Self {
77 language,
78 sizer,
79 max_size: DEFAULT_MAX_SIZE,
80 })
81 }
82
83 /// Set the maximum size of a chunk. The default is 512.
84 ///
85 /// # Example: set the maximum size to 256
86 /// ```
87 /// use code_splitter::{CharCounter, Splitter};
88 ///
89 /// let lang = tree_sitter_md::language();
90 /// let splitter = Splitter::new(lang, CharCounter)
91 /// .unwrap()
92 /// .with_max_size(256);
93 /// let chunks = splitter.split(b"hello, world!").unwrap();
94 /// ```
95 pub fn with_max_size(mut self, max_size: usize) -> Self {
96 self.max_size = max_size;
97 self
98 }
99
100 /// Split the code into chunks with no larger than `max_size`.
101 pub fn split(&self, code: &[u8]) -> Result<Vec<Chunk>> {
102 if code.is_empty() {
103 return Ok(vec![]);
104 }
105
106 let mut parser = Parser::new();
107 parser
108 .set_language(&self.language)
109 .expect("Error loading tree-sitter language");
110 let tree = parser.parse(code, None).ok_or("Error parsing code")?;
111 let root_node = tree.root_node();
112
113 let chunks = self.split_node(&root_node, 0, code)?;
114
115 Ok(chunks)
116 }
117
118 fn split_node(&self, node: &Node, depth: usize, code: &[u8]) -> Result<Vec<Chunk>> {
119 let text = node.utf8_text(code)?;
120 let chunk_size = self.sizer.size(text)?;
121
122 if chunk_size == 0 {
123 return Ok(vec![]);
124 }
125
126 if chunk_size <= self.max_size {
127 return Ok(vec![Chunk {
128 subtree: format!("{}: {}", format_node(node, depth), chunk_size),
129 range: node.range(),
130 size: chunk_size,
131 }]);
132 }
133
134 let chunks = node
135 // Traverse the children in depth-first order
136 .children(&mut node.walk())
137 .map(|child| self.split_node(&child, depth + 1, code))
138 .collect::<Result<Vec<_>>>()?
139 .into_iter()
140 // Join the tail and head of neighboring chunks if possible
141 .try_fold(Vec::new(), |mut acc, mut next| -> Result<Vec<Chunk>> {
142 if let Some(tail) = acc.pop() {
143 if let Some(head) = next.first_mut() {
144 let joined_size = self.joined_size(&tail, head, code)?;
145 if joined_size <= self.max_size {
146 // Concatenate the tail and head names
147 head.subtree = format!("{}\n{}", tail.subtree, head.subtree);
148 head.range.start_byte = tail.range.start_byte;
149 head.range.start_point = tail.range.start_point;
150 head.size = joined_size;
151 } else {
152 acc.push(tail);
153 }
154 } else {
155 // Push the tail back if next is empty
156 acc.push(tail);
157 }
158 }
159 acc.append(&mut next);
160 Ok(acc)
161 })?;
162
163 Ok(chunks)
164 }
165
166 fn joined_size(&self, chunk: &Chunk, next: &Chunk, code: &[u8]) -> Result<usize> {
167 let joined_bytes = &code[chunk.range.start_byte..next.range.end_byte];
168 let joined_text = str::from_utf8(joined_bytes)?;
169 self.sizer.size(joined_text)
170 }
171}
172
173fn format_node(node: &Node, depth: usize) -> String {
174 format!(
175 "{indent}{branch} {kind:<32} [{start}..{end}]",
176 indent = "│ ".repeat(depth.saturating_sub(1)),
177 branch = if depth > 0 { "├─" } else { "" },
178 kind = node.kind(),
179 start = node.start_position().row,
180 end = node.end_position().row
181 )
182}