use std::collections::VecDeque;
use std::sync::Arc;
use regex::Regex;
use crate::character::{compile_separator_regex, validate_chunk_config};
use crate::chunk::{MeasuredSpan, TextChunk, TextChunkIter, TextSpan};
use crate::error::ChunkError;
use crate::merge::merge_spans;
use crate::sizing::{CharSizer, ChunkConfig, ChunkSizer, FunctionSizer};
use crate::split::KeepSeparator;
pub const DEFAULT_SEPARATORS: &[&str] = &["\n\n", "\n", " ", ""];
#[derive(Clone)]
pub struct RecursiveCharacterTextSplitter<S = CharSizer> {
pub(crate) separators: Vec<String>,
separator_regexes: Vec<Option<Regex>>,
pub(crate) config: ChunkConfig<S>,
pub(crate) keep_separator: bool,
pub(crate) strip_whitespace: bool,
length_fn: crate::LengthFn,
}
impl RecursiveCharacterTextSplitter<CharSizer> {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
separators: DEFAULT_SEPARATORS.iter().map(|s| s.to_string()).collect(),
separator_regexes: DEFAULT_SEPARATORS
.iter()
.map(|separator| {
compile_separator_regex(separator, false)
.expect("escaped default separator must compile")
})
.collect(),
config: ChunkConfig::new(chunk_size, chunk_overlap, CharSizer),
keep_separator: true,
strip_whitespace: true,
length_fn: Arc::new(crate::char_len),
}
}
pub fn builder() -> RecursiveCharacterTextSplitterBuilder<CharSizer> {
RecursiveCharacterTextSplitterBuilder::default()
}
}
impl<S> RecursiveCharacterTextSplitter<S>
where
S: ChunkSizer,
{
pub fn split_text(&self, text: &str) -> Vec<String> {
self.split_chunks(text)
.into_iter()
.map(|chunk| chunk.text.to_string())
.collect()
}
pub fn chunks<'a>(&'a self, text: &'a str) -> impl Iterator<Item = TextChunk<'a>> + 'a {
TextChunkIter::new(
text,
RecursiveSpanIter::new(self, text, self.length_fn.as_ref()),
)
}
pub fn split_chunks<'a>(&'a self, text: &'a str) -> Vec<TextChunk<'a>> {
self.chunks(text).collect()
}
}
struct RecursiveSpanIter<'a, S = CharSizer> {
splitter: &'a RecursiveCharacterTextSplitter<S>,
input: &'a str,
length_fn: &'a dyn Fn(&str) -> usize,
frames: Vec<RecursiveFrame>,
output: VecDeque<TextSpan>,
}
struct RecursiveFrame {
splits: Vec<TextSpan>,
next_split: usize,
next_separator: usize,
good_splits: Vec<TextSpan>,
}
impl<'a, S> RecursiveSpanIter<'a, S>
where
S: ChunkSizer,
{
fn new(
splitter: &'a RecursiveCharacterTextSplitter<S>,
input: &'a str,
length_fn: &'a dyn Fn(&str) -> usize,
) -> Self {
let initial = TextSpan::new(0, input.len());
let frames = RecursiveFrame::new(splitter, input, initial, 0)
.into_iter()
.collect();
Self {
splitter,
input,
length_fn,
frames,
output: VecDeque::new(),
}
}
fn flush_good_splits(&mut self) {
let Some(frame) = self.frames.last_mut() else {
return;
};
if frame.good_splits.is_empty() {
return;
}
let good_splits = std::mem::take(&mut frame.good_splits);
self.output.extend(merge_structured_spans(
self.input,
&good_splits,
self.splitter.config.chunk_size,
self.splitter.config.chunk_overlap,
self.splitter.strip_whitespace,
self.splitter.keep_separator,
self.length_fn,
));
}
}
impl<S> Iterator for RecursiveSpanIter<'_, S>
where
S: ChunkSizer,
{
type Item = MeasuredSpan;
fn next(&mut self) -> Option<Self::Item> {
loop {
if let Some(span) = self.output.pop_front() {
return Some(MeasuredSpan::new(self.input, span, self.length_fn));
}
let frame = self.frames.last_mut()?;
if frame.next_split >= frame.splits.len() {
self.flush_good_splits();
if let Some(span) = self.output.pop_front() {
return Some(MeasuredSpan::new(self.input, span, self.length_fn));
}
self.frames.pop();
continue;
}
let split = frame.splits[frame.next_split];
frame.next_split += 1;
if (self.length_fn)(split.text(self.input)) < self.splitter.config.chunk_size {
frame.good_splits.push(split);
continue;
}
self.flush_good_splits();
let next_separator = self.frames.last().map(|frame| frame.next_separator)?;
if next_separator >= self.splitter.separators.len() {
self.output.push_back(split);
continue;
}
if let Some(child) =
RecursiveFrame::new(self.splitter, self.input, split, next_separator)
{
self.frames.push(child);
} else {
self.output.push_back(split);
}
}
}
}
impl RecursiveFrame {
fn new<S>(
splitter: &RecursiveCharacterTextSplitter<S>,
input: &str,
span: TextSpan,
separator_start: usize,
) -> Option<Self>
where
S: ChunkSizer,
{
if span.start == span.end {
return None;
}
let text = span.text(input);
let (separator_index, _separator, separator_regex) =
choose_separator(splitter, text, separator_start).unwrap_or_else(|| {
let last = splitter.separators.len().saturating_sub(1);
(
last,
splitter.separators[last].as_str(),
splitter.separator_regexes[last].as_ref(),
)
});
let keep = if splitter.keep_separator {
Some(KeepSeparator::Start)
} else {
None
};
let splits = crate::split::split_spans_with_compiled_regex(text, separator_regex, keep)
.into_iter()
.map(|split| TextSpan::new(span.start + split.start, span.start + split.end))
.collect();
Some(Self {
splits,
next_split: 0,
next_separator: separator_index + 1,
good_splits: Vec::new(),
})
}
}
fn choose_separator<'a, S>(
splitter: &'a RecursiveCharacterTextSplitter<S>,
text: &str,
separator_start: usize,
) -> Option<(usize, &'a str, Option<&'a Regex>)>
where
S: ChunkSizer,
{
for (offset, separator) in splitter.separators[separator_start..].iter().enumerate() {
let index = separator_start + offset;
if separator.is_empty() {
return Some((index, separator, None));
}
let regex = splitter.separator_regexes[index]
.as_ref()
.expect("non-empty separator must have compiled regex");
if regex.is_match(text) {
return Some((index, separator, Some(regex)));
}
}
None
}
fn merge_structured_spans(
input: &str,
splits: &[TextSpan],
chunk_size: usize,
chunk_overlap: usize,
strip_whitespace: bool,
keep_separator: bool,
length_fn: &dyn Fn(&str) -> usize,
) -> Vec<TextSpan> {
if keep_separator {
return merge_spans(
input,
splits,
chunk_size,
chunk_overlap,
strip_whitespace,
length_fn,
);
}
let mut chunks = Vec::new();
let mut group_start = 0usize;
while group_start < splits.len() {
let mut group_end = group_start + 1;
while group_end < splits.len() && splits[group_end - 1].end == splits[group_end].start {
group_end += 1;
}
chunks.extend(merge_spans(
input,
&splits[group_start..group_end],
chunk_size,
chunk_overlap,
strip_whitespace,
length_fn,
));
group_start = group_end;
}
chunks
}
#[derive(Clone)]
pub struct RecursiveCharacterTextSplitterBuilder<S = CharSizer> {
separators: Vec<String>,
is_separator_regex: bool,
config: ChunkConfig<S>,
keep_separator: bool,
strip_whitespace: bool,
length_fn: crate::LengthFn,
}
impl Default for RecursiveCharacterTextSplitterBuilder<CharSizer> {
fn default() -> Self {
Self {
separators: DEFAULT_SEPARATORS.iter().map(|s| s.to_string()).collect(),
is_separator_regex: false,
config: ChunkConfig::default(),
keep_separator: true,
strip_whitespace: true,
length_fn: Arc::new(crate::char_len),
}
}
}
impl<S> RecursiveCharacterTextSplitterBuilder<S>
where
S: ChunkSizer,
{
pub fn separators(mut self, separators: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.separators = separators.into_iter().map(Into::into).collect();
self
}
pub fn separators_are_regex(mut self, is_separator_regex: bool) -> Self {
self.is_separator_regex = is_separator_regex;
self
}
pub fn chunk_size(mut self, chunk_size: usize) -> Self {
self.config.chunk_size = chunk_size;
self
}
pub fn chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.config.chunk_overlap = chunk_overlap;
self
}
pub fn keep_separator(mut self, keep_separator: bool) -> Self {
self.keep_separator = keep_separator;
self
}
pub fn strip_whitespace(mut self, strip_whitespace: bool) -> Self {
self.strip_whitespace = strip_whitespace;
self
}
pub fn sizer<T>(self, sizer: T) -> RecursiveCharacterTextSplitterBuilder<T>
where
T: ChunkSizer,
{
let length_sizer = sizer.clone();
RecursiveCharacterTextSplitterBuilder {
separators: self.separators,
is_separator_regex: self.is_separator_regex,
config: ChunkConfig::new(self.config.chunk_size, self.config.chunk_overlap, sizer),
keep_separator: self.keep_separator,
strip_whitespace: self.strip_whitespace,
length_fn: Arc::new(move |value: &str| length_sizer.size(value)),
}
}
pub fn length_fn(
self,
length_fn: crate::LengthFn,
) -> RecursiveCharacterTextSplitterBuilder<FunctionSizer> {
self.sizer(FunctionSizer::new(length_fn))
}
pub fn build(self) -> Result<RecursiveCharacterTextSplitter<S>, ChunkError> {
validate_chunk_config(self.config.chunk_size, self.config.chunk_overlap)?;
if self.separators.is_empty() {
return Err(ChunkError::invalid_configuration(
"recursive splitter requires at least one separator",
));
}
let separator_regexes = self
.separators
.iter()
.map(|separator| {
compile_separator_regex(separator, self.is_separator_regex).map_err(|err| {
ChunkError::invalid_configuration(format!(
"invalid separator regex {separator:?}: {err}"
))
})
})
.collect::<Result<Vec<_>, _>>()?;
Ok(RecursiveCharacterTextSplitter {
separators: self.separators,
separator_regexes,
config: self.config,
keep_separator: self.keep_separator,
strip_whitespace: self.strip_whitespace,
length_fn: self.length_fn,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn regexes(separators: &[String], is_separator_regex: bool) -> Vec<Option<Regex>> {
separators
.iter()
.map(|separator| compile_separator_regex(separator, is_separator_regex).unwrap())
.collect()
}
#[test]
fn test_recursive_basic() {
let separators = DEFAULT_SEPARATORS
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let splitter = RecursiveCharacterTextSplitter {
separator_regexes: regexes(&separators, false),
separators,
config: crate::sizing::ChunkConfig::new(10, 1, crate::sizing::CharSizer),
keep_separator: true,
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let text =
"Hi.\n\nI'm Harrison.\n\nHow? Are? You?\nOkay then f f f f.\nThis is a weird text to write, but gotta test the splittting am I right?";
let result = splitter.split_text(text);
for chunk in &result {
assert!(
chunk.chars().count() <= 10,
"Chunk exceeded chunk_size: {:?} ({})",
chunk,
chunk.chars().count()
);
}
assert!(result.len() > 1, "Expected multiple chunks");
}
#[test]
fn test_recursive_custom_separators() {
let separators = vec!["X".to_string(), "Y".to_string()];
let splitter = RecursiveCharacterTextSplitter {
separator_regexes: regexes(&separators, false),
separators,
config: crate::sizing::ChunkConfig::new(5, 0, crate::sizing::CharSizer),
keep_separator: false,
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let result = splitter.split_text("aaXbbYcc");
assert!(result.contains(&"aa".to_string()));
}
#[test]
fn test_recursive_default_separators() {
let splitter = RecursiveCharacterTextSplitter::new(20, 0);
let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
let result = splitter.split_text(text);
assert!(result.len() >= 2);
for chunk in &result {
assert!(chunk.chars().count() <= 20, "Chunk too long: {:?}", chunk);
}
}
#[test]
fn test_recursive_empty_text() {
let splitter = RecursiveCharacterTextSplitter::new(10, 0);
let result = splitter.split_text("");
assert!(result.is_empty());
}
#[test]
fn test_recursive_single_chunk_fits() {
let splitter = RecursiveCharacterTextSplitter::new(100, 0);
let result = splitter.split_text("Short text");
assert_eq!(result, vec!["Short text"]);
}
#[test]
fn test_recursive_with_overlap() {
let splitter = RecursiveCharacterTextSplitter::new(10, 3);
let text = "aaaa\n\nbbbb\n\ncccc\n\ndddd";
let result = splitter.split_text(text);
assert!(result.len() >= 2);
for chunk in &result {
assert!(chunk.chars().count() <= 10, "Chunk too long: {:?}", chunk);
}
}
#[test]
fn test_recursive_keep_separator_false() {
let separators = DEFAULT_SEPARATORS
.iter()
.map(|s| s.to_string())
.collect::<Vec<_>>();
let splitter = RecursiveCharacterTextSplitter {
separator_regexes: regexes(&separators, false),
separators,
config: crate::sizing::ChunkConfig::new(10, 0, crate::sizing::CharSizer),
keep_separator: false,
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let text = "Hello\n\nWorld\n\nFoo";
let result = splitter.split_text(text);
for chunk in &result {
assert!(
!chunk.starts_with("\n\n"),
"Separator should not be at start"
);
}
}
#[test]
fn test_recursive_structured_chunks_do_not_reintroduce_removed_separators() {
let separators = vec!["\n\n".to_string(), " ".to_string(), "".to_string()];
let splitter = RecursiveCharacterTextSplitter {
separator_regexes: regexes(&separators, false),
separators,
config: crate::sizing::ChunkConfig::new(100, 0, crate::sizing::CharSizer),
keep_separator: false,
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let chunks = splitter.split_chunks("Hello\n\nWorld\n\nFoo");
assert_eq!(
chunks.iter().map(|chunk| chunk.text).collect::<Vec<_>>(),
vec!["Hello", "World", "Foo"]
);
for chunk in chunks {
assert_eq!(
chunk.text,
&"Hello\n\nWorld\n\nFoo"[chunk.start_byte..chunk.end_byte]
);
assert!(!chunk.text.contains("\n\n"));
}
}
#[test]
fn test_recursive_is_separator_regex() {
let separators = vec![r"\d+".to_string(), "".to_string()];
let splitter = RecursiveCharacterTextSplitter {
separator_regexes: regexes(&separators, true),
separators,
config: crate::sizing::ChunkConfig::new(10, 0, crate::sizing::CharSizer),
keep_separator: false,
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let result = splitter.split_text("abc123def456ghi");
assert_eq!(result, vec!["abc", "def", "ghi"]);
}
#[test]
fn test_recursive_respects_chunk_size_on_long_text() {
let text = "a ".repeat(200);
let splitter = RecursiveCharacterTextSplitter::new(50, 5);
let result = splitter.split_text(&text);
assert!(result.len() > 1);
for chunk in &result {
assert!(
chunk.chars().count() <= 50,
"Chunk exceeded 50 chars: {} chars",
chunk.chars().count()
);
}
}
}