1use std::ops::Range;
7
8use itertools::Itertools;
9use memchr::memchr2_iter;
10
11use crate::{
12 splitter::{SemanticLevel, Splitter},
13 ChunkConfig, ChunkSizer,
14};
15
16use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
17
18#[derive(Debug)]
22pub struct TextSplitter<Sizer>
23where
24 Sizer: ChunkSizer,
25{
26 chunk_config: ChunkConfig<Sizer>,
28}
29
30impl<Sizer> TextSplitter<Sizer>
31where
32 Sizer: ChunkSizer,
33{
34 #[must_use]
43 pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
44 Self {
45 chunk_config: chunk_config.into(),
46 }
47 }
48
49 pub fn chunks<'splitter, 'text: 'splitter>(
81 &'splitter self,
82 text: &'text str,
83 ) -> impl Iterator<Item = &'text str> + 'splitter {
84 Splitter::<_>::chunks(self, text)
85 }
86
87 pub fn chunk_indices<'splitter, 'text: 'splitter>(
102 &'splitter self,
103 text: &'text str,
104 ) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
105 Splitter::<_>::chunk_indices(self, text)
106 }
107
108 pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
127 &'splitter self,
128 text: &'text str,
129 ) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
130 Splitter::<_>::chunk_char_indices(self, text)
131 }
132}
133
134impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
135where
136 Sizer: ChunkSizer,
137{
138 type Level = LineBreaks;
139
140 fn chunk_config(&self) -> &ChunkConfig<Sizer> {
141 &self.chunk_config
142 }
143
144 fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
145 memchr2_iter(b'\n', b'\r', text.as_bytes())
146 .map(|i| i..i + 1)
147 .coalesce(|a, b| {
148 if a.end == b.start {
149 Ok(a.start..b.end)
150 } else {
151 Err((a, b))
152 }
153 })
154 .map(|range| {
155 let level = GRAPHEME_SEGMENTER
156 .segment_str(text.get(range.start..range.end).unwrap())
157 .tuple_windows::<(usize, usize)>()
158 .count();
159 (
160 match level {
161 0 => unreachable!("regex should always match at least one newline"),
162 n => LineBreaks(n),
163 },
164 range,
165 )
166 })
167 .collect()
168 }
169}
170
171#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
177pub struct LineBreaks(usize);
178
179impl SemanticLevel for LineBreaks {}
180
181#[cfg(test)]
182mod tests {
183 use std::cmp::min;
184
185 use fake::{Fake, Faker};
186
187 use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
188
189 use super::*;
190
191 #[test]
192 fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
193 let text = Faker.fake::<String>();
194 let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
195 .chunks(&text)
196 .collect::<Vec<_>>();
197
198 assert_eq!(vec![&text], chunks);
199 }
200
201 #[test]
202 fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
203 let text1 = Faker.fake::<String>();
204 let text2 = Faker.fake::<String>();
205 let text = format!("{text1}{text2}");
206 let max_chunk_size = text.chars().count() / 2 + 1;
208 let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
209 .chunks(&text)
210 .collect::<Vec<_>>();
211
212 assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
213
214 let len = min(text1.len(), chunks[0].len());
216 assert_eq!(text1[..len], chunks[0][..len]);
217 let len = min(text2.len(), chunks[1].len());
219 assert_eq!(
220 text2[(text2.len() - len)..],
221 chunks[1][chunks[1].len() - len..]
222 );
223
224 assert_eq!(chunks.join(""), text);
225 }
226
227 #[test]
228 fn empty_string() {
229 let text = "";
230 let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
231 .chunks(text)
232 .collect::<Vec<_>>();
233
234 assert!(chunks.is_empty());
235 }
236
237 #[test]
238 fn can_handle_unicode_characters() {
239 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
241 .chunks(text)
242 .collect::<Vec<_>>();
243 assert_eq!(vec!["é", "é"], chunks);
244 }
245
246 struct Str;
248
249 impl ChunkSizer for Str {
250 fn size(&self, chunk: &str) -> usize {
251 chunk.len()
252 }
253 }
254
255 #[test]
256 fn custom_len_function() {
257 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
259 .chunks(text)
260 .collect::<Vec<_>>();
261
262 assert_eq!(vec!["é", "é"], chunks);
263 }
264
265 #[test]
266 fn handles_char_bigger_than_len() {
267 let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
269 .chunks(text)
270 .collect::<Vec<_>>();
271
272 assert_eq!(vec!["é", "é"], chunks);
274 }
275
276 #[test]
277 fn chunk_by_graphemes() {
278 let text = "a̐éö̲\r\n";
279 let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
280 .chunks(text)
281 .collect::<Vec<_>>();
282
283 assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
285 }
286
287 #[test]
288 fn trim_char_indices() {
289 let text = " a b ";
290 let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
291
292 assert_eq!(vec![(1, "a"), (3, "b")], chunks);
293 }
294
295 #[test]
296 fn chunk_char_indices() {
297 let text = " a b ";
298 let chunks = TextSplitter::new(1)
299 .chunk_char_indices(text)
300 .collect::<Vec<_>>();
301
302 assert_eq!(
303 vec![
304 ChunkCharIndex {
305 chunk: "a",
306 byte_offset: 1,
307 char_offset: 1
308 },
309 ChunkCharIndex {
310 chunk: "b",
311 byte_offset: 3,
312 char_offset: 3,
313 },
314 ],
315 chunks
316 );
317 }
318
319 #[test]
320 fn graphemes_fallback_to_chars() {
321 let text = "a̐éö̲\r\n";
322 let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
323 .chunks(text)
324 .collect::<Vec<_>>();
325 assert_eq!(
326 vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
327 chunks
328 );
329 }
330
331 #[test]
332 fn trim_grapheme_indices() {
333 let text = "\r\na̐éö̲\r\n";
334 let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
335
336 assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
337 }
338
339 #[test]
340 fn grapheme_char_indices() {
341 let text = "\r\na̐éö̲\r\n";
342 let chunks = TextSplitter::new(3)
343 .chunk_char_indices(text)
344 .collect::<Vec<_>>();
345
346 assert_eq!(
347 vec![
348 ChunkCharIndex {
349 chunk: "a̐é",
350 byte_offset: 2,
351 char_offset: 2
352 },
353 ChunkCharIndex {
354 chunk: "ö̲",
355 byte_offset: 7,
356 char_offset: 5
357 }
358 ],
359 chunks
360 );
361 }
362
363 #[test]
364 fn chunk_by_words() {
365 let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
366 let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
367 .chunks(text)
368 .collect::<Vec<_>>();
369
370 assert_eq!(
371 vec![
372 "The quick ",
373 "(\"brown\") ",
374 "fox can't ",
375 "jump 32.3 ",
376 "feet, ",
377 "right?"
378 ],
379 chunks
380 );
381 }
382
383 #[test]
384 fn words_fallback_to_graphemes() {
385 let text = "Thé quick\r\n";
386 let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
387 .chunks(text)
388 .collect::<Vec<_>>();
389 assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
390 }
391
392 #[test]
393 fn trim_word_indices() {
394 let text = "Some text from a document";
395 let chunks = TextSplitter::new(10)
396 .chunk_indices(text)
397 .collect::<Vec<_>>();
398 assert_eq!(
399 vec![(0, "Some text"), (10, "from a"), (17, "document")],
400 chunks
401 );
402 }
403
404 #[test]
405 fn chunk_by_sentences() {
406 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
407 let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
408 .chunks(text)
409 .collect::<Vec<_>>();
410 assert_eq!(
411 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
412 chunks
413 );
414 }
415
416 #[test]
417 fn sentences_falls_back_to_words() {
418 let text = "Mr. Fox jumped. [...] The dog was too lazy.";
419 let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
420 .chunks(text)
421 .collect::<Vec<_>>();
422 assert_eq!(
423 vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
424 chunks
425 );
426 }
427
428 #[test]
429 fn trim_sentence_indices() {
430 let text = "Some text. From a document.";
431 let chunks = TextSplitter::new(10)
432 .chunk_indices(text)
433 .collect::<Vec<_>>();
434 assert_eq!(
435 vec![(0, "Some text."), (11, "From a"), (18, "document.")],
436 chunks
437 );
438 }
439
440 #[test]
441 fn trim_paragraph_indices() {
442 let text = "Some text\n\nfrom a\ndocument";
443 let chunks = TextSplitter::new(10)
444 .chunk_indices(text)
445 .collect::<Vec<_>>();
446 assert_eq!(
447 vec![(0, "Some text"), (11, "from a"), (18, "document")],
448 chunks
449 );
450 }
451
452 #[test]
453 fn correctly_determines_newlines() {
454 let text = "\r\n\r\ntext\n\n\ntext2";
455 let splitter = TextSplitter::new(10);
456 let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
457 assert_eq!(
458 vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
459 linebreaks.ranges
460 );
461 }
462}