use std::fmt;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TextChunk {
pub index: usize,
pub start: usize,
pub end: usize,
pub text: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChunkStrategy {
Character,
Sentence,
Paragraph,
}
impl ChunkStrategy {
pub fn parse(s: &str) -> Option<Self> {
match s.to_lowercase().as_str() {
"character" | "char" => Some(Self::Character),
"sentence" | "sent" => Some(Self::Sentence),
"paragraph" | "para" => Some(Self::Paragraph),
_ => None,
}
}
}
impl fmt::Display for ChunkStrategy {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Character => write!(f, "character"),
Self::Sentence => write!(f, "sentence"),
Self::Paragraph => write!(f, "paragraph"),
}
}
}
#[derive(Debug)]
pub enum ChunkError {
InvalidChunkSize,
OverlapTooLarge,
}
impl fmt::Display for ChunkError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::InvalidChunkSize => write!(f, "chunk_size must be greater than 0"),
Self::OverlapTooLarge => write!(f, "overlap must be less than chunk_size"),
}
}
}
pub fn chunk_text(
text: &str,
chunk_size: usize,
overlap: usize,
strategy: ChunkStrategy,
) -> Result<Vec<TextChunk>, ChunkError> {
if chunk_size == 0 {
return Err(ChunkError::InvalidChunkSize);
}
if overlap >= chunk_size {
return Err(ChunkError::OverlapTooLarge);
}
if text.is_empty() {
return Ok(Vec::new());
}
match strategy {
ChunkStrategy::Character => chunk_by_characters(text, chunk_size, overlap),
ChunkStrategy::Sentence => chunk_by_sentences(text, chunk_size, overlap),
ChunkStrategy::Paragraph => chunk_by_paragraphs(text, chunk_size, overlap),
}
}
fn chunk_by_characters(
text: &str,
chunk_size: usize,
overlap: usize,
) -> Result<Vec<TextChunk>, ChunkError> {
let chars: Vec<char> = text.chars().collect();
let total = chars.len();
let step = chunk_size - overlap;
let mut chunks = Vec::new();
let mut pos = 0usize;
let mut index = 0usize;
while pos < total {
let end = (pos + chunk_size).min(total);
let chunk_chars = &chars[pos..end];
let text_content: String = chunk_chars.iter().collect();
chunks.push(TextChunk {
index,
start: pos,
end,
text: text_content,
});
index += 1;
pos += step;
if end == total {
break;
}
}
Ok(chunks)
}
fn chunk_by_sentences(
text: &str,
chunk_size: usize,
overlap: usize,
) -> Result<Vec<TextChunk>, ChunkError> {
let sentences = split_sentences(text);
if sentences.is_empty() {
return Ok(Vec::new());
}
build_chunks_from_segments(&sentences, chunk_size, overlap)
}
fn chunk_by_paragraphs(
text: &str,
chunk_size: usize,
overlap: usize,
) -> Result<Vec<TextChunk>, ChunkError> {
let paragraphs = split_paragraphs(text);
if paragraphs.is_empty() {
return Ok(Vec::new());
}
build_chunks_from_segments(¶graphs, chunk_size, overlap)
}
fn split_sentences(text: &str) -> Vec<(usize, String)> {
let chars: Vec<char> = text.chars().collect();
let mut segments = Vec::new();
let mut start = 0usize;
let mut i = 0usize;
while i < chars.len() {
let ch = chars[i];
let is_sentence_end = (ch == '.' || ch == '!' || ch == '?')
&& (i + 1 >= chars.len() || chars[i + 1].is_whitespace());
if is_sentence_end {
let mut end = i + 1;
while end < chars.len() && chars[end].is_whitespace() && chars[end] != '\n' {
end += 1;
}
let segment: String = chars[start..end].iter().collect();
segments.push((start, segment));
start = end;
i = end;
} else {
i += 1;
}
}
if start < chars.len() {
let segment: String = chars[start..].iter().collect();
segments.push((start, segment));
}
segments
}
fn split_paragraphs(text: &str) -> Vec<(usize, String)> {
let mut segments = Vec::new();
let mut char_offset = 0usize;
let mut remaining = text;
while let Some(pos) = find_paragraph_break(remaining) {
let para = &remaining[..pos];
let para_chars: Vec<char> = para.chars().collect();
if !para_chars.is_empty() {
segments.push((char_offset, para.to_string()));
}
char_offset += para.chars().count();
let break_str = &remaining[pos..];
let break_len = if break_str.starts_with("\r\n\r\n") {
4
} else {
2 };
let break_chars = remaining[pos..pos + break_len].chars().count();
char_offset += break_chars;
remaining = &remaining[pos + break_len..];
}
if !remaining.is_empty() {
segments.push((char_offset, remaining.to_string()));
}
segments
}
fn find_paragraph_break(text: &str) -> Option<usize> {
if let Some(pos) = text.find("\r\n\r\n") {
let nn_pos = text.find("\n\n");
match nn_pos {
Some(nn) if nn < pos => Some(nn),
_ => Some(pos),
}
} else {
text.find("\n\n")
}
}
fn build_chunks_from_segments(
segments: &[(usize, String)],
chunk_size: usize,
overlap: usize,
) -> Result<Vec<TextChunk>, ChunkError> {
let mut chunks = Vec::new();
let mut current_text = String::new();
let mut current_start: Option<usize> = None;
let mut index = 0usize;
for (seg_offset, seg_text) in segments {
let seg_chars = seg_text.chars().count();
if seg_chars > chunk_size {
if !current_text.is_empty() {
let start = current_start.unwrap_or(0);
let end = start + current_text.chars().count();
chunks.push(TextChunk {
index,
start,
end,
text: std::mem::take(&mut current_text),
});
index += 1;
current_start = None;
}
let sub_chunks = chunk_by_characters(seg_text, chunk_size, overlap)?;
for sub in sub_chunks {
chunks.push(TextChunk {
index,
start: seg_offset + sub.start,
end: seg_offset + sub.end,
text: sub.text,
});
index += 1;
}
continue;
}
let current_chars = current_text.chars().count();
if current_chars + seg_chars > chunk_size && !current_text.is_empty() {
let start = current_start.unwrap_or(0);
let end = start + current_chars;
chunks.push(TextChunk {
index,
start,
end,
text: current_text.clone(),
});
index += 1;
if overlap > 0 && current_chars > overlap {
let chars: Vec<char> = current_text.chars().collect();
let overlap_chars = &chars[current_chars - overlap..];
current_text = overlap_chars.iter().collect();
current_start = Some(end - overlap);
} else {
current_text.clear();
current_start = None;
}
}
if current_start.is_none() {
current_start = Some(*seg_offset);
}
current_text.push_str(seg_text);
}
if !current_text.is_empty() {
let start = current_start.unwrap_or(0);
let end = start + current_text.chars().count();
chunks.push(TextChunk {
index,
start,
end,
text: current_text,
});
}
Ok(chunks)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn character_basic() {
let text = "Hello, World! This is a test.";
let chunks = chunk_text(text, 10, 0, ChunkStrategy::Character).unwrap();
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].text, "Hello, Wor");
assert_eq!(chunks[0].start, 0);
assert_eq!(chunks[0].end, 10);
assert_eq!(chunks[1].text, "ld! This i");
assert_eq!(chunks[2].text, "s a test.");
}
#[test]
fn character_with_overlap() {
let text = "abcdefghijklmnop";
let chunks = chunk_text(text, 8, 3, ChunkStrategy::Character).unwrap();
assert_eq!(chunks.len(), 3);
assert_eq!(chunks[0].text, "abcdefgh");
assert_eq!(chunks[1].text, "fghijklm");
assert_eq!(chunks[1].start, 5);
assert_eq!(chunks[2].text, "klmnop");
}
#[test]
fn sentence_basic() {
let text = "First sentence. Second sentence. Third sentence.";
let chunks = chunk_text(text, 20, 0, ChunkStrategy::Sentence).unwrap();
assert!(chunks.len() >= 2);
assert!(chunks[0].text.contains("First"));
}
#[test]
fn paragraph_basic() {
let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
let chunks = chunk_text(text, 20, 0, ChunkStrategy::Paragraph).unwrap();
assert!(chunks.len() >= 2);
assert!(chunks[0].text.contains("Paragraph one"));
}
#[test]
fn empty_text() {
let chunks = chunk_text("", 10, 0, ChunkStrategy::Character).unwrap();
assert!(chunks.is_empty());
}
#[test]
fn text_smaller_than_chunk() {
let text = "short";
let chunks = chunk_text(text, 100, 0, ChunkStrategy::Character).unwrap();
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0].text, "short");
assert_eq!(chunks[0].start, 0);
assert_eq!(chunks[0].end, 5);
}
#[test]
fn invalid_params() {
assert!(chunk_text("text", 0, 0, ChunkStrategy::Character).is_err());
assert!(chunk_text("text", 5, 5, ChunkStrategy::Character).is_err());
assert!(chunk_text("text", 5, 10, ChunkStrategy::Character).is_err());
}
#[test]
fn utf8_safety() {
let text = "🌍🌎🌏🌍🌎🌏";
let chunks = chunk_text(text, 3, 0, ChunkStrategy::Character).unwrap();
assert_eq!(chunks.len(), 2);
assert_eq!(chunks[0].text, "🌍🌎🌏");
assert_eq!(chunks[1].text, "🌍🌎🌏");
}
#[test]
fn sentence_fallback_to_character() {
let text = "This is a very long sentence that exceeds the chunk size limit.";
let chunks = chunk_text(text, 20, 0, ChunkStrategy::Sentence).unwrap();
assert!(chunks.len() > 1);
for chunk in &chunks {
assert!(chunk.text.chars().count() <= 20);
}
}
#[test]
fn deterministic() {
let text = "Deterministic output means same input produces same output every time.";
let a = chunk_text(text, 15, 3, ChunkStrategy::Character).unwrap();
let b = chunk_text(text, 15, 3, ChunkStrategy::Character).unwrap();
assert_eq!(a.len(), b.len());
for (ca, cb) in a.iter().zip(b.iter()) {
assert_eq!(ca.text, cb.text);
assert_eq!(ca.start, cb.start);
assert_eq!(ca.end, cb.end);
}
}
#[test]
fn overlap_produces_shared_chars() {
let text = "0123456789abcdef";
let chunks = chunk_text(text, 8, 4, ChunkStrategy::Character).unwrap();
assert_eq!(chunks.len(), 3);
let c0_tail: String = chunks[0]
.text
.chars()
.rev()
.take(4)
.collect::<Vec<_>>()
.into_iter()
.rev()
.collect();
let c1_head: String = chunks[1].text.chars().take(4).collect();
assert_eq!(c0_tail, c1_head);
}
}