#![allow(clippy::unused_self)]
use crate::types::ConversionOptions;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[allow(missing_docs)]
pub enum ChunkStrategy {
TokenBased,
Semantic,
SlidingWindow,
}
#[derive(Debug, Clone)]
pub struct TextChunk {
pub content: String,
pub index: usize,
pub start_offset: usize,
pub end_offset: usize,
pub token_count: usize,
}
#[derive(Debug)]
pub struct Chunker {
strategy: ChunkStrategy,
max_chunk_size: usize,
overlap: usize,
}
impl Chunker {
pub fn new(strategy: ChunkStrategy, max_chunk_size: usize, overlap: usize) -> Self {
Self {
strategy,
max_chunk_size,
overlap,
}
}
pub fn from_options(options: &ConversionOptions) -> Self {
Self::new(
ChunkStrategy::Semantic, options.max_chunk_size,
200, )
}
pub fn chunk(&self, text: &str) -> Vec<TextChunk> {
match self.strategy {
ChunkStrategy::TokenBased => self.chunk_token_based(text),
ChunkStrategy::Semantic => self.chunk_semantic(text),
ChunkStrategy::SlidingWindow => self.chunk_sliding_window(text),
}
}
fn chunk_token_based(&self, text: &str) -> Vec<TextChunk> {
let mut chunks = Vec::new();
let mut current_pos = 0;
let mut chunk_index = 0;
while current_pos < text.len() {
let char_limit = self.max_chunk_size * 4;
let end_pos = (current_pos + char_limit).min(text.len());
let chunk_end = if end_pos < text.len() {
self.find_break_point(text, current_pos, end_pos)
} else {
end_pos
};
let chunk_end = chunk_end.max(current_pos + 1);
let content = text[current_pos..chunk_end].to_string();
let token_count = self.estimate_tokens(&content);
chunks.push(TextChunk {
content,
index: chunk_index,
start_offset: current_pos,
end_offset: chunk_end,
token_count,
});
let overlap_chars = self.overlap * 4;
let next_pos = if chunk_end > overlap_chars {
chunk_end - overlap_chars
} else {
chunk_end
};
current_pos = next_pos.max(current_pos + 1);
chunk_index += 1;
}
chunks
}
fn chunk_semantic(&self, text: &str) -> Vec<TextChunk> {
let mut chunks = Vec::new();
let mut current_chunk = String::new();
let mut chunk_start = 0;
let mut chunk_index = 0;
for paragraph in text.split("\n\n") {
let paragraph_tokens = self.estimate_tokens(paragraph);
let current_tokens = self.estimate_tokens(¤t_chunk);
if current_tokens + paragraph_tokens > self.max_chunk_size && !current_chunk.is_empty()
{
let chunk_end = chunk_start + current_chunk.len();
chunks.push(TextChunk {
content: current_chunk.trim().to_string(),
index: chunk_index,
start_offset: chunk_start,
end_offset: chunk_end,
token_count: current_tokens,
});
chunk_index += 1;
chunk_start = chunk_end;
current_chunk.clear();
}
if !current_chunk.is_empty() {
current_chunk.push_str("\n\n");
}
current_chunk.push_str(paragraph);
}
if !current_chunk.is_empty() {
let chunk_end = chunk_start + current_chunk.len();
let token_count = self.estimate_tokens(¤t_chunk);
chunks.push(TextChunk {
content: current_chunk.trim().to_string(),
index: chunk_index,
start_offset: chunk_start,
end_offset: chunk_end,
token_count,
});
}
chunks
}
fn chunk_sliding_window(&self, text: &str) -> Vec<TextChunk> {
let mut chunks = Vec::new();
let char_limit = self.max_chunk_size * 4;
let step_size = char_limit - (self.overlap * 4);
let mut chunk_index = 0;
let mut current_pos = 0;
while current_pos < text.len() {
let end_pos = (current_pos + char_limit).min(text.len());
let content = text[current_pos..end_pos].to_string();
let token_count = self.estimate_tokens(&content);
chunks.push(TextChunk {
content,
index: chunk_index,
start_offset: current_pos,
end_offset: end_pos,
token_count,
});
current_pos += step_size;
chunk_index += 1;
if end_pos >= text.len() {
break;
}
}
chunks
}
fn find_break_point(&self, text: &str, start: usize, preferred_end: usize) -> usize {
let search_text = &text[start..preferred_end];
if let Some(pos) = search_text.rfind(". ") {
return start + pos + 1;
}
if let Some(pos) = search_text.rfind("\n\n") {
return start + pos + 2;
}
if let Some(pos) = search_text.rfind('\n') {
return start + pos + 1;
}
if let Some(pos) = search_text.rfind(' ') {
return start + pos + 1;
}
preferred_end
}
fn estimate_tokens(&self, text: &str) -> usize {
(text.len() + 3) / 4
}
}
impl Default for Chunker {
fn default() -> Self {
Self::new(ChunkStrategy::Semantic, 2048, 200)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_based_chunking() {
let chunker = Chunker::new(ChunkStrategy::TokenBased, 100, 20);
let text = "This is a test. ".repeat(50); let chunks = chunker.chunk(&text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(chunk.token_count <= 100 + 50); }
}
#[test]
fn test_semantic_chunking() {
let chunker = Chunker::new(ChunkStrategy::Semantic, 100, 0);
let text = [
"Paragraph 1. This is the first paragraph.",
"Paragraph 2. This is the second paragraph.",
"Paragraph 3. This is the third paragraph.",
]
.join("\n\n");
let chunks = chunker.chunk(&text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(!chunk.content.trim().is_empty());
}
}
#[test]
fn test_sliding_window_chunking() {
let chunker = Chunker::new(ChunkStrategy::SlidingWindow, 50, 10);
let text = "word ".repeat(200); let chunks = chunker.chunk(&text);
assert!(chunks.len() > 1);
for i in 1..chunks.len() {
let prev_end = &chunks[i - 1].content[chunks[i - 1].content.len().saturating_sub(50)..];
let curr_start = &chunks[i].content[..50.min(chunks[i].content.len())];
assert!(
prev_end
.split_whitespace()
.any(|word| curr_start.contains(word)),
"Expected overlap between chunks"
);
}
}
#[test]
fn test_chunk_metadata() {
let chunker = Chunker::default();
let text = "Test content. ".repeat(100);
let chunks = chunker.chunk(&text);
for (i, chunk) in chunks.iter().enumerate() {
assert_eq!(chunk.index, i);
assert!(chunk.start_offset < chunk.end_offset);
assert!(chunk.token_count > 0);
}
}
#[test]
fn test_empty_text() {
let chunker = Chunker::default();
let chunks = chunker.chunk("");
assert!(chunks.is_empty() || chunks.len() == 1);
}
#[test]
fn test_token_estimation() {
let chunker = Chunker::default();
let text = "This is a test";
let tokens = chunker.estimate_tokens(text);
assert!((3..=5).contains(&tokens));
}
}