Skip to main content

retrieval_kit/
chunking.rs

1#[derive(Clone, Copy, Debug, Eq, PartialEq)]
2/// Character-based chunking configuration used by `chunk_text`.
3pub struct ChunkingConfig {
4    /// Target chunk size in Unicode scalar values.
5    pub chunk_size: usize,
6    /// Approximate overlap size between adjacent chunks.
7    pub overlap_size: usize,
8}
9
10impl Default for ChunkingConfig {
11    fn default() -> Self {
12        Self {
13            chunk_size: 1500,
14            overlap_size: 200,
15        }
16    }
17}
18
19#[derive(Clone, Debug, Eq, PartialEq)]
20pub enum ChunkingError {
21    InvalidChunkSize,
22    InvalidOverlapSize {
23        chunk_size: usize,
24        overlap_size: usize,
25    },
26    EmbeddingTokenizer,
27}
28
29impl std::fmt::Display for ChunkingError {
30    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31        match self {
32            Self::InvalidChunkSize => write!(f, "chunk_size must be greater than zero"),
33            Self::InvalidOverlapSize {
34                chunk_size,
35                overlap_size,
36            } => write!(
37                f,
38                "overlap_size ({overlap_size}) must be smaller than chunk_size ({chunk_size})"
39            ),
40            Self::EmbeddingTokenizer => write!(f, "failed to tokenize document for chunking"),
41        }
42    }
43}
44
45impl std::error::Error for ChunkingError {}
46
47#[derive(Clone, Copy, Debug, Eq, Ord, PartialEq, PartialOrd)]
48enum BoundaryStrength {
49    Whitespace,
50    Weak,
51    Medium,
52    Strong,
53}
54
55const LOOKBACK_WINDOW: usize = 400;
56const LOOKAHEAD_WINDOW: usize = 200;
57const OVERLAP_WINDOW: usize = 120;
58
59/// Split text into trimmed, non-empty character chunks with approximate overlap.
60pub fn chunk_text(text: &str, config: ChunkingConfig) -> Result<Vec<String>, ChunkingError> {
61    validate_config(config)?;
62
63    if text.trim().is_empty() {
64        return Ok(Vec::new());
65    }
66
67    let chars: Vec<char> = text.chars().collect();
68    let total_chars = chars.len();
69    let mut boundaries = collect_boundaries(text, &chars);
70    boundaries.push(text.len());
71
72    let mut chunks = Vec::new();
73    let mut start_char = first_non_whitespace_char(&chars, 0);
74
75    while start_char < total_chars {
76        let end_char = choose_chunk_end(&chars, &boundaries, start_char, config.chunk_size);
77        let start_byte = char_to_byte_idx(text, start_char);
78        let end_byte = char_to_byte_idx(text, end_char);
79
80        let chunk = text[start_byte..end_byte].trim();
81        if !chunk.is_empty() {
82            chunks.push(chunk.to_string());
83        }
84
85        if end_char >= total_chars {
86            break;
87        }
88
89        let next_start = choose_overlap_start(
90            &chars,
91            &boundaries,
92            start_char,
93            end_char,
94            config.overlap_size,
95        );
96
97        if next_start <= start_char {
98            start_char = first_non_whitespace_char(&chars, end_char);
99        } else {
100            start_char = first_non_whitespace_char(&chars, next_start);
101        }
102    }
103
104    Ok(chunks)
105}
106
107fn validate_config(config: ChunkingConfig) -> Result<(), ChunkingError> {
108    if config.chunk_size == 0 {
109        return Err(ChunkingError::InvalidChunkSize);
110    }
111
112    if config.overlap_size >= config.chunk_size {
113        return Err(ChunkingError::InvalidOverlapSize {
114            chunk_size: config.chunk_size,
115            overlap_size: config.overlap_size,
116        });
117    }
118
119    Ok(())
120}
121
122fn collect_boundaries(text: &str, chars: &[char]) -> Vec<usize> {
123    let mut boundaries = Vec::new();
124
125    for boundary in 1..chars.len() {
126        if classify_boundary(chars, boundary).is_some() {
127            boundaries.push(char_to_byte_idx(text, boundary));
128        }
129    }
130
131    boundaries
132}
133
134fn choose_chunk_end(
135    chars: &[char],
136    boundaries: &[usize],
137    start_char: usize,
138    chunk_size: usize,
139) -> usize {
140    let total_chars = chars.len();
141    let ideal_end = (start_char + chunk_size).min(total_chars);
142
143    if ideal_end >= total_chars {
144        return total_chars;
145    }
146
147    let lookback_start = (ideal_end.saturating_sub(LOOKBACK_WINDOW)).max(start_char + 1);
148    let before = best_boundary_in_range(chars, boundaries, lookback_start, ideal_end, true);
149    if let Some(boundary) = before {
150        return boundary;
151    }
152
153    let lookahead_end = (ideal_end + LOOKAHEAD_WINDOW).min(total_chars);
154    let after = best_boundary_in_range(chars, boundaries, ideal_end + 1, lookahead_end, false);
155    if let Some(boundary) = after {
156        return boundary;
157    }
158
159    nearest_safe_boundary(chars, start_char + 1, total_chars, ideal_end).unwrap_or(ideal_end)
160}
161
162fn choose_overlap_start(
163    chars: &[char],
164    boundaries: &[usize],
165    start_char: usize,
166    end_char: usize,
167    overlap_size: usize,
168) -> usize {
169    if overlap_size == 0 || end_char <= start_char + 1 {
170        return end_char;
171    }
172
173    let desired_start = end_char.saturating_sub(overlap_size).max(start_char + 1);
174    let search_start = desired_start
175        .saturating_sub(OVERLAP_WINDOW)
176        .max(start_char + 1);
177    let search_end = (desired_start + OVERLAP_WINDOW).min(end_char.saturating_sub(1));
178
179    if search_start <= search_end
180        && let Some(boundary) =
181            nearest_semantic_boundary(chars, boundaries, search_start, search_end, desired_start)
182    {
183        return boundary;
184    }
185
186    end_char
187}
188
189fn best_boundary_in_range(
190    chars: &[char],
191    boundaries: &[usize],
192    start_char: usize,
193    end_char: usize,
194    prefer_later: bool,
195) -> Option<usize> {
196    let mut best: Option<(BoundaryStrength, usize, usize)> = None;
197
198    for boundary_byte in boundaries {
199        let boundary_char = byte_to_char_idx(chars, *boundary_byte);
200        if boundary_char < start_char || boundary_char > end_char {
201            continue;
202        }
203
204        let Some(strength) = classify_boundary(chars, boundary_char) else {
205            continue;
206        };
207
208        let distance = if prefer_later {
209            end_char.saturating_sub(boundary_char)
210        } else {
211            boundary_char.saturating_sub(start_char)
212        };
213
214        match best {
215            None => best = Some((strength, distance, boundary_char)),
216            Some((best_strength, best_distance, best_boundary)) => {
217                let should_replace = strength > best_strength
218                    || (strength == best_strength
219                        && (distance < best_distance
220                            || (distance == best_distance
221                                && ((prefer_later && boundary_char > best_boundary)
222                                    || (!prefer_later && boundary_char < best_boundary)))));
223                if should_replace {
224                    best = Some((strength, distance, boundary_char));
225                }
226            }
227        }
228    }
229
230    best.map(|(_, _, boundary)| boundary)
231}
232
233fn nearest_semantic_boundary(
234    chars: &[char],
235    boundaries: &[usize],
236    start_char: usize,
237    end_char: usize,
238    desired_char: usize,
239) -> Option<usize> {
240    let mut best: Option<(BoundaryStrength, usize, usize)> = None;
241
242    for boundary_byte in boundaries {
243        let boundary_char = byte_to_char_idx(chars, *boundary_byte);
244        if boundary_char < start_char || boundary_char > end_char {
245            continue;
246        }
247
248        let Some(strength) = classify_boundary(chars, boundary_char) else {
249            continue;
250        };
251
252        let distance = boundary_char.abs_diff(desired_char);
253        match best {
254            None => best = Some((strength, distance, boundary_char)),
255            Some((best_strength, best_distance, best_boundary)) => {
256                let should_replace = strength > best_strength
257                    || (strength == best_strength
258                        && (distance < best_distance
259                            || (distance == best_distance && boundary_char < best_boundary)));
260                if should_replace {
261                    best = Some((strength, distance, boundary_char));
262                }
263            }
264        }
265    }
266
267    best.map(|(_, _, boundary)| boundary)
268}
269
270fn nearest_safe_boundary(
271    chars: &[char],
272    start_char: usize,
273    end_char: usize,
274    desired_char: usize,
275) -> Option<usize> {
276    if start_char > end_char {
277        return None;
278    }
279
280    let mut best_whitespace: Option<(usize, usize)> = None;
281
282    for boundary in start_char..=end_char {
283        let Some(strength) = classify_boundary(chars, boundary) else {
284            continue;
285        };
286
287        if strength == BoundaryStrength::Whitespace {
288            let distance = boundary.abs_diff(desired_char);
289            match best_whitespace {
290                None => best_whitespace = Some((distance, boundary)),
291                Some((best_distance, best_boundary)) => {
292                    if distance < best_distance
293                        || (distance == best_distance && boundary < best_boundary)
294                    {
295                        best_whitespace = Some((distance, boundary));
296                    }
297                }
298            }
299        }
300    }
301
302    if let Some((_, boundary)) = best_whitespace {
303        return Some(boundary);
304    }
305
306    Some(desired_char.clamp(start_char, end_char))
307}
308
309fn classify_boundary(chars: &[char], boundary: usize) -> Option<BoundaryStrength> {
310    if boundary == 0 || boundary >= chars.len() {
311        return None;
312    }
313
314    let prev = chars[boundary - 1];
315    let next = chars[boundary];
316
317    if is_paragraph_break(chars, boundary) {
318        return Some(BoundaryStrength::Strong);
319    }
320
321    if prev == '\n' {
322        return Some(BoundaryStrength::Strong);
323    }
324
325    let semantic_prev = previous_semantic_char(chars, boundary - 1);
326    if matches!(semantic_prev, Some('.' | '!' | '?')) && next.is_whitespace() {
327        return Some(BoundaryStrength::Medium);
328    }
329
330    if matches!(semantic_prev, Some(';' | ':' | ',')) && next.is_whitespace() {
331        return Some(BoundaryStrength::Weak);
332    }
333
334    if prev.is_whitespace() {
335        return Some(BoundaryStrength::Whitespace);
336    }
337
338    None
339}
340
341fn is_paragraph_break(chars: &[char], boundary: usize) -> bool {
342    if chars[boundary - 1] != '\n' {
343        return false;
344    }
345
346    let mut idx = boundary;
347    while idx < chars.len() {
348        let ch = chars[idx];
349        if ch == '\n' {
350            return true;
351        }
352        if !ch.is_whitespace() {
353            return false;
354        }
355        idx += 1;
356    }
357
358    false
359}
360
361fn previous_semantic_char(chars: &[char], mut idx: usize) -> Option<char> {
362    loop {
363        let ch = chars[idx];
364        if !matches!(ch, '"' | '\'' | ')' | ']' | '}') {
365            return Some(ch);
366        }
367        if idx == 0 {
368            return None;
369        }
370        idx -= 1;
371    }
372}
373
374fn first_non_whitespace_char(chars: &[char], mut idx: usize) -> usize {
375    while idx < chars.len() && chars[idx].is_whitespace() {
376        idx += 1;
377    }
378    idx
379}
380
381fn char_to_byte_idx(text: &str, char_idx: usize) -> usize {
382    if char_idx == 0 {
383        return 0;
384    }
385
386    text.char_indices()
387        .nth(char_idx)
388        .map(|(idx, _)| idx)
389        .unwrap_or(text.len())
390}
391
392fn byte_to_char_idx(chars: &[char], byte_idx: usize) -> usize {
393    let mut total = 0;
394    for (idx, ch) in chars.iter().enumerate() {
395        if total == byte_idx {
396            return idx;
397        }
398        total += ch.len_utf8();
399    }
400    chars.len()
401}
402
403#[cfg(test)]
404mod tests {
405    use super::{ChunkingConfig, ChunkingError, chunk_text};
406
407    #[test]
408    fn uses_expected_defaults() {
409        let config = ChunkingConfig::default();
410
411        assert_eq!(config.chunk_size, 1500);
412        assert_eq!(config.overlap_size, 200);
413    }
414
415    #[test]
416    fn rejects_zero_chunk_size() {
417        let error = chunk_text(
418            "hello world",
419            ChunkingConfig {
420                chunk_size: 0,
421                overlap_size: 0,
422            },
423        )
424        .unwrap_err();
425
426        assert_eq!(error, ChunkingError::InvalidChunkSize);
427    }
428
429    #[test]
430    fn rejects_overlap_equal_to_chunk_size() {
431        let error = chunk_text(
432            "hello world",
433            ChunkingConfig {
434                chunk_size: 200,
435                overlap_size: 200,
436            },
437        )
438        .unwrap_err();
439
440        assert_eq!(
441            error,
442            ChunkingError::InvalidOverlapSize {
443                chunk_size: 200,
444                overlap_size: 200,
445            }
446        );
447    }
448
449    #[test]
450    fn returns_single_chunk_for_short_input() {
451        let chunks = chunk_text(
452            "Small inputs should stay together.",
453            ChunkingConfig {
454                chunk_size: 100,
455                overlap_size: 20,
456            },
457        )
458        .unwrap();
459
460        assert_eq!(chunks, vec!["Small inputs should stay together."]);
461    }
462
463    #[test]
464    fn prefers_paragraph_boundaries() {
465        let text = "First paragraph has a few sentences. It should stay grouped.\n\nSecond paragraph also has enough text to trigger chunking when the target is small.";
466        let chunks = chunk_text(
467            text,
468            ChunkingConfig {
469                chunk_size: 70,
470                overlap_size: 15,
471            },
472        )
473        .unwrap();
474
475        assert!(chunks.len() >= 2);
476        assert!(chunks[0].ends_with("grouped."));
477        assert!(
478            chunks[1].starts_with("Second paragraph") || chunks[1].contains("Second paragraph")
479        );
480    }
481
482    #[test]
483    fn prefers_sentence_boundaries() {
484        let text = "Alpha beta gamma. Delta epsilon zeta. Eta theta iota. Kappa lambda mu.";
485        let chunks = chunk_text(
486            text,
487            ChunkingConfig {
488                chunk_size: 30,
489                overlap_size: 8,
490            },
491        )
492        .unwrap();
493
494        assert!(chunks.len() >= 2);
495        assert!(chunks[0].ends_with('.'));
496        assert!(chunks[1].contains("Delta epsilon zeta.") || chunks[1].contains("Eta theta iota."));
497    }
498
499    #[test]
500    fn overlap_is_approximate_and_word_safe() {
501        let text = "This is the first sentence. This is the second sentence with more words. This is the third sentence with even more words.";
502        let chunks = chunk_text(
503            text,
504            ChunkingConfig {
505                chunk_size: 55,
506                overlap_size: 12,
507            },
508        )
509        .unwrap();
510
511        assert!(chunks.len() >= 2);
512        assert!(
513            chunks[0].split_whitespace().last().is_some()
514                && chunks[1].split_whitespace().next().is_some()
515        );
516        let trailing_words = normalized_words(chunks[0].as_str());
517        let leading_words = normalized_words(chunks[1].as_str());
518        assert!(
519            trailing_words
520                .iter()
521                .rev()
522                .take(4)
523                .any(|word| word.len() > 3 && leading_words.contains(word))
524        );
525    }
526
527    #[test]
528    fn falls_back_to_whitespace_for_punctuation_poor_text() {
529        let text = "alpha beta gamma delta epsilon zeta eta theta iota kappa lambda mu";
530        let chunks = chunk_text(
531            text,
532            ChunkingConfig {
533                chunk_size: 24,
534                overlap_size: 6,
535            },
536        )
537        .unwrap();
538
539        assert!(chunks.len() >= 2);
540        assert!(!chunks.iter().any(|chunk| chunk.contains("  ")));
541        assert!(chunks.iter().all(|chunk| !chunk.starts_with(' ')));
542        assert!(chunks.iter().all(|chunk| !chunk.ends_with(' ')));
543    }
544
545    #[test]
546    fn progresses_even_without_whitespace() {
547        let text = "supercalifragilisticexpialidociousandbeyond";
548        let chunks = chunk_text(
549            text,
550            ChunkingConfig {
551                chunk_size: 10,
552                overlap_size: 3,
553            },
554        )
555        .unwrap();
556
557        assert!(chunks.len() > 1);
558        assert!(chunks.iter().all(|chunk| !chunk.is_empty()));
559        assert_eq!(chunks.join(""), text);
560    }
561
562    #[test]
563    fn handles_unicode_without_invalid_boundaries() {
564        let text = "Здравей свят. Добре дошли в retrieval kit. Това е тест.";
565        let chunks = chunk_text(
566            text,
567            ChunkingConfig {
568                chunk_size: 24,
569                overlap_size: 6,
570            },
571        )
572        .unwrap();
573
574        assert!(chunks.len() >= 2);
575        assert!(chunks.iter().all(|chunk| !chunk.is_empty()));
576    }
577
578    #[test]
579    fn returns_empty_for_whitespace_only_input() {
580        let chunks = chunk_text("   \n\t  ", ChunkingConfig::default()).unwrap();
581
582        assert!(chunks.is_empty());
583    }
584
585    fn normalized_words(input: &str) -> Vec<String> {
586        input
587            .split_whitespace()
588            .map(|word| {
589                word.trim_matches(|ch: char| !ch.is_alphanumeric())
590                    .to_lowercase()
591            })
592            .filter(|word| !word.is_empty())
593            .collect()
594    }
595}