1mod dfs_chunker;
2mod linear_chunker;
3mod overlap;
4
5use crate::splitting::{Separator, SeparatorGroup, TextSplit, TextSplitter};
6
7use alith_models::tokenizer::Tokenizer;
8use anyhow::Result;
9use dfs_chunker::DfsTextChunker;
10use linear_chunker::LinearChunker;
11use overlap::OverlapChunker;
12use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
13use std::{
14 collections::VecDeque,
15 sync::{
16 Arc,
17 atomic::{AtomicBool, Ordering},
18 },
19};
20
21pub const DEFAULT_CHUNK_SIZE: usize = 1024;
22
23pub fn chunk_text(
29 text: &str,
30 max_chunk_token_size: u32,
31 overlap_percent: Option<f32>,
32) -> Result<Option<Vec<String>>> {
33 let mut splitter = TextChunker::new()?.max_chunk_token_size(max_chunk_token_size);
34 if let Some(overlap_percent) = overlap_percent {
35 splitter = splitter.overlap_percent(overlap_percent);
36 }
37 Ok(splitter.run(text))
38}
39
40const ABSOLUTE_LENGTH_MAX_DEFAULT: u32 = 1024;
41const ABSOLUTE_LENGTH_MIN_DEFAULT_RATIO: f32 = 0.75;
42const TOKENIZER_TIKTOKEN_DEFAULT: &str = "gpt-4";
43
44pub struct TextChunker {
46 tokenizer: Arc<Tokenizer>,
48 absolute_length_max: u32,
50 absolute_length_min: Option<u32>,
52 overlap_percent: Option<f32>,
54 use_dfs_semantic_splitter: bool,
56}
57
58impl TextChunker {
59 pub fn new() -> Result<Self> {
61 Ok(Self {
62 tokenizer: Arc::new(Tokenizer::new_tiktoken(TOKENIZER_TIKTOKEN_DEFAULT)?),
63 absolute_length_max: ABSOLUTE_LENGTH_MAX_DEFAULT,
64 absolute_length_min: None,
65 overlap_percent: None,
66 use_dfs_semantic_splitter: true,
67 })
68 }
69 pub fn new_with_tokenizer(custom_tokenizer: &Arc<Tokenizer>) -> Self {
71 Self {
72 tokenizer: Arc::clone(custom_tokenizer),
73 absolute_length_max: ABSOLUTE_LENGTH_MAX_DEFAULT,
74 absolute_length_min: None,
75 overlap_percent: None,
76 use_dfs_semantic_splitter: true,
77 }
78 }
79
80 pub fn max_chunk_token_size(mut self, max_chunk_token_size: u32) -> Self {
84 self.absolute_length_max = max_chunk_token_size;
85 self
86 }
87
88 pub fn min_chunk_token_size(mut self, min_chunk_token_size: u32) -> Self {
92 self.absolute_length_min = Some(min_chunk_token_size);
93 self
94 }
95
96 pub fn use_dfs_semantic_splitter(mut self, use_dfs_semantic_splitter: bool) -> Self {
101 self.use_dfs_semantic_splitter = use_dfs_semantic_splitter;
102 self
103 }
104
105 pub fn overlap_percent(mut self, overlap_percent: f32) -> Self {
111 self.overlap_percent = if !(0.01..=0.5).contains(&overlap_percent) {
112 Some(0.10)
113 } else {
114 Some(overlap_percent)
115 };
116 self
117 }
118
119 pub fn run(&self, incoming_text: &str) -> Option<Vec<String>> {
123 Some(self.text_chunker(incoming_text)?.chunks_to_text())
124 }
125
126 pub fn run_return_result(&self, incoming_text: &str) -> Option<ChunkerResult> {
131 self.text_chunker(incoming_text)
132 }
133
134 fn text_chunker(&self, incoming_text: &str) -> Option<ChunkerResult> {
139 let chunking_start_time = std::time::Instant::now();
140 let chunks_found: Arc<AtomicBool> = Arc::new(AtomicBool::new(false));
142
143 Separator::get_all().par_iter().find_map_any(|separator| {
145 if chunks_found.load(Ordering::Relaxed) {
146 return None;
147 }
148 let config = Arc::new(ChunkerConfig::new(
149 &chunks_found,
150 separator.clone(),
151 incoming_text,
152 self.absolute_length_max,
153 self.absolute_length_min,
154 self.overlap_percent,
155 self.tokenizer(),
156 )?);
157 if chunks_found.load(Ordering::Relaxed) {
158 return None;
159 }
160 if config.initial_separator == Separator::None {
162 chunks_found.store(true, Ordering::Relaxed);
163 return Some(ChunkerResult::new(
164 incoming_text,
165 &config,
166 chunking_start_time,
167 vec![Chunk::dummy_chunk(&config, incoming_text)],
168 ));
169 };
170
171 if config.initial_separator.group() == SeparatorGroup::Semantic
172 && self.use_dfs_semantic_splitter
173 {
174 let chunks: Option<Vec<Chunk>> = DfsTextChunker::run(&config);
175 if let Some(chunks) = chunks {
176 let chunks = OverlapChunker::run(&config, chunks);
177 match chunks {
178 Ok(chunk) => {
179 chunks_found.store(true, Ordering::Relaxed);
180 return Some(ChunkerResult::new(
181 incoming_text,
182 &config,
183 chunking_start_time,
184 chunk,
185 ));
186 }
187 Err(e) => {
188 eprintln!("Error: {:#?}", e);
189 }
190 }
191 }
192 }
193 let chunks = LinearChunker::run(&config)?;
194 let chunks = OverlapChunker::run(&config, chunks);
195 match chunks {
196 Ok(chunks) => {
197 chunks_found.store(true, Ordering::Relaxed);
198 Some(ChunkerResult::new(
199 incoming_text,
200 &config,
201 chunking_start_time,
202 chunks,
203 ))
204 }
205 Err(e) => {
206 eprintln!("Error: {:#?}", e);
207 None
208 }
209 }
210 })
211 }
212
213 #[inline]
214 fn tokenizer(&self) -> Arc<Tokenizer> {
215 Arc::clone(&self.tokenizer)
216 }
217}
218
219pub struct ChunkerConfig {
222 chunks_found: Arc<AtomicBool>,
223 absolute_length_max: u32,
224 absolute_length_min: u32,
225 length_max: f32,
226 overlap_percent: Option<f32>,
227 tokenizer: Arc<Tokenizer>,
228 base_text: Arc<str>,
229 initial_separator: Separator,
230 initial_splits: VecDeque<TextSplit>,
231}
232
233impl ChunkerConfig {
234 fn new(
235 chunks_found: &Arc<AtomicBool>,
236 separator: Separator,
237 incoming_text: &str,
238 absolute_length_max: u32,
239 absolute_length_min: Option<u32>,
240 overlap_percent: Option<f32>,
241 tokenizer: Arc<Tokenizer>,
242 ) -> Option<Self> {
243 let length_max = if let Some(overlap_percent) = overlap_percent {
244 (absolute_length_max as f32 - (absolute_length_max as f32 * overlap_percent)).floor()
245 } else {
246 absolute_length_max as f32
247 };
248 let absolute_length_min = if let Some(absolute_length_min) = absolute_length_min {
249 absolute_length_min
250 } else {
251 (absolute_length_max as f32 * ABSOLUTE_LENGTH_MIN_DEFAULT_RATIO) as u32
252 };
253 if absolute_length_max <= absolute_length_min {
254 panic!(
255 "\nA combination absolute_length_max: {:#?} and overlap_percent: {:#?} is less than or equal to absolute_length_min: {:#?}.",
256 absolute_length_max, overlap_percent, absolute_length_min
257 );
258 }
259
260 let mut config = Self {
261 chunks_found: Arc::clone(chunks_found),
262 absolute_length_max,
263 absolute_length_min,
264 length_max,
265 overlap_percent,
266 tokenizer,
267 base_text: Arc::from(separator.clean_text(incoming_text)),
268 initial_separator: separator.clone(),
269 initial_splits: VecDeque::new(),
270 };
271
272 let cleaned_text_token_count = config.tokenizer.count_tokens(&config.base_text);
273 if cleaned_text_token_count <= absolute_length_max {
274 config.initial_separator = Separator::None;
275 return Some(config);
276 }
277 let splits = if let Some(mut splits) = TextSplitter::new()
278 .recursive(false)
279 .clean_text(false)
280 .on_separator(&separator)
281 .split_text(&config.base_text)
282 {
283 splits.iter_mut().for_each(|split| {
284 config.set_split_token_count(split);
285 });
286 splits
287 } else {
288 return None;
289 };
290 let splits_token_count = config.estimate_splits_token_count(&splits);
291 let chunk_count = (splits_token_count / config.length_max).ceil() as usize;
292 if splits.len() < chunk_count {
293 eprintln!(
294 "\nChunking is impossible for separator: {:#?}. Splits count: {:#?} is less than the minimum chunk_count: {:#?}.",
295 separator,
296 splits.len(),
297 chunk_count,
298 );
299 return None;
300 };
301
302 config.initial_splits = splits;
303 Some(config)
304 }
305
306 fn split_split(&self, split: TextSplit) -> Option<VecDeque<TextSplit>> {
310 let mut new_splits: VecDeque<TextSplit> = match split.split() {
311 Some(splits) => splits,
312 None => {
313 return None;
314 }
315 };
316 new_splits.iter_mut().for_each(|split| {
317 self.set_split_token_count(split);
318 });
319 Some(new_splits)
320 }
321
322 fn set_split_token_count(&self, split: &mut TextSplit) {
323 if split.token_count.is_none() {
324 let token_count = self.tokenizer.count_tokens(split.text());
325 split.token_count = Some(token_count);
326 }
327 }
328
329 fn estimate_splits_token_count(&self, splits: &VecDeque<TextSplit>) -> f32 {
333 let mut last_separator = Separator::None;
334 let mut total_tokens = 0.0;
335 for split in splits {
336 let split_tokens = match split.split_separator {
337 Separator::GraphemesUnicode => match last_separator {
338 Separator::None | Separator::GraphemesUnicode => 0.55,
339 _ => 1.0,
340 },
341 _ => split.token_count.unwrap() as f32,
342 };
343 if last_separator != Separator::None {
344 let white_space_ratio = match split.split_separator {
345 Separator::None => {
346 unreachable!()
347 }
348 Separator::TwoPlusEoL => 0.999,
349 Separator::SingleEol => 0.999,
350 Separator::SentencesRuleBased => 0.998,
351 Separator::SentencesUnicode => 0.998,
352 Separator::WordsUnicode => 0.89,
353 Separator::GraphemesUnicode => 1.0,
354 };
355 total_tokens += split_tokens * white_space_ratio;
356 } else {
357 total_tokens += split_tokens;
358 }
359 last_separator = split.split_separator.clone();
360 }
361 total_tokens
362 }
363}
364
365#[derive(Clone)]
366pub struct Chunk {
367 text: Option<String>,
368 used_splits: VecDeque<TextSplit>,
369 token_count: Option<usize>,
370 estimated_token_count: f32,
371 config: Arc<ChunkerConfig>,
372}
373
374impl Chunk {
375 fn new(config: &Arc<ChunkerConfig>) -> Self {
376 Chunk {
377 text: None,
378 used_splits: VecDeque::new(),
379 token_count: Some(0),
380 estimated_token_count: 0.0,
381 config: Arc::clone(config),
382 }
383 }
384
385 fn dummy_chunk(config: &Arc<ChunkerConfig>, text: &str) -> Self {
386 Chunk {
387 text: Some(text.to_string()),
388 used_splits: VecDeque::new(),
389 token_count: Some(0),
390 estimated_token_count: 0.0,
391 config: Arc::clone(config),
392 }
393 }
394
395 fn add_split(&mut self, split: TextSplit, backwards: bool) {
396 if backwards {
397 self.used_splits.push_front(split);
398 } else {
399 self.used_splits.push_back(split);
400 }
401 self.estimated_token_count = self.config.estimate_splits_token_count(&self.used_splits);
402 self.token_count = None;
403 self.text = None;
404 }
405
406 fn remove_split(&mut self, backwards: bool) -> TextSplit {
407 let split = if backwards {
408 self.used_splits.pop_front().unwrap()
409 } else {
410 self.used_splits.pop_back().unwrap()
411 };
412 self.estimated_token_count = self.config.estimate_splits_token_count(&self.used_splits);
413 self.token_count = None;
414 self.text = None;
415 split
416 }
417
418 fn token_count(&mut self, estimated: bool) -> f32 {
419 if let Some(token_count) = self.token_count {
420 token_count as f32
421 } else if estimated {
422 self.estimated_token_count
423 } else {
424 let text = &self.text();
425 let token_count = self.config.tokenizer.count_tokens(text) as usize;
426 self.token_count = Some(token_count);
427 self.estimated_token_count = token_count as f32;
428 token_count as f32
429 }
430 }
431
432 fn text(&mut self) -> String {
433 if let Some(text) = &self.text {
434 text.to_owned()
435 } else {
436 let text = TextSplitter::splits_to_text(&self.used_splits, false);
437 self.text = Some(text.clone());
438 text
439 }
440 }
441}
442
443pub struct ChunkerResult {
444 incoming_text: Arc<str>,
445 initial_separator: Separator,
446 chunks: Vec<Chunk>,
447 tokenizer: Arc<Tokenizer>,
448 chunking_duration: std::time::Duration,
449}
450
451impl ChunkerResult {
452 fn new(
453 incoming_text: &str,
454 config: &Arc<ChunkerConfig>,
455 chunking_start_time: std::time::Instant,
456 mut chunks: Vec<Chunk>,
457 ) -> ChunkerResult {
458 chunks.iter_mut().for_each(|chunk| {
459 chunk.text();
460 });
461 ChunkerResult {
462 incoming_text: Arc::from(incoming_text),
463 initial_separator: config.initial_separator.clone(),
464 chunks,
465 tokenizer: Arc::clone(&config.tokenizer),
466 chunking_duration: chunking_start_time.elapsed(),
467 }
468 }
469
470 pub fn chunks_to_text(&mut self) -> Vec<String> {
471 self.chunks.iter_mut().map(|chunk| chunk.text()).collect()
472 }
473
474 pub fn token_counts(&mut self) -> Vec<u32> {
475 let mut token_counts: Vec<u32> = Vec::with_capacity(self.chunks.len());
476 for chunk in &self.chunks {
477 let chunk_text = if let Some(text) = &chunk.text {
478 text.to_owned()
479 } else {
480 TextSplitter::splits_to_text(&chunk.used_splits, false)
481 };
482 token_counts.push(self.tokenizer.count_tokens(&chunk_text));
483 }
484 token_counts
485 }
486}
487
488impl std::fmt::Debug for ChunkerResult {
489 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
490 let mut chunk_token_sizes = Vec::with_capacity(self.chunks.len());
491 let mut largest_token_size = 0;
492 let mut smallest_token_size = u32::MAX;
493 let mut all_chunks_token_count = 0;
494 let mut chunk_char_sizes = Vec::with_capacity(self.chunks.len());
495 let mut largest_char_size = 0;
496 let mut smallest_char_size = u32::MAX;
497 let mut all_chunks_char_count = 0;
498
499 for chunk in &self.chunks {
500 let chunk_text = if let Some(text) = &chunk.text {
501 text.to_owned()
502 } else {
503 panic!("Chunk text not found.")
504 };
505 let token_count = self.tokenizer.count_tokens(&chunk_text);
506 let char_count = u32::try_from(chunk_text.chars().count()).unwrap();
507 chunk_token_sizes.push(token_count);
508 chunk_char_sizes.push(char_count);
509 all_chunks_token_count += token_count;
510 all_chunks_char_count += char_count;
511 if token_count > largest_token_size {
512 largest_token_size = token_count;
513 }
514 if char_count > largest_char_size {
515 largest_char_size = char_count;
516 }
517 if token_count < smallest_token_size {
518 smallest_token_size = token_count;
519 }
520 if char_count < smallest_char_size {
521 smallest_char_size = char_count;
522 }
523 }
524 f.debug_struct("\nChunkerTestResult")
525 .field("chunk_count", &self.chunks.len())
526 .field("chunk_token_sizes", &chunk_token_sizes)
527 .field(
528 "avg_token_size",
529 &(all_chunks_token_count / u32::try_from(self.chunks.len()).unwrap()),
530 )
531 .field("largest_token_size", &largest_token_size)
532 .field("smallest_token_size", &smallest_token_size)
533 .field(
534 "incoming_text_token_count",
535 &self.tokenizer.count_tokens(&self.incoming_text),
536 )
537 .field("all_chunks_token_count", &all_chunks_token_count)
538 .field("chunk_char_sizes", &chunk_char_sizes)
539 .field(
540 "avg_char_size",
541 &(all_chunks_char_count / u32::try_from(self.chunks.len()).unwrap()),
542 )
543 .field("largest_char_size", &largest_char_size)
544 .field("smallest_char_size", &smallest_char_size)
545 .field(
546 "incoming_text_char_count",
547 &self.incoming_text.chars().count(),
548 )
549 .field("all_chunks_char_count", &all_chunks_char_count)
550 .field("chunking_duration", &self.chunking_duration)
551 .field("initial_separator", &self.initial_separator)
552 .finish()
553 }
554}
555
556pub trait Chunker: Send + Sync {
557 fn chunk_size(&self) -> usize {
558 DEFAULT_CHUNK_SIZE
559 }
560
561 fn overlap_percent(&self) -> Option<f32> {
562 None
563 }
564
565 fn chunk(&self) -> Result<Vec<String>, ChunkError>;
566}
567
568#[derive(Debug, thiserror::Error)]
570pub enum ChunkError {
571 #[error("A normal chunk error occurred: {0}")]
573 Normal(String),
574}