use crate::{Chunker, Slab};
#[derive(Debug, Clone)]
pub struct RecursiveChunker {
max_size: usize,
overlap: usize,
separators: Vec<String>,
}
impl RecursiveChunker {
#[must_use]
pub fn new(max_size: usize, separators: &[&str]) -> Self {
assert!(max_size > 0, "max_size must be > 0");
assert!(!separators.is_empty(), "separators must not be empty");
Self {
max_size,
overlap: 0,
separators: separators.iter().map(|&s| s.to_string()).collect(),
}
}
#[must_use]
pub fn with_overlap(mut self, overlap: usize) -> Self {
self.overlap = overlap;
self
}
fn split_recursive(&self, text: &str, sep_index: usize) -> Vec<String> {
if text.len() <= self.max_size || sep_index >= self.separators.len() {
if text.len() <= self.max_size {
return vec![text.to_string()];
}
return self.force_split(text);
}
let sep = &self.separators[sep_index];
let parts: Vec<&str> = text.split(sep).collect();
if parts.len() == 1 {
return self.split_recursive(text, sep_index + 1);
}
let mut result = Vec::new();
let mut current = String::new();
for (i, part) in parts.iter().enumerate() {
let with_sep = if i < parts.len() - 1 {
format!("{part}{sep}")
} else {
(*part).to_string()
};
if current.is_empty() {
current = with_sep;
} else if current.len() + with_sep.len() <= self.max_size {
current.push_str(&with_sep);
} else {
if current.len() <= self.max_size {
result.push(current);
} else {
result.extend(self.split_recursive(¤t, sep_index + 1));
}
current = with_sep;
}
}
if !current.is_empty() {
if current.len() <= self.max_size {
result.push(current);
} else {
result.extend(self.split_recursive(¤t, sep_index + 1));
}
}
result
}
fn force_split(&self, text: &str) -> Vec<String> {
let mut result = Vec::new();
let mut start = 0;
while start < text.len() {
let end = (start + self.max_size).min(text.len());
let mut end = end;
while !text.is_char_boundary(end) {
end -= 1;
}
if end > start {
result.push(text[start..end].to_string());
}
start = end;
}
result
}
}
impl Chunker for RecursiveChunker {
fn chunk_bytes(&self, text: &str) -> Vec<Slab> {
if text.is_empty() {
return vec![];
}
let chunks = self.split_recursive(text, 0);
let mut slabs = Vec::with_capacity(chunks.len());
let mut cursor = 0usize;
for (index, chunk) in chunks.into_iter().enumerate() {
let start = cursor;
let end = start + chunk.len();
cursor = end;
let mut start_with_overlap = start.saturating_sub(self.overlap);
if end.saturating_sub(start_with_overlap) > self.max_size {
start_with_overlap = end.saturating_sub(self.max_size);
}
while start_with_overlap > 0 && !text.is_char_boundary(start_with_overlap) {
start_with_overlap -= 1;
}
slabs.push(Slab::new(
text[start_with_overlap..end].to_string(),
start_with_overlap,
end,
index,
));
}
slabs
}
fn estimate_chunks(&self, text_len: usize) -> usize {
let step = self.max_size.saturating_sub(self.overlap).max(1);
(text_len / step).max(1)
}
}
#[cfg(test)]
mod tests {
use super::*;
const PROSE_SEPS: &[&str] = &["\n\n", "\n", ". ", " "];
#[test]
fn test_paragraph_split() {
let chunker = RecursiveChunker::new(50, PROSE_SEPS);
let text =
"Short.\n\nThis is a longer paragraph that might need splitting into smaller pieces.";
let slabs = chunker.chunk(text);
assert!(slabs.len() >= 2);
assert!(slabs[0].text.contains("Short"));
}
#[test]
fn test_respects_max_size() {
let chunker = RecursiveChunker::new(20, PROSE_SEPS);
let text = "The quick brown fox jumps over the lazy dog.";
let slabs = chunker.chunk(text);
for slab in &slabs {
assert!(slab.len() <= 20, "Chunk too large: {} bytes", slab.len());
}
}
#[test]
fn test_empty_text() {
let chunker = RecursiveChunker::new(100, PROSE_SEPS);
let slabs = chunker.chunk("");
assert!(slabs.is_empty());
}
#[test]
fn test_small_text_single_chunk() {
let chunker = RecursiveChunker::new(100, PROSE_SEPS);
let slabs = chunker.chunk("Small text.");
assert_eq!(slabs.len(), 1);
}
#[test]
#[should_panic]
fn test_zero_size_panics() {
let _ = RecursiveChunker::new(0, PROSE_SEPS);
}
#[test]
#[should_panic]
fn test_empty_separators_panics() {
let _ = RecursiveChunker::new(100, &[]);
}
}