alith_core/
chunking.rs

1mod dfs_chunker;
2mod linear_chunker;
3mod overlap;
4
5use crate::splitting::{Separator, SeparatorGroup, TextSplit, TextSplitter};
6
7use alith_models::tokenizer::Tokenizer;
8use anyhow::Result;
9use dfs_chunker::DfsTextChunker;
10use linear_chunker::LinearChunker;
11use overlap::OverlapChunker;
12use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
13use std::{
14    collections::VecDeque,
15    sync::{
16        Arc,
17        atomic::{AtomicBool, Ordering},
18    },
19};
20
21pub const DEFAULT_CHUNK_SIZE: usize = 1024;
22
23/// An easy alternative to the [`TextChunker`] struct.  
24///
25/// * `text` - The natural language text to chunk.
26/// * `max_chunk_token_size` - The maxium token sized to be chunked to. Inclusive.
27/// * `overlap_percent` - The percentage of overlap between chunks. Default is None.
28pub fn chunk_text(
29    text: &str,
30    max_chunk_token_size: u32,
31    overlap_percent: Option<f32>,
32) -> Result<Option<Vec<String>>> {
33    let mut splitter = TextChunker::new()?.max_chunk_token_size(max_chunk_token_size);
34    if let Some(overlap_percent) = overlap_percent {
35        splitter = splitter.overlap_percent(overlap_percent);
36    }
37    Ok(splitter.run(text))
38}
39
40const ABSOLUTE_LENGTH_MAX_DEFAULT: u32 = 1024;
41const ABSOLUTE_LENGTH_MIN_DEFAULT_RATIO: f32 = 0.75;
42const TOKENIZER_TIKTOKEN_DEFAULT: &str = "gpt-4";
43
44/// Splits text by paragraphs, newlines, sentences, spaces, and finally graphemes, and builds chunks from the splits that are within the desired token ranges.
45pub struct TextChunker {
46    /// An atomic reference to the tokenizer. Defaults to the TikToken tokenizer.
47    tokenizer: Arc<Tokenizer>,
48    /// Inclusive hard limit.
49    absolute_length_max: u32,
50    /// This is used solely for the [`DfsTextChunker`] to determine the minimum chunk size. Default is 75% of the `absolute_length_max`.
51    absolute_length_min: Option<u32>,
52    /// The percentage of overlap between chunks. Default is None.
53    overlap_percent: Option<f32>,
54    /// Whether to use the DFS semantic splitter to attempt to build valid chunks. Default is true.
55    use_dfs_semantic_splitter: bool,
56}
57
58impl TextChunker {
59    /// Creates a new instance of the [`TextChunker`] struct using the default TikToken tokenizer.
60    pub fn new() -> Result<Self> {
61        Ok(Self {
62            tokenizer: Arc::new(Tokenizer::new_tiktoken(TOKENIZER_TIKTOKEN_DEFAULT)?),
63            absolute_length_max: ABSOLUTE_LENGTH_MAX_DEFAULT,
64            absolute_length_min: None,
65            overlap_percent: None,
66            use_dfs_semantic_splitter: true,
67        })
68    }
69    /// Creates a new instance of the [`TextChunker`] struct using a custom tokenizer. For example a Hugging Face tokenizer.
70    pub fn new_with_tokenizer(custom_tokenizer: &Arc<Tokenizer>) -> Self {
71        Self {
72            tokenizer: Arc::clone(custom_tokenizer),
73            absolute_length_max: ABSOLUTE_LENGTH_MAX_DEFAULT,
74            absolute_length_min: None,
75            overlap_percent: None,
76            use_dfs_semantic_splitter: true,
77        }
78    }
79
80    /// Sets the maximum token size for the chunks. Default is 1024.
81    ///
82    /// * `max_chunk_token_size` - The maxium token sized to be chunked to. Inclusive.
83    pub fn max_chunk_token_size(mut self, max_chunk_token_size: u32) -> Self {
84        self.absolute_length_max = max_chunk_token_size;
85        self
86    }
87
88    /// Sets the minimum token size for the chunks. Default is 75% of the `absolute_length_max`. Used solely for the [`DfsTextChunker`] to determine the minimum chunk size.
89    ///
90    /// * `min_chunk_token_size` - The minimum token sized to be chunked to..
91    pub fn min_chunk_token_size(mut self, min_chunk_token_size: u32) -> Self {
92        self.absolute_length_min = Some(min_chunk_token_size);
93        self
94    }
95
96    /// The [`DfsTextChunker`] is faster is completely respective of semantic separators. However, it produces less balanced chunk sizes and will fail if the text cannot be split.
97    /// By default the [`TextChunker`] attempts to chunk with the [`DfsTextChunker`] first, and if that fails, it will use the [`LinearChunker`].
98    ///
99    /// * `use_dfs_semantic_splitter` - Whether to use the DFS semantic splitter to attempt to build valid chunks. Default is true.
100    pub fn use_dfs_semantic_splitter(mut self, use_dfs_semantic_splitter: bool) -> Self {
101        self.use_dfs_semantic_splitter = use_dfs_semantic_splitter;
102        self
103    }
104
105    /// Sets the percentage of overlap between chunks. Default is None.
106    /// The full percentage is used foward for the first chunk, and backwards for the last chunk.
107    /// Middle chunks evenly split the percentage between forward and backwards.
108    ///
109    /// * `overlap_percent` - The percentage of overlap between chunks. Minimum is 0.01, and maximum is 0.5. Default is None.
110    pub fn overlap_percent(mut self, overlap_percent: f32) -> Self {
111        self.overlap_percent = if !(0.01..=0.5).contains(&overlap_percent) {
112            Some(0.10)
113        } else {
114            Some(overlap_percent)
115        };
116        self
117    }
118
119    /// Runs the [`TextChunker`] on the incoming text and returns the chunks as a vector of strings.
120    ///
121    /// * `incoming_text` - The natural language text to chunk.
122    pub fn run(&self, incoming_text: &str) -> Option<Vec<String>> {
123        Some(self.text_chunker(incoming_text)?.chunks_to_text())
124    }
125
126    /// Runs the [`TextChunker`] on the incoming text and returns the chunks as a [`ChunkerResult`].
127    /// The [`ChunkerResult`] contains the incoming text, the initial separator used, the chunks, the tokenizer, and the chunking duration. Useful for testing, benching, and diagnostics.
128    ///
129    /// * `incoming_text` - The natural language text to chunk.
130    pub fn run_return_result(&self, incoming_text: &str) -> Option<ChunkerResult> {
131        self.text_chunker(incoming_text)
132    }
133
134    /// Backend runner for [`TextChunker`].
135    /// Attempts to chunk the incoming text on all [`Separator`] first using the [`DfsTextChunker`] and then [`LinearChunker`].
136    /// Returns whichever [`Separator`] chunking attempt was successful first, and if none are successful, returns None.
137    /// If the incoming text is less than the `absolute_length_max`, it will return a single chunk.
138    fn text_chunker(&self, incoming_text: &str) -> Option<ChunkerResult> {
139        let chunking_start_time = std::time::Instant::now();
140        // A flag to signal if chunks have been found, and for all other threads to stop searching.
141        let chunks_found: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
142
143        // Parallize the search for the first successful chunking attempt.
144        Separator::get_all().par_iter().find_map_any(|separator| {
145            if chunks_found.load(Ordering::Relaxed) {
146                return None;
147            }
148            let config = Arc::new(ChunkerConfig::new(
149                &chunks_found,
150                separator.clone(),
151                incoming_text,
152                self.absolute_length_max,
153                self.absolute_length_min,
154                self.overlap_percent,
155                self.tokenizer(),
156            )?);
157            if chunks_found.load(Ordering::Relaxed) {
158                return None;
159            }
160            // If the text is less than the absolute_length_max, `initial_separator` will be set to Separator::None, and we return a single chunk.
161            if config.initial_separator == Separator::None {
162                chunks_found.store(true, Ordering::Relaxed);
163                return Some(ChunkerResult::new(
164                    incoming_text,
165                    &config,
166                    chunking_start_time,
167                    vec![Chunk::dummy_chunk(&config, incoming_text)],
168                ));
169            };
170
171            if config.initial_separator.group() == SeparatorGroup::Semantic
172                && self.use_dfs_semantic_splitter
173            {
174                let chunks: Option<Vec<Chunk>> = DfsTextChunker::run(&config);
175                if let Some(chunks) = chunks {
176                    let chunks = OverlapChunker::run(&config, chunks);
177                    match chunks {
178                        Ok(chunk) => {
179                            chunks_found.store(true, Ordering::Relaxed);
180                            return Some(ChunkerResult::new(
181                                incoming_text,
182                                &config,
183                                chunking_start_time,
184                                chunk,
185                            ));
186                        }
187                        Err(e) => {
188                            eprintln!("Error: {:#?}", e);
189                        }
190                    }
191                }
192            }
193            let chunks = LinearChunker::run(&config)?;
194            let chunks = OverlapChunker::run(&config, chunks);
195            match chunks {
196                Ok(chunks) => {
197                    chunks_found.store(true, Ordering::Relaxed);
198                    Some(ChunkerResult::new(
199                        incoming_text,
200                        &config,
201                        chunking_start_time,
202                        chunks,
203                    ))
204                }
205                Err(e) => {
206                    eprintln!("Error: {:#?}", e);
207                    None
208                }
209            }
210        })
211    }
212
213    #[inline]
214    fn tokenizer(&self) -> Arc<Tokenizer> {
215        Arc::clone(&self.tokenizer)
216    }
217}
218
219/// Configuration used by the [`TextChunker`], [`DfsTextChunker`], [`LinearChunker`], and [`OverlapChunker`] to build chunks.
220/// Instantiated by the [`TextChunker`] on each [`Separator`] and passed to the chunkers.
221pub struct ChunkerConfig {
222    chunks_found: Arc<AtomicBool>,
223    absolute_length_max: u32,
224    absolute_length_min: u32,
225    length_max: f32,
226    overlap_percent: Option<f32>,
227    tokenizer: Arc<Tokenizer>,
228    base_text: Arc<str>,
229    initial_separator: Separator,
230    initial_splits: VecDeque<TextSplit>,
231}
232
233impl ChunkerConfig {
234    fn new(
235        chunks_found: &Arc<AtomicBool>,
236        separator: Separator,
237        incoming_text: &str,
238        absolute_length_max: u32,
239        absolute_length_min: Option<u32>,
240        overlap_percent: Option<f32>,
241        tokenizer: Arc<Tokenizer>,
242    ) -> Option<Self> {
243        let length_max = if let Some(overlap_percent) = overlap_percent {
244            (absolute_length_max as f32 - (absolute_length_max as f32 * overlap_percent)).floor()
245        } else {
246            absolute_length_max as f32
247        };
248        let absolute_length_min = if let Some(absolute_length_min) = absolute_length_min {
249            absolute_length_min
250        } else {
251            (absolute_length_max as f32 * ABSOLUTE_LENGTH_MIN_DEFAULT_RATIO) as u32
252        };
253        if absolute_length_max <= absolute_length_min {
254            panic!(
255                "\nA combination absolute_length_max: {:#?} and overlap_percent: {:#?} is less than or equal to absolute_length_min: {:#?}.",
256                absolute_length_max, overlap_percent, absolute_length_min
257            );
258        }
259
260        let mut config = Self {
261            chunks_found: Arc::clone(chunks_found),
262            absolute_length_max,
263            absolute_length_min,
264            length_max,
265            overlap_percent,
266            tokenizer,
267            base_text: Arc::from(separator.clean_text(incoming_text)),
268            initial_separator: separator.clone(),
269            initial_splits: VecDeque::new(),
270        };
271
272        let cleaned_text_token_count = config.tokenizer.count_tokens(&config.base_text);
273        if cleaned_text_token_count <= absolute_length_max {
274            config.initial_separator = Separator::None;
275            return Some(config);
276        }
277        let splits = if let Some(mut splits) = TextSplitter::new()
278            .recursive(false)
279            .clean_text(false)
280            .on_separator(&separator)
281            .split_text(&config.base_text)
282        {
283            splits.iter_mut().for_each(|split| {
284                config.set_split_token_count(split);
285            });
286            splits
287        } else {
288            return None;
289        };
290        let splits_token_count = config.estimate_splits_token_count(&splits);
291        let chunk_count = (splits_token_count / config.length_max).ceil() as usize;
292        if splits.len() < chunk_count {
293            eprintln!(
294                "\nChunking is impossible for separator: {:#?}. Splits count: {:#?} is less than the minimum chunk_count: {:#?}.",
295                separator,
296                splits.len(),
297                chunk_count,
298            );
299            return None;
300        };
301
302        config.initial_splits = splits;
303        Some(config)
304    }
305
306    /// Splits an existing [`TextSplit`] into multiple [`TextSplit`]s on the next [`Separator`].
307    /// If no splits are found, at attempts split on the following [`Separator`].
308    /// If it reaches the final [`Separator`] without successfully splitting, it returns None.
309    fn split_split(&self, split: TextSplit) -> Option<VecDeque<TextSplit>> {
310        let mut new_splits: VecDeque<TextSplit> = match split.split() {
311            Some(splits) => splits,
312            None => {
313                return None;
314            }
315        };
316        new_splits.iter_mut().for_each(|split| {
317            self.set_split_token_count(split);
318        });
319        Some(new_splits)
320    }
321
322    fn set_split_token_count(&self, split: &mut TextSplit) {
323        if split.token_count.is_none() {
324            let token_count = self.tokenizer.count_tokens(split.text());
325            split.token_count = Some(token_count);
326        }
327    }
328
329    /// Estimates the token count of the splits.
330    /// This is used for estimating the remaining token count, and is also used to estimate the token count of chunks.
331    /// It is somewhat accurate.
332    fn estimate_splits_token_count(&self, splits: &VecDeque<TextSplit>) -> f32 {
333        let mut last_separator = Separator::None;
334        let mut total_tokens = 0.0;
335        for split in splits {
336            let split_tokens = match split.split_separator {
337                Separator::GraphemesUnicode => match last_separator {
338                    Separator::None | Separator::GraphemesUnicode => 0.55,
339                    _ => 1.0,
340                },
341                _ => split.token_count.unwrap() as f32,
342            };
343            if last_separator != Separator::None {
344                let white_space_ratio = match split.split_separator {
345                    Separator::None => {
346                        unreachable!()
347                    }
348                    Separator::TwoPlusEoL => 0.999,
349                    Separator::SingleEol => 0.999,
350                    Separator::SentencesRuleBased => 0.998,
351                    Separator::SentencesUnicode => 0.998,
352                    Separator::WordsUnicode => 0.89,
353                    Separator::GraphemesUnicode => 1.0,
354                };
355                total_tokens += split_tokens * white_space_ratio;
356            } else {
357                total_tokens += split_tokens;
358            }
359            last_separator = split.split_separator.clone();
360        }
361        total_tokens
362    }
363}
364
365#[derive(Clone)]
366pub struct Chunk {
367    text: Option<String>,
368    used_splits: VecDeque<TextSplit>,
369    token_count: Option<usize>,
370    estimated_token_count: f32,
371    config: Arc<ChunkerConfig>,
372}
373
374impl Chunk {
375    fn new(config: &Arc<ChunkerConfig>) -> Self {
376        Chunk {
377            text: None,
378            used_splits: VecDeque::new(),
379            token_count: Some(0),
380            estimated_token_count: 0.0,
381            config: Arc::clone(config),
382        }
383    }
384
385    fn dummy_chunk(config: &Arc<ChunkerConfig>, text: &str) -> Self {
386        Chunk {
387            text: Some(text.to_string()),
388            used_splits: VecDeque::new(),
389            token_count: Some(0),
390            estimated_token_count: 0.0,
391            config: Arc::clone(config),
392        }
393    }
394
395    fn add_split(&mut self, split: TextSplit, backwards: bool) {
396        if backwards {
397            self.used_splits.push_front(split);
398        } else {
399            self.used_splits.push_back(split);
400        }
401        self.estimated_token_count = self.config.estimate_splits_token_count(&self.used_splits);
402        self.token_count = None;
403        self.text = None;
404    }
405
406    fn remove_split(&mut self, backwards: bool) -> TextSplit {
407        let split = if backwards {
408            self.used_splits.pop_front().unwrap()
409        } else {
410            self.used_splits.pop_back().unwrap()
411        };
412        self.estimated_token_count = self.config.estimate_splits_token_count(&self.used_splits);
413        self.token_count = None;
414        self.text = None;
415        split
416    }
417
418    fn token_count(&mut self, estimated: bool) -> f32 {
419        if let Some(token_count) = self.token_count {
420            token_count as f32
421        } else if estimated {
422            self.estimated_token_count
423        } else {
424            let text = &self.text();
425            let token_count = self.config.tokenizer.count_tokens(text) as usize;
426            self.token_count = Some(token_count);
427            self.estimated_token_count = token_count as f32;
428            token_count as f32
429        }
430    }
431
432    fn text(&mut self) -> String {
433        if let Some(text) = &self.text {
434            text.to_owned()
435        } else {
436            let text = TextSplitter::splits_to_text(&self.used_splits, false);
437            self.text = Some(text.clone());
438            text
439        }
440    }
441}
442
443pub struct ChunkerResult {
444    incoming_text: Arc<str>,
445    initial_separator: Separator,
446    chunks: Vec<Chunk>,
447    tokenizer: Arc<Tokenizer>,
448    chunking_duration: std::time::Duration,
449}
450
451impl ChunkerResult {
452    fn new(
453        incoming_text: &str,
454        config: &Arc<ChunkerConfig>,
455        chunking_start_time: std::time::Instant,
456        mut chunks: Vec<Chunk>,
457    ) -> ChunkerResult {
458        chunks.iter_mut().for_each(|chunk| {
459            chunk.text();
460        });
461        ChunkerResult {
462            incoming_text: Arc::from(incoming_text),
463            initial_separator: config.initial_separator.clone(),
464            chunks,
465            tokenizer: Arc::clone(&config.tokenizer),
466            chunking_duration: chunking_start_time.elapsed(),
467        }
468    }
469
470    pub fn chunks_to_text(&mut self) -> Vec<String> {
471        self.chunks.iter_mut().map(|chunk| chunk.text()).collect()
472    }
473
474    pub fn token_counts(&mut self) -> Vec<u32> {
475        let mut token_counts: Vec<u32> = Vec::with_capacity(self.chunks.len());
476        for chunk in &self.chunks {
477            let chunk_text = if let Some(text) = &chunk.text {
478                text.to_owned()
479            } else {
480                TextSplitter::splits_to_text(&chunk.used_splits, false)
481            };
482            token_counts.push(self.tokenizer.count_tokens(&chunk_text));
483        }
484        token_counts
485    }
486}
487
488impl std::fmt::Debug for ChunkerResult {
489    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490        let mut chunk_token_sizes = Vec::with_capacity(self.chunks.len());
491        let mut largest_token_size = 0;
492        let mut smallest_token_size = u32::MAX;
493        let mut all_chunks_token_count = 0;
494        let mut chunk_char_sizes = Vec::with_capacity(self.chunks.len());
495        let mut largest_char_size = 0;
496        let mut smallest_char_size = u32::MAX;
497        let mut all_chunks_char_count = 0;
498
499        for chunk in &self.chunks {
500            let chunk_text = if let Some(text) = &chunk.text {
501                text.to_owned()
502            } else {
503                panic!("Chunk text not found.")
504            };
505            let token_count = self.tokenizer.count_tokens(&chunk_text);
506            let char_count = u32::try_from(chunk_text.chars().count()).unwrap();
507            chunk_token_sizes.push(token_count);
508            chunk_char_sizes.push(char_count);
509            all_chunks_token_count += token_count;
510            all_chunks_char_count += char_count;
511            if token_count > largest_token_size {
512                largest_token_size = token_count;
513            }
514            if char_count > largest_char_size {
515                largest_char_size = char_count;
516            }
517            if token_count < smallest_token_size {
518                smallest_token_size = token_count;
519            }
520            if char_count < smallest_char_size {
521                smallest_char_size = char_count;
522            }
523        }
524        f.debug_struct("\nChunkerTestResult")
525            .field("chunk_count", &self.chunks.len())
526            .field("chunk_token_sizes", &chunk_token_sizes)
527            .field(
528                "avg_token_size",
529                &(all_chunks_token_count / u32::try_from(self.chunks.len()).unwrap()),
530            )
531            .field("largest_token_size", &largest_token_size)
532            .field("smallest_token_size", &smallest_token_size)
533            .field(
534                "incoming_text_token_count",
535                &self.tokenizer.count_tokens(&self.incoming_text),
536            )
537            .field("all_chunks_token_count", &all_chunks_token_count)
538            .field("chunk_char_sizes", &chunk_char_sizes)
539            .field(
540                "avg_char_size",
541                &(all_chunks_char_count / u32::try_from(self.chunks.len()).unwrap()),
542            )
543            .field("largest_char_size", &largest_char_size)
544            .field("smallest_char_size", &smallest_char_size)
545            .field(
546                "incoming_text_char_count",
547                &self.incoming_text.chars().count(),
548            )
549            .field("all_chunks_char_count", &all_chunks_char_count)
550            .field("chunking_duration", &self.chunking_duration)
551            .field("initial_separator", &self.initial_separator)
552            .finish()
553    }
554}
555
556pub trait Chunker: Send + Sync {
557    fn chunk_size(&self) -> usize {
558        DEFAULT_CHUNK_SIZE
559    }
560
561    fn overlap_percent(&self) -> Option<f32> {
562        None
563    }
564
565    fn chunk(&self) -> Result<Vec<String>, ChunkError>;
566}
567
568/// An enumeration of possible errors that may occur during chunk operations.
569#[derive(Debug, thiserror::Error)]
570pub enum ChunkError {
571    /// A generic chunk error.
572    #[error("A normal chunk error occurred: {0}")]
573    Normal(String),
574}