1use std::fmt;
14
15#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct TextChunk {
18 pub index: usize,
20 pub start: usize,
22 pub end: usize,
24 pub text: String,
26}
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum ChunkStrategy {
31 Character,
32 Sentence,
33 Paragraph,
34}
35
36impl ChunkStrategy {
37 pub fn parse(s: &str) -> Option<Self> {
38 match s.to_lowercase().as_str() {
39 "character" | "char" => Some(Self::Character),
40 "sentence" | "sent" => Some(Self::Sentence),
41 "paragraph" | "para" => Some(Self::Paragraph),
42 _ => None,
43 }
44 }
45}
46
47impl fmt::Display for ChunkStrategy {
48 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
49 match self {
50 Self::Character => write!(f, "character"),
51 Self::Sentence => write!(f, "sentence"),
52 Self::Paragraph => write!(f, "paragraph"),
53 }
54 }
55}
56
57#[derive(Debug)]
59pub enum ChunkError {
60 InvalidChunkSize,
62 OverlapTooLarge,
64}
65
66impl fmt::Display for ChunkError {
67 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
68 match self {
69 Self::InvalidChunkSize => write!(f, "chunk_size must be greater than 0"),
70 Self::OverlapTooLarge => write!(f, "overlap must be less than chunk_size"),
71 }
72 }
73}
74
75pub fn chunk_text(
83 text: &str,
84 chunk_size: usize,
85 overlap: usize,
86 strategy: ChunkStrategy,
87) -> Result<Vec<TextChunk>, ChunkError> {
88 if chunk_size == 0 {
89 return Err(ChunkError::InvalidChunkSize);
90 }
91 if overlap >= chunk_size {
92 return Err(ChunkError::OverlapTooLarge);
93 }
94 if text.is_empty() {
95 return Ok(Vec::new());
96 }
97
98 match strategy {
99 ChunkStrategy::Character => chunk_by_characters(text, chunk_size, overlap),
100 ChunkStrategy::Sentence => chunk_by_sentences(text, chunk_size, overlap),
101 ChunkStrategy::Paragraph => chunk_by_paragraphs(text, chunk_size, overlap),
102 }
103}
104
105fn chunk_by_characters(
107 text: &str,
108 chunk_size: usize,
109 overlap: usize,
110) -> Result<Vec<TextChunk>, ChunkError> {
111 let chars: Vec<char> = text.chars().collect();
112 let total = chars.len();
113 let step = chunk_size - overlap;
114 let mut chunks = Vec::new();
115 let mut pos = 0usize;
116 let mut index = 0usize;
117
118 while pos < total {
119 let end = (pos + chunk_size).min(total);
120 let chunk_chars = &chars[pos..end];
121 let text_content: String = chunk_chars.iter().collect();
122
123 chunks.push(TextChunk {
125 index,
126 start: pos,
127 end,
128 text: text_content,
129 });
130
131 index += 1;
132 pos += step;
133
134 if end == total {
136 break;
137 }
138 }
139
140 Ok(chunks)
141}
142
143fn chunk_by_sentences(
148 text: &str,
149 chunk_size: usize,
150 overlap: usize,
151) -> Result<Vec<TextChunk>, ChunkError> {
152 let sentences = split_sentences(text);
153 if sentences.is_empty() {
154 return Ok(Vec::new());
155 }
156
157 build_chunks_from_segments(&sentences, chunk_size, overlap)
158}
159
160fn chunk_by_paragraphs(
165 text: &str,
166 chunk_size: usize,
167 overlap: usize,
168) -> Result<Vec<TextChunk>, ChunkError> {
169 let paragraphs = split_paragraphs(text);
170 if paragraphs.is_empty() {
171 return Ok(Vec::new());
172 }
173
174 build_chunks_from_segments(¶graphs, chunk_size, overlap)
175}
176
177fn split_sentences(text: &str) -> Vec<(usize, String)> {
181 let chars: Vec<char> = text.chars().collect();
182 let mut segments = Vec::new();
183 let mut start = 0usize;
184
185 let mut i = 0usize;
186 while i < chars.len() {
187 let ch = chars[i];
188 let is_sentence_end = (ch == '.' || ch == '!' || ch == '?')
189 && (i + 1 >= chars.len() || chars[i + 1].is_whitespace());
190
191 if is_sentence_end {
192 let mut end = i + 1;
194 while end < chars.len() && chars[end].is_whitespace() && chars[end] != '\n' {
195 end += 1;
196 }
197 let segment: String = chars[start..end].iter().collect();
198 segments.push((start, segment));
199 start = end;
200 i = end;
201 } else {
202 i += 1;
203 }
204 }
205
206 if start < chars.len() {
208 let segment: String = chars[start..].iter().collect();
209 segments.push((start, segment));
210 }
211
212 segments
213}
214
215fn split_paragraphs(text: &str) -> Vec<(usize, String)> {
217 let mut segments = Vec::new();
218 let mut char_offset = 0usize;
219
220 let mut remaining = text;
222 while let Some(pos) = find_paragraph_break(remaining) {
223 let para = &remaining[..pos];
224 let para_chars: Vec<char> = para.chars().collect();
225 if !para_chars.is_empty() {
226 segments.push((char_offset, para.to_string()));
227 }
228 char_offset += para.chars().count();
229
230 let break_str = &remaining[pos..];
232 let break_len = if break_str.starts_with("\r\n\r\n") {
233 4
234 } else {
235 2 };
237 let break_chars = remaining[pos..pos + break_len].chars().count();
238 char_offset += break_chars;
239 remaining = &remaining[pos + break_len..];
240 }
241
242 if !remaining.is_empty() {
244 segments.push((char_offset, remaining.to_string()));
245 }
246
247 segments
248}
249
250fn find_paragraph_break(text: &str) -> Option<usize> {
252 if let Some(pos) = text.find("\r\n\r\n") {
254 let nn_pos = text.find("\n\n");
255 match nn_pos {
257 Some(nn) if nn < pos => Some(nn),
258 _ => Some(pos),
259 }
260 } else {
261 text.find("\n\n")
262 }
263}
264
265fn build_chunks_from_segments(
269 segments: &[(usize, String)],
270 chunk_size: usize,
271 overlap: usize,
272) -> Result<Vec<TextChunk>, ChunkError> {
273 let mut chunks = Vec::new();
274 let mut current_text = String::new();
275 let mut current_start: Option<usize> = None;
276 let mut index = 0usize;
277
278 for (seg_offset, seg_text) in segments {
279 let seg_chars = seg_text.chars().count();
280
281 if seg_chars > chunk_size {
283 if !current_text.is_empty() {
285 let start = current_start.unwrap_or(0);
286 let end = start + current_text.chars().count();
287 chunks.push(TextChunk {
288 index,
289 start,
290 end,
291 text: std::mem::take(&mut current_text),
292 });
293 index += 1;
294 current_start = None;
295 }
296
297 let sub_chunks = chunk_by_characters(seg_text, chunk_size, overlap)?;
299 for sub in sub_chunks {
300 chunks.push(TextChunk {
301 index,
302 start: seg_offset + sub.start,
303 end: seg_offset + sub.end,
304 text: sub.text,
305 });
306 index += 1;
307 }
308 continue;
309 }
310
311 let current_chars = current_text.chars().count();
312 if current_chars + seg_chars > chunk_size && !current_text.is_empty() {
314 let start = current_start.unwrap_or(0);
316 let end = start + current_chars;
317 chunks.push(TextChunk {
318 index,
319 start,
320 end,
321 text: current_text.clone(),
322 });
323 index += 1;
324
325 if overlap > 0 && current_chars > overlap {
327 let chars: Vec<char> = current_text.chars().collect();
328 let overlap_chars = &chars[current_chars - overlap..];
329 current_text = overlap_chars.iter().collect();
330 current_start = Some(end - overlap);
331 } else {
332 current_text.clear();
333 current_start = None;
334 }
335 }
336
337 if current_start.is_none() {
338 current_start = Some(*seg_offset);
339 }
340 current_text.push_str(seg_text);
341 }
342
343 if !current_text.is_empty() {
345 let start = current_start.unwrap_or(0);
346 let end = start + current_text.chars().count();
347 chunks.push(TextChunk {
348 index,
349 start,
350 end,
351 text: current_text,
352 });
353 }
354
355 Ok(chunks)
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn character_basic() {
364 let text = "Hello, World! This is a test.";
365 let chunks = chunk_text(text, 10, 0, ChunkStrategy::Character).unwrap();
366 assert_eq!(chunks.len(), 3);
367 assert_eq!(chunks[0].text, "Hello, Wor");
368 assert_eq!(chunks[0].start, 0);
369 assert_eq!(chunks[0].end, 10);
370 assert_eq!(chunks[1].text, "ld! This i");
371 assert_eq!(chunks[2].text, "s a test.");
372 }
373
374 #[test]
375 fn character_with_overlap() {
376 let text = "abcdefghijklmnop";
377 let chunks = chunk_text(text, 8, 3, ChunkStrategy::Character).unwrap();
378 assert_eq!(chunks.len(), 3);
383 assert_eq!(chunks[0].text, "abcdefgh");
384 assert_eq!(chunks[1].text, "fghijklm");
385 assert_eq!(chunks[1].start, 5);
386 assert_eq!(chunks[2].text, "klmnop");
387 }
388
389 #[test]
390 fn sentence_basic() {
391 let text = "First sentence. Second sentence. Third sentence.";
392 let chunks = chunk_text(text, 20, 0, ChunkStrategy::Sentence).unwrap();
393 assert!(chunks.len() >= 2);
397 assert!(chunks[0].text.contains("First"));
398 }
399
400 #[test]
401 fn paragraph_basic() {
402 let text = "Paragraph one.\n\nParagraph two.\n\nParagraph three.";
403 let chunks = chunk_text(text, 20, 0, ChunkStrategy::Paragraph).unwrap();
404 assert!(chunks.len() >= 2);
405 assert!(chunks[0].text.contains("Paragraph one"));
406 }
407
408 #[test]
409 fn empty_text() {
410 let chunks = chunk_text("", 10, 0, ChunkStrategy::Character).unwrap();
411 assert!(chunks.is_empty());
412 }
413
414 #[test]
415 fn text_smaller_than_chunk() {
416 let text = "short";
417 let chunks = chunk_text(text, 100, 0, ChunkStrategy::Character).unwrap();
418 assert_eq!(chunks.len(), 1);
419 assert_eq!(chunks[0].text, "short");
420 assert_eq!(chunks[0].start, 0);
421 assert_eq!(chunks[0].end, 5);
422 }
423
424 #[test]
425 fn invalid_params() {
426 assert!(chunk_text("text", 0, 0, ChunkStrategy::Character).is_err());
427 assert!(chunk_text("text", 5, 5, ChunkStrategy::Character).is_err());
428 assert!(chunk_text("text", 5, 10, ChunkStrategy::Character).is_err());
429 }
430
431 #[test]
432 fn utf8_safety() {
433 let text = "🌍🌎🌏🌍🌎🌏";
435 let chunks = chunk_text(text, 3, 0, ChunkStrategy::Character).unwrap();
436 assert_eq!(chunks.len(), 2);
437 assert_eq!(chunks[0].text, "🌍🌎🌏");
438 assert_eq!(chunks[1].text, "🌍🌎🌏");
439 }
440
441 #[test]
442 fn sentence_fallback_to_character() {
443 let text = "This is a very long sentence that exceeds the chunk size limit.";
445 let chunks = chunk_text(text, 20, 0, ChunkStrategy::Sentence).unwrap();
446 assert!(chunks.len() > 1);
447 for chunk in &chunks {
448 assert!(chunk.text.chars().count() <= 20);
449 }
450 }
451
452 #[test]
453 fn deterministic() {
454 let text = "Deterministic output means same input produces same output every time.";
455 let a = chunk_text(text, 15, 3, ChunkStrategy::Character).unwrap();
456 let b = chunk_text(text, 15, 3, ChunkStrategy::Character).unwrap();
457 assert_eq!(a.len(), b.len());
458 for (ca, cb) in a.iter().zip(b.iter()) {
459 assert_eq!(ca.text, cb.text);
460 assert_eq!(ca.start, cb.start);
461 assert_eq!(ca.end, cb.end);
462 }
463 }
464
465 #[test]
466 fn overlap_produces_shared_chars() {
467 let text = "0123456789abcdef";
468 let chunks = chunk_text(text, 8, 4, ChunkStrategy::Character).unwrap();
469 assert_eq!(chunks.len(), 3);
471 let c0_tail: String = chunks[0]
473 .text
474 .chars()
475 .rev()
476 .take(4)
477 .collect::<Vec<_>>()
478 .into_iter()
479 .rev()
480 .collect();
481 let c1_head: String = chunks[1].text.chars().take(4).collect();
482 assert_eq!(c0_tail, c1_head);
483 }
484}