1use std::fmt;
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
15pub struct TextChunk {
16 pub index: usize,
18 pub start: usize,
20 pub end: usize,
22 pub text: String,
24}
25
26#[derive(Debug, Clone, Copy, PartialEq, Eq)]
28pub enum ChunkStrategy {
29 Character,
30 Sentence,
31 Paragraph,
32}
33
34impl ChunkStrategy {
35 pub fn parse(s: &str) -> Option<Self> {
36 match s.to_lowercase().as_str() {
37 "character" | "char" => Some(Self::Character),
38 "sentence" | "sent" => Some(Self::Sentence),
39 "paragraph" | "para" => Some(Self::Paragraph),
40 _ => None,
41 }
42 }
43}
44
45impl fmt::Display for ChunkStrategy {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::Character => write!(f, "character"),
49 Self::Sentence => write!(f, "sentence"),
50 Self::Paragraph => write!(f, "paragraph"),
51 }
52 }
53}
54
55#[derive(Debug)]
57pub enum ChunkError {
58 InvalidChunkSize,
60 OverlapTooLarge,
62}
63
64impl fmt::Display for ChunkError {
65 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
66 match self {
67 Self::InvalidChunkSize => write!(f, "chunk_size must be greater than 0"),
68 Self::OverlapTooLarge => write!(f, "overlap must be less than chunk_size"),
69 }
70 }
71}
72
73pub fn chunk_text(
81 text: &str,
82 chunk_size: usize,
83 overlap: usize,
84 strategy: ChunkStrategy,
85) -> Result<Vec<TextChunk>, ChunkError> {
86 if chunk_size == 0 {
87 return Err(ChunkError::InvalidChunkSize);
88 }
89 if overlap >= chunk_size {
90 return Err(ChunkError::OverlapTooLarge);
91 }
92 if text.is_empty() {
93 return Ok(Vec::new());
94 }
95
96 match strategy {
97 ChunkStrategy::Character => chunk_by_characters(text, chunk_size, overlap),
98 ChunkStrategy::Sentence => chunk_by_sentences(text, chunk_size, overlap),
99 ChunkStrategy::Paragraph => chunk_by_paragraphs(text, chunk_size, overlap),
100 }
101}
102
103fn chunk_by_characters(
105 text: &str,
106 chunk_size: usize,
107 overlap: usize,
108) -> Result<Vec<TextChunk>, ChunkError> {
109 let chars: Vec<char> = text.chars().collect();
110 let total = chars.len();
111 let step = chunk_size - overlap;
112 let mut chunks = Vec::new();
113 let mut pos = 0usize;
114 let mut index = 0usize;
115
116 while pos < total {
117 let end = (pos + chunk_size).min(total);
118 let chunk_chars = &chars[pos..end];
119 let text_content: String = chunk_chars.iter().collect();
120
121 chunks.push(TextChunk {
123 index,
124 start: pos,
125 end,
126 text: text_content,
127 });
128
129 index += 1;
130 pos += step;
131
132 if end == total {
134 break;
135 }
136 }
137
138 Ok(chunks)
139}
140
141fn chunk_by_sentences(
146 text: &str,
147 chunk_size: usize,
148 overlap: usize,
149) -> Result<Vec<TextChunk>, ChunkError> {
150 let sentences = split_sentences(text);
151 if sentences.is_empty() {
152 return Ok(Vec::new());
153 }
154
155 build_chunks_from_segments(&sentences, chunk_size, overlap)
156}
157
158fn chunk_by_paragraphs(
163 text: &str,
164 chunk_size: usize,
165 overlap: usize,
166) -> Result<Vec<TextChunk>, ChunkError> {
167 let paragraphs = split_paragraphs(text);
168 if paragraphs.is_empty() {
169 return Ok(Vec::new());
170 }
171
172 build_chunks_from_segments(¶graphs, chunk_size, overlap)
173}
174
175fn split_sentences(text: &str) -> Vec<(usize, String)> {
179 let chars: Vec<char> = text.chars().collect();
180 let mut segments = Vec::new();
181 let mut start = 0usize;
182
183 let mut i = 0usize;
184 while i < chars.len() {
185 let ch = chars[i];
186 let is_sentence_end = (ch == '.' || ch == '!' || ch == '?')
187 && (i + 1 >= chars.len() || chars[i + 1].is_whitespace());
188
189 if is_sentence_end {
190 let mut end = i + 1;
192 while end < chars.len() && chars[end].is_whitespace() && chars[end] != '\n' {
193 end += 1;
194 }
195 let segment: String = chars[start..end].iter().collect();
196 segments.push((start, segment));
197 start = end;
198 i = end;
199 } else {
200 i += 1;
201 }
202 }
203
204 if start < chars.len() {
206 let segment: String = chars[start..].iter().collect();
207 segments.push((start, segment));
208 }
209
210 segments
211}
212
213fn split_paragraphs(text: &str) -> Vec<(usize, String)> {
215 let mut segments = Vec::new();
216 let mut char_offset = 0usize;
217
218 let mut remaining = text;
220 while let Some(pos) = find_paragraph_break(remaining) {
221 let para = &remaining[..pos];
222 let para_chars: Vec<char> = para.chars().collect();
223 if !para_chars.is_empty() {
224 segments.push((char_offset, para.to_string()));
225 }
226 char_offset += para.chars().count();
227
228 let break_str = &remaining[pos..];
230 let break_len = if break_str.starts_with("\r\n\r\n") {
231 4
232 } else {
233 2 };
235 let break_chars = remaining[pos..pos + break_len].chars().count();
236 char_offset += break_chars;
237 remaining = &remaining[pos + break_len..];
238 }
239
240 if !remaining.is_empty() {
242 segments.push((char_offset, remaining.to_string()));
243 }
244
245 segments
246}
247
248fn find_paragraph_break(text: &str) -> Option<usize> {
250 if let Some(pos) = text.find("\r\n\r\n") {
252 let nn_pos = text.find("\n\n");
253 match nn_pos {
255 Some(nn) if nn < pos => Some(nn),
256 _ => Some(pos),
257 }
258 } else {
259 text.find("\n\n")
260 }
261}
262
263fn build_chunks_from_segments(
267 segments: &[(usize, String)],
268 chunk_size: usize,
269 overlap: usize,
270) -> Result<Vec<TextChunk>, ChunkError> {
271 let mut chunks = Vec::new();
272 let mut current_text = String::new();
273 let mut current_start: Option<usize> = None;
274 let mut index = 0usize;
275
276 for (seg_offset, seg_text) in segments {
277 let seg_chars = seg_text.chars().count();
278
279 if seg_chars > chunk_size {
281 if !current_text.is_empty() {
283 let start = current_start.unwrap_or(0);
284 let end = start + current_text.chars().count();
285 chunks.push(TextChunk {
286 index,
287 start,
288 end,
289 text: std::mem::take(&mut current_text),
290 });
291 index += 1;
292 current_start = None;
293 }
294
295 let sub_chunks = chunk_by_characters(seg_text, chunk_size, overlap)?;
297 for sub in sub_chunks {
298 chunks.push(TextChunk {
299 index,
300 start: seg_offset + sub.start,
301 end: seg_offset + sub.end,
302 text: sub.text,
303 });
304 index += 1;
305 }
306 continue;
307 }
308
309 let current_chars = current_text.chars().count();
310 if current_chars + seg_chars > chunk_size && !current_text.is_empty() {
312 let start = current_start.unwrap_or(0);
314 let end = start + current_chars;
315 chunks.push(TextChunk {
316 index,
317 start,
318 end,
319 text: current_text.clone(),
320 });
321 index += 1;
322
323 if overlap > 0 && current_chars > overlap {
325 let chars: Vec<char> = current_text.chars().collect();
326 let overlap_chars = &chars[current_chars - overlap..];
327 current_text = overlap_chars.iter().collect();
328 current_start = Some(end - overlap);
329 } else {
330 current_text.clear();
331 current_start = None;
332 }
333 }
334
335 if current_start.is_none() {
336 current_start = Some(*seg_offset);
337 }
338 current_text.push_str(seg_text);
339 }
340
341 if !current_text.is_empty() {
343 let start = current_start.unwrap_or(0);
344 let end = start + current_text.chars().count();
345 chunks.push(TextChunk {
346 index,
347 start,
348 end,
349 text: current_text,
350 });
351 }
352
353 Ok(chunks)
354}
355
356#[cfg(test)]
357mod tests {
358 use super::*;
359
360 #[test]
361 fn character_basic() {
362 let text = "Hello, World! This is a test.";
363 let chunks = chunk_text(text, 10, 0, ChunkStrategy::Character).unwrap();
364 assert_eq!(chunks.len(), 3);
365 assert_eq!(chunks[0].text, "Hello, Wor");
366 assert_eq!(chunks[0].start, 0);
367 assert_eq!(chunks[0].end, 10);
368 assert_eq!(chunks[1].text, "ld! This i");
369 assert_eq!(chunks[2].text, "s a test.");
370 }
371
372 #[test]
373 fn character_with_overlap() {
374 let text = "abcdefghijklmnop";
375 let chunks = chunk_text(text, 8, 3, ChunkStrategy::Character).unwrap();
376 assert_eq!(chunks.len(), 3);
381 assert_eq!(chunks[0].text, "abcdefgh");
382 assert_eq!(chunks[1].text, "fghijklm");
383 assert_eq!(chunks[1].start, 5);
384 assert_eq!(chunks[2].text, "klmnop");
385 }
386
387 #[test]
388 fn sentence_basic() {
389 let text = "First sentence. Second sentence. Third sentence.";
390 let chunks = chunk_text(text, 20, 0, ChunkStrategy::Sentence).unwrap();
391 assert!(chunks.len() >= 2);
395 assert!(chunks[0].text.contains("First"));
396 }
397
398 #[test]
399 fn paragraph_basic() {
400 let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
401 let chunks = chunk_text(text, 20, 0, ChunkStrategy::Paragraph).unwrap();
402 assert!(chunks.len() >= 2);
403 assert!(chunks[0].text.contains("Paragraph one"));
404 }
405
406 #[test]
407 fn empty_text() {
408 let chunks = chunk_text("", 10, 0, ChunkStrategy::Character).unwrap();
409 assert!(chunks.is_empty());
410 }
411
412 #[test]
413 fn text_smaller_than_chunk() {
414 let text = "short";
415 let chunks = chunk_text(text, 100, 0, ChunkStrategy::Character).unwrap();
416 assert_eq!(chunks.len(), 1);
417 assert_eq!(chunks[0].text, "short");
418 assert_eq!(chunks[0].start, 0);
419 assert_eq!(chunks[0].end, 5);
420 }
421
422 #[test]
423 fn invalid_params() {
424 assert!(chunk_text("text", 0, 0, ChunkStrategy::Character).is_err());
425 assert!(chunk_text("text", 5, 5, ChunkStrategy::Character).is_err());
426 assert!(chunk_text("text", 5, 10, ChunkStrategy::Character).is_err());
427 }
428
429 #[test]
430 fn utf8_safety() {
431 let text = "🌍🌎🌏🌍🌎🌏";
433 let chunks = chunk_text(text, 3, 0, ChunkStrategy::Character).unwrap();
434 assert_eq!(chunks.len(), 2);
435 assert_eq!(chunks[0].text, "🌍🌎🌏");
436 assert_eq!(chunks[1].text, "🌍🌎🌏");
437 }
438
439 #[test]
440 fn sentence_fallback_to_character() {
441 let text = "This is a very long sentence that exceeds the chunk size limit.";
443 let chunks = chunk_text(text, 20, 0, ChunkStrategy::Sentence).unwrap();
444 assert!(chunks.len() > 1);
445 for chunk in &chunks {
446 assert!(chunk.text.chars().count() <= 20);
447 }
448 }
449
450 #[test]
451 fn deterministic() {
452 let text = "Deterministic output means same input produces same output every time.";
453 let a = chunk_text(text, 15, 3, ChunkStrategy::Character).unwrap();
454 let b = chunk_text(text, 15, 3, ChunkStrategy::Character).unwrap();
455 assert_eq!(a.len(), b.len());
456 for (ca, cb) in a.iter().zip(b.iter()) {
457 assert_eq!(ca.text, cb.text);
458 assert_eq!(ca.start, cb.start);
459 assert_eq!(ca.end, cb.end);
460 }
461 }
462
463 #[test]
464 fn overlap_produces_shared_chars() {
465 let text = "0123456789abcdef";
466 let chunks = chunk_text(text, 8, 4, ChunkStrategy::Character).unwrap();
467 assert_eq!(chunks.len(), 3);
469 let c0_tail: String = chunks[0]
471 .text
472 .chars()
473 .rev()
474 .take(4)
475 .collect::<Vec<_>>()
476 .into_iter()
477 .rev()
478 .collect();
479 let c1_head: String = chunks[1].text.chars().take(4).collect();
480 assert_eq!(c0_tail, c1_head);
481 }
482}