swiftide_integrations/treesitter/
splitter.rs1use anyhow::{Context as _, Result};
2use std::ops::Range;
3use tree_sitter::{Node, Parser};
4
5use derive_builder::Builder;
6
7use super::supported_languages::SupportedLanguages;
8
9const DEFAULT_MAX_BYTES: usize = 1500;
11
12#[derive(Debug, Builder, Clone)]
13#[builder(setter(into), build_fn(error = "anyhow::Error"))]
17pub struct CodeSplitter {
18 #[builder(default, setter(into))]
20 chunk_size: ChunkSize,
21 #[builder(setter(custom))]
22 language: SupportedLanguages,
23}
24
25impl CodeSplitterBuilder {
26 pub fn try_language(mut self, language: impl TryInto<SupportedLanguages>) -> Result<Self> {
41 self.language = Some(
42 language
43 .try_into()
44 .ok()
45 .context("Treesitter language not supported")?,
46 );
47 Ok(self)
48 }
49}
50
51#[derive(Debug, Clone)]
52pub enum ChunkSize {
54 Bytes(usize),
55 Range(Range<usize>),
56}
57
58impl From<usize> for ChunkSize {
59 fn from(size: usize) -> Self {
61 ChunkSize::Bytes(size)
62 }
63}
64
65impl From<Range<usize>> for ChunkSize {
66 fn from(range: Range<usize>) -> Self {
68 ChunkSize::Range(range)
69 }
70}
71
72impl Default for ChunkSize {
73 fn default() -> Self {
75 ChunkSize::Bytes(DEFAULT_MAX_BYTES)
76 }
77}
78
79impl CodeSplitter {
80 pub fn new(language: SupportedLanguages) -> Self {
90 Self {
91 chunk_size: ChunkSize::default(),
92 language,
93 }
94 }
95
96 pub fn builder() -> CodeSplitterBuilder {
102 CodeSplitterBuilder::default()
103 }
104
105 fn chunk_node(
117 &self,
118 node: Node,
119 source: &str,
120 mut last_end: usize,
121 current_chunk: Option<String>,
122 ) -> Vec<String> {
123 let mut new_chunks: Vec<String> = Vec::new();
124 let mut current_chunk = current_chunk.unwrap_or_default();
125
126 for child in node.children(&mut node.walk()) {
127 debug_assert!(
128 current_chunk.len() <= self.max_bytes(),
129 "Chunk too big: {} > {}",
130 current_chunk.len(),
131 self.max_bytes()
132 );
133
134 let next_child_size = child.end_byte() - last_end;
141 if current_chunk.len() + next_child_size >= self.max_bytes() {
142 if next_child_size > self.max_bytes() {
143 let mut sub_chunks =
144 self.chunk_node(child, source, last_end, Some(current_chunk));
145 current_chunk = sub_chunks.pop().unwrap_or_default();
146 new_chunks.extend(sub_chunks);
147 } else {
148 if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() {
151 new_chunks.push(current_chunk);
152 }
153 current_chunk = source[last_end..child.end_byte()].to_string();
154 }
155 } else {
156 current_chunk += &source[last_end..child.end_byte()];
157 }
158
159 last_end = child.end_byte();
160 }
161
162 if !current_chunk.is_empty() && current_chunk.len() > self.min_bytes() {
163 new_chunks.push(current_chunk);
164 }
165
166 new_chunks
167 }
168
169 pub fn split(&self, code: &str) -> Result<Vec<String>> {
184 let mut parser = Parser::new();
185 parser.set_language(&self.language.into())?;
186 let tree = parser.parse(code, None).context("No nodes found")?;
187 let root_node = tree.root_node();
188
189 if root_node.has_error() {
190 tracing::error!("Syntax error parsing code: {:?}", code);
191 return Ok(vec![code.to_string()]);
192 }
193
194 Ok(self.chunk_node(root_node, code, 0, None))
195 }
196
197 fn max_bytes(&self) -> usize {
203 match &self.chunk_size {
204 ChunkSize::Bytes(size) => *size,
205 ChunkSize::Range(range) => range.end,
206 }
207 }
208
209 fn min_bytes(&self) -> usize {
215 if let ChunkSize::Range(range) = &self.chunk_size {
216 range.start
217 } else {
218 0
219 }
220 }
221}
222
223#[cfg(test)]
224mod test {
225 use super::*;
226 use indoc::indoc;
227
228 #[test]
229 fn test_split_single_chunk() {
230 let code = "fn hello_world() {}";
231
232 let splitter = CodeSplitter::new(SupportedLanguages::Rust);
233
234 let chunks = splitter.split(code);
235
236 assert_eq!(chunks.unwrap(), vec!["fn hello_world() {}"]);
237 }
238
239 #[test]
240 fn test_chunk_lines() {
241 let splitter = CodeSplitter::new(SupportedLanguages::Rust);
242
243 let text = indoc! {r#"
244 fn main() {
245 println!("Hello");
246 println!("World");
247 println!("!");
248 }
249 "#};
250
251 let chunks = splitter.split(text).unwrap();
252
253 dbg!(&chunks);
254 assert_eq!(chunks.len(), 1);
255 assert_eq!(
256 chunks[0],
257 "fn main() {\n println!(\"Hello\");\n println!(\"World\");\n println!(\"!\");\n}"
258 );
259 }
260
261 #[test]
262 fn test_max_bytes_limit() {
263 let splitter = CodeSplitter::builder()
264 .try_language(SupportedLanguages::Rust)
265 .unwrap()
266 .chunk_size(50)
267 .build()
268 .unwrap();
269
270 let text = indoc! {r#"
271 fn main() {
272 println!("Hello, World!");
273 println!("Goodbye, World!");
274 }
275 "#};
276 let chunks = splitter.split(text).unwrap();
277
278 assert!(chunks.iter().all(|chunk| chunk.len() <= 50));
279 assert!(
280 chunks
281 .windows(2)
282 .all(|pair| pair.iter().map(String::len).sum::<usize>() >= 50)
283 );
284
285 assert_eq!(
286 chunks,
287 vec![
288 "fn main() {\n println!(\"Hello, World!\");",
289 "\n println!(\"Goodbye, World!\");\n}",
290 ]
291 );
292 }
293
294 #[test]
295 fn test_empty_text() {
296 let splitter = CodeSplitter::builder()
297 .try_language(SupportedLanguages::Rust)
298 .unwrap()
299 .chunk_size(50)
300 .build()
301 .unwrap();
302
303 let text = "";
304 let chunks = splitter.split(text).unwrap();
305
306 dbg!(&chunks);
307 assert_eq!(chunks.len(), 0);
308 }
309
310 #[test]
311 fn test_range_max() {
312 let splitter = CodeSplitter::builder()
313 .try_language(SupportedLanguages::Rust)
314 .unwrap()
315 .chunk_size(0..50)
316 .build()
317 .unwrap();
318
319 let text = indoc! {r#"
320 fn main() {
321 println!("Hello, World!");
322 println!("Goodbye, World!");
323 }
324 "#};
325 let chunks = splitter.split(text).unwrap();
326 assert_eq!(
327 chunks,
328 vec![
329 "fn main() {\n println!(\"Hello, World!\");",
330 "\n println!(\"Goodbye, World!\");\n}",
331 ]
332 );
333 }
334
335 #[test]
336 fn test_range_min_and_max() {
337 let splitter = CodeSplitter::builder()
338 .try_language(SupportedLanguages::Rust)
339 .unwrap()
340 .chunk_size(20..50)
341 .build()
342 .unwrap();
343 let text = indoc! {r#"
344 fn main() {
345 println!("Hello, World!");
346 println!("Goodbye, World!");
347 }
348 "#};
349 let chunks = splitter.split(text).unwrap();
350
351 assert!(chunks.iter().all(|chunk| chunk.len() <= 50));
352 assert!(
353 chunks
354 .windows(2)
355 .all(|pair| pair.iter().map(String::len).sum::<usize>() > 50)
356 );
357 assert!(chunks.iter().all(|chunk| chunk.len() >= 20));
358
359 assert_eq!(
360 chunks,
361 vec![
362 "fn main() {\n println!(\"Hello, World!\");",
363 "\n println!(\"Goodbye, World!\");\n}"
364 ]
365 );
366 }
367
368 #[test]
369 fn test_on_self() {
370 let code = include_str!("splitter.rs");
372 let ranges = vec![
374 10..200,
375 50..100,
376 100..150,
377 150..200,
378 200..250,
379 250..300,
380 300..350,
381 350..400,
382 400..450,
383 450..500,
384 ];
385
386 for range in ranges {
387 let min = range.start;
388 let max = range.end;
389 let splitter = CodeSplitter::builder()
390 .try_language("rust")
391 .unwrap()
392 .chunk_size(range)
393 .build()
394 .unwrap();
395
396 assert_eq!(splitter.min_bytes(), min);
397 assert_eq!(splitter.max_bytes(), max);
398
399 let chunks = splitter.split(code).unwrap();
400
401 assert!(chunks.iter().all(|chunk| chunk.len() <= max));
402 let chunk_pairs_that_are_smaller_than_max = chunks
403 .windows(2)
404 .filter(|pair| pair.iter().map(String::len).sum::<usize>() < max);
405 assert!(
406 chunk_pairs_that_are_smaller_than_max.clone().count() == 0,
407 "max: {}, {} + {}, {:?}",
408 max,
409 chunk_pairs_that_are_smaller_than_max
410 .clone()
411 .next()
412 .unwrap()[0]
413 .len(),
414 chunk_pairs_that_are_smaller_than_max
415 .clone()
416 .next()
417 .unwrap()[1]
418 .len(),
419 chunk_pairs_that_are_smaller_than_max
420 .collect::<Vec<_>>()
421 .first()
422 );
423 assert!(chunks.iter().all(|chunk| chunk.len() >= min));
424
425 assert!(
426 chunks.iter().all(|chunk| chunk.len() >= min),
427 "{:?}",
428 chunks
429 .iter()
430 .filter(|chunk| chunk.len() < min)
431 .collect::<Vec<_>>()
432 );
433 assert!(
434 chunks.iter().all(|chunk| chunk.len() <= max),
435 "max = {}, chunks = {:?}",
436 max,
437 chunks
438 .iter()
439 .filter(|chunk| chunk.len() > max)
440 .collect::<Vec<_>>()
441 );
442 }
443
444 }
446}