Skip to main content

alith_core/
splitting.rs

1pub mod markdown;
2pub mod rule_based;
3
4use crate::cleaner::TextCleaner;
5use regex::Regex;
6use std::{
7    collections::VecDeque,
8    ops::Range,
9    sync::{Arc, LazyLock},
10};
11use text_splitter::ChunkConfigError;
12use thiserror::Error;
13
14pub use markdown::split_markdown;
15pub use rule_based::split_text_into_indices;
16
17#[derive(Error, Debug)]
18pub enum SplitError {
19    #[error("Chunk config error: {0}")]
20    ChunkConfigError(#[from] ChunkConfigError),
21}
22
23#[derive(Default)]
24pub struct TextSplitter {
25    pub split_separator: Separator,
26    pub recursive: bool,
27    pub clean_text: bool,
28}
29
30impl TextSplitter {
31    pub fn new() -> Self {
32        Self {
33            split_separator: Separator::TwoPlusEoL,
34            recursive: true,
35            clean_text: true,
36        }
37    }
38
39    pub fn split_text(&self, text: &str) -> Option<VecDeque<TextSplit>> {
40        let base_text: Arc<str> = if self.clean_text {
41            Arc::from(self.split_separator.clean_text(text.as_ref()))
42        } else {
43            Arc::from(text)
44        };
45
46        let mut split_separator = self.split_separator.clone();
47        let split_indices = if self.recursive {
48            loop {
49                let split_indices = split_separator.split_text_into_indices(&base_text);
50                if split_indices.len() > 1 {
51                    break split_indices;
52                } else {
53                    split_separator = split_separator.next()?;
54                }
55            }
56        } else {
57            split_separator.split_text_into_indices(&base_text)
58        };
59        if split_indices.len() < 2 {
60            return None;
61        }
62
63        Some(
64            split_indices
65                .into_iter()
66                .map(|indices| TextSplit::new(&indices, &split_separator, &base_text))
67                .collect(),
68        )
69    }
70
71    pub fn on_two_plus_newline(mut self) -> Self {
72        self.split_separator = Separator::TwoPlusEoL;
73        self
74    }
75
76    pub fn on_single_newline(mut self) -> Self {
77        self.split_separator = Separator::SingleEol;
78        self
79    }
80
81    pub fn on_sentences_rule_based(mut self) -> Self {
82        self.split_separator = Separator::SentencesRuleBased;
83        self
84    }
85
86    pub fn on_sentences_unicode(mut self) -> Self {
87        self.split_separator = Separator::SentencesUnicode;
88        self
89    }
90
91    pub fn on_words_unicode(mut self) -> Self {
92        self.split_separator = Separator::WordsUnicode;
93        self
94    }
95
96    pub fn on_graphemes_unicode(mut self) -> Self {
97        self.split_separator = Separator::GraphemesUnicode;
98        self
99    }
100
101    pub fn on_separator(mut self, split_separator: &Separator) -> Self {
102        self.split_separator = split_separator.clone();
103        self
104    }
105
106    pub fn recursive(mut self, recursive: bool) -> Self {
107        self.recursive = recursive;
108        self
109    }
110
111    pub fn clean_text(mut self, clean_text: bool) -> Self {
112        self.clean_text = clean_text;
113        self
114    }
115
116    pub fn split_split(
117        self,
118        base_text: &Arc<str>,
119        split_indices: &Range<usize>,
120    ) -> Option<VecDeque<TextSplit>> {
121        let start_offset = split_indices.start;
122        let split_text = &base_text[split_indices.clone()];
123
124        let mut split_separator = self.split_separator.clone();
125        let split_indices = loop {
126            let split_indices = split_separator.split_text_into_indices(split_text);
127            if split_indices.len() > 1 {
128                break split_indices;
129            } else {
130                split_separator = split_separator.next()?;
131            }
132        };
133        Some(
134            split_indices
135                .into_iter()
136                .map(|indices| {
137                    let start = start_offset + indices.start;
138                    let end = start_offset + indices.end;
139                    TextSplit::new(&Range { start, end }, &split_separator, base_text)
140                })
141                .collect(),
142        )
143    }
144
145    pub fn splits_to_text(splits: &VecDeque<TextSplit>, with_seperator: bool) -> String {
146        let mut text = String::new();
147        let mut last_separator = Separator::None;
148        for (i, split) in splits.iter().enumerate() {
149            if last_separator == Separator::GraphemesUnicode
150                && split.split_separator != Separator::GraphemesUnicode
151            {
152                text.push(' ');
153            };
154            last_separator = split.split_separator.clone();
155            match split.split_separator {
156                Separator::TwoPlusEoL => {
157                    text.push_str(split.text());
158                    if with_seperator {
159                        text.push_str("\n\n");
160                    } else if i < splits.len() - 1 {
161                        text.push(' ');
162                    }
163                }
164                Separator::SingleEol => {
165                    text.push_str(split.text());
166                    if with_seperator {
167                        text.push('\n');
168                    } else if i < splits.len() - 1 {
169                        text.push(' ');
170                    }
171                }
172                Separator::SentencesRuleBased
173                | Separator::SentencesUnicode
174                | Separator::WordsUnicode => {
175                    text.push_str(split.text());
176                    if i < splits.len() - 1 {
177                        text.push(' ');
178                    }
179                }
180                Separator::GraphemesUnicode => {
181                    text.push_str(split.text());
182                }
183                Separator::None => unreachable!(),
184            }
185        }
186        text
187    }
188}
189
190#[derive(Debug, Clone)]
191pub struct TextSplit {
192    pub indices: Range<usize>,
193    pub split_separator: Separator,
194    pub base_text: Arc<str>,
195    pub token_count: Option<u32>,
196}
197
198impl TextSplit {
199    fn new(indices: &Range<usize>, split_separator: &Separator, base_text: &Arc<str>) -> Self {
200        Self {
201            indices: indices.clone(),
202            split_separator: split_separator.clone(),
203            base_text: Arc::clone(base_text),
204
205            token_count: None,
206        }
207    }
208
209    pub fn char_count(&mut self) -> usize {
210        self.text().chars().count()
211    }
212
213    pub fn text(&self) -> &str {
214        &self.base_text[self.indices.clone()]
215    }
216
217    pub fn split(&self) -> Option<VecDeque<TextSplit>> {
218        TextSplitter::default()
219            .on_separator(&self.split_separator.next()?)
220            .split_split(&self.base_text, &self.indices)
221    }
222}
223
224#[derive(PartialEq)]
225pub enum SeparatorGroup {
226    Semantic,
227    Syntactic,
228}
229impl SeparatorGroup {
230    pub fn get(&self) -> Vec<Separator> {
231        match self {
232            Self::Semantic => vec![
233                Separator::TwoPlusEoL,
234                Separator::SingleEol,
235                Separator::SentencesRuleBased,
236                Separator::SentencesUnicode,
237            ],
238            Self::Syntactic => vec![Separator::WordsUnicode, Separator::GraphemesUnicode],
239        }
240    }
241}
242
243#[derive(PartialEq, Debug, Clone, Default)]
244pub enum Separator {
245    #[default]
246    TwoPlusEoL,
247    SingleEol,
248    SentencesRuleBased,
249    SentencesUnicode,
250    WordsUnicode,
251    GraphemesUnicode,
252    None,
253}
254
255impl Separator {
256    pub fn get_all() -> Vec<Self> {
257        vec![
258            Self::TwoPlusEoL,
259            Self::SingleEol,
260            Self::SentencesRuleBased,
261            Self::SentencesUnicode,
262            Self::WordsUnicode,
263            // Self::GraphemesUnicode,
264        ]
265    }
266
267    pub fn group(&self) -> SeparatorGroup {
268        match self {
269            Self::TwoPlusEoL
270            | Self::SingleEol
271            | Self::SentencesRuleBased
272            | Self::SentencesUnicode => SeparatorGroup::Semantic,
273            Self::WordsUnicode | Self::GraphemesUnicode => SeparatorGroup::Syntactic,
274            Self::None => unreachable!(),
275        }
276    }
277
278    pub fn clean_text(&self, text: &str) -> String {
279        match self {
280            Self::TwoPlusEoL => TextCleaner::new()
281                .reduce_newlines_to_double_newline()
282                .run(text),
283            Self::SingleEol => TextCleaner::new()
284                .reduce_newlines_to_single_newline()
285                .run(text),
286            Self::SentencesRuleBased
287            | Self::SentencesUnicode
288            | Self::WordsUnicode
289            | Self::GraphemesUnicode => TextCleaner::new()
290                .reduce_newlines_to_single_space()
291                .run(text),
292            Self::None => unreachable!(),
293        }
294    }
295
296    pub fn split_text_into_indices<T: AsRef<str>>(&self, text: T) -> Vec<Range<usize>> {
297        let mut split_indices: Vec<Range<usize>> = Vec::new();
298        match self {
299            Self::TwoPlusEoL | Self::SingleEol => {
300                let pattern_matches = match self {
301                    Self::TwoPlusEoL => TWO_PLUS_NEWLINE_REGEX.find_iter(text.as_ref()),
302                    Self::SingleEol => SINGLE_NEWLINE_REGEX.find_iter(text.as_ref()),
303                    _ => unreachable!(),
304                };
305                let mut last_end = 0;
306                for m in pattern_matches {
307                    let start = m.start();
308                    let end = m.end();
309                    if start > last_end {
310                        split_indices.push(Range {
311                            start: last_end,
312                            end: start,
313                        });
314                    }
315                    split_indices.push(Range { start, end });
316                    last_end = end;
317                }
318                if last_end < text.as_ref().len() {
319                    split_indices.push(Range {
320                        start: last_end,
321                        end: text.as_ref().len(),
322                    });
323                }
324            }
325            Self::SentencesRuleBased => {
326                split_indices = split_text_into_indices(text.as_ref(), true);
327            }
328            Self::SentencesUnicode | Self::WordsUnicode | Self::GraphemesUnicode => {
329                let indices: Vec<(usize, &str)> = match self {
330                    Self::SentencesUnicode => {
331                        unicode_segmentation::UnicodeSegmentation::split_sentence_bound_indices(
332                            text.as_ref(),
333                        )
334                        .collect()
335                    }
336                    Self::WordsUnicode => {
337                        unicode_segmentation::UnicodeSegmentation::unicode_word_indices(
338                            text.as_ref(),
339                        )
340                        .collect()
341                    }
342                    Self::GraphemesUnicode => {
343                        unicode_segmentation::UnicodeSegmentation::grapheme_indices(
344                            text.as_ref(),
345                            true,
346                        )
347                        .collect()
348                    }
349                    _ => unreachable!(),
350                };
351                for i in 0..indices.len() {
352                    let end_index = if i == indices.len() - 1 {
353                        text.as_ref().len()
354                    } else {
355                        indices[i + 1].0
356                    };
357                    split_indices.push(Range {
358                        start: indices[i].0,
359                        end: end_index,
360                    });
361                }
362            }
363            Self::None => unreachable!(),
364        }
365        split_indices
366            .into_iter()
367            .filter_map(|indices| self.trim_range(&indices, text.as_ref()))
368            .collect()
369    }
370
371    pub fn next(&self) -> Option<Self> {
372        match self {
373            Self::TwoPlusEoL => Some(Self::SingleEol),
374            Self::SingleEol => Some(Self::SentencesRuleBased),
375            Self::SentencesRuleBased => Some(Self::SentencesUnicode),
376            Self::SentencesUnicode => Some(Self::WordsUnicode),
377            Self::WordsUnicode => Some(Self::GraphemesUnicode),
378            Self::GraphemesUnicode => None,
379            Self::None => unreachable!(),
380        }
381    }
382    fn trim_range<T: AsRef<str>>(&self, indices: &Range<usize>, text: T) -> Option<Range<usize>> {
383        let (start, end) = match self {
384            Self::TwoPlusEoL
385            | Self::SingleEol
386            | Self::SentencesRuleBased
387            | Self::SentencesUnicode => {
388                let start = text.as_ref()[indices.start..indices.end]
389                    .char_indices()
390                    .find(|(_, c)| !c.is_whitespace())
391                    .map(|(i, _)| indices.start + i)
392                    .unwrap_or(indices.end);
393                let end = if indices.end == text.as_ref().len() {
394                    text.as_ref().len()
395                } else {
396                    text.as_ref()[indices.start..indices.end]
397                        .char_indices()
398                        .rev()
399                        .find(|(_, c)| !c.is_whitespace())
400                        .map(|(i, c)| indices.start + i + c.len_utf8())
401                        .unwrap_or(start)
402                };
403                (start, end)
404            }
405            Self::WordsUnicode => {
406                let start = text.as_ref()[..indices.start]
407                    .char_indices()
408                    .rev()
409                    .find(|(_, c)| c.is_whitespace())
410                    .map(|(i, c)| i + c.len_utf8())
411                    .unwrap_or(indices.start);
412                let end = if indices.end == text.as_ref().len() {
413                    text.as_ref().len()
414                } else {
415                    text.as_ref()[indices.start..indices.end]
416                        .char_indices()
417                        .find(|(_, c)| c.is_whitespace())
418                        .map(|(i, _)| indices.start + i)
419                        .unwrap_or(start)
420                };
421                (start, end)
422            }
423            Self::GraphemesUnicode => (indices.start, indices.end),
424            Self::None => unreachable!(),
425        };
426
427        if start >= end {
428            None
429        } else {
430            Some(Range { start, end })
431        }
432    }
433}
434
435pub static TWO_PLUS_NEWLINE_REGEX: LazyLock<Regex> =
436    LazyLock::new(|| Regex::new(r"\n{2,}").unwrap());
437pub static SINGLE_NEWLINE_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\n").unwrap());
438
439#[inline]
440pub fn split_text(text: &str) -> Vec<String> {
441    match TextSplitter::new().split_text(text) {
442        Some(splits) => splits
443            .iter()
444            .map(|split| split.text().to_string())
445            .collect(),
446        None => vec![],
447    }
448}