1use regex::{Matches, Regex};
16use std::collections::HashMap;
17use std::sync::{Arc, LazyLock};
18use unicase::UniCase;
19
20use super::output_positions::{Position, set_output_positions};
21use super::{Chunk, TextRange};
22use crate::prog_langs::{self, TreeSitterLanguageInfo};
23
24const SYNTAX_LEVEL_GAP_COST: usize = 512;
25const MISSING_OVERLAP_COST: usize = 512;
26const PER_LINE_BREAK_LEVEL_GAP_COST: usize = 64;
27const TOO_SMALL_CHUNK_COST: usize = 1048576;
28
29#[derive(Debug, Clone)]
31pub struct CustomLanguageConfig {
32 pub language_name: String,
34 pub aliases: Vec<String>,
36 pub separators_regex: Vec<String>,
38}
39
40#[derive(Debug, Clone, Default)]
42pub struct RecursiveSplitConfig {
43 pub custom_languages: Vec<CustomLanguageConfig>,
45}
46
47#[derive(Debug, Clone)]
49pub struct RecursiveChunkConfig {
50 pub chunk_size: usize,
52 pub min_chunk_size: Option<usize>,
54 pub chunk_overlap: Option<usize>,
56 pub language: Option<String>,
58}
59
60struct SimpleLanguageConfig {
61 name: String,
62 aliases: Vec<String>,
63 separator_regex: Vec<Regex>,
64}
65
66static DEFAULT_LANGUAGE_CONFIG: LazyLock<SimpleLanguageConfig> =
67 LazyLock::new(|| SimpleLanguageConfig {
68 name: "_DEFAULT".to_string(),
69 aliases: vec![],
70 separator_regex: [
71 r"\n\n+",
72 r"\n",
73 r"[\.\?!]\s+|。|?|!",
74 r"[;:\-—]\s+|;|:|—+",
75 r",\s+|,",
76 r"\s+",
77 ]
78 .into_iter()
79 .map(|s| Regex::new(s).unwrap())
80 .collect(),
81 });
82
83enum ChunkKind<'t> {
84 TreeSitterNode {
85 tree_sitter_info: &'t TreeSitterLanguageInfo,
86 node: tree_sitter::Node<'t>,
87 },
88 RegexpSepChunk {
89 lang_config: &'t SimpleLanguageConfig,
90 next_regexp_sep_id: usize,
91 },
92}
93
94struct InternalChunk<'t, 's: 't> {
95 full_text: &'s str,
96 range: TextRange,
97 kind: ChunkKind<'t>,
98}
99
100struct TextChunksIter<'t, 's: 't> {
101 lang_config: &'t SimpleLanguageConfig,
102 full_text: &'s str,
103 range: TextRange,
104 matches_iter: Matches<'t, 's>,
105 regexp_sep_id: usize,
106 next_start_pos: Option<usize>,
107}
108
109impl<'t, 's: 't> TextChunksIter<'t, 's> {
110 fn new(
111 lang_config: &'t SimpleLanguageConfig,
112 full_text: &'s str,
113 range: TextRange,
114 regexp_sep_id: usize,
115 ) -> Self {
116 let std_range = range.start..range.end;
117 Self {
118 lang_config,
119 full_text,
120 range,
121 matches_iter: lang_config.separator_regex[regexp_sep_id]
122 .find_iter(&full_text[std_range.clone()]),
123 regexp_sep_id,
124 next_start_pos: Some(std_range.start),
125 }
126 }
127}
128
129impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> {
130 type Item = InternalChunk<'t, 's>;
131
132 fn next(&mut self) -> Option<Self::Item> {
133 let start_pos = self.next_start_pos?;
134 let end_pos = match self.matches_iter.next() {
135 Some(grp) => {
136 self.next_start_pos = Some(self.range.start + grp.end());
137 self.range.start + grp.start()
138 }
139 None => {
140 self.next_start_pos = None;
141 if start_pos >= self.range.end {
142 return None;
143 }
144 self.range.end
145 }
146 };
147 Some(InternalChunk {
148 full_text: self.full_text,
149 range: TextRange::new(start_pos, end_pos),
150 kind: ChunkKind::RegexpSepChunk {
151 lang_config: self.lang_config,
152 next_regexp_sep_id: self.regexp_sep_id + 1,
153 },
154 })
155 }
156}
157
158struct TreeSitterNodeIter<'t, 's: 't> {
159 lang_config: &'t TreeSitterLanguageInfo,
160 full_text: &'s str,
161 cursor: Option<tree_sitter::TreeCursor<'t>>,
162 next_start_pos: usize,
163 end_pos: usize,
164}
165
166impl<'t, 's: 't> TreeSitterNodeIter<'t, 's> {
167 fn fill_gap(
168 next_start_pos: &mut usize,
169 gap_end_pos: usize,
170 full_text: &'s str,
171 ) -> Option<InternalChunk<'t, 's>> {
172 let start_pos = *next_start_pos;
173 if start_pos < gap_end_pos {
174 *next_start_pos = gap_end_pos;
175 Some(InternalChunk {
176 full_text,
177 range: TextRange::new(start_pos, gap_end_pos),
178 kind: ChunkKind::RegexpSepChunk {
179 lang_config: &DEFAULT_LANGUAGE_CONFIG,
180 next_regexp_sep_id: 0,
181 },
182 })
183 } else {
184 None
185 }
186 }
187}
188
189impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> {
190 type Item = InternalChunk<'t, 's>;
191
192 fn next(&mut self) -> Option<Self::Item> {
193 let cursor = if let Some(cursor) = &mut self.cursor {
194 cursor
195 } else {
196 return Self::fill_gap(&mut self.next_start_pos, self.end_pos, self.full_text);
197 };
198 let node = cursor.node();
199 if let Some(gap) =
200 Self::fill_gap(&mut self.next_start_pos, node.start_byte(), self.full_text)
201 {
202 return Some(gap);
203 }
204 if !cursor.goto_next_sibling() {
205 self.cursor = None;
206 }
207 self.next_start_pos = node.end_byte();
208 Some(InternalChunk {
209 full_text: self.full_text,
210 range: TextRange::new(node.start_byte(), node.end_byte()),
211 kind: ChunkKind::TreeSitterNode {
212 tree_sitter_info: self.lang_config,
213 node,
214 },
215 })
216 }
217}
218
219enum ChunkIterator<'t, 's: 't> {
220 TreeSitter(TreeSitterNodeIter<'t, 's>),
221 Text(TextChunksIter<'t, 's>),
222 Once(std::iter::Once<InternalChunk<'t, 's>>),
223}
224
225impl<'t, 's: 't> Iterator for ChunkIterator<'t, 's> {
226 type Item = InternalChunk<'t, 's>;
227
228 fn next(&mut self) -> Option<Self::Item> {
229 match self {
230 ChunkIterator::TreeSitter(iter) => iter.next(),
231 ChunkIterator::Text(iter) => iter.next(),
232 ChunkIterator::Once(iter) => iter.next(),
233 }
234 }
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
238enum LineBreakLevel {
239 Inline,
240 Newline,
241 DoubleNewline,
242}
243
244impl LineBreakLevel {
245 fn ord(self) -> usize {
246 match self {
247 LineBreakLevel::Inline => 0,
248 LineBreakLevel::Newline => 1,
249 LineBreakLevel::DoubleNewline => 2,
250 }
251 }
252}
253
254fn line_break_level(c: &str) -> LineBreakLevel {
255 let mut lb_level = LineBreakLevel::Inline;
256 let mut iter = c.chars();
257 while let Some(c) = iter.next() {
258 if c == '\n' || c == '\r' {
259 lb_level = LineBreakLevel::Newline;
260 for c2 in iter.by_ref() {
261 if c2 == '\n' || c2 == '\r' {
262 if c == c2 {
263 return LineBreakLevel::DoubleNewline;
264 }
265 } else {
266 break;
267 }
268 }
269 }
270 }
271 lb_level
272}
273
274const INLINE_SPACE_CHARS: [char; 2] = [' ', '\t'];
275
276struct AtomChunk {
277 range: TextRange,
278 boundary_syntax_level: usize,
279 internal_lb_level: LineBreakLevel,
280 boundary_lb_level: LineBreakLevel,
281}
282
283struct AtomChunksCollector<'s> {
284 full_text: &'s str,
285 curr_level: usize,
286 min_level: usize,
287 atom_chunks: Vec<AtomChunk>,
288}
289
290impl<'s> AtomChunksCollector<'s> {
291 fn collect(&mut self, range: TextRange) {
292 let end_trimmed_text = &self.full_text[range.start..range.end].trim_end();
294 if end_trimmed_text.is_empty() {
295 return;
296 }
297
298 let trimmed_text = end_trimmed_text.trim_start();
300 let new_start = range.start + (end_trimmed_text.len() - trimmed_text.len());
301 let new_end = new_start + trimmed_text.len();
302
303 let prev_end = self.atom_chunks.last().map_or(0, |chunk| chunk.range.end);
305 let gap = &self.full_text[prev_end..new_start];
306 let boundary_lb_level = line_break_level(gap);
307 let range = if boundary_lb_level != LineBreakLevel::Inline {
308 let trimmed_gap = gap.trim_end_matches(INLINE_SPACE_CHARS);
309 TextRange::new(prev_end + trimmed_gap.len(), new_end)
310 } else {
311 TextRange::new(new_start, new_end)
312 };
313
314 self.atom_chunks.push(AtomChunk {
315 range,
316 boundary_syntax_level: self.min_level,
317 internal_lb_level: line_break_level(trimmed_text),
318 boundary_lb_level,
319 });
320 self.min_level = self.curr_level;
321 }
322
323 fn into_atom_chunks(mut self) -> Vec<AtomChunk> {
324 self.atom_chunks.push(AtomChunk {
325 range: TextRange::new(self.full_text.len(), self.full_text.len()),
326 boundary_syntax_level: self.min_level,
327 internal_lb_level: LineBreakLevel::Inline,
328 boundary_lb_level: LineBreakLevel::DoubleNewline,
329 });
330 self.atom_chunks
331 }
332}
333
334struct ChunkOutput {
335 start_pos: Position,
336 end_pos: Position,
337}
338
339struct InternalRecursiveChunker<'s> {
340 full_text: &'s str,
341 chunk_size: usize,
342 chunk_overlap: usize,
343 min_chunk_size: usize,
344 min_atom_chunk_size: usize,
345}
346
347impl<'t, 's: 't> InternalRecursiveChunker<'s> {
348 fn collect_atom_chunks(
349 &self,
350 chunk: InternalChunk<'t, 's>,
351 atom_collector: &mut AtomChunksCollector<'s>,
352 ) {
353 let mut iter_stack: Vec<ChunkIterator<'t, 's>> =
354 vec![ChunkIterator::Once(std::iter::once(chunk))];
355
356 while !iter_stack.is_empty() {
357 atom_collector.curr_level = iter_stack.len();
358
359 if let Some(current_chunk) = iter_stack.last_mut().unwrap().next() {
360 if current_chunk.range.len() <= self.min_atom_chunk_size {
361 atom_collector.collect(current_chunk.range);
362 } else {
363 match current_chunk.kind {
364 ChunkKind::TreeSitterNode {
365 tree_sitter_info: lang_config,
366 node,
367 } => {
368 if !lang_config.terminal_node_kind_ids.contains(&node.kind_id()) {
369 let mut cursor = node.walk();
370 if cursor.goto_first_child() {
371 iter_stack.push(ChunkIterator::TreeSitter(
372 TreeSitterNodeIter {
373 lang_config,
374 full_text: self.full_text,
375 cursor: Some(cursor),
376 next_start_pos: node.start_byte(),
377 end_pos: node.end_byte(),
378 },
379 ));
380 continue;
381 }
382 }
383 iter_stack.push(ChunkIterator::Once(std::iter::once(InternalChunk {
384 full_text: self.full_text,
385 range: current_chunk.range,
386 kind: ChunkKind::RegexpSepChunk {
387 lang_config: &DEFAULT_LANGUAGE_CONFIG,
388 next_regexp_sep_id: 0,
389 },
390 })));
391 }
392 ChunkKind::RegexpSepChunk {
393 lang_config,
394 next_regexp_sep_id,
395 } => {
396 if next_regexp_sep_id >= lang_config.separator_regex.len() {
397 atom_collector.collect(current_chunk.range);
398 } else {
399 iter_stack.push(ChunkIterator::Text(TextChunksIter::new(
400 lang_config,
401 current_chunk.full_text,
402 current_chunk.range,
403 next_regexp_sep_id,
404 )));
405 }
406 }
407 }
408 }
409 } else {
410 iter_stack.pop();
411 let level_after_pop = iter_stack.len();
412 atom_collector.curr_level = level_after_pop;
413 if level_after_pop < atom_collector.min_level {
414 atom_collector.min_level = level_after_pop;
415 }
416 }
417 }
418 atom_collector.curr_level = 0;
419 }
420
421 fn get_overlap_cost_base(&self, offset: usize) -> usize {
422 if self.chunk_overlap == 0 {
423 0
424 } else {
425 (self.full_text.len() - offset) * MISSING_OVERLAP_COST / self.chunk_overlap
426 }
427 }
428
429 fn merge_atom_chunks(&self, atom_chunks: Vec<AtomChunk>) -> Vec<ChunkOutput> {
430 struct AtomRoutingPlan {
431 start_idx: usize,
432 prev_plan_idx: usize,
433 cost: usize,
434 overlap_cost_base: usize,
435 }
436 type PrevPlanCandidate = (std::cmp::Reverse<usize>, usize);
437
438 let mut plans = Vec::with_capacity(atom_chunks.len());
439 plans.push(AtomRoutingPlan {
440 start_idx: 0,
441 prev_plan_idx: 0,
442 cost: 0,
443 overlap_cost_base: self.get_overlap_cost_base(0),
444 });
445 let mut prev_plan_candidates = std::collections::BinaryHeap::<PrevPlanCandidate>::new();
446
447 let mut gap_cost_cache = vec![0];
448 let mut syntax_level_gap_cost = |boundary: usize, internal: usize| -> usize {
449 if boundary > internal {
450 let gap = boundary - internal;
451 for i in gap_cost_cache.len()..=gap {
452 gap_cost_cache.push(gap_cost_cache[i - 1] + SYNTAX_LEVEL_GAP_COST / i);
453 }
454 gap_cost_cache[gap]
455 } else {
456 0
457 }
458 };
459
460 for (i, chunk) in atom_chunks[0..atom_chunks.len() - 1].iter().enumerate() {
461 let mut min_cost = usize::MAX;
462 let mut arg_min_start_idx: usize = 0;
463 let mut arg_min_prev_plan_idx: usize = 0;
464 let mut start_idx = i;
465
466 let end_syntax_level = atom_chunks[i + 1].boundary_syntax_level;
467 let end_lb_level = atom_chunks[i + 1].boundary_lb_level;
468
469 let mut internal_syntax_level = usize::MAX;
470 let mut internal_lb_level = LineBreakLevel::Inline;
471
472 fn lb_level_gap(boundary: LineBreakLevel, internal: LineBreakLevel) -> usize {
473 if boundary.ord() < internal.ord() {
474 internal.ord() - boundary.ord()
475 } else {
476 0
477 }
478 }
479 loop {
480 let start_chunk = &atom_chunks[start_idx];
481 let chunk_size = chunk.range.end - start_chunk.range.start;
482
483 let mut cost = 0;
484 cost +=
485 syntax_level_gap_cost(start_chunk.boundary_syntax_level, internal_syntax_level);
486 cost += syntax_level_gap_cost(end_syntax_level, internal_syntax_level);
487 cost += (lb_level_gap(start_chunk.boundary_lb_level, internal_lb_level)
488 + lb_level_gap(end_lb_level, internal_lb_level))
489 * PER_LINE_BREAK_LEVEL_GAP_COST;
490 if chunk_size < self.min_chunk_size {
491 cost += TOO_SMALL_CHUNK_COST;
492 }
493
494 if chunk_size > self.chunk_size {
495 if min_cost == usize::MAX {
496 min_cost = cost + plans[start_idx].cost;
497 arg_min_start_idx = start_idx;
498 arg_min_prev_plan_idx = start_idx;
499 }
500 break;
501 }
502
503 let prev_plan_idx = if self.chunk_overlap > 0 {
504 while let Some(top_prev_plan) = prev_plan_candidates.peek() {
505 let overlap_size =
506 atom_chunks[top_prev_plan.1].range.end - start_chunk.range.start;
507 if overlap_size <= self.chunk_overlap {
508 break;
509 }
510 prev_plan_candidates.pop();
511 }
512 prev_plan_candidates.push((
513 std::cmp::Reverse(
514 plans[start_idx].cost + plans[start_idx].overlap_cost_base,
515 ),
516 start_idx,
517 ));
518 prev_plan_candidates.peek().unwrap().1
519 } else {
520 start_idx
521 };
522 let prev_plan = &plans[prev_plan_idx];
523 cost += prev_plan.cost;
524 if self.chunk_overlap == 0 {
525 cost += MISSING_OVERLAP_COST / 2;
526 } else {
527 let start_cost_base = self.get_overlap_cost_base(start_chunk.range.start);
528 cost += if prev_plan.overlap_cost_base < start_cost_base {
529 MISSING_OVERLAP_COST + prev_plan.overlap_cost_base - start_cost_base
530 } else {
531 MISSING_OVERLAP_COST
532 };
533 }
534 if cost < min_cost {
535 min_cost = cost;
536 arg_min_start_idx = start_idx;
537 arg_min_prev_plan_idx = prev_plan_idx;
538 }
539
540 if start_idx == 0 {
541 break;
542 }
543
544 start_idx -= 1;
545 internal_syntax_level =
546 internal_syntax_level.min(start_chunk.boundary_syntax_level);
547 internal_lb_level = internal_lb_level.max(start_chunk.internal_lb_level);
548 }
549 plans.push(AtomRoutingPlan {
550 start_idx: arg_min_start_idx,
551 prev_plan_idx: arg_min_prev_plan_idx,
552 cost: min_cost,
553 overlap_cost_base: self.get_overlap_cost_base(chunk.range.end),
554 });
555 prev_plan_candidates.clear();
556 }
557
558 let mut output = Vec::new();
559 let mut plan_idx = plans.len() - 1;
560 while plan_idx > 0 {
561 let plan = &plans[plan_idx];
562 let start_chunk = &atom_chunks[plan.start_idx];
563 let end_chunk = &atom_chunks[plan_idx - 1];
564 output.push(ChunkOutput {
565 start_pos: Position::new(start_chunk.range.start),
566 end_pos: Position::new(end_chunk.range.end),
567 });
568 plan_idx = plan.prev_plan_idx;
569 }
570 output.reverse();
571 output
572 }
573
574 fn split_root_chunk(&self, kind: ChunkKind<'t>) -> Vec<ChunkOutput> {
575 let mut atom_collector = AtomChunksCollector {
576 full_text: self.full_text,
577 min_level: 0,
578 curr_level: 0,
579 atom_chunks: Vec::new(),
580 };
581 self.collect_atom_chunks(
582 InternalChunk {
583 full_text: self.full_text,
584 range: TextRange::new(0, self.full_text.len()),
585 kind,
586 },
587 &mut atom_collector,
588 );
589 let atom_chunks = atom_collector.into_atom_chunks();
590 self.merge_atom_chunks(atom_chunks)
591 }
592}
593
594pub struct RecursiveChunker {
596 custom_languages: HashMap<UniCase<String>, Arc<SimpleLanguageConfig>>,
597}
598
599impl RecursiveChunker {
600 pub fn new(config: RecursiveSplitConfig) -> Result<Self, String> {
604 let mut custom_languages = HashMap::new();
605 for lang in config.custom_languages {
606 let separator_regex = lang
607 .separators_regex
608 .iter()
609 .map(|s| Regex::new(s))
610 .collect::<Result<Vec<_>, _>>()
611 .map_err(|e| {
612 format!(
613 "failed in parsing regexp for language `{}`: {}",
614 lang.language_name, e
615 )
616 })?;
617 let language_config = Arc::new(SimpleLanguageConfig {
618 name: lang.language_name,
619 aliases: lang.aliases,
620 separator_regex,
621 });
622 if custom_languages
623 .insert(
624 UniCase::new(language_config.name.clone()),
625 language_config.clone(),
626 )
627 .is_some()
628 {
629 return Err(format!(
630 "duplicate language name / alias: `{}`",
631 language_config.name
632 ));
633 }
634 for alias in &language_config.aliases {
635 if custom_languages
636 .insert(UniCase::new(alias.clone()), language_config.clone())
637 .is_some()
638 {
639 return Err(format!("duplicate language name / alias: `{}`", alias));
640 }
641 }
642 }
643 Ok(Self { custom_languages })
644 }
645
646 pub fn split(&self, text: &str, config: RecursiveChunkConfig) -> Vec<Chunk> {
648 let min_chunk_size = config.min_chunk_size.unwrap_or(config.chunk_size / 2);
649 let chunk_overlap = std::cmp::min(config.chunk_overlap.unwrap_or(0), min_chunk_size);
650
651 let internal_chunker = InternalRecursiveChunker {
652 full_text: text,
653 chunk_size: config.chunk_size,
654 chunk_overlap,
655 min_chunk_size,
656 min_atom_chunk_size: if chunk_overlap > 0 {
657 chunk_overlap
658 } else {
659 min_chunk_size
660 },
661 };
662
663 let language = UniCase::new(config.language.unwrap_or_default());
664 let mut output = if let Some(lang_config) = self.custom_languages.get(&language) {
665 internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
666 lang_config,
667 next_regexp_sep_id: 0,
668 })
669 } else if let Some(lang_info) = prog_langs::get_language_info(&language)
670 && let Some(tree_sitter_info) = lang_info.treesitter_info.as_ref()
671 {
672 let mut parser = tree_sitter::Parser::new();
673 if parser
674 .set_language(&tree_sitter_info.tree_sitter_lang)
675 .is_err()
676 {
677 internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
679 lang_config: &DEFAULT_LANGUAGE_CONFIG,
680 next_regexp_sep_id: 0,
681 })
682 } else if let Some(tree) = parser.parse(text, None) {
683 internal_chunker.split_root_chunk(ChunkKind::TreeSitterNode {
684 tree_sitter_info,
685 node: tree.root_node(),
686 })
687 } else {
688 internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
690 lang_config: &DEFAULT_LANGUAGE_CONFIG,
691 next_regexp_sep_id: 0,
692 })
693 }
694 } else {
695 internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
696 lang_config: &DEFAULT_LANGUAGE_CONFIG,
697 next_regexp_sep_id: 0,
698 })
699 };
700
701 set_output_positions(
703 text,
704 output.iter_mut().flat_map(|chunk_output| {
705 std::iter::once(&mut chunk_output.start_pos)
706 .chain(std::iter::once(&mut chunk_output.end_pos))
707 }),
708 );
709
710 output
712 .into_iter()
713 .map(|chunk_output| {
714 let start = chunk_output.start_pos.output.unwrap();
715 let end = chunk_output.end_pos.output.unwrap();
716 Chunk {
717 range: TextRange::new(
718 chunk_output.start_pos.byte_offset,
719 chunk_output.end_pos.byte_offset,
720 ),
721 start,
722 end,
723 }
724 })
725 .collect()
726 }
727}
728
729#[cfg(test)]
730mod tests {
731 use super::*;
732
733 #[test]
734 fn test_split_basic() {
735 let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
736 let text = "Linea 1.\nLinea 2.\n\nLinea 3.";
737 let config = RecursiveChunkConfig {
738 chunk_size: 15,
739 min_chunk_size: Some(5),
740 chunk_overlap: Some(0),
741 language: None,
742 };
743 let chunks = chunker.split(text, config);
744
745 assert_eq!(chunks.len(), 3);
746 assert_eq!(
747 &text[chunks[0].range.start..chunks[0].range.end],
748 "Linea 1."
749 );
750 assert_eq!(
751 &text[chunks[1].range.start..chunks[1].range.end],
752 "Linea 2."
753 );
754 assert_eq!(
755 &text[chunks[2].range.start..chunks[2].range.end],
756 "Linea 3."
757 );
758 }
759
760 #[test]
761 fn test_split_long_text() {
762 let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
763 let text = "A very very long text that needs to be split.";
764 let config = RecursiveChunkConfig {
765 chunk_size: 20,
766 min_chunk_size: Some(12),
767 chunk_overlap: Some(0),
768 language: None,
769 };
770 let chunks = chunker.split(text, config);
771
772 assert!(chunks.len() > 1);
773 for chunk in &chunks {
774 let chunk_text = &text[chunk.range.start..chunk.range.end];
775 assert!(chunk_text.len() <= 20);
776 }
777 }
778
779 #[test]
780 fn test_split_with_overlap() {
781 let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
782 let text = "This is a test text that is a bit longer to see how the overlap works.";
783 let config = RecursiveChunkConfig {
784 chunk_size: 20,
785 min_chunk_size: Some(10),
786 chunk_overlap: Some(5),
787 language: None,
788 };
789 let chunks = chunker.split(text, config);
790
791 assert!(chunks.len() > 1);
792 for chunk in &chunks {
793 let chunk_text = &text[chunk.range.start..chunk.range.end];
794 assert!(
795 chunk_text.len() <= 25,
796 "Chunk was too long: '{}'",
797 chunk_text
798 );
799 }
800 }
801
802 #[test]
803 fn test_split_trims_whitespace() {
804 let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
805 let text = " \n First chunk \n\n Second chunk with spaces at the end \n";
806 let config = RecursiveChunkConfig {
807 chunk_size: 30,
808 min_chunk_size: Some(10),
809 chunk_overlap: Some(0),
810 language: None,
811 };
812 let chunks = chunker.split(text, config);
813
814 assert_eq!(chunks.len(), 3);
815 let chunk_text = &text[chunks[0].range.start..chunks[0].range.end];
817 assert!(!chunk_text.starts_with(" "));
818 }
819
820 #[test]
821 fn test_split_with_rust_language() {
822 let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
823 let text = r#"
824fn main() {
825 println!("Hello");
826}
827
828fn other() {
829 let x = 1;
830}
831"#;
832 let config = RecursiveChunkConfig {
833 chunk_size: 50,
834 min_chunk_size: Some(20),
835 chunk_overlap: Some(0),
836 language: Some("rust".to_string()),
837 };
838 let chunks = chunker.split(text, config);
839
840 assert!(!chunks.is_empty());
841 }
842
843 #[test]
844 fn test_split_positions() {
845 let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
846 let text = "Chunk1\n\nChunk2";
847 let config = RecursiveChunkConfig {
848 chunk_size: 10,
849 min_chunk_size: Some(5),
850 chunk_overlap: Some(0),
851 language: None,
852 };
853 let chunks = chunker.split(text, config);
854
855 assert_eq!(chunks.len(), 2);
856 assert_eq!(chunks[0].start.line, 1);
857 assert_eq!(chunks[0].start.column, 1);
858 assert_eq!(chunks[1].start.line, 3);
859 assert_eq!(chunks[1].start.column, 1);
860 }
861
862 #[test]
863 fn test_custom_language() {
864 let config = RecursiveSplitConfig {
865 custom_languages: vec![CustomLanguageConfig {
866 language_name: "myformat".to_string(),
867 aliases: vec!["mf".to_string()],
868 separators_regex: vec![r"---".to_string()],
869 }],
870 };
871 let chunker = RecursiveChunker::new(config).unwrap();
872 let text = "Part1---Part2---Part3";
873 let chunk_config = RecursiveChunkConfig {
874 chunk_size: 10,
875 min_chunk_size: Some(4),
876 chunk_overlap: Some(0),
877 language: Some("myformat".to_string()),
878 };
879 let chunks = chunker.split(text, chunk_config);
880
881 assert_eq!(chunks.len(), 3);
882 assert_eq!(&text[chunks[0].range.start..chunks[0].range.end], "Part1");
883 assert_eq!(&text[chunks[1].range.start..chunks[1].range.end], "Part2");
884 assert_eq!(&text[chunks[2].range.start..chunks[2].range.end], "Part3");
885 }
886}