use std::ops::Range;
use itertools::Itertools;
use memchr::memchr2_iter;
use crate::{
splitter::{SemanticLevel, Splitter},
ChunkConfig, ChunkSizer,
};
use super::{fallback::GRAPHEME_SEGMENTER, ChunkCharIndex};
#[derive(Debug)]
pub struct TextSplitter<Sizer>
where
Sizer: ChunkSizer,
{
chunk_config: ChunkConfig<Sizer>,
}
impl<Sizer> TextSplitter<Sizer>
where
Sizer: ChunkSizer,
{
#[must_use]
pub fn new(chunk_config: impl Into<ChunkConfig<Sizer>>) -> Self {
Self {
chunk_config: chunk_config.into(),
}
}
pub fn chunks<'splitter, 'text: 'splitter>(
&'splitter self,
text: &'text str,
) -> impl Iterator<Item = &'text str> + 'splitter {
Splitter::<_>::chunks(self, text)
}
pub fn chunk_indices<'splitter, 'text: 'splitter>(
&'splitter self,
text: &'text str,
) -> impl Iterator<Item = (usize, &'text str)> + 'splitter {
Splitter::<_>::chunk_indices(self, text)
}
pub fn chunk_char_indices<'splitter, 'text: 'splitter>(
&'splitter self,
text: &'text str,
) -> impl Iterator<Item = ChunkCharIndex<'text>> + 'splitter {
Splitter::<_>::chunk_char_indices(self, text)
}
}
impl<Sizer> Splitter<Sizer> for TextSplitter<Sizer>
where
Sizer: ChunkSizer,
{
type Level = LineBreaks;
fn chunk_config(&self) -> &ChunkConfig<Sizer> {
&self.chunk_config
}
fn parse(&self, text: &str) -> Vec<(Self::Level, Range<usize>)> {
memchr2_iter(b'\n', b'\r', text.as_bytes())
.map(|i| i..i + 1)
.coalesce(|a, b| {
if a.end == b.start {
Ok(a.start..b.end)
} else {
Err((a, b))
}
})
.map(|range| {
let level = GRAPHEME_SEGMENTER
.segment_str(text.get(range.start..range.end).unwrap())
.tuple_windows::<(usize, usize)>()
.count();
(
match level {
0 => unreachable!("regex should always match at least one newline"),
n => LineBreaks(n),
},
range,
)
})
.collect()
}
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, Ord, PartialOrd)]
pub struct LineBreaks(usize);
impl SemanticLevel for LineBreaks {}
#[cfg(test)]
mod tests {
use std::cmp::min;
use fake::{Fake, Faker};
use crate::{splitter::SemanticSplitRanges, ChunkCharIndex};
use super::*;
#[test]
fn returns_one_chunk_if_text_is_shorter_than_max_chunk_size() {
let text = Faker.fake::<String>();
let chunks = TextSplitter::new(ChunkConfig::new(text.chars().count()).with_trim(false))
.chunks(&text)
.collect::<Vec<_>>();
assert_eq!(vec![&text], chunks);
}
#[test]
fn returns_two_chunks_if_text_is_longer_than_max_chunk_size() {
let text1 = Faker.fake::<String>();
let text2 = Faker.fake::<String>();
let text = format!("{text1}{text2}");
let max_chunk_size = text.chars().count() / 2 + 1;
let chunks = TextSplitter::new(ChunkConfig::new(max_chunk_size).with_trim(false))
.chunks(&text)
.collect::<Vec<_>>();
assert!(chunks.iter().all(|c| c.chars().count() <= max_chunk_size));
let len = min(text1.len(), chunks[0].len());
assert_eq!(text1[..len], chunks[0][..len]);
let len = min(text2.len(), chunks[1].len());
assert_eq!(
text2[(text2.len() - len)..],
chunks[1][chunks[1].len() - len..]
);
assert_eq!(chunks.join(""), text);
}
#[test]
fn empty_string() {
let text = "";
let chunks = TextSplitter::new(ChunkConfig::new(100).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert!(chunks.is_empty());
}
#[test]
fn can_handle_unicode_characters() {
let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(vec!["é", "é"], chunks);
}
struct Str;
impl ChunkSizer for Str {
fn size(&self, chunk: &str) -> usize {
chunk.len()
}
}
#[test]
fn custom_len_function() {
let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(2).with_sizer(Str).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(vec!["é", "é"], chunks);
}
#[test]
fn handles_char_bigger_than_len() {
let text = "éé"; let chunks = TextSplitter::new(ChunkConfig::new(1).with_sizer(Str).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(vec!["é", "é"], chunks);
}
#[test]
fn chunk_by_graphemes() {
let text = "a̐éö̲\r\n";
let chunks = TextSplitter::new(ChunkConfig::new(3).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(vec!["a̐é", "ö̲", "\r\n"], chunks);
}
#[test]
fn trim_char_indices() {
let text = " a b ";
let chunks = TextSplitter::new(1).chunk_indices(text).collect::<Vec<_>>();
assert_eq!(vec![(1, "a"), (3, "b")], chunks);
}
#[test]
fn chunk_char_indices() {
let text = " a b ";
let chunks = TextSplitter::new(1)
.chunk_char_indices(text)
.collect::<Vec<_>>();
assert_eq!(
vec![
ChunkCharIndex {
chunk: "a",
byte_offset: 1,
char_offset: 1
},
ChunkCharIndex {
chunk: "b",
byte_offset: 3,
char_offset: 3,
},
],
chunks
);
}
#[test]
fn graphemes_fallback_to_chars() {
let text = "a̐éö̲\r\n";
let chunks = TextSplitter::new(ChunkConfig::new(1).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(
vec!["a", "\u{310}", "é", "ö", "\u{332}", "\r", "\n"],
chunks
);
}
#[test]
fn trim_grapheme_indices() {
let text = "\r\na̐éö̲\r\n";
let chunks = TextSplitter::new(3).chunk_indices(text).collect::<Vec<_>>();
assert_eq!(vec![(2, "a̐é"), (7, "ö̲")], chunks);
}
#[test]
fn grapheme_char_indices() {
let text = "\r\na̐éö̲\r\n";
let chunks = TextSplitter::new(3)
.chunk_char_indices(text)
.collect::<Vec<_>>();
assert_eq!(
vec![
ChunkCharIndex {
chunk: "a̐é",
byte_offset: 2,
char_offset: 2
},
ChunkCharIndex {
chunk: "ö̲",
byte_offset: 7,
char_offset: 5
}
],
chunks
);
}
#[test]
fn chunk_by_words() {
let text = "The quick (\"brown\") fox can't jump 32.3 feet, right?";
let chunks = TextSplitter::new(ChunkConfig::new(10).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(
vec![
"The quick ",
"(\"brown\") ",
"fox can't ",
"jump 32.3 ",
"feet, ",
"right?"
],
chunks
);
}
#[test]
fn words_fallback_to_graphemes() {
let text = "Thé quick\r\n";
let chunks = TextSplitter::new(ChunkConfig::new(2).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(vec!["Th", "é ", "qu", "ic", "k", "\r\n"], chunks);
}
#[test]
fn trim_word_indices() {
let text = "Some text from a document";
let chunks = TextSplitter::new(10)
.chunk_indices(text)
.collect::<Vec<_>>();
assert_eq!(
vec![(0, "Some text"), (10, "from a"), (17, "document")],
chunks
);
}
#[test]
fn chunk_by_sentences() {
let text = "Mr. Fox jumped. [...] The dog was too lazy.";
let chunks = TextSplitter::new(ChunkConfig::new(21).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(
vec!["Mr. Fox jumped. ", "[...] ", "The dog was too lazy."],
chunks
);
}
#[test]
fn sentences_falls_back_to_words() {
let text = "Mr. Fox jumped. [...] The dog was too lazy.";
let chunks = TextSplitter::new(ChunkConfig::new(16).with_trim(false))
.chunks(text)
.collect::<Vec<_>>();
assert_eq!(
vec!["Mr. Fox jumped. ", "[...] ", "The dog was too ", "lazy."],
chunks
);
}
#[test]
fn trim_sentence_indices() {
let text = "Some text. From a document.";
let chunks = TextSplitter::new(10)
.chunk_indices(text)
.collect::<Vec<_>>();
assert_eq!(
vec![(0, "Some text."), (11, "From a"), (18, "document.")],
chunks
);
}
#[test]
fn trim_paragraph_indices() {
let text = "Some text\n\nfrom a\ndocument";
let chunks = TextSplitter::new(10)
.chunk_indices(text)
.collect::<Vec<_>>();
assert_eq!(
vec![(0, "Some text"), (11, "from a"), (18, "document")],
chunks
);
}
#[test]
fn correctly_determines_newlines() {
let text = "\r\n\r\ntext\n\n\ntext2";
let splitter = TextSplitter::new(10);
let linebreaks = SemanticSplitRanges::new(splitter.parse(text));
assert_eq!(
vec![(LineBreaks(2), 0..4), (LineBreaks(3), 8..11)],
linebreaks.ranges
);
}
}