Skip to main content

llm_text/
splitting.rs

1use std::{
2    collections::VecDeque,
3    ops::Range,
4    sync::{Arc, LazyLock},
5};
6
7use regex::Regex;
8
9/// Split text into sentences using improved Unicode-aware sentence boundary detection.
10pub fn split_text_into_sentences(text: &str, keep_separator: bool) -> Vec<String> {
11    let sentences: Vec<&str> =
12        unicode_segmentation::UnicodeSegmentation::split_sentence_bounds(text)
13            .filter(|s| !s.trim().is_empty())
14            .collect();
15
16    if keep_separator {
17        sentences.into_iter().map(|s| s.to_string()).collect()
18    } else {
19        sentences.into_iter().map(|s| s.trim().to_string()).collect()
20    }
21}
22
23/// Split text into sentence indices using Unicode-aware sentence boundary detection.
24pub fn split_text_into_indices(text: &str, keep_separator: bool) -> Vec<Range<usize>> {
25    let mut indices = Vec::new();
26    let mut last_end = 0;
27
28    for sentence in unicode_segmentation::UnicodeSegmentation::split_sentence_bounds(text) {
29        if sentence.trim().is_empty() {
30            last_end += sentence.len();
31            continue;
32        }
33
34        let start = if keep_separator {
35            last_end
36        } else {
37            // Find the start of non-whitespace content
38            text[last_end..]
39                .char_indices()
40                .find(|(_, c)| !c.is_whitespace())
41                .map(|(i, _)| last_end + i)
42                .unwrap_or(last_end)
43        };
44
45        let end = if keep_separator {
46            last_end + sentence.len()
47        } else {
48            // Compute end of non-whitespace content using trim_end
49            let sentence_end = last_end + sentence.len();
50            let trimmed_len = text[last_end..sentence_end].trim_end().len();
51            if trimmed_len == 0 { start } else { last_end + trimmed_len }
52        };
53
54        if start < end {
55            indices.push(Range { start, end });
56        }
57
58        last_end += sentence.len();
59    }
60
61    indices
62}
63
64/// Trait for counting tokens in text. Implement this to integrate with
65/// specific tokenizers (e.g., tiktoken, HuggingFace tokenizers).
66pub trait Tokenizer {
67    /// Count the number of tokens in the given text.
68    fn count_tokens(&self, text: &str) -> u32;
69}
70
71/// A simple character-based tokenizer that approximates token count.
72/// Useful as a fallback when no real tokenizer is available.
73#[derive(Debug, Clone, Default)]
74pub struct CharRatioTokenizer {
75    /// Approximate characters per token (default: 4, similar to GPT models)
76    pub chars_per_token: f32,
77}
78
79impl CharRatioTokenizer {
80    pub fn new() -> Self {
81        Self { chars_per_token: 4.0 }
82    }
83
84    pub fn with_ratio(mut self, ratio: f32) -> Self {
85        self.chars_per_token = ratio;
86        self
87    }
88}
89
90impl Tokenizer for CharRatioTokenizer {
91    fn count_tokens(&self, text: &str) -> u32 {
92        (text.len() as f32 / self.chars_per_token).ceil() as u32
93    }
94}
95
96// --- From mod.rs ---
97
98#[derive(Default)]
99pub struct TextSplitter {
100    pub split_separator: Separator,
101    pub recursive: bool,
102    pub clean_text: bool,
103    /// Optional maximum token size for chunks. When set, recursive splitting
104    /// will only drill down to finer separators when a chunk exceeds this limit.
105    pub max_token_size: Option<u32>,
106    /// Optional tokenizer for accurate token counting. If not set, token
107    /// counting is skipped and max_token_size constraint is ignored.
108    pub tokenizer: Option<std::sync::Arc<dyn Tokenizer + Send + Sync>>,
109}
110
111impl TextSplitter {
112    pub fn new() -> Self {
113        Self {
114            split_separator: Separator::TwoPlusEoL,
115            recursive: true,
116            clean_text: true,
117            max_token_size: None,
118            tokenizer: None,
119        }
120    }
121
122    pub fn split_text(&self, text: &str) -> Option<VecDeque<TextSplit>> {
123        let base_text: Arc<str> = if self.clean_text {
124            Arc::from(self.split_separator.clean_text(text.as_ref()))
125        } else {
126            Arc::from(text)
127        };
128
129        let mut split_separator = self.split_separator.clone();
130        let split_indices = if self.recursive {
131            // When max_token_size is set with a tokenizer, only drill down to
132            // finer separators when chunks exceed the token limit
133            if let (Some(max_tokens), Some(tokenizer)) = (self.max_token_size, &self.tokenizer) {
134                self.split_with_token_limit(
135                    &base_text,
136                    &mut split_separator,
137                    max_tokens,
138                    tokenizer,
139                )?
140            } else {
141                // Original behavior: recursive split without token constraints
142                loop {
143                    let split_indices = split_separator.split_text_into_indices(&base_text);
144                    if split_indices.len() > 1 {
145                        break split_indices;
146                    } else {
147                        split_separator = split_separator.next()?;
148                    }
149                }
150            }
151        } else {
152            split_separator.split_text_into_indices(&base_text)
153        };
154        if split_indices.len() < 2 {
155            return None;
156        }
157
158        // Create splits, computing token counts if a tokenizer is available
159        let splits: VecDeque<TextSplit> = if let Some(ref tokenizer) = self.tokenizer {
160            split_indices
161                .into_iter()
162                .map(|indices| {
163                    TextSplit::with_tokenizer(
164                        &indices,
165                        &split_separator,
166                        &base_text,
167                        tokenizer.as_ref(),
168                    )
169                })
170                .collect()
171        } else {
172            split_indices
173                .into_iter()
174                .map(|indices| TextSplit::new(&indices, &split_separator, &base_text))
175                .collect()
176        };
177
178        Some(splits)
179    }
180
181    /// Split text with token size constraint. Only drill down to finer separators
182    /// when chunks exceed max_token_size, preserving larger context when possible.
183    fn split_with_token_limit(
184        &self,
185        base_text: &Arc<str>,
186        split_separator: &mut Separator,
187        max_tokens: u32,
188        tokenizer: &Arc<dyn Tokenizer + Send + Sync>,
189    ) -> Option<Vec<Range<usize>>> {
190        loop {
191            let split_indices = split_separator.split_text_into_indices(base_text);
192
193            // Check if any chunk exceeds the token limit
194            let needs_finer_split = split_indices.iter().any(|indices| {
195                let chunk_text = &base_text[indices.clone()];
196                tokenizer.count_tokens(chunk_text) > max_tokens
197            });
198
199            if needs_finer_split && split_indices.len() > 1 {
200                // Some chunks are too large, try finer separator
201                *split_separator = split_separator.next()?;
202            } else if split_indices.len() > 1 {
203                // All chunks fit within token limit, or we've reached finest separator
204                break Some(split_indices);
205            } else {
206                // Single chunk, try finer separator
207                *split_separator = split_separator.next()?;
208            }
209        }
210    }
211
212    pub fn on_two_plus_newline(mut self) -> Self {
213        self.split_separator = Separator::TwoPlusEoL;
214        self
215    }
216
217    pub fn on_single_newline(mut self) -> Self {
218        self.split_separator = Separator::SingleEol;
219        self
220    }
221
222    pub fn on_sentences_rule_based(mut self) -> Self {
223        self.split_separator = Separator::SentencesRuleBased;
224        self
225    }
226
227    pub fn on_sentences_unicode(mut self) -> Self {
228        self.split_separator = Separator::SentencesUnicode;
229        self
230    }
231
232    pub fn on_words_unicode(mut self) -> Self {
233        self.split_separator = Separator::WordsUnicode;
234        self
235    }
236
237    pub fn on_graphemes_unicode(mut self) -> Self {
238        self.split_separator = Separator::GraphemesUnicode;
239        self
240    }
241
242    pub fn on_separator(mut self, split_separator: &Separator) -> Self {
243        self.split_separator = split_separator.clone();
244        self
245    }
246
247    pub fn recursive(mut self, recursive: bool) -> Self {
248        self.recursive = recursive;
249        self
250    }
251
252    pub fn clean_text(mut self, clean_text: bool) -> Self {
253        self.clean_text = clean_text;
254        self
255    }
256
257    /// Set the maximum token size for chunks. When set with a tokenizer,
258    /// recursive splitting will only drill down to finer separators when
259    /// a chunk exceeds this limit.
260    pub fn max_token_size(mut self, max_tokens: u32) -> Self {
261        self.max_token_size = Some(max_tokens);
262        self
263    }
264
265    /// Set the tokenizer for accurate token counting.
266    pub fn with_tokenizer(mut self, tokenizer: Arc<dyn Tokenizer + Send + Sync>) -> Self {
267        self.tokenizer = Some(tokenizer);
268        self
269    }
270
271    /// Split a TextSplit into smaller splits using the configured separator.
272    /// This is the public API that accepts a TextSplit directly.
273    pub fn split_text_split(self, split: &TextSplit) -> Option<VecDeque<TextSplit>> {
274        self.split_split_inner(&split.base_text, &split.indices)
275    }
276
277    /// Internal implementation that works with raw base_text and indices.
278    fn split_split_inner(
279        self,
280        base_text: &Arc<str>,
281        split_indices: &Range<usize>,
282    ) -> Option<VecDeque<TextSplit>> {
283        let start_offset = split_indices.start;
284        let split_text = &base_text[split_indices.clone()];
285
286        let mut split_separator = self.split_separator.clone();
287        let split_indices = loop {
288            let split_indices = split_separator.split_text_into_indices(split_text);
289            if split_indices.len() > 1 {
290                break split_indices;
291            } else {
292                split_separator = split_separator.next()?;
293            }
294        };
295        Some(
296            split_indices
297                .into_iter()
298                .map(|indices| {
299                    let start = start_offset + indices.start;
300                    let end = start_offset + indices.end;
301                    TextSplit::new(&Range { start, end }, &split_separator, base_text)
302                })
303                .collect(),
304        )
305    }
306
307    pub fn splits_to_text(splits: &VecDeque<TextSplit>, with_separator: bool) -> String {
308        let mut text = String::new();
309        let mut last_separator = Separator::None;
310        for (i, split) in splits.iter().enumerate() {
311            if last_separator == Separator::GraphemesUnicode &&
312                split.split_separator != Separator::GraphemesUnicode
313            {
314                text.push(' ');
315            };
316            last_separator = split.split_separator.clone();
317            match split.split_separator {
318                Separator::TwoPlusEoL => {
319                    text.push_str(split.text());
320                    if with_separator {
321                        text.push_str("\n\n");
322                    } else if i < splits.len() - 1 {
323                        text.push(' ');
324                    }
325                }
326                Separator::SingleEol => {
327                    text.push_str(split.text());
328                    if with_separator {
329                        text.push('\n');
330                    } else if i < splits.len() - 1 {
331                        text.push(' ');
332                    }
333                }
334                Separator::SentencesRuleBased |
335                Separator::SentencesUnicode |
336                Separator::WordsUnicode => {
337                    text.push_str(split.text());
338                    if i < splits.len() - 1 {
339                        text.push(' ');
340                    }
341                }
342                Separator::GraphemesUnicode => {
343                    text.push_str(split.text());
344                }
345                Separator::None => unreachable!(),
346            }
347        }
348        text
349    }
350}
351
352#[derive(Debug, Clone)]
353pub struct TextSplit {
354    pub indices: Range<usize>,
355    pub split_separator: Separator,
356    pub base_text: Arc<str>,
357    /// Token count computed by the tokenizer, if one was provided.
358    pub token_count: Option<u32>,
359}
360
361impl TextSplit {
362    fn new(indices: &Range<usize>, split_separator: &Separator, base_text: &Arc<str>) -> Self {
363        Self {
364            indices: indices.clone(),
365            split_separator: split_separator.clone(),
366            base_text: Arc::clone(base_text),
367            token_count: None,
368        }
369    }
370
371    /// Create a new TextSplit with token count computed using the provided tokenizer.
372    fn with_tokenizer(
373        indices: &Range<usize>,
374        split_separator: &Separator,
375        base_text: &Arc<str>,
376        tokenizer: &dyn Tokenizer,
377    ) -> Self {
378        let text = &base_text[indices.clone()];
379        Self {
380            indices: indices.clone(),
381            split_separator: split_separator.clone(),
382            base_text: Arc::clone(base_text),
383            token_count: Some(tokenizer.count_tokens(text)),
384        }
385    }
386
387    pub fn char_count(&mut self) -> usize {
388        self.text().chars().count()
389    }
390
391    pub fn text(&self) -> &str {
392        &self.base_text[self.indices.clone()]
393    }
394
395    pub fn split(&self) -> Option<VecDeque<TextSplit>> {
396        TextSplitter::default()
397            .on_separator(&self.split_separator.next()?)
398            .split_split_inner(&self.base_text, &self.indices)
399    }
400}
401
402#[derive(PartialEq)]
403pub enum SeparatorGroup {
404    Semantic,
405    Syntactic,
406}
407impl SeparatorGroup {
408    pub fn get(&self) -> Vec<Separator> {
409        match self {
410            Self::Semantic => vec![
411                Separator::TwoPlusEoL,
412                Separator::SingleEol,
413                Separator::SentencesRuleBased,
414                Separator::SentencesUnicode,
415            ],
416            Self::Syntactic => vec![Separator::WordsUnicode, Separator::GraphemesUnicode],
417        }
418    }
419}
420
421#[derive(PartialEq, Debug, Clone, Default)]
422pub enum Separator {
423    #[default]
424    TwoPlusEoL,
425    SingleEol,
426    SentencesRuleBased,
427    SentencesUnicode,
428    WordsUnicode,
429    GraphemesUnicode,
430    None,
431}
432
433impl Separator {
434    pub fn get_all() -> Vec<Self> {
435        vec![
436            Self::TwoPlusEoL,
437            Self::SingleEol,
438            Self::SentencesRuleBased,
439            Self::SentencesUnicode,
440            Self::WordsUnicode,
441            // Self::GraphemesUnicode,
442        ]
443    }
444
445    pub fn group(&self) -> SeparatorGroup {
446        match self {
447            Self::TwoPlusEoL |
448            Self::SingleEol |
449            Self::SentencesRuleBased |
450            Self::SentencesUnicode => SeparatorGroup::Semantic,
451            Self::WordsUnicode | Self::GraphemesUnicode => SeparatorGroup::Syntactic,
452            Self::None => unreachable!(),
453        }
454    }
455
456    pub fn clean_text(&self, text: &str) -> String {
457        match self {
458            Self::TwoPlusEoL => {
459                crate::text::TextCleaner::new().reduce_newlines_to_double_newline().run(text)
460            }
461            Self::SingleEol => {
462                crate::text::TextCleaner::new().reduce_newlines_to_single_newline().run(text)
463            }
464            Self::SentencesRuleBased |
465            Self::SentencesUnicode |
466            Self::WordsUnicode |
467            Self::GraphemesUnicode => {
468                crate::text::TextCleaner::new().reduce_newlines_to_single_space().run(text)
469            }
470            Self::None => unreachable!(),
471        }
472    }
473
474    pub fn split_text_into_indices<T: AsRef<str>>(&self, text: T) -> Vec<Range<usize>> {
475        let mut split_indices: Vec<Range<usize>> = Vec::new();
476        match self {
477            Self::TwoPlusEoL => {
478                let mut last_end = 0;
479                for m in TWO_PLUS_NEWLINE_REGEX.find_iter(text.as_ref()) {
480                    let start = m.start();
481                    let end = m.end();
482                    if start > last_end {
483                        split_indices.push(Range { start: last_end, end: start });
484                    }
485                    split_indices.push(Range { start, end });
486                    last_end = end;
487                }
488                if last_end < text.as_ref().len() {
489                    split_indices.push(Range { start: last_end, end: text.as_ref().len() });
490                }
491            }
492            Self::SingleEol => {
493                // Use native string matching instead of regex for single newline
494                let text_ref = text.as_ref();
495                let mut last_end = 0;
496                for (idx, _) in text_ref.match_indices('\n') {
497                    if idx > last_end {
498                        split_indices.push(Range { start: last_end, end: idx });
499                    }
500                    split_indices.push(Range { start: idx, end: idx + 1 });
501                    last_end = idx + 1;
502                }
503                if last_end < text_ref.len() {
504                    split_indices.push(Range { start: last_end, end: text_ref.len() });
505                }
506            }
507            Self::SentencesRuleBased => {
508                split_indices = split_text_into_indices(text.as_ref(), true);
509            }
510            Self::SentencesUnicode | Self::WordsUnicode | Self::GraphemesUnicode => {
511                let indices: Vec<(usize, &str)> = match self {
512                    Self::SentencesUnicode => {
513                        unicode_segmentation::UnicodeSegmentation::split_sentence_bound_indices(
514                            text.as_ref(),
515                        )
516                        .collect()
517                    }
518                    Self::WordsUnicode => {
519                        unicode_segmentation::UnicodeSegmentation::unicode_word_indices(
520                            text.as_ref(),
521                        )
522                        .collect()
523                    }
524                    Self::GraphemesUnicode => {
525                        unicode_segmentation::UnicodeSegmentation::grapheme_indices(
526                            text.as_ref(),
527                            true,
528                        )
529                        .collect()
530                    }
531                    _ => unreachable!(),
532                };
533                for i in 0..indices.len() {
534                    let end_index =
535                        if i == indices.len() - 1 { text.as_ref().len() } else { indices[i + 1].0 };
536                    split_indices.push(Range { start: indices[i].0, end: end_index });
537                }
538            }
539            Self::None => unreachable!(),
540        }
541        split_indices
542            .into_iter()
543            .filter_map(|indices| self.trim_range(&indices, text.as_ref()))
544            .collect()
545    }
546
547    pub fn next(&self) -> Option<Self> {
548        match self {
549            Self::TwoPlusEoL => Some(Self::SingleEol),
550            Self::SingleEol => Some(Self::SentencesRuleBased),
551            Self::SentencesRuleBased => Some(Self::SentencesUnicode),
552            Self::SentencesUnicode => Some(Self::WordsUnicode),
553            Self::WordsUnicode => Some(Self::GraphemesUnicode),
554            Self::GraphemesUnicode => None,
555            Self::None => unreachable!(),
556        }
557    }
558    fn trim_range<T: AsRef<str>>(&self, indices: &Range<usize>, text: T) -> Option<Range<usize>> {
559        let (start, end) = match self {
560            Self::TwoPlusEoL |
561            Self::SingleEol |
562            Self::SentencesRuleBased |
563            Self::SentencesUnicode => {
564                let start = text.as_ref()[indices.start..indices.end]
565                    .char_indices()
566                    .find(|(_, c)| !c.is_whitespace())
567                    .map(|(i, _)| indices.start + i)
568                    .unwrap_or(indices.end);
569                let end = if indices.end == text.as_ref().len() {
570                    text.as_ref().len()
571                } else {
572                    text.as_ref()[indices.start..indices.end]
573                        .char_indices()
574                        .rev()
575                        .find(|(_, c)| !c.is_whitespace())
576                        .map(|(i, c)| indices.start + i + c.len_utf8())
577                        .unwrap_or(start)
578                };
579                (start, end)
580            }
581            Self::WordsUnicode => {
582                let start = text.as_ref()[..indices.start]
583                    .char_indices()
584                    .rev()
585                    .find(|(_, c)| c.is_whitespace())
586                    .map(|(i, c)| i + c.len_utf8())
587                    .unwrap_or(indices.start);
588                let end = if indices.end == text.as_ref().len() {
589                    text.as_ref().len()
590                } else {
591                    text.as_ref()[indices.start..indices.end]
592                        .char_indices()
593                        .find(|(_, c)| c.is_whitespace())
594                        .map(|(i, _)| indices.start + i)
595                        .unwrap_or(start)
596                };
597                (start, end)
598            }
599            Self::GraphemesUnicode => (indices.start, indices.end),
600            Self::None => unreachable!(),
601        };
602
603        if start >= end { None } else { Some(Range { start, end }) }
604    }
605}
606
607pub static TWO_PLUS_NEWLINE_REGEX: LazyLock<Regex> =
608    LazyLock::new(|| Regex::new(r"\n{2,}").unwrap());