use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct RichChunk {
pub text: String,
pub char_start: usize,
pub char_end: usize,
pub chunk_index: usize,
pub metadata: HashMap<String, String>,
}
impl RichChunk {
pub fn new(text: String, start: usize, end: usize, index: usize) -> Self {
Self {
text,
char_start: start,
char_end: end,
chunk_index: index,
metadata: HashMap::new(),
}
}
pub fn len(&self) -> usize {
self.text.chars().count()
}
pub fn is_empty(&self) -> bool {
self.text.is_empty()
}
pub fn word_count(&self) -> usize {
self.text.split_whitespace().count()
}
}
pub trait ChunkStrategy: Send + Sync {
fn chunk(&self, text: &str) -> Vec<RichChunk>;
fn name(&self) -> &'static str;
}
pub struct SentenceChunker {
pub max_chars: usize,
pub overlap_sentences: usize,
}
impl SentenceChunker {
pub fn new(max_chars: usize) -> Self {
Self {
max_chars,
overlap_sentences: 0,
}
}
pub fn with_overlap(mut self, sentences: usize) -> Self {
self.overlap_sentences = sentences;
self
}
}
impl ChunkStrategy for SentenceChunker {
fn name(&self) -> &'static str {
"sentence"
}
fn chunk(&self, text: &str) -> Vec<RichChunk> {
if text.is_empty() {
return Vec::new();
}
let sentences = split_sentences(text);
if sentences.is_empty() {
return Vec::new();
}
let mut chunks: Vec<RichChunk> = Vec::new();
let mut i = 0usize;
while i < sentences.len() {
let mut group: Vec<&str> = Vec::new();
let mut total_chars = 0usize;
let mut j = i;
while j < sentences.len() {
let s = sentences[j];
let added = if group.is_empty() {
s.len()
} else {
s.len() + 1
};
if !group.is_empty() && total_chars + added > self.max_chars {
break;
}
group.push(s);
total_chars += added;
j += 1;
}
if group.is_empty() {
group.push(sentences[i]);
j = i + 1;
}
let chunk_text = group.join(" ");
let char_start: usize = sentences[..i].iter().map(|s| s.chars().count() + 1).sum();
let char_end = char_start + chunk_text.chars().count();
chunks.push(RichChunk::new(
chunk_text,
char_start,
char_end,
chunks.len(),
));
let consumed = j - i;
let step = if consumed > self.overlap_sentences {
consumed - self.overlap_sentences
} else {
1
};
i += step.max(1);
}
chunks
}
}
pub struct RecursiveCharSplitter {
pub max_chars: usize,
pub overlap: usize,
pub separators: Vec<String>,
}
impl RecursiveCharSplitter {
pub fn new(max_chars: usize) -> Self {
Self {
max_chars,
overlap: 0,
separators: Self::default_separators(),
}
}
pub fn with_overlap(mut self, overlap: usize) -> Self {
self.overlap = overlap;
self
}
pub fn with_separators(mut self, seps: Vec<String>) -> Self {
self.separators = seps;
self
}
pub fn default_separators() -> Vec<String> {
vec![
"\n\n".to_string(),
"\n".to_string(),
" ".to_string(),
"".to_string(),
]
}
fn split_recursive(&self, text: &str, seps: &[String]) -> Vec<String> {
if text.chars().count() <= self.max_chars {
return vec![text.to_string()];
}
let sep = match seps.first() {
Some(s) => s,
None => {
return split_by_chars(text, self.max_chars, self.overlap);
}
};
let remaining_seps = &seps[1..];
if sep.is_empty() {
return split_by_chars(text, self.max_chars, self.overlap);
}
let parts: Vec<&str> = text.split(sep.as_str()).collect();
let mut result: Vec<String> = Vec::new();
let mut current_group: Vec<&str> = Vec::new();
let mut current_len = 0usize;
for part in &parts {
let part_len = part.chars().count();
let sep_len = if current_group.is_empty() {
0
} else {
sep.chars().count()
};
if current_len + sep_len + part_len > self.max_chars && !current_group.is_empty() {
let joined = current_group.join(sep.as_str());
if joined.chars().count() > self.max_chars {
let sub = self.split_recursive(&joined, remaining_seps);
result.extend(sub);
} else {
result.push(joined);
}
if self.overlap > 0 {
let mut overlap_items: Vec<&str> = Vec::new();
let mut overlap_len = 0usize;
for &item in current_group.iter().rev() {
let item_len = item.chars().count() + sep.chars().count();
if overlap_len + item_len > self.overlap {
break;
}
overlap_items.push(item);
overlap_len += item_len;
}
overlap_items.reverse();
current_group = overlap_items;
current_len = current_group
.iter()
.map(|s| s.chars().count())
.sum::<usize>()
+ if current_group.len() > 1 {
(current_group.len() - 1) * sep.chars().count()
} else {
0
};
} else {
current_group.clear();
current_len = 0;
}
}
if part_len > self.max_chars {
if !current_group.is_empty() {
result.push(current_group.join(sep.as_str()));
current_group.clear();
current_len = 0;
}
let sub = self.split_recursive(part, remaining_seps);
result.extend(sub);
} else {
let sep_add = if current_group.is_empty() {
0
} else {
sep.chars().count()
};
current_len += sep_add + part_len;
current_group.push(part);
}
}
if !current_group.is_empty() {
let joined = current_group.join(sep.as_str());
if joined.chars().count() > self.max_chars {
let sub = self.split_recursive(&joined, remaining_seps);
result.extend(sub);
} else {
result.push(joined);
}
}
result
}
}
impl ChunkStrategy for RecursiveCharSplitter {
fn name(&self) -> &'static str {
"recursive"
}
fn chunk(&self, text: &str) -> Vec<RichChunk> {
if text.is_empty() {
return Vec::new();
}
let pieces = self.split_recursive(text, &self.separators.clone());
let mut chunks: Vec<RichChunk> = Vec::new();
let mut char_cursor = 0usize;
let text_chars: Vec<char> = text.chars().collect();
let total_chars = text_chars.len();
for piece in pieces {
if piece.is_empty() {
continue;
}
let piece_len = piece.chars().count();
let start = find_substring_char_offset(&text_chars, &piece, char_cursor);
let start = start.unwrap_or(char_cursor);
let end = (start + piece_len).min(total_chars);
chunks.push(RichChunk::new(piece, start, end, chunks.len()));
char_cursor = start + piece_len;
}
chunks
}
}
pub struct SlidingWindowChunker {
pub window_size: usize,
pub step_size: usize,
}
impl SlidingWindowChunker {
pub fn new(window_size: usize, step_size: usize) -> Self {
Self {
window_size,
step_size: step_size.max(1),
}
}
pub fn non_overlapping(size: usize) -> Self {
Self::new(size, size)
}
pub fn with_50pct_overlap(size: usize) -> Self {
let step = (size / 2).max(1);
Self::new(size, step)
}
}
impl ChunkStrategy for SlidingWindowChunker {
fn name(&self) -> &'static str {
"sliding_window"
}
fn chunk(&self, text: &str) -> Vec<RichChunk> {
if text.is_empty() || self.window_size == 0 {
return Vec::new();
}
let chars: Vec<char> = text.chars().collect();
let total = chars.len();
let mut chunks: Vec<RichChunk> = Vec::new();
let mut start = 0usize;
while start < total {
let end = (start + self.window_size).min(total);
let chunk_text: String = chars[start..end].iter().collect();
chunks.push(RichChunk::new(chunk_text, start, end, chunks.len()));
if end == total {
break;
}
start += self.step_size;
}
chunks
}
}
pub struct MarkdownChunker {
pub max_chars: usize,
pub min_heading_level: u8,
}
impl MarkdownChunker {
pub fn new(max_chars: usize) -> Self {
Self {
max_chars,
min_heading_level: 1,
}
}
fn heading_level(line: &str) -> u8 {
let trimmed = line.trim_start();
let count = trimmed.bytes().take_while(|&b| b == b'#').count();
if count == 0 || count > 6 {
return 0;
}
let after = &trimmed[count..];
if after.is_empty() || after.starts_with(' ') {
count as u8
} else {
0
}
}
}
impl ChunkStrategy for MarkdownChunker {
fn name(&self) -> &'static str {
"markdown"
}
fn chunk(&self, text: &str) -> Vec<RichChunk> {
if text.is_empty() {
return Vec::new();
}
let lines: Vec<&str> = text.lines().collect();
let mut sections: Vec<String> = Vec::new();
let mut current: Vec<&str> = Vec::new();
for line in &lines {
let level = Self::heading_level(line);
if level > 0 && level <= self.min_heading_level + 5 {
if !current.is_empty() {
sections.push(current.join("\n"));
current.clear();
}
}
current.push(line);
}
if !current.is_empty() {
sections.push(current.join("\n"));
}
let splitter = RecursiveCharSplitter::new(self.max_chars);
let mut chunks: Vec<RichChunk> = Vec::new();
let mut char_cursor = 0usize;
for section in §ions {
if section.trim().is_empty() {
char_cursor += section.chars().count() + 1; continue;
}
let section_len = section.chars().count();
if section_len <= self.max_chars {
let start = char_cursor;
let end = start + section_len;
chunks.push(RichChunk::new(section.clone(), start, end, chunks.len()));
} else {
let sub_chunks = splitter.chunk(section);
for mut sc in sub_chunks {
sc.char_start += char_cursor;
sc.char_end += char_cursor;
sc.chunk_index = chunks.len();
chunks.push(sc);
}
}
char_cursor += section_len + 1; }
chunks
}
}
pub struct ChunkerRegistry {
strategies: HashMap<String, Box<dyn ChunkStrategy>>,
}
impl ChunkerRegistry {
pub fn new() -> Self {
Self {
strategies: HashMap::new(),
}
}
pub fn register(&mut self, strategy: Box<dyn ChunkStrategy>) {
self.strategies
.insert(strategy.name().to_string(), strategy);
}
pub fn chunk(&self, name: &str, text: &str) -> Option<Vec<RichChunk>> {
self.strategies.get(name).map(|s| s.chunk(text))
}
pub fn available_strategies(&self) -> Vec<&str> {
let mut names: Vec<&str> = self.strategies.keys().map(|s| s.as_str()).collect();
names.sort();
names
}
pub fn default_registry() -> Self {
let mut registry = Self::new();
registry.register(Box::new(SentenceChunker::new(512)));
registry.register(Box::new(RecursiveCharSplitter::new(512)));
registry.register(Box::new(SlidingWindowChunker::non_overlapping(512)));
registry.register(Box::new(MarkdownChunker::new(512)));
registry
}
}
impl Default for ChunkerRegistry {
fn default() -> Self {
Self::new()
}
}
fn split_sentences(text: &str) -> Vec<&str> {
let mut sentences = Vec::new();
let bytes = text.as_bytes();
let mut start = 0usize;
let mut i = 0usize;
while i < bytes.len() {
let b = bytes[i];
if b == b'.' || b == b'!' || b == b'?' {
let mut j = i + 1;
while j < bytes.len() && (bytes[j] == b'.' || bytes[j] == b'!' || bytes[j] == b'?') {
j += 1;
}
let at_end = j >= bytes.len();
let followed_by_space = !at_end
&& (bytes[j] == b' '
|| bytes[j] == b'\t'
|| bytes[j] == b'\n'
|| bytes[j] == b'\r');
if at_end || followed_by_space {
while j < bytes.len()
&& (bytes[j] == b' '
|| bytes[j] == b'\t'
|| bytes[j] == b'\n'
|| bytes[j] == b'\r')
{
j += 1;
}
let sentence = text[start..j].trim();
if !sentence.is_empty() {
sentences.push(sentence);
}
start = j;
i = j;
continue;
}
}
i += 1;
}
let tail = text[start..].trim();
if !tail.is_empty() {
sentences.push(tail);
}
sentences
}
fn split_by_chars(text: &str, max_chars: usize, overlap: usize) -> Vec<String> {
let chars: Vec<char> = text.chars().collect();
let total = chars.len();
if total == 0 {
return Vec::new();
}
let step = max_chars.saturating_sub(overlap).max(1);
let mut result = Vec::new();
let mut start = 0usize;
while start < total {
let end = (start + max_chars).min(total);
let chunk: String = chars[start..end].iter().collect();
result.push(chunk);
if end == total {
break;
}
start += step;
}
result
}
fn find_substring_char_offset(haystack: &[char], needle: &str, from: usize) -> Option<usize> {
let needle_chars: Vec<char> = needle.chars().collect();
let needle_len = needle_chars.len();
if needle_len == 0 {
return Some(from);
}
let limit = if haystack.len() >= needle_len {
haystack.len() - needle_len + 1
} else {
return None;
};
for i in from..limit {
if haystack[i..i + needle_len] == needle_chars[..] {
return Some(i);
}
}
None
}
#[cfg(test)]
mod inline_tests {
use super::*;
#[test]
fn rich_chunk_len_and_words() {
let chunk = RichChunk::new("hello world".to_string(), 0, 11, 0);
assert_eq!(chunk.len(), 11);
assert_eq!(chunk.word_count(), 2);
}
#[test]
fn sentence_chunker_splits() {
let chunker = SentenceChunker::new(200);
let text = "Hello world. This is a test. Another sentence here.";
let chunks = chunker.chunk(text);
assert!(!chunks.is_empty());
}
#[test]
fn recursive_splitter_short() {
let splitter = RecursiveCharSplitter::new(1000);
let text = "short text";
let chunks = splitter.chunk(text);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].text, text);
}
#[test]
fn sliding_window_basic() {
let chunker = SlidingWindowChunker::non_overlapping(5);
let chunks = chunker.chunk("abcdefghij");
assert_eq!(chunks.len(), 2);
}
}