1pub mod markdown;
2pub mod rule_based;
3
4use crate::cleaner::TextCleaner;
5use regex::Regex;
6use std::{
7 collections::VecDeque,
8 ops::Range,
9 sync::{Arc, LazyLock},
10};
11use text_splitter::ChunkConfigError;
12use thiserror::Error;
13
14pub use markdown::split_markdown;
15pub use rule_based::split_text_into_indices;
16
17#[derive(Error, Debug)]
18pub enum SplitError {
19 #[error("Chunk config error: {0}")]
20 ChunkConfigError(#[from] ChunkConfigError),
21}
22
23#[derive(Default)]
24pub struct TextSplitter {
25 pub split_separator: Separator,
26 pub recursive: bool,
27 pub clean_text: bool,
28}
29
30impl TextSplitter {
31 pub fn new() -> Self {
32 Self {
33 split_separator: Separator::TwoPlusEoL,
34 recursive: true,
35 clean_text: true,
36 }
37 }
38
39 pub fn split_text(&self, text: &str) -> Option<VecDeque<TextSplit>> {
40 let base_text: Arc<str> = if self.clean_text {
41 Arc::from(self.split_separator.clean_text(text.as_ref()))
42 } else {
43 Arc::from(text)
44 };
45
46 let mut split_separator = self.split_separator.clone();
47 let split_indices = if self.recursive {
48 loop {
49 let split_indices = split_separator.split_text_into_indices(&base_text);
50 if split_indices.len() > 1 {
51 break split_indices;
52 } else {
53 split_separator = split_separator.next()?;
54 }
55 }
56 } else {
57 split_separator.split_text_into_indices(&base_text)
58 };
59 if split_indices.len() < 2 {
60 return None;
61 }
62
63 Some(
64 split_indices
65 .into_iter()
66 .map(|indices| TextSplit::new(&indices, &split_separator, &base_text))
67 .collect(),
68 )
69 }
70
71 pub fn on_two_plus_newline(mut self) -> Self {
72 self.split_separator = Separator::TwoPlusEoL;
73 self
74 }
75
76 pub fn on_single_newline(mut self) -> Self {
77 self.split_separator = Separator::SingleEol;
78 self
79 }
80
81 pub fn on_sentences_rule_based(mut self) -> Self {
82 self.split_separator = Separator::SentencesRuleBased;
83 self
84 }
85
86 pub fn on_sentences_unicode(mut self) -> Self {
87 self.split_separator = Separator::SentencesUnicode;
88 self
89 }
90
91 pub fn on_words_unicode(mut self) -> Self {
92 self.split_separator = Separator::WordsUnicode;
93 self
94 }
95
96 pub fn on_graphemes_unicode(mut self) -> Self {
97 self.split_separator = Separator::GraphemesUnicode;
98 self
99 }
100
101 pub fn on_separator(mut self, split_separator: &Separator) -> Self {
102 self.split_separator = split_separator.clone();
103 self
104 }
105
106 pub fn recursive(mut self, recursive: bool) -> Self {
107 self.recursive = recursive;
108 self
109 }
110
111 pub fn clean_text(mut self, clean_text: bool) -> Self {
112 self.clean_text = clean_text;
113 self
114 }
115
116 pub fn split_split(
117 self,
118 base_text: &Arc<str>,
119 split_indices: &Range<usize>,
120 ) -> Option<VecDeque<TextSplit>> {
121 let start_offset = split_indices.start;
122 let split_text = &base_text[split_indices.clone()];
123
124 let mut split_separator = self.split_separator.clone();
125 let split_indices = loop {
126 let split_indices = split_separator.split_text_into_indices(split_text);
127 if split_indices.len() > 1 {
128 break split_indices;
129 } else {
130 split_separator = split_separator.next()?;
131 }
132 };
133 Some(
134 split_indices
135 .into_iter()
136 .map(|indices| {
137 let start = start_offset + indices.start;
138 let end = start_offset + indices.end;
139 TextSplit::new(&Range { start, end }, &split_separator, base_text)
140 })
141 .collect(),
142 )
143 }
144
145 pub fn splits_to_text(splits: &VecDeque<TextSplit>, with_seperator: bool) -> String {
146 let mut text = String::new();
147 let mut last_separator = Separator::None;
148 for (i, split) in splits.iter().enumerate() {
149 if last_separator == Separator::GraphemesUnicode
150 && split.split_separator != Separator::GraphemesUnicode
151 {
152 text.push(' ');
153 };
154 last_separator = split.split_separator.clone();
155 match split.split_separator {
156 Separator::TwoPlusEoL => {
157 text.push_str(split.text());
158 if with_seperator {
159 text.push_str("\n\n");
160 } else if i < splits.len() - 1 {
161 text.push(' ');
162 }
163 }
164 Separator::SingleEol => {
165 text.push_str(split.text());
166 if with_seperator {
167 text.push('\n');
168 } else if i < splits.len() - 1 {
169 text.push(' ');
170 }
171 }
172 Separator::SentencesRuleBased
173 | Separator::SentencesUnicode
174 | Separator::WordsUnicode => {
175 text.push_str(split.text());
176 if i < splits.len() - 1 {
177 text.push(' ');
178 }
179 }
180 Separator::GraphemesUnicode => {
181 text.push_str(split.text());
182 }
183 Separator::None => unreachable!(),
184 }
185 }
186 text
187 }
188}
189
190#[derive(Debug, Clone)]
191pub struct TextSplit {
192 pub indices: Range<usize>,
193 pub split_separator: Separator,
194 pub base_text: Arc<str>,
195 pub token_count: Option<u32>,
196}
197
198impl TextSplit {
199 fn new(indices: &Range<usize>, split_separator: &Separator, base_text: &Arc<str>) -> Self {
200 Self {
201 indices: indices.clone(),
202 split_separator: split_separator.clone(),
203 base_text: Arc::clone(base_text),
204
205 token_count: None,
206 }
207 }
208
209 pub fn char_count(&mut self) -> usize {
210 self.text().chars().count()
211 }
212
213 pub fn text(&self) -> &str {
214 &self.base_text[self.indices.clone()]
215 }
216
217 pub fn split(&self) -> Option<VecDeque<TextSplit>> {
218 TextSplitter::default()
219 .on_separator(&self.split_separator.next()?)
220 .split_split(&self.base_text, &self.indices)
221 }
222}
223
224#[derive(PartialEq)]
225pub enum SeparatorGroup {
226 Semantic,
227 Syntactic,
228}
229impl SeparatorGroup {
230 pub fn get(&self) -> Vec<Separator> {
231 match self {
232 Self::Semantic => vec![
233 Separator::TwoPlusEoL,
234 Separator::SingleEol,
235 Separator::SentencesRuleBased,
236 Separator::SentencesUnicode,
237 ],
238 Self::Syntactic => vec![Separator::WordsUnicode, Separator::GraphemesUnicode],
239 }
240 }
241}
242
243#[derive(PartialEq, Debug, Clone, Default)]
244pub enum Separator {
245 #[default]
246 TwoPlusEoL,
247 SingleEol,
248 SentencesRuleBased,
249 SentencesUnicode,
250 WordsUnicode,
251 GraphemesUnicode,
252 None,
253}
254
255impl Separator {
256 pub fn get_all() -> Vec<Self> {
257 vec![
258 Self::TwoPlusEoL,
259 Self::SingleEol,
260 Self::SentencesRuleBased,
261 Self::SentencesUnicode,
262 Self::WordsUnicode,
263 ]
265 }
266
267 pub fn group(&self) -> SeparatorGroup {
268 match self {
269 Self::TwoPlusEoL
270 | Self::SingleEol
271 | Self::SentencesRuleBased
272 | Self::SentencesUnicode => SeparatorGroup::Semantic,
273 Self::WordsUnicode | Self::GraphemesUnicode => SeparatorGroup::Syntactic,
274 Self::None => unreachable!(),
275 }
276 }
277
278 pub fn clean_text(&self, text: &str) -> String {
279 match self {
280 Self::TwoPlusEoL => TextCleaner::new()
281 .reduce_newlines_to_double_newline()
282 .run(text),
283 Self::SingleEol => TextCleaner::new()
284 .reduce_newlines_to_single_newline()
285 .run(text),
286 Self::SentencesRuleBased
287 | Self::SentencesUnicode
288 | Self::WordsUnicode
289 | Self::GraphemesUnicode => TextCleaner::new()
290 .reduce_newlines_to_single_space()
291 .run(text),
292 Self::None => unreachable!(),
293 }
294 }
295
296 pub fn split_text_into_indices<T: AsRef<str>>(&self, text: T) -> Vec<Range<usize>> {
297 let mut split_indices: Vec<Range<usize>> = Vec::new();
298 match self {
299 Self::TwoPlusEoL | Self::SingleEol => {
300 let pattern_matches = match self {
301 Self::TwoPlusEoL => TWO_PLUS_NEWLINE_REGEX.find_iter(text.as_ref()),
302 Self::SingleEol => SINGLE_NEWLINE_REGEX.find_iter(text.as_ref()),
303 _ => unreachable!(),
304 };
305 let mut last_end = 0;
306 for m in pattern_matches {
307 let start = m.start();
308 let end = m.end();
309 if start > last_end {
310 split_indices.push(Range {
311 start: last_end,
312 end: start,
313 });
314 }
315 split_indices.push(Range { start, end });
316 last_end = end;
317 }
318 if last_end < text.as_ref().len() {
319 split_indices.push(Range {
320 start: last_end,
321 end: text.as_ref().len(),
322 });
323 }
324 }
325 Self::SentencesRuleBased => {
326 split_indices = split_text_into_indices(text.as_ref(), true);
327 }
328 Self::SentencesUnicode | Self::WordsUnicode | Self::GraphemesUnicode => {
329 let indices: Vec<(usize, &str)> = match self {
330 Self::SentencesUnicode => {
331 unicode_segmentation::UnicodeSegmentation::split_sentence_bound_indices(
332 text.as_ref(),
333 )
334 .collect()
335 }
336 Self::WordsUnicode => {
337 unicode_segmentation::UnicodeSegmentation::unicode_word_indices(
338 text.as_ref(),
339 )
340 .collect()
341 }
342 Self::GraphemesUnicode => {
343 unicode_segmentation::UnicodeSegmentation::grapheme_indices(
344 text.as_ref(),
345 true,
346 )
347 .collect()
348 }
349 _ => unreachable!(),
350 };
351 for i in 0..indices.len() {
352 let end_index = if i == indices.len() - 1 {
353 text.as_ref().len()
354 } else {
355 indices[i + 1].0
356 };
357 split_indices.push(Range {
358 start: indices[i].0,
359 end: end_index,
360 });
361 }
362 }
363 Self::None => unreachable!(),
364 }
365 split_indices
366 .into_iter()
367 .filter_map(|indices| self.trim_range(&indices, text.as_ref()))
368 .collect()
369 }
370
371 pub fn next(&self) -> Option<Self> {
372 match self {
373 Self::TwoPlusEoL => Some(Self::SingleEol),
374 Self::SingleEol => Some(Self::SentencesRuleBased),
375 Self::SentencesRuleBased => Some(Self::SentencesUnicode),
376 Self::SentencesUnicode => Some(Self::WordsUnicode),
377 Self::WordsUnicode => Some(Self::GraphemesUnicode),
378 Self::GraphemesUnicode => None,
379 Self::None => unreachable!(),
380 }
381 }
382 fn trim_range<T: AsRef<str>>(&self, indices: &Range<usize>, text: T) -> Option<Range<usize>> {
383 let (start, end) = match self {
384 Self::TwoPlusEoL
385 | Self::SingleEol
386 | Self::SentencesRuleBased
387 | Self::SentencesUnicode => {
388 let start = text.as_ref()[indices.start..indices.end]
389 .char_indices()
390 .find(|(_, c)| !c.is_whitespace())
391 .map(|(i, _)| indices.start + i)
392 .unwrap_or(indices.end);
393 let end = if indices.end == text.as_ref().len() {
394 text.as_ref().len()
395 } else {
396 text.as_ref()[indices.start..indices.end]
397 .char_indices()
398 .rev()
399 .find(|(_, c)| !c.is_whitespace())
400 .map(|(i, c)| indices.start + i + c.len_utf8())
401 .unwrap_or(start)
402 };
403 (start, end)
404 }
405 Self::WordsUnicode => {
406 let start = text.as_ref()[..indices.start]
407 .char_indices()
408 .rev()
409 .find(|(_, c)| c.is_whitespace())
410 .map(|(i, c)| i + c.len_utf8())
411 .unwrap_or(indices.start);
412 let end = if indices.end == text.as_ref().len() {
413 text.as_ref().len()
414 } else {
415 text.as_ref()[indices.start..indices.end]
416 .char_indices()
417 .find(|(_, c)| c.is_whitespace())
418 .map(|(i, _)| indices.start + i)
419 .unwrap_or(start)
420 };
421 (start, end)
422 }
423 Self::GraphemesUnicode => (indices.start, indices.end),
424 Self::None => unreachable!(),
425 };
426
427 if start >= end {
428 None
429 } else {
430 Some(Range { start, end })
431 }
432 }
433}
434
435pub static TWO_PLUS_NEWLINE_REGEX: LazyLock<Regex> =
436 LazyLock::new(|| Regex::new(r"\n{2,}").unwrap());
437pub static SINGLE_NEWLINE_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(r"\n").unwrap());
438
439#[inline]
440pub fn split_text(text: &str) -> Vec<String> {
441 match TextSplitter::new().split_text(text) {
442 Some(splits) => splits
443 .iter()
444 .map(|split| split.text().to_string())
445 .collect(),
446 None => vec![],
447 }
448}