use crate::{Document, Result};
#[derive(Debug, Clone)]
pub struct DocumentChunk {
pub id: String,
pub content: String,
pub tokens: usize,
pub page_numbers: Vec<usize>,
pub chunk_index: usize,
pub metadata: ChunkMetadata,
}
#[derive(Debug, Clone, Default)]
pub struct ChunkMetadata {
pub position: ChunkPosition,
pub confidence: f32,
pub sentence_boundary_respected: bool,
}
#[derive(Debug, Clone, Default)]
pub struct ChunkPosition {
pub start_char: usize,
pub end_char: usize,
pub first_page: usize,
pub last_page: usize,
}
#[derive(Debug, Clone)]
pub struct DocumentChunker {
chunk_size: usize,
overlap: usize,
}
impl DocumentChunker {
pub fn new(chunk_size: usize, overlap: usize) -> Self {
Self {
chunk_size,
overlap,
}
}
pub fn default() -> Self {
Self::new(512, 50)
}
pub fn chunk_document(&self, doc: &Document) -> Result<Vec<DocumentChunk>> {
let full_text = doc.extract_text()?;
self.chunk_text(&full_text)
}
pub fn chunk_text(&self, text: &str) -> Result<Vec<DocumentChunk>> {
self.chunk_text_internal(text, &[], 0)
}
pub fn chunk_text_with_pages(
&self,
page_texts: &[(usize, String)],
) -> Result<Vec<DocumentChunk>> {
let mut full_text = String::new();
let mut page_boundaries = vec![0];
for (_page_num, text) in page_texts {
if !full_text.is_empty() {
full_text.push_str("\n\n"); }
full_text.push_str(text);
page_boundaries.push(full_text.len());
}
let page_numbers: Vec<usize> = page_texts.iter().map(|(num, _)| *num).collect();
self.chunk_text_internal(&full_text, &page_boundaries, page_numbers[0])
}
fn chunk_text_internal(
&self,
text: &str,
page_boundaries: &[usize],
first_page: usize,
) -> Result<Vec<DocumentChunk>> {
if text.is_empty() {
return Ok(Vec::new());
}
let tokens: Vec<&str> = text.split_whitespace().collect();
if tokens.is_empty() {
return Ok(Vec::new());
}
let mut chunks = Vec::new();
let mut start = 0;
let mut chunk_idx = 0;
let mut char_offset = 0;
while start < tokens.len() {
let mut end = (start + self.chunk_size).min(tokens.len());
let sentence_boundary_respected = if end < tokens.len() && end > start {
let search_window = (end.saturating_sub(10)..end).rev();
let mut found_boundary = false;
for i in search_window {
let token = tokens[i];
if token.ends_with('.') || token.ends_with('!') || token.ends_with('?') {
end = i + 1; found_boundary = true;
break;
}
}
found_boundary
} else {
false
};
let chunk_tokens = &tokens[start..end];
let content = chunk_tokens.join(" ");
let start_char = char_offset;
let end_char = char_offset + content.len();
char_offset = end_char;
let (page_nums, first_pg, last_pg) = if page_boundaries.is_empty() {
(Vec::new(), 0, 0)
} else {
let mut pages = Vec::new();
let mut first = first_page;
let mut last = first_page;
for (idx, &boundary) in page_boundaries.iter().enumerate().skip(1) {
if start_char < boundary && end_char > page_boundaries[idx - 1] {
let page_num = first_page + idx - 1;
pages.push(page_num);
if pages.len() == 1 {
first = page_num;
}
last = page_num;
}
}
if pages.is_empty() {
pages.push(first_page);
first = first_page;
last = first_page;
}
(pages, first, last)
};
let chunk = DocumentChunk {
id: format!("chunk_{}", chunk_idx),
content,
tokens: chunk_tokens.len(),
page_numbers: page_nums.clone(),
chunk_index: chunk_idx,
metadata: ChunkMetadata {
position: ChunkPosition {
start_char,
end_char,
first_page: first_pg,
last_page: last_pg,
},
confidence: 1.0, sentence_boundary_respected,
},
};
chunks.push(chunk);
chunk_idx += 1;
if end < tokens.len() {
start = end.saturating_sub(self.overlap);
if start + self.chunk_size <= end {
start = end;
}
} else {
break;
}
}
Ok(chunks)
}
pub fn estimate_tokens(text: &str) -> usize {
let words = text.split_whitespace().count();
((words as f32) * 1.33) as usize
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_chunking() {
let chunker = DocumentChunker::new(10, 2);
let text = (0..25)
.map(|i| format!("word{}", i))
.collect::<Vec<_>>()
.join(" ");
let chunks = chunker.chunk_text(&text).unwrap();
assert_eq!(chunks.len(), 3, "Should create 3 chunks");
assert_eq!(chunks[0].tokens, 10);
assert_eq!(chunks[0].chunk_index, 0);
assert_eq!(chunks[0].id, "chunk_0");
assert_eq!(chunks[0].metadata.position.start_char, 0);
assert_eq!(chunks[1].tokens, 10);
assert_eq!(chunks[1].chunk_index, 1);
assert_eq!(chunks[2].tokens, 9);
assert_eq!(chunks[2].chunk_index, 2);
}
#[test]
fn test_overlap_preserves_context() {
let chunker = DocumentChunker::new(5, 2);
let text = "a b c d e f g h i j";
let chunks = chunker.chunk_text(&text).unwrap();
let chunk0_end = chunks[0]
.content
.split_whitespace()
.rev()
.take(2)
.collect::<Vec<_>>();
let chunk1_start = chunks[1]
.content
.split_whitespace()
.take(2)
.collect::<Vec<_>>();
assert_eq!(chunk0_end, vec!["e", "d"]);
assert_eq!(chunk1_start, vec!["d", "e"]);
}
#[test]
fn test_empty_text() {
let chunker = DocumentChunker::new(10, 2);
let chunks = chunker.chunk_text("").unwrap();
assert_eq!(chunks.len(), 0);
}
#[test]
fn test_text_smaller_than_chunk_size() {
let chunker = DocumentChunker::new(100, 10);
let text = "just a few words";
let chunks = chunker.chunk_text(&text).unwrap();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].tokens, 4);
}
#[test]
fn test_token_estimation() {
let tokens = DocumentChunker::estimate_tokens("hello world");
assert!(
tokens >= 2 && tokens <= 3,
"Expected ~2-3 tokens, got {}",
tokens
);
assert_eq!(DocumentChunker::estimate_tokens(""), 0);
let long_text = (0..100)
.map(|i| format!("word{}", i))
.collect::<Vec<_>>()
.join(" ");
let tokens_long = DocumentChunker::estimate_tokens(&long_text);
assert!(
tokens_long >= 120 && tokens_long <= 140,
"Expected ~133 tokens, got {}",
tokens_long
);
}
#[test]
fn test_chunk_ids_are_unique() {
let chunker = DocumentChunker::new(5, 1);
let text = (0..20)
.map(|i| format!("word{}", i))
.collect::<Vec<_>>()
.join(" ");
let chunks = chunker.chunk_text(&text).unwrap();
let ids: Vec<String> = chunks.iter().map(|c| c.id.clone()).collect();
let unique_ids: std::collections::HashSet<_> = ids.iter().collect();
assert_eq!(
ids.len(),
unique_ids.len(),
"All chunk IDs should be unique"
);
}
#[test]
fn test_sentence_boundary_detection() {
let chunker = DocumentChunker::new(10, 2);
let text = "This is the first sentence. This is the second sentence. This is the third sentence. And here is a fourth one.";
let chunks = chunker.chunk_text(&text).unwrap();
let has_boundary_respect = chunks
.iter()
.any(|c| c.metadata.sentence_boundary_respected);
assert!(
has_boundary_respect,
"At least some chunks should respect sentence boundaries"
);
for (i, chunk) in chunks.iter().enumerate() {
if i < chunks.len() - 1 && chunk.metadata.sentence_boundary_respected {
assert!(
chunk.content.ends_with('.')
|| chunk.content.ends_with('!')
|| chunk.content.ends_with('?'),
"Chunk {} should end with sentence punctuation",
i
);
}
}
}
#[test]
fn test_page_tracking() {
let chunker = DocumentChunker::new(10, 2);
let page_texts = vec![
(1, "This is page one content.".to_string()),
(2, "This is page two content.".to_string()),
(3, "This is page three content.".to_string()),
];
let chunks = chunker.chunk_text_with_pages(&page_texts).unwrap();
for chunk in &chunks {
assert!(
!chunk.page_numbers.is_empty(),
"Chunk should have page numbers"
);
assert!(
chunk.metadata.position.first_page > 0,
"First page should be > 0"
);
assert!(
chunk.metadata.position.last_page > 0,
"Last page should be > 0"
);
}
assert_eq!(
chunks[0].metadata.position.first_page, 1,
"First chunk should start at page 1"
);
}
#[test]
fn test_metadata_position_tracking() {
let chunker = DocumentChunker::new(5, 1);
let text = "word1 word2 word3 word4 word5 word6 word7 word8 word9 word10";
let chunks = chunker.chunk_text(&text).unwrap();
for i in 0..chunks.len() - 1 {
assert!(
chunks[i].metadata.position.end_char
<= chunks[i + 1].metadata.position.start_char + 10,
"Chunks should have reasonable character positions"
);
}
assert_eq!(chunks[0].metadata.position.start_char, 0);
for chunk in &chunks {
assert!(
chunk.metadata.position.end_char > chunk.metadata.position.start_char,
"End char should be greater than start char"
);
}
}
#[test]
fn test_confidence_scores() {
let chunker = DocumentChunker::new(10, 2);
let text = "This is a test document with multiple sentences.";
let chunks = chunker.chunk_text(&text).unwrap();
for chunk in &chunks {
assert!(
chunk.metadata.confidence >= 0.0 && chunk.metadata.confidence <= 1.0,
"Confidence should be between 0.0 and 1.0"
);
}
}
#[test]
fn test_performance_100_pages() {
use std::time::Instant;
let chunker = DocumentChunker::new(512, 50);
let page_texts: Vec<(usize, String)> = (1..=100)
.map(|page_num| {
let words: Vec<String> = (0..200).map(|i| format!("word{}", i)).collect();
(page_num, words.join(" "))
})
.collect();
let start = Instant::now();
let chunks = chunker.chunk_text_with_pages(&page_texts).unwrap();
let duration = start.elapsed();
tracing::debug!("Chunked 100 pages in {:?}", duration);
tracing::debug!("Created {} chunks", chunks.len());
assert!(
duration.as_millis() < 500,
"Chunking 100 pages took {:?}, should be < 500ms",
duration
);
}
}