use std::sync::Arc;
use regex::Regex;
use crate::chunk::{TextChunk, TextChunkIter};
use crate::error::ChunkError;
use crate::merge::{merge_splits, StreamingMerger};
use crate::sizing::{CharSizer, ChunkConfig, ChunkSizer, FunctionSizer};
use crate::split::{split_text_with_compiled_regex, RegexSpanIter};
#[derive(Clone)]
pub struct CharacterTextSplitter<S = CharSizer> {
pub(crate) separator: String,
separator_regex: Option<Regex>,
pub(crate) config: ChunkConfig<S>,
pub(crate) strip_whitespace: bool,
length_fn: crate::LengthFn,
}
impl CharacterTextSplitter<CharSizer> {
pub fn new(separator: &str, chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
separator: separator.to_string(),
separator_regex: compile_separator_regex(separator, false)
.expect("escaped literal separator must compile"),
config: ChunkConfig::new(chunk_size, chunk_overlap, CharSizer),
strip_whitespace: true,
length_fn: Arc::new(crate::char_len),
}
}
pub fn builder() -> CharacterTextSplitterBuilder<CharSizer> {
CharacterTextSplitterBuilder::default()
}
}
impl<S> CharacterTextSplitter<S>
where
S: ChunkSizer,
{
pub fn split_text(&self, text: &str) -> Vec<String> {
let len_fn = self.length_fn.as_ref();
let splits = if let Some(regex) = &self.separator_regex {
split_text_with_compiled_regex(text, regex, None)
} else {
text.chars().map(|ch| ch.to_string()).collect()
};
merge_splits(
&splits,
&self.separator,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
len_fn,
)
}
pub fn chunks<'a>(&'a self, text: &'a str) -> impl Iterator<Item = TextChunk<'a>> + 'a {
let len_fn = self.length_fn.as_ref();
let splits = RegexSpanIter::from_regex(text, self.separator_regex.clone(), None);
let merged = StreamingMerger::new(
text,
splits,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
len_fn,
);
TextChunkIter::new(text, merged)
}
pub fn split_chunks<'a>(&'a self, text: &'a str) -> Vec<TextChunk<'a>> {
self.chunks(text).collect()
}
}
#[derive(Clone)]
pub struct CharacterTextSplitterBuilder<S = CharSizer> {
separator: String,
is_separator_regex: bool,
config: ChunkConfig<S>,
strip_whitespace: bool,
length_fn: crate::LengthFn,
}
impl Default for CharacterTextSplitterBuilder<CharSizer> {
fn default() -> Self {
Self {
separator: "\n\n".to_string(),
is_separator_regex: false,
config: ChunkConfig::default(),
strip_whitespace: true,
length_fn: Arc::new(crate::char_len),
}
}
}
impl<S> CharacterTextSplitterBuilder<S>
where
S: ChunkSizer,
{
pub fn separator(mut self, separator: impl Into<String>) -> Self {
self.separator = separator.into();
self
}
pub fn separator_regex(mut self, separator: impl Into<String>) -> Self {
self.separator = separator.into();
self.is_separator_regex = true;
self
}
pub fn chunk_size(mut self, chunk_size: usize) -> Self {
self.config.chunk_size = chunk_size;
self
}
pub fn chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.config.chunk_overlap = chunk_overlap;
self
}
pub fn strip_whitespace(mut self, strip_whitespace: bool) -> Self {
self.strip_whitespace = strip_whitespace;
self
}
pub fn sizer<T>(self, sizer: T) -> CharacterTextSplitterBuilder<T>
where
T: ChunkSizer,
{
let length_sizer = sizer.clone();
CharacterTextSplitterBuilder {
separator: self.separator,
is_separator_regex: self.is_separator_regex,
config: ChunkConfig::new(self.config.chunk_size, self.config.chunk_overlap, sizer),
strip_whitespace: self.strip_whitespace,
length_fn: Arc::new(move |value: &str| length_sizer.size(value)),
}
}
pub fn length_fn(
self,
length_fn: crate::LengthFn,
) -> CharacterTextSplitterBuilder<FunctionSizer> {
self.sizer(FunctionSizer::new(length_fn))
}
pub fn build(self) -> Result<CharacterTextSplitter<S>, ChunkError> {
validate_chunk_config(self.config.chunk_size, self.config.chunk_overlap)?;
let separator_regex = compile_separator_regex(&self.separator, self.is_separator_regex)
.map_err(|err| {
ChunkError::invalid_configuration(format!(
"invalid separator regex {:?}: {err}",
self.separator
))
})?;
Ok(CharacterTextSplitter {
separator: self.separator,
separator_regex,
config: self.config,
strip_whitespace: self.strip_whitespace,
length_fn: self.length_fn,
})
}
}
pub(crate) fn compile_separator_regex(
separator: &str,
is_separator_regex: bool,
) -> Result<Option<Regex>, regex::Error> {
if separator.is_empty() {
return Ok(None);
}
let pattern = if is_separator_regex {
separator.to_string()
} else {
regex::escape(separator)
};
Regex::new(&pattern).map(Some)
}
pub(crate) fn validate_chunk_config(
chunk_size: usize,
chunk_overlap: usize,
) -> Result<(), ChunkError> {
if chunk_size == 0 {
return Err(ChunkError::invalid_configuration(
"chunk_size must be greater than zero",
));
}
if chunk_overlap >= chunk_size {
return Err(ChunkError::invalid_configuration(
"chunk_overlap must be smaller than chunk_size",
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_character_splitter_basic() {
let splitter = CharacterTextSplitter {
separator: " ".to_string(),
separator_regex: compile_separator_regex(" ", false).unwrap(),
config: crate::sizing::ChunkConfig::new(7, 3, crate::sizing::CharSizer),
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let result = splitter.split_text("foo bar baz 123");
assert_eq!(result, vec!["foo bar", "bar baz", "baz 123"]);
}
#[test]
fn test_character_splitter_empty_doc_filtering() {
let splitter = CharacterTextSplitter {
separator: " ".to_string(),
separator_regex: compile_separator_regex(" ", false).unwrap(),
config: crate::sizing::ChunkConfig::new(9, 0, crate::sizing::CharSizer),
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let result = splitter.split_text("foo bar");
assert_eq!(result, vec!["foo bar"]);
}
#[test]
fn test_character_splitter_small_chunks() {
let splitter = CharacterTextSplitter {
separator: " ".to_string(),
separator_regex: compile_separator_regex(" ", false).unwrap(),
config: crate::sizing::ChunkConfig::new(3, 1, crate::sizing::CharSizer),
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let result = splitter.split_text("foo bar baz a a");
assert_eq!(result, vec!["foo", "bar", "baz", "a a"]);
}
#[test]
fn test_character_splitter_empty_input() {
let splitter = CharacterTextSplitter::new("\n", 100, 0);
let result = splitter.split_text("");
assert!(result.is_empty());
}
#[test]
fn test_character_splitter_whitespace_only() {
let splitter = CharacterTextSplitter::new(" ", 100, 0);
let result = splitter.split_text(" ");
assert!(result.is_empty());
}
#[test]
fn test_character_splitter_no_separator_match() {
let splitter = CharacterTextSplitter::new("X", 100, 0);
let result = splitter.split_text("hello world");
assert_eq!(result, vec!["hello world"]);
}
#[test]
fn test_character_splitter_single_word() {
let splitter = CharacterTextSplitter::new(" ", 100, 0);
let result = splitter.split_text("hello");
assert_eq!(result, vec!["hello"]);
}
#[test]
fn test_character_splitter_is_separator_regex() {
let splitter = CharacterTextSplitter {
separator: r"\s+".to_string(),
separator_regex: compile_separator_regex(r"\s+", true).unwrap(),
config: crate::sizing::ChunkConfig::new(7, 0, crate::sizing::CharSizer),
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let result = splitter.split_text("foo bar\tbaz");
assert_eq!(result, vec!["foo", "bar", "baz"]);
}
#[test]
fn test_character_splitter_regex_special_chars() {
let splitter = CharacterTextSplitter {
separator: ".".to_string(),
separator_regex: compile_separator_regex(".", false).unwrap(),
config: crate::sizing::ChunkConfig::new(10, 0, crate::sizing::CharSizer),
strip_whitespace: true,
length_fn: std::sync::Arc::new(crate::char_len),
};
let result = splitter.split_text("hello.world.test");
assert_eq!(result, vec!["hello", "world.test"]);
}
#[test]
fn test_character_splitter_newline() {
let splitter = CharacterTextSplitter::new("\n\n", 20, 0);
let result = splitter.split_text("Hello World\n\nFoo Bar\n\nBaz");
assert_eq!(result, vec!["Hello World\n\nFoo Bar", "Baz"]);
}
#[test]
fn test_character_splitter_multichar_sep() {
let splitter = CharacterTextSplitter::new("<SEP>", 20, 0);
let result = splitter.split_text("part one<SEP>part two<SEP>part three");
assert!(result.len() >= 2);
for chunk in &result {
assert!(chunk.chars().count() <= 20, "Chunk too long: {:?}", chunk);
}
}
}