1use crate::{Chunk, Document, DocumentType, EmbeddingIds};
7use uuid::Uuid;
8
9#[derive(Debug, Clone)]
11pub struct ChunkingConfig {
12 pub chunk_size: usize,
14 pub chunk_overlap: usize,
16 pub min_chunk_size: usize,
18 pub strategy: ChunkingStrategy,
20 pub respect_sentences: bool,
22}
23
24impl Default for ChunkingConfig {
25 fn default() -> Self {
26 Self {
27 chunk_size: 512,
28 chunk_overlap: 50,
29 min_chunk_size: 100,
30 strategy: ChunkingStrategy::Recursive,
31 respect_sentences: true,
32 }
33 }
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq)]
38pub enum ChunkingStrategy {
39 FixedSize,
41 Semantic,
43 Recursive,
45 DocumentAware,
47}
48
49#[derive(Debug, thiserror::Error)]
51pub enum ChunkingError {
52 #[error("Text too short for chunking: {0} characters")]
53 TextTooShort(usize),
54 #[error("Invalid chunk size: {0}")]
55 InvalidChunkSize(usize),
56 #[error("Chunking failed: {0}")]
57 ChunkingFailed(String),
58}
59
60pub fn chunk_document(
62 document: &Document,
63 config: &ChunkingConfig,
64) -> Result<Vec<Chunk>, ChunkingError> {
65 let text = &document.content.raw;
66
67 if text.is_empty() {
68 return Ok(Vec::new());
69 }
70
71 if text.len() < config.min_chunk_size {
72 return Err(ChunkingError::TextTooShort(text.len()));
73 }
74
75 if config.chunk_size < config.min_chunk_size {
76 return Err(ChunkingError::InvalidChunkSize(config.chunk_size));
77 }
78
79 let chunks = match config.strategy {
80 ChunkingStrategy::FixedSize => chunk_fixed_size(text, config, document.id),
81 ChunkingStrategy::Semantic => chunk_semantic(text, config, document.id, &document.doc_type),
82 ChunkingStrategy::Recursive => {
83 chunk_recursive(text, config, document.id, &document.doc_type)
84 }
85 ChunkingStrategy::DocumentAware => {
86 chunk_document_aware(text, config, document.id, &document.doc_type)
87 }
88 }?;
89
90 Ok(chunks)
91}
92
93fn chunk_fixed_size(
95 text: &str,
96 config: &ChunkingConfig,
97 _document_id: Uuid,
98) -> Result<Vec<Chunk>, ChunkingError> {
99 let mut chunks = Vec::new();
100 let chunk_size_chars = estimate_chars_from_tokens(config.chunk_size);
101 let overlap_chars = estimate_chars_from_tokens(config.chunk_overlap);
102
103 let mut start = 0;
104 let mut index = 0;
105
106 while start < text.len() {
107 let end = (start + chunk_size_chars).min(text.len());
108 let chunk_text = &text[start..end];
109
110 if chunk_text.trim().len() < config.min_chunk_size {
111 break; }
113
114 let token_count = super::estimate_tokens(chunk_text);
115
116 chunks.push(Chunk {
117 id: Uuid::new_v4(),
118 text: chunk_text.to_string(),
119 index,
120 start_char: start,
121 end_char: end,
122 token_count: Some(token_count),
123 section: None,
124 page: None,
125 embedding_ids: EmbeddingIds::default(),
126 });
127
128 let next_start = end.saturating_sub(overlap_chars);
130 if next_start > start {
131 start = next_start;
132 } else {
133 start += 1;
135 }
136 index += 1;
137
138 if start >= text.len() {
140 break;
141 }
142 }
143
144 Ok(chunks)
145}
146
147fn chunk_semantic(
149 text: &str,
150 config: &ChunkingConfig,
151 document_id: Uuid,
152 doc_type: &DocumentType,
153) -> Result<Vec<Chunk>, ChunkingError> {
154 let paragraphs = super::split_paragraphs(text);
156
157 if paragraphs.is_empty() {
158 return chunk_fixed_size(text, config, document_id);
159 }
160
161 let mut chunks = Vec::new();
162 let mut current_chunk = String::new();
163 let mut current_start = 0;
164 let mut chunk_index = 0;
165 let chunk_size_chars = estimate_chars_from_tokens(config.chunk_size);
166 let overlap_chars = estimate_chars_from_tokens(config.chunk_overlap);
167
168 for paragraph in paragraphs.iter() {
169 let para_text = paragraph.trim();
170 if para_text.is_empty() {
171 continue;
172 }
173
174 if !current_chunk.is_empty()
176 && (current_chunk.len() + para_text.len() + 1) > chunk_size_chars
177 {
178 let end_pos = current_start + current_chunk.len();
180 let token_count = super::estimate_tokens(¤t_chunk);
181
182 chunks.push(Chunk {
183 id: Uuid::new_v4(),
184 text: current_chunk.clone(),
185 index: chunk_index,
186 start_char: current_start,
187 end_char: end_pos,
188 token_count: Some(token_count),
189 section: extract_section_header(paragraph, doc_type),
190 page: None,
191 embedding_ids: EmbeddingIds::default(),
192 });
193
194 let overlap_text = extract_overlap(¤t_chunk, overlap_chars);
196 current_chunk = format!("{}{}", overlap_text, para_text);
197 current_start = end_pos.saturating_sub(overlap_chars);
198 chunk_index += 1;
199 } else {
200 if current_chunk.is_empty() {
202 current_start = text.find(para_text).unwrap_or(current_start);
203 } else {
204 current_chunk.push_str("\n\n");
205 }
206 current_chunk.push_str(para_text);
207 }
208 }
209
210 if !current_chunk.trim().is_empty() {
212 let end_pos = current_start + current_chunk.len();
213 let token_count = super::estimate_tokens(¤t_chunk);
214
215 chunks.push(Chunk {
216 id: Uuid::new_v4(),
217 text: current_chunk,
218 index: chunk_index,
219 start_char: current_start,
220 end_char: end_pos,
221 token_count: Some(token_count),
222 section: None,
223 page: None,
224 embedding_ids: EmbeddingIds::default(),
225 });
226 }
227
228 Ok(chunks)
229}
230
231fn chunk_recursive(
233 text: &str,
234 config: &ChunkingConfig,
235 document_id: Uuid,
236 doc_type: &DocumentType,
237) -> Result<Vec<Chunk>, ChunkingError> {
238 let _chunk_size_chars = estimate_chars_from_tokens(config.chunk_size);
239
240 let delimiters = if matches!(doc_type, DocumentType::Code) {
242 vec!["\n\n\n", "\n\n", "\n", ". ", " "]
243 } else {
244 vec!["\n\n", "\n", ". ", " "]
246 };
247
248 chunk_recursive_internal(text, config, document_id, &delimiters, 0)
249}
250
251fn chunk_recursive_internal(
252 text: &str,
253 config: &ChunkingConfig,
254 document_id: Uuid,
255 delimiters: &[&str],
256 delimiter_idx: usize,
257) -> Result<Vec<Chunk>, ChunkingError> {
258 if delimiter_idx >= delimiters.len() {
259 return chunk_fixed_size(text, config, document_id);
261 }
262
263 let delimiter = delimiters[delimiter_idx];
264 let chunk_size_chars = estimate_chars_from_tokens(config.chunk_size);
265 let overlap_chars = estimate_chars_from_tokens(config.chunk_overlap);
266
267 let parts: Vec<&str> = text.split(delimiter).collect();
269
270 if parts.len() <= 1 {
271 return chunk_recursive_internal(text, config, document_id, delimiters, delimiter_idx + 1);
273 }
274
275 let mut chunks = Vec::new();
276 let mut current_chunk = String::new();
277 let mut current_start = 0;
278 let mut chunk_index = 0;
279
280 for part in parts {
281 let part_trimmed = part.trim();
282 if part_trimmed.is_empty() {
283 continue;
284 }
285
286 let part_with_delim = if current_chunk.is_empty() {
287 part_trimmed.to_string()
288 } else {
289 format!("{}{}", delimiter, part_trimmed)
290 };
291
292 if (current_chunk.len() + part_with_delim.len()) > chunk_size_chars
293 && !current_chunk.is_empty()
294 {
295 let end_pos = current_start + current_chunk.len();
297 let token_count = super::estimate_tokens(¤t_chunk);
298
299 chunks.push(Chunk {
300 id: Uuid::new_v4(),
301 text: current_chunk.clone(),
302 index: chunk_index,
303 start_char: current_start,
304 end_char: end_pos,
305 token_count: Some(token_count),
306 section: None,
307 page: None,
308 embedding_ids: EmbeddingIds::default(),
309 });
310
311 let overlap_text = extract_overlap(¤t_chunk, overlap_chars);
313 current_chunk = format!("{}{}", overlap_text, part_with_delim);
314 current_start = end_pos.saturating_sub(overlap_chars);
315 chunk_index += 1;
316 } else {
317 if current_chunk.is_empty() {
318 current_start = text.find(part_trimmed).unwrap_or(current_start);
319 }
320 current_chunk.push_str(&part_with_delim);
321 }
322 }
323
324 if !current_chunk.trim().is_empty() {
326 let end_pos = current_start + current_chunk.len();
327 let token_count = super::estimate_tokens(¤t_chunk);
328
329 chunks.push(Chunk {
330 id: Uuid::new_v4(),
331 text: current_chunk,
332 index: chunk_index,
333 start_char: current_start,
334 end_char: end_pos,
335 token_count: Some(token_count),
336 section: None,
337 page: None,
338 embedding_ids: EmbeddingIds::default(),
339 });
340 }
341
342 let mut final_chunks = Vec::new();
344 for chunk in chunks {
345 if chunk.text.len() > chunk_size_chars * 2 {
346 let sub_chunks = chunk_recursive_internal(
348 &chunk.text,
349 config,
350 document_id,
351 delimiters,
352 delimiter_idx + 1,
353 )?;
354 final_chunks.extend(sub_chunks);
355 } else {
356 final_chunks.push(chunk);
357 }
358 }
359
360 Ok(final_chunks)
361}
362
363fn chunk_document_aware(
365 text: &str,
366 config: &ChunkingConfig,
367 document_id: Uuid,
368 doc_type: &DocumentType,
369) -> Result<Vec<Chunk>, ChunkingError> {
370 match doc_type {
371 DocumentType::Code => {
372 chunk_code_aware(text, config, document_id)
374 }
375 DocumentType::Documentation | DocumentType::Paper => {
376 if text.contains('#') {
378 chunk_markdown_aware(text, config, document_id)
379 } else {
380 chunk_semantic(text, config, document_id, doc_type)
381 }
382 }
383 _ => {
384 chunk_recursive(text, config, document_id, doc_type)
386 }
387 }
388}
389
390fn chunk_code_aware(
392 text: &str,
393 config: &ChunkingConfig,
394 document_id: Uuid,
395) -> Result<Vec<Chunk>, ChunkingError> {
396 let parts: Vec<&str> = text.split("\n\n\n").collect();
399
400 if parts.len() <= 1 {
401 return chunk_recursive(text, config, document_id, &DocumentType::Code);
403 }
404
405 let mut chunks = Vec::new();
406 let mut current_pos = 0;
407
408 for (idx, part) in parts.iter().enumerate() {
409 let part_trimmed = part.trim();
410 if part_trimmed.is_empty() {
411 continue;
412 }
413
414 let start_pos = text[current_pos..]
415 .find(part_trimmed)
416 .map(|p| current_pos + p)
417 .unwrap_or(current_pos);
418 let end_pos = start_pos + part_trimmed.len();
419 let token_count = super::estimate_tokens(part_trimmed);
420
421 chunks.push(Chunk {
422 id: Uuid::new_v4(),
423 text: part_trimmed.to_string(),
424 index: idx,
425 start_char: start_pos,
426 end_char: end_pos,
427 token_count: Some(token_count),
428 section: extract_function_name(part_trimmed),
429 page: None,
430 embedding_ids: EmbeddingIds::default(),
431 });
432
433 current_pos = end_pos;
434 }
435
436 Ok(chunks)
437}
438
439fn chunk_markdown_aware(
441 text: &str,
442 config: &ChunkingConfig,
443 document_id: Uuid,
444) -> Result<Vec<Chunk>, ChunkingError> {
445 let header_pattern = regex::Regex::new(r"(?m)^#{1,6}\s+.+$").expect("Invalid regex pattern");
449 let mut chunks = Vec::new();
450 let mut last_header_end = 0;
451 let mut chunk_index = 0;
452 let chunk_size_chars = estimate_chars_from_tokens(config.chunk_size);
453
454 for mat in header_pattern.find_iter(text) {
455 let header_start = mat.start();
456
457 if header_start > last_header_end {
459 let section_text = &text[last_header_end..header_start].trim();
460 if !section_text.is_empty() && section_text.len() >= config.min_chunk_size {
461 let token_count = super::estimate_tokens(section_text);
462 let header_text = extract_previous_header(&text[..last_header_end]);
463
464 chunks.push(Chunk {
465 id: Uuid::new_v4(),
466 text: section_text.to_string(),
467 index: chunk_index,
468 start_char: last_header_end,
469 end_char: header_start,
470 token_count: Some(token_count),
471 section: header_text,
472 page: None,
473 embedding_ids: EmbeddingIds::default(),
474 });
475 chunk_index += 1;
476 }
477 }
478
479 last_header_end = header_start;
480 }
481
482 if last_header_end < text.len() {
484 let section_text = &text[last_header_end..].trim();
485 if !section_text.is_empty() && section_text.len() >= config.min_chunk_size {
486 let token_count = super::estimate_tokens(section_text);
487 let header_text = extract_previous_header(&text[..last_header_end]);
488
489 chunks.push(Chunk {
490 id: Uuid::new_v4(),
491 text: section_text.to_string(),
492 index: chunk_index,
493 start_char: last_header_end,
494 end_char: text.len(),
495 token_count: Some(token_count),
496 section: header_text,
497 page: None,
498 embedding_ids: EmbeddingIds::default(),
499 });
500 }
501 }
502
503 if chunks.is_empty() || chunks.iter().any(|c| c.text.len() > chunk_size_chars * 2) {
505 return chunk_semantic(text, config, document_id, &DocumentType::Documentation);
506 }
507
508 Ok(chunks)
509}
510
511fn estimate_chars_from_tokens(tokens: usize) -> usize {
517 tokens * 4
518}
519
520fn extract_overlap(text: &str, overlap_chars: usize) -> String {
522 if text.len() <= overlap_chars {
523 return text.to_string();
524 }
525
526 let start_search = text.len().saturating_sub(overlap_chars);
528 let overlap_region = &text[start_search..];
529
530 if let Some(sentence_start) = overlap_region.find(|c: char| c.is_uppercase()) {
533 if start_search + sentence_start >= 2 {
535 let prev_chars =
536 &text[start_search + sentence_start - 2..start_search + sentence_start];
537 if prev_chars.ends_with(". ")
538 || prev_chars.ends_with("! ")
539 || prev_chars.ends_with("? ")
540 {
541 return text[start_search + sentence_start..].to_string();
542 }
543 }
544 }
545
546 text[start_search..].to_string()
548}
549
550fn extract_section_header(paragraph: &str, _doc_type: &DocumentType) -> Option<String> {
552 if let Some(header_match) = regex::Regex::new(r"^#{1,6}\s+(.+)$")
554 .ok()
555 .and_then(|re| re.captures(paragraph.lines().next().unwrap_or("")))
556 {
557 return header_match.get(1).map(|m| m.as_str().trim().to_string());
558 }
559
560 if let Some(first_line) = paragraph.lines().next() {
562 if first_line.len() > 5
563 && first_line
564 .chars()
565 .all(|c| c.is_uppercase() || c.is_whitespace() || c.is_ascii_punctuation())
566 {
567 return Some(first_line.trim().to_string());
568 }
569 }
570
571 None
572}
573
574fn extract_function_name(code: &str) -> Option<String> {
576 let patterns = vec![
578 r"fn\s+(\w+)",
579 r"function\s+(\w+)",
580 r"def\s+(\w+)",
581 r"pub\s+fn\s+(\w+)",
582 ];
583
584 for pattern in patterns {
585 if let Some(captures) = regex::Regex::new(pattern)
586 .ok()
587 .and_then(|re| re.captures(code))
588 {
589 return captures.get(1).map(|m| m.as_str().to_string());
590 }
591 }
592
593 None
594}
595
596fn extract_previous_header(text: &str) -> Option<String> {
598 regex::Regex::new(r"(?m)^#{1,6}\s+(.+)$")
599 .ok()
600 .and_then(|re| {
601 re.captures_iter(text)
602 .last()
603 .and_then(|cap| cap.get(1))
604 .map(|m| m.as_str().trim().to_string())
605 })
606}
607
608#[cfg(test)]
609mod tests {
610 use super::*;
611 use crate::{DocumentType, Source, SourceType};
612 use chrono::Utc;
613
614 fn create_test_document(text: &str, doc_type: DocumentType) -> Document {
615 let source = Source {
616 source_type: SourceType::Local,
617 url: None,
618 path: Some("/test/doc.txt".to_string()),
619 arxiv_id: None,
620 github_repo: None,
621 retrieved_at: Utc::now(),
622 version: None,
623 };
624
625 Document::new(doc_type, source).with_content(text.to_string())
626 }
627
628 #[test]
629 fn test_fixed_size_chunking() {
630 let text = "This is a test document. ".repeat(100); let doc = create_test_document(&text, DocumentType::Note);
632 let config = ChunkingConfig {
633 chunk_size: 512,
634 chunk_overlap: 50,
635 min_chunk_size: 100,
636 strategy: ChunkingStrategy::FixedSize,
637 respect_sentences: true,
638 };
639
640 let chunks = chunk_document(&doc, &config).unwrap();
641 assert!(!chunks.is_empty());
642 assert!(chunks.len() > 1); for chunk in &chunks {
646 assert!(!chunk.text.is_empty());
647 assert!(chunk.start_char < chunk.end_char);
648 assert!(chunk.token_count.is_some());
649 }
650 }
651
652 #[test]
653 fn test_semantic_chunking() {
654 let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
655 let doc = create_test_document(text, DocumentType::Documentation);
656 let config = ChunkingConfig {
657 chunk_size: 100, chunk_overlap: 10,
659 min_chunk_size: 10,
660 strategy: ChunkingStrategy::Semantic,
661 respect_sentences: true,
662 };
663
664 let chunks = chunk_document(&doc, &config).unwrap();
665 assert!(!chunks.is_empty());
666 }
667
668 #[test]
669 fn test_recursive_chunking() {
670 let text = "Sentence one. Sentence two. Sentence three. ".repeat(20);
671 let doc = create_test_document(&text, DocumentType::Note);
672 let config = ChunkingConfig {
673 chunk_size: 200,
674 chunk_overlap: 20,
675 min_chunk_size: 50,
676 strategy: ChunkingStrategy::Recursive,
677 respect_sentences: true,
678 };
679
680 let chunks = chunk_document(&doc, &config).unwrap();
681 assert!(!chunks.is_empty());
682 }
683
684 #[test]
685 fn test_markdown_aware_chunking() {
686 let text =
687 "# Header 1\n\nContent under header 1.\n\n## Header 2\n\nContent under header 2.";
688 let doc = create_test_document(text, DocumentType::Documentation);
689 let config = ChunkingConfig {
690 chunk_size: 200,
691 chunk_overlap: 10,
692 min_chunk_size: 10,
693 strategy: ChunkingStrategy::DocumentAware,
694 respect_sentences: true,
695 };
696
697 let chunks = chunk_document(&doc, &config).unwrap();
698 assert!(!chunks.is_empty());
699
700 assert!(chunks.iter().any(|c| c.section.is_some()));
702 }
703
704 #[test]
705 fn test_empty_text() {
706 let doc = create_test_document("", DocumentType::Note);
707 let config = ChunkingConfig::default();
708 let chunks = chunk_document(&doc, &config).unwrap();
709 assert!(chunks.is_empty());
710 }
711
712 #[test]
713 fn test_text_too_short() {
714 let doc = create_test_document("Short", DocumentType::Note);
715 let config = ChunkingConfig {
716 min_chunk_size: 100,
717 ..Default::default()
718 };
719 let result = chunk_document(&doc, &config);
720 assert!(result.is_err());
721 }
722
723 #[test]
724 fn test_overlap_extraction() {
725 let text = "This is a long sentence. This is another sentence. Final sentence.";
726 let overlap = extract_overlap(text, 20);
727 assert!(!overlap.is_empty());
728 assert!(overlap.len() <= 20);
729 }
730}