use crate::vector_stores::Document;
pub trait TextSplitter: Send + Sync {
fn split_text(&self, text: &str) -> Vec<String>;
fn split_document(&self, document: &Document) -> Vec<Document> {
let chunks = self.split_text(&document.content);
chunks
.into_iter()
.enumerate()
.map(|(i, chunk)| {
let mut metadata = document.metadata.clone();
metadata.insert("chunk".to_string(), i.to_string());
Document {
content: chunk,
metadata,
id: None,
}
})
.collect()
}
}
pub struct RecursiveCharacterSplitter {
chunk_size: usize,
chunk_overlap: usize,
separators: Vec<String>,
}
impl RecursiveCharacterSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
chunk_size,
chunk_overlap,
separators: vec![
"\n\n".to_string(), "\n".to_string(), "。".to_string(), ".".to_string(), " ".to_string(), "".to_string(), ],
}
}
pub fn with_defaults() -> Self {
Self::new(1000, 200)
}
pub fn with_separators(mut self, separators: Vec<String>) -> Self {
self.separators = separators;
self
}
fn split_text_recursive(&self, text: &str, separators: &[String]) -> Vec<String> {
let mut chunks = Vec::new();
if text.is_empty() {
return chunks;
}
if text.len() <= self.chunk_size {
chunks.push(text.to_string());
return chunks;
}
let separator = separators
.iter()
.find(|s| text.contains(s.as_str()))
.cloned()
.unwrap_or_default();
let splits: Vec<String> = if separator.is_empty() {
text.chars().map(|c| c.to_string()).collect()
} else {
text.split(&separator).map(|s| s.to_string()).collect()
};
let mut current_chunk = String::new();
for split in splits {
let split_with_sep = if separator.is_empty() {
split.clone()
} else if current_chunk.is_empty() {
split
} else {
format!("{}{}", separator, split)
};
if split_with_sep.len() > self.chunk_size {
if !current_chunk.is_empty() {
chunks.push(current_chunk.clone());
current_chunk.clear();
}
let next_separators = if separators.len() > 1 {
&separators[1..]
} else {
&[]
};
let sub_chunks = self.split_text_recursive(&split_with_sep, next_separators);
chunks.extend(sub_chunks);
} else if current_chunk.len() + split_with_sep.len() > self.chunk_size {
chunks.push(current_chunk.clone());
current_chunk = split_with_sep;
} else {
current_chunk.push_str(&split_with_sep);
}
}
if !current_chunk.is_empty() {
chunks.push(current_chunk);
}
chunks
}
}
impl TextSplitter for RecursiveCharacterSplitter {
fn split_text(&self, text: &str) -> Vec<String> {
let mut chunks = self.split_text_recursive(text, &self.separators);
if self.chunk_overlap > 0 && chunks.len() > 1 {
let mut overlapped = Vec::new();
for (i, chunk) in chunks.into_iter().enumerate() {
if i == 0 {
overlapped.push(chunk);
} else {
let prev = &overlapped[i - 1];
let chars: Vec<char> = prev.chars().collect();
let overlap_chars = chars.len().saturating_sub(self.chunk_overlap);
let overlap: String = chars[overlap_chars..].iter().collect();
overlapped.push(format!("{}{}", overlap, chunk));
}
}
chunks = overlapped;
}
chunks
}
}
#[allow(dead_code)]
pub struct CharacterTextSplitter {
chunk_size: usize,
chunk_overlap: usize,
separator: String,
}
#[allow(dead_code)]
impl CharacterTextSplitter {
pub fn new(chunk_size: usize, chunk_overlap: usize, separator: &str) -> Self {
Self {
chunk_size,
chunk_overlap,
separator: separator.to_string(),
}
}
}
impl TextSplitter for CharacterTextSplitter {
fn split_text(&self, text: &str) -> Vec<String> {
let splits: Vec<&str> = text.split(&self.separator).collect();
let mut chunks = Vec::new();
let mut current = String::new();
for split in splits {
if current.len() + split.len() + self.separator.len() > self.chunk_size
&& !current.is_empty()
{
chunks.push(current.clone());
if self.chunk_overlap > 0 {
let overlap_start = current.len().saturating_sub(self.chunk_overlap);
current = current[overlap_start..].to_string();
} else {
current.clear();
}
}
if !current.is_empty() {
current.push_str(&self.separator);
}
current.push_str(split);
}
if !current.is_empty() {
chunks.push(current);
}
chunks
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recursive_splitter() {
let splitter = RecursiveCharacterSplitter::new(50, 10);
let text = "This is a sentence. This is another sentence. And a third one.";
let chunks = splitter.split_text(text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(chunk.len() <= 60); }
}
#[test]
fn test_split_document() {
let splitter = RecursiveCharacterSplitter::new(100, 20);
let doc = Document::new("First paragraph.\n\nSecond paragraph.\n\nThird paragraph.")
.with_metadata("source", "test");
let chunks = splitter.split_document(&doc);
assert!(!chunks.is_empty());
for (i, chunk) in chunks.iter().enumerate() {
assert!(chunk.metadata.contains_key("chunk"));
assert_eq!(chunk.metadata.get("chunk"), Some(&i.to_string()));
assert_eq!(chunk.metadata.get("source"), Some(&"test".to_string()));
}
}
#[test]
fn test_character_splitter() {
let splitter = CharacterTextSplitter::new(20, 5, " ");
let text = "This is a test sentence with multiple words";
let chunks = splitter.split_text(text);
assert!(!chunks.is_empty());
}
#[test]
fn test_empty_text() {
let splitter = RecursiveCharacterSplitter::new(100, 20);
let chunks = splitter.split_text("");
assert!(chunks.is_empty());
}
#[test]
fn test_small_text() {
let splitter = RecursiveCharacterSplitter::new(1000, 200);
let text = "Short text";
let chunks = splitter.split_text(text);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], "Short text");
}
}