Skip to main content

recoco_splitters/
recursive.rs

1// Recoco is a Rust-only fork of CocoIndex, by [CocoIndex](https://CocoIndex)
2// Original code from CocoIndex is copyrighted by CocoIndex
3// SPDX-FileCopyrightText: 2025-2026 CocoIndex (upstream)
4// SPDX-FileContributor: CocoIndex Contributors
5//
6// All modifications from the upstream for Recoco are copyrighted by Knitli Inc.
7// SPDX-FileCopyrightText: 2026 Knitli Inc. (Recoco)
8// SPDX-FileContributor: Adam Poulemanos <adam@knit.li>
9//
10// Both the upstream CocoIndex code and the Recoco modifications are licensed under the Apache-2.0 License.
11// SPDX-License-Identifier: Apache-2.0
12
13//! Recursive text chunking with syntax awareness.
14
15use regex::{Matches, Regex};
16use std::collections::HashMap;
17use std::sync::{Arc, LazyLock};
18use unicase::UniCase;
19
20use crate::output_positions::{Position, set_output_positions};
21use crate::prog_langs::{self, TreeSitterLanguageInfo};
22use crate::split::{Chunk, TextRange};
23
24const SYNTAX_LEVEL_GAP_COST: usize = 512;
25const MISSING_OVERLAP_COST: usize = 512;
26const PER_LINE_BREAK_LEVEL_GAP_COST: usize = 64;
27const TOO_SMALL_CHUNK_COST: usize = 1048576;
28
29/// Configuration for a custom language with regex-based separators.
30#[derive(Debug, Clone)]
31pub struct CustomLanguageConfig {
32    /// The name of the language.
33    pub language_name: String,
34    /// Aliases for the language name.
35    pub aliases: Vec<String>,
36    /// Regex patterns for separators, in order of priority.
37    pub separators_regex: Vec<String>,
38}
39
40/// Configuration for recursive text splitting.
41#[derive(Debug, Clone, Default)]
42pub struct RecursiveSplitConfig {
43    /// Custom language configurations.
44    pub custom_languages: Vec<CustomLanguageConfig>,
45}
46
47/// Configuration for a single chunking operation.
48#[derive(Debug, Clone)]
49pub struct RecursiveChunkConfig {
50    /// Target chunk size in bytes.
51    pub chunk_size: usize,
52    /// Minimum chunk size in bytes. Defaults to chunk_size / 2.
53    pub min_chunk_size: Option<usize>,
54    /// Overlap between consecutive chunks in bytes.
55    pub chunk_overlap: Option<usize>,
56    /// Language name or file extension for syntax-aware splitting.
57    pub language: Option<String>,
58}
59
60struct SimpleLanguageConfig {
61    name: String,
62    aliases: Vec<String>,
63    separator_regex: Vec<Regex>,
64}
65
66static DEFAULT_LANGUAGE_CONFIG: LazyLock<SimpleLanguageConfig> =
67    LazyLock::new(|| SimpleLanguageConfig {
68        name: "_DEFAULT".to_string(),
69        aliases: vec![],
70        separator_regex: [
71            r"\n\n+",
72            r"\n",
73            r"[\.\?!]\s+|。|?|!",
74            r"[;:\-—]\s+|;|:|—+",
75            r",\s+|,",
76            r"\s+",
77        ]
78        .into_iter()
79        .map(|s| Regex::new(s).unwrap())
80        .collect(),
81    });
82
83enum ChunkKind<'t> {
84    TreeSitterNode {
85        tree_sitter_info: &'t TreeSitterLanguageInfo,
86        node: tree_sitter::Node<'t>,
87    },
88    RegexpSepChunk {
89        lang_config: &'t SimpleLanguageConfig,
90        next_regexp_sep_id: usize,
91    },
92}
93
94struct InternalChunk<'t, 's: 't> {
95    full_text: &'s str,
96    range: TextRange,
97    kind: ChunkKind<'t>,
98}
99
100struct TextChunksIter<'t, 's: 't> {
101    lang_config: &'t SimpleLanguageConfig,
102    full_text: &'s str,
103    range: TextRange,
104    matches_iter: Matches<'t, 's>,
105    regexp_sep_id: usize,
106    next_start_pos: Option<usize>,
107}
108
109impl<'t, 's: 't> TextChunksIter<'t, 's> {
110    fn new(
111        lang_config: &'t SimpleLanguageConfig,
112        full_text: &'s str,
113        range: TextRange,
114        regexp_sep_id: usize,
115    ) -> Self {
116        let std_range = range.start..range.end;
117        Self {
118            lang_config,
119            full_text,
120            range,
121            matches_iter: lang_config.separator_regex[regexp_sep_id]
122                .find_iter(&full_text[std_range.clone()]),
123            regexp_sep_id,
124            next_start_pos: Some(std_range.start),
125        }
126    }
127}
128
129impl<'t, 's: 't> Iterator for TextChunksIter<'t, 's> {
130    type Item = InternalChunk<'t, 's>;
131
132    fn next(&mut self) -> Option<Self::Item> {
133        let start_pos = self.next_start_pos?;
134        let end_pos = match self.matches_iter.next() {
135            Some(grp) => {
136                self.next_start_pos = Some(self.range.start + grp.end());
137                self.range.start + grp.start()
138            }
139            None => {
140                self.next_start_pos = None;
141                if start_pos >= self.range.end {
142                    return None;
143                }
144                self.range.end
145            }
146        };
147        Some(InternalChunk {
148            full_text: self.full_text,
149            range: TextRange::new(start_pos, end_pos),
150            kind: ChunkKind::RegexpSepChunk {
151                lang_config: self.lang_config,
152                next_regexp_sep_id: self.regexp_sep_id + 1,
153            },
154        })
155    }
156}
157
158struct TreeSitterNodeIter<'t, 's: 't> {
159    lang_config: &'t TreeSitterLanguageInfo,
160    full_text: &'s str,
161    cursor: Option<tree_sitter::TreeCursor<'t>>,
162    next_start_pos: usize,
163    end_pos: usize,
164}
165
166impl<'t, 's: 't> TreeSitterNodeIter<'t, 's> {
167    fn fill_gap(
168        next_start_pos: &mut usize,
169        gap_end_pos: usize,
170        full_text: &'s str,
171    ) -> Option<InternalChunk<'t, 's>> {
172        let start_pos = *next_start_pos;
173        if start_pos < gap_end_pos {
174            *next_start_pos = gap_end_pos;
175            Some(InternalChunk {
176                full_text,
177                range: TextRange::new(start_pos, gap_end_pos),
178                kind: ChunkKind::RegexpSepChunk {
179                    lang_config: &DEFAULT_LANGUAGE_CONFIG,
180                    next_regexp_sep_id: 0,
181                },
182            })
183        } else {
184            None
185        }
186    }
187}
188
189impl<'t, 's: 't> Iterator for TreeSitterNodeIter<'t, 's> {
190    type Item = InternalChunk<'t, 's>;
191
192    fn next(&mut self) -> Option<Self::Item> {
193        let cursor = if let Some(cursor) = &mut self.cursor {
194            cursor
195        } else {
196            return Self::fill_gap(&mut self.next_start_pos, self.end_pos, self.full_text);
197        };
198        let node = cursor.node();
199        if let Some(gap) =
200            Self::fill_gap(&mut self.next_start_pos, node.start_byte(), self.full_text)
201        {
202            return Some(gap);
203        }
204        if !cursor.goto_next_sibling() {
205            self.cursor = None;
206        }
207        self.next_start_pos = node.end_byte();
208        Some(InternalChunk {
209            full_text: self.full_text,
210            range: TextRange::new(node.start_byte(), node.end_byte()),
211            kind: ChunkKind::TreeSitterNode {
212                tree_sitter_info: self.lang_config,
213                node,
214            },
215        })
216    }
217}
218
219enum ChunkIterator<'t, 's: 't> {
220    TreeSitter(TreeSitterNodeIter<'t, 's>),
221    Text(TextChunksIter<'t, 's>),
222    Once(std::iter::Once<InternalChunk<'t, 's>>),
223}
224
225impl<'t, 's: 't> Iterator for ChunkIterator<'t, 's> {
226    type Item = InternalChunk<'t, 's>;
227
228    fn next(&mut self) -> Option<Self::Item> {
229        match self {
230            ChunkIterator::TreeSitter(iter) => iter.next(),
231            ChunkIterator::Text(iter) => iter.next(),
232            ChunkIterator::Once(iter) => iter.next(),
233        }
234    }
235}
236
237#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
238enum LineBreakLevel {
239    Inline,
240    Newline,
241    DoubleNewline,
242}
243
244impl LineBreakLevel {
245    fn ord(self) -> usize {
246        match self {
247            LineBreakLevel::Inline => 0,
248            LineBreakLevel::Newline => 1,
249            LineBreakLevel::DoubleNewline => 2,
250        }
251    }
252}
253
254fn line_break_level(c: &str) -> LineBreakLevel {
255    let mut lb_level = LineBreakLevel::Inline;
256    let mut iter = c.chars();
257    while let Some(c) = iter.next() {
258        if c == '\n' || c == '\r' {
259            lb_level = LineBreakLevel::Newline;
260            for c2 in iter.by_ref() {
261                if c2 == '\n' || c2 == '\r' {
262                    if c == c2 {
263                        return LineBreakLevel::DoubleNewline;
264                    }
265                } else {
266                    break;
267                }
268            }
269        }
270    }
271    lb_level
272}
273
274const INLINE_SPACE_CHARS: [char; 2] = [' ', '\t'];
275
276struct AtomChunk {
277    range: TextRange,
278    boundary_syntax_level: usize,
279    internal_lb_level: LineBreakLevel,
280    boundary_lb_level: LineBreakLevel,
281}
282
283struct AtomChunksCollector<'s> {
284    full_text: &'s str,
285    curr_level: usize,
286    min_level: usize,
287    atom_chunks: Vec<AtomChunk>,
288}
289
290impl<'s> AtomChunksCollector<'s> {
291    fn collect(&mut self, range: TextRange) {
292        // Trim trailing whitespaces.
293        let end_trimmed_text = &self.full_text[range.start..range.end].trim_end();
294        if end_trimmed_text.is_empty() {
295            return;
296        }
297
298        // Trim leading whitespaces.
299        let trimmed_text = end_trimmed_text.trim_start();
300        let new_start = range.start + (end_trimmed_text.len() - trimmed_text.len());
301        let new_end = new_start + trimmed_text.len();
302
303        // Align to beginning of the line if possible.
304        let prev_end = self.atom_chunks.last().map_or(0, |chunk| chunk.range.end);
305        let gap = &self.full_text[prev_end..new_start];
306        let boundary_lb_level = line_break_level(gap);
307        let range = if boundary_lb_level != LineBreakLevel::Inline {
308            let trimmed_gap = gap.trim_end_matches(INLINE_SPACE_CHARS);
309            TextRange::new(prev_end + trimmed_gap.len(), new_end)
310        } else {
311            TextRange::new(new_start, new_end)
312        };
313
314        self.atom_chunks.push(AtomChunk {
315            range,
316            boundary_syntax_level: self.min_level,
317            internal_lb_level: line_break_level(trimmed_text),
318            boundary_lb_level,
319        });
320        self.min_level = self.curr_level;
321    }
322
323    fn into_atom_chunks(mut self) -> Vec<AtomChunk> {
324        self.atom_chunks.push(AtomChunk {
325            range: TextRange::new(self.full_text.len(), self.full_text.len()),
326            boundary_syntax_level: self.min_level,
327            internal_lb_level: LineBreakLevel::Inline,
328            boundary_lb_level: LineBreakLevel::DoubleNewline,
329        });
330        self.atom_chunks
331    }
332}
333
334struct ChunkOutput {
335    start_pos: Position,
336    end_pos: Position,
337}
338
339struct InternalRecursiveChunker<'s> {
340    full_text: &'s str,
341    chunk_size: usize,
342    chunk_overlap: usize,
343    min_chunk_size: usize,
344    min_atom_chunk_size: usize,
345}
346
347impl<'t, 's: 't> InternalRecursiveChunker<'s> {
348    fn collect_atom_chunks(
349        &self,
350        chunk: InternalChunk<'t, 's>,
351        atom_collector: &mut AtomChunksCollector<'s>,
352    ) {
353        let mut iter_stack: Vec<ChunkIterator<'t, 's>> =
354            vec![ChunkIterator::Once(std::iter::once(chunk))];
355
356        while !iter_stack.is_empty() {
357            atom_collector.curr_level = iter_stack.len();
358
359            if let Some(current_chunk) = iter_stack.last_mut().unwrap().next() {
360                if current_chunk.range.len() <= self.min_atom_chunk_size {
361                    atom_collector.collect(current_chunk.range);
362                } else {
363                    match current_chunk.kind {
364                        ChunkKind::TreeSitterNode {
365                            tree_sitter_info: lang_config,
366                            node,
367                        } => {
368                            if !lang_config.terminal_node_kind_ids.contains(&node.kind_id()) {
369                                let mut cursor = node.walk();
370                                if cursor.goto_first_child() {
371                                    iter_stack.push(ChunkIterator::TreeSitter(
372                                        TreeSitterNodeIter {
373                                            lang_config,
374                                            full_text: self.full_text,
375                                            cursor: Some(cursor),
376                                            next_start_pos: node.start_byte(),
377                                            end_pos: node.end_byte(),
378                                        },
379                                    ));
380                                    continue;
381                                }
382                            }
383                            iter_stack.push(ChunkIterator::Once(std::iter::once(InternalChunk {
384                                full_text: self.full_text,
385                                range: current_chunk.range,
386                                kind: ChunkKind::RegexpSepChunk {
387                                    lang_config: &DEFAULT_LANGUAGE_CONFIG,
388                                    next_regexp_sep_id: 0,
389                                },
390                            })));
391                        }
392                        ChunkKind::RegexpSepChunk {
393                            lang_config,
394                            next_regexp_sep_id,
395                        } => {
396                            if next_regexp_sep_id >= lang_config.separator_regex.len() {
397                                atom_collector.collect(current_chunk.range);
398                            } else {
399                                iter_stack.push(ChunkIterator::Text(TextChunksIter::new(
400                                    lang_config,
401                                    current_chunk.full_text,
402                                    current_chunk.range,
403                                    next_regexp_sep_id,
404                                )));
405                            }
406                        }
407                    }
408                }
409            } else {
410                iter_stack.pop();
411                let level_after_pop = iter_stack.len();
412                atom_collector.curr_level = level_after_pop;
413                if level_after_pop < atom_collector.min_level {
414                    atom_collector.min_level = level_after_pop;
415                }
416            }
417        }
418        atom_collector.curr_level = 0;
419    }
420
421    fn get_overlap_cost_base(&self, offset: usize) -> usize {
422        if self.chunk_overlap == 0 {
423            0
424        } else {
425            (self.full_text.len() - offset) * MISSING_OVERLAP_COST / self.chunk_overlap
426        }
427    }
428
429    fn merge_atom_chunks(&self, atom_chunks: Vec<AtomChunk>) -> Vec<ChunkOutput> {
430        struct AtomRoutingPlan {
431            start_idx: usize,
432            prev_plan_idx: usize,
433            cost: usize,
434            overlap_cost_base: usize,
435        }
436        type PrevPlanCandidate = (std::cmp::Reverse<usize>, usize);
437
438        let mut plans = Vec::with_capacity(atom_chunks.len());
439        plans.push(AtomRoutingPlan {
440            start_idx: 0,
441            prev_plan_idx: 0,
442            cost: 0,
443            overlap_cost_base: self.get_overlap_cost_base(0),
444        });
445        let mut prev_plan_candidates = std::collections::BinaryHeap::<PrevPlanCandidate>::new();
446
447        let mut gap_cost_cache = vec![0];
448        let mut syntax_level_gap_cost = |boundary: usize, internal: usize| -> usize {
449            if boundary > internal {
450                let gap = boundary - internal;
451                for i in gap_cost_cache.len()..=gap {
452                    gap_cost_cache.push(gap_cost_cache[i - 1] + SYNTAX_LEVEL_GAP_COST / i);
453                }
454                gap_cost_cache[gap]
455            } else {
456                0
457            }
458        };
459
460        for (i, chunk) in atom_chunks[0..atom_chunks.len() - 1].iter().enumerate() {
461            let mut min_cost = usize::MAX;
462            let mut arg_min_start_idx: usize = 0;
463            let mut arg_min_prev_plan_idx: usize = 0;
464            let mut start_idx = i;
465
466            let end_syntax_level = atom_chunks[i + 1].boundary_syntax_level;
467            let end_lb_level = atom_chunks[i + 1].boundary_lb_level;
468
469            let mut internal_syntax_level = usize::MAX;
470            let mut internal_lb_level = LineBreakLevel::Inline;
471
472            fn lb_level_gap(boundary: LineBreakLevel, internal: LineBreakLevel) -> usize {
473                if boundary.ord() < internal.ord() {
474                    internal.ord() - boundary.ord()
475                } else {
476                    0
477                }
478            }
479            loop {
480                let start_chunk = &atom_chunks[start_idx];
481                let chunk_size = chunk.range.end - start_chunk.range.start;
482
483                let mut cost = 0;
484                cost +=
485                    syntax_level_gap_cost(start_chunk.boundary_syntax_level, internal_syntax_level);
486                cost += syntax_level_gap_cost(end_syntax_level, internal_syntax_level);
487                cost += (lb_level_gap(start_chunk.boundary_lb_level, internal_lb_level)
488                    + lb_level_gap(end_lb_level, internal_lb_level))
489                    * PER_LINE_BREAK_LEVEL_GAP_COST;
490                if chunk_size < self.min_chunk_size {
491                    cost += TOO_SMALL_CHUNK_COST;
492                }
493
494                if chunk_size > self.chunk_size {
495                    if min_cost == usize::MAX {
496                        min_cost = cost + plans[start_idx].cost;
497                        arg_min_start_idx = start_idx;
498                        arg_min_prev_plan_idx = start_idx;
499                    }
500                    break;
501                }
502
503                let prev_plan_idx = if self.chunk_overlap > 0 {
504                    while let Some(top_prev_plan) = prev_plan_candidates.peek() {
505                        let overlap_size =
506                            atom_chunks[top_prev_plan.1].range.end - start_chunk.range.start;
507                        if overlap_size <= self.chunk_overlap {
508                            break;
509                        }
510                        prev_plan_candidates.pop();
511                    }
512                    prev_plan_candidates.push((
513                        std::cmp::Reverse(
514                            plans[start_idx].cost + plans[start_idx].overlap_cost_base,
515                        ),
516                        start_idx,
517                    ));
518                    prev_plan_candidates.peek().unwrap().1
519                } else {
520                    start_idx
521                };
522                let prev_plan = &plans[prev_plan_idx];
523                cost += prev_plan.cost;
524                if self.chunk_overlap == 0 {
525                    cost += MISSING_OVERLAP_COST / 2;
526                } else {
527                    let start_cost_base = self.get_overlap_cost_base(start_chunk.range.start);
528                    cost += if prev_plan.overlap_cost_base < start_cost_base {
529                        MISSING_OVERLAP_COST + prev_plan.overlap_cost_base - start_cost_base
530                    } else {
531                        MISSING_OVERLAP_COST
532                    };
533                }
534                if cost < min_cost {
535                    min_cost = cost;
536                    arg_min_start_idx = start_idx;
537                    arg_min_prev_plan_idx = prev_plan_idx;
538                }
539
540                if start_idx == 0 {
541                    break;
542                }
543
544                start_idx -= 1;
545                internal_syntax_level =
546                    internal_syntax_level.min(start_chunk.boundary_syntax_level);
547                internal_lb_level = internal_lb_level.max(start_chunk.internal_lb_level);
548            }
549            plans.push(AtomRoutingPlan {
550                start_idx: arg_min_start_idx,
551                prev_plan_idx: arg_min_prev_plan_idx,
552                cost: min_cost,
553                overlap_cost_base: self.get_overlap_cost_base(chunk.range.end),
554            });
555            prev_plan_candidates.clear();
556        }
557
558        let mut output = Vec::new();
559        let mut plan_idx = plans.len() - 1;
560        while plan_idx > 0 {
561            let plan = &plans[plan_idx];
562            let start_chunk = &atom_chunks[plan.start_idx];
563            let end_chunk = &atom_chunks[plan_idx - 1];
564            output.push(ChunkOutput {
565                start_pos: Position::new(start_chunk.range.start),
566                end_pos: Position::new(end_chunk.range.end),
567            });
568            plan_idx = plan.prev_plan_idx;
569        }
570        output.reverse();
571        output
572    }
573
574    fn split_root_chunk(&self, kind: ChunkKind<'t>) -> Vec<ChunkOutput> {
575        let mut atom_collector = AtomChunksCollector {
576            full_text: self.full_text,
577            min_level: 0,
578            curr_level: 0,
579            atom_chunks: Vec::new(),
580        };
581        self.collect_atom_chunks(
582            InternalChunk {
583                full_text: self.full_text,
584                range: TextRange::new(0, self.full_text.len()),
585                kind,
586            },
587            &mut atom_collector,
588        );
589        let atom_chunks = atom_collector.into_atom_chunks();
590        self.merge_atom_chunks(atom_chunks)
591    }
592}
593
594/// A recursive text chunker with syntax awareness.
595pub struct RecursiveChunker {
596    custom_languages: HashMap<UniCase<String>, Arc<SimpleLanguageConfig>>,
597}
598
599impl RecursiveChunker {
600    /// Create a new recursive chunker with the given configuration.
601    ///
602    /// Returns an error if any regex pattern is invalid or if there are duplicate language names.
603    pub fn new(config: RecursiveSplitConfig) -> Result<Self, String> {
604        let mut custom_languages = HashMap::new();
605        for lang in config.custom_languages {
606            let separator_regex = lang
607                .separators_regex
608                .iter()
609                .map(|s| Regex::new(s))
610                .collect::<Result<Vec<_>, _>>()
611                .map_err(|e| {
612                    format!(
613                        "failed in parsing regexp for language `{}`: {}",
614                        lang.language_name, e
615                    )
616                })?;
617            let language_config = Arc::new(SimpleLanguageConfig {
618                name: lang.language_name,
619                aliases: lang.aliases,
620                separator_regex,
621            });
622            if custom_languages
623                .insert(
624                    UniCase::new(language_config.name.clone()),
625                    language_config.clone(),
626                )
627                .is_some()
628            {
629                return Err(format!(
630                    "duplicate language name / alias: `{}`",
631                    language_config.name
632                ));
633            }
634            for alias in &language_config.aliases {
635                if custom_languages
636                    .insert(UniCase::new(alias.clone()), language_config.clone())
637                    .is_some()
638                {
639                    return Err(format!("duplicate language name / alias: `{}`", alias));
640                }
641            }
642        }
643        Ok(Self { custom_languages })
644    }
645
646    /// Split the text into chunks according to the configuration.
647    pub fn split(&self, text: &str, config: RecursiveChunkConfig) -> Vec<Chunk> {
648        let min_chunk_size = config.min_chunk_size.unwrap_or(config.chunk_size / 2);
649        let chunk_overlap = std::cmp::min(config.chunk_overlap.unwrap_or(0), min_chunk_size);
650
651        let internal_chunker = InternalRecursiveChunker {
652            full_text: text,
653            chunk_size: config.chunk_size,
654            chunk_overlap,
655            min_chunk_size,
656            min_atom_chunk_size: if chunk_overlap > 0 {
657                chunk_overlap
658            } else {
659                min_chunk_size
660            },
661        };
662
663        let language = UniCase::new(config.language.unwrap_or_default());
664        let mut output = if let Some(lang_config) = self.custom_languages.get(&language) {
665            internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
666                lang_config,
667                next_regexp_sep_id: 0,
668            })
669        } else if let Some(lang_info) = prog_langs::get_language_info(&language)
670            && let Some(tree_sitter_info) = lang_info.treesitter_info.as_ref()
671        {
672            let mut parser = tree_sitter::Parser::new();
673            if parser
674                .set_language(&tree_sitter_info.tree_sitter_lang)
675                .is_err()
676            {
677                // Fall back to default if language setup fails
678                internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
679                    lang_config: &DEFAULT_LANGUAGE_CONFIG,
680                    next_regexp_sep_id: 0,
681                })
682            } else if let Some(tree) = parser.parse(text, None) {
683                internal_chunker.split_root_chunk(ChunkKind::TreeSitterNode {
684                    tree_sitter_info,
685                    node: tree.root_node(),
686                })
687            } else {
688                // Fall back to default if parsing fails
689                internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
690                    lang_config: &DEFAULT_LANGUAGE_CONFIG,
691                    next_regexp_sep_id: 0,
692                })
693            }
694        } else {
695            internal_chunker.split_root_chunk(ChunkKind::RegexpSepChunk {
696                lang_config: &DEFAULT_LANGUAGE_CONFIG,
697                next_regexp_sep_id: 0,
698            })
699        };
700
701        // Compute positions
702        set_output_positions(
703            text,
704            output.iter_mut().flat_map(|chunk_output| {
705                std::iter::once(&mut chunk_output.start_pos)
706                    .chain(std::iter::once(&mut chunk_output.end_pos))
707            }),
708        );
709
710        // Convert to final output
711        output
712            .into_iter()
713            .map(|chunk_output| {
714                let start = chunk_output.start_pos.output.unwrap();
715                let end = chunk_output.end_pos.output.unwrap();
716                Chunk {
717                    range: TextRange::new(
718                        chunk_output.start_pos.byte_offset,
719                        chunk_output.end_pos.byte_offset,
720                    ),
721                    start,
722                    end,
723                }
724            })
725            .collect()
726    }
727}
728
729#[cfg(test)]
730mod tests {
731    use super::*;
732
733    #[test]
734    fn test_split_basic() {
735        let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
736        let text = "Linea 1.\nLinea 2.\n\nLinea 3.";
737        let config = RecursiveChunkConfig {
738            chunk_size: 15,
739            min_chunk_size: Some(5),
740            chunk_overlap: Some(0),
741            language: None,
742        };
743        let chunks = chunker.split(text, config);
744
745        assert_eq!(chunks.len(), 3);
746        assert_eq!(
747            &text[chunks[0].range.start..chunks[0].range.end],
748            "Linea 1."
749        );
750        assert_eq!(
751            &text[chunks[1].range.start..chunks[1].range.end],
752            "Linea 2."
753        );
754        assert_eq!(
755            &text[chunks[2].range.start..chunks[2].range.end],
756            "Linea 3."
757        );
758    }
759
760    #[test]
761    fn test_split_long_text() {
762        let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
763        let text = "A very very long text that needs to be split.";
764        let config = RecursiveChunkConfig {
765            chunk_size: 20,
766            min_chunk_size: Some(12),
767            chunk_overlap: Some(0),
768            language: None,
769        };
770        let chunks = chunker.split(text, config);
771
772        assert!(chunks.len() > 1);
773        for chunk in &chunks {
774            let chunk_text = &text[chunk.range.start..chunk.range.end];
775            assert!(chunk_text.len() <= 20);
776        }
777    }
778
779    #[test]
780    fn test_split_with_overlap() {
781        let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
782        let text = "This is a test text that is a bit longer to see how the overlap works.";
783        let config = RecursiveChunkConfig {
784            chunk_size: 20,
785            min_chunk_size: Some(10),
786            chunk_overlap: Some(5),
787            language: None,
788        };
789        let chunks = chunker.split(text, config);
790
791        assert!(chunks.len() > 1);
792        for chunk in &chunks {
793            let chunk_text = &text[chunk.range.start..chunk.range.end];
794            assert!(
795                chunk_text.len() <= 25,
796                "Chunk was too long: '{}'",
797                chunk_text
798            );
799        }
800    }
801
802    #[test]
803    fn test_split_trims_whitespace() {
804        let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
805        let text = "  \n First chunk  \n\n  Second chunk with spaces at the end    \n";
806        let config = RecursiveChunkConfig {
807            chunk_size: 30,
808            min_chunk_size: Some(10),
809            chunk_overlap: Some(0),
810            language: None,
811        };
812        let chunks = chunker.split(text, config);
813
814        assert_eq!(chunks.len(), 3);
815        // Verify chunks are trimmed appropriately
816        let chunk_text = &text[chunks[0].range.start..chunks[0].range.end];
817        assert!(!chunk_text.starts_with("  "));
818    }
819
820    #[test]
821    fn test_split_with_rust_language() {
822        let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
823        let text = r#"
824fn main() {
825    println!("Hello");
826}
827
828fn other() {
829    let x = 1;
830}
831"#;
832        let config = RecursiveChunkConfig {
833            chunk_size: 50,
834            min_chunk_size: Some(20),
835            chunk_overlap: Some(0),
836            language: Some("rust".to_string()),
837        };
838        let chunks = chunker.split(text, config);
839
840        assert!(!chunks.is_empty());
841    }
842
843    #[test]
844    fn test_split_positions() {
845        let chunker = RecursiveChunker::new(RecursiveSplitConfig::default()).unwrap();
846        let text = "Chunk1\n\nChunk2";
847        let config = RecursiveChunkConfig {
848            chunk_size: 10,
849            min_chunk_size: Some(5),
850            chunk_overlap: Some(0),
851            language: None,
852        };
853        let chunks = chunker.split(text, config);
854
855        assert_eq!(chunks.len(), 2);
856        assert_eq!(chunks[0].start.line, 1);
857        assert_eq!(chunks[0].start.column, 1);
858        assert_eq!(chunks[1].start.line, 3);
859        assert_eq!(chunks[1].start.column, 1);
860    }
861
862    #[test]
863    fn test_custom_language() {
864        let config = RecursiveSplitConfig {
865            custom_languages: vec![CustomLanguageConfig {
866                language_name: "myformat".to_string(),
867                aliases: vec!["mf".to_string()],
868                separators_regex: vec![r"---".to_string()],
869            }],
870        };
871        let chunker = RecursiveChunker::new(config).unwrap();
872        let text = "Part1---Part2---Part3";
873        let chunk_config = RecursiveChunkConfig {
874            chunk_size: 10,
875            min_chunk_size: Some(4),
876            chunk_overlap: Some(0),
877            language: Some("myformat".to_string()),
878        };
879        let chunks = chunker.split(text, chunk_config);
880
881        assert_eq!(chunks.len(), 3);
882        assert_eq!(&text[chunks[0].range.start..chunks[0].range.end], "Part1");
883        assert_eq!(&text[chunks[1].range.start..chunks[1].range.end], "Part2");
884        assert_eq!(&text[chunks[2].range.start..chunks[2].range.end], "Part3");
885    }
886}