use std::collections::HashMap;
use cognis_core::documents::Document;
use regex::Regex;
use serde_json::Value;
use super::TextSplitter;
#[derive(Debug, Clone, Default)]
pub enum SentencePattern {
#[default]
Default,
Simple,
Unicode,
Custom(String),
}
pub struct SentenceTextSplitter {
pub chunk_size: usize,
pub chunk_overlap: usize,
pub min_chunk_size: Option<usize>,
pub separator_pattern: SentencePattern,
pub strip_whitespace: bool,
pub preserve_paragraphs: bool,
}
impl Default for SentenceTextSplitter {
fn default() -> Self {
Self {
chunk_size: 1000,
chunk_overlap: 0,
min_chunk_size: None,
separator_pattern: SentencePattern::Default,
strip_whitespace: true,
preserve_paragraphs: false,
}
}
}
impl SentenceTextSplitter {
pub fn new() -> Self {
Self::default()
}
pub fn builder() -> SentenceTextSplitterBuilder {
SentenceTextSplitterBuilder::default()
}
pub fn split_into_sentences(&self, text: &str) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
match &self.separator_pattern {
SentencePattern::Default => self.split_sentences_default(text),
SentencePattern::Simple => self.split_sentences_simple(text),
SentencePattern::Unicode => self.split_sentences_unicode(text),
SentencePattern::Custom(pat) => self.split_sentences_custom(text, pat),
}
}
fn split_sentences_default(&self, text: &str) -> Vec<String> {
let abbrevs = [
"Mr.", "Mrs.", "Ms.", "Dr.", "Prof.", "Sr.", "Jr.", "St.", "Gen.", "Gov.", "Sgt.",
"Cpl.", "Pvt.", "Lt.", "Col.", "Capt.", "Maj.", "Rev.", "Hon.", "Pres.", "Inc.",
"Corp.", "Ltd.", "Co.", "vs.", "etc.", "approx.", "dept.", "est.", "vol.", "fig.",
"no.",
];
let chars: Vec<char> = text.chars().collect();
let len = chars.len();
let mut sentences: Vec<String> = Vec::new();
let mut start = 0;
let mut i = 0;
while i < len {
let ch = chars[i];
if (ch == '.' || ch == '!' || ch == '?') && i + 1 < len {
let next = chars[i + 1];
let is_boundary = next.is_whitespace() || next == '\n';
if !is_boundary {
i += 1;
continue;
}
let mut end_punct = i;
while end_punct + 1 < len
&& (chars[end_punct + 1] == '.'
|| chars[end_punct + 1] == '!'
|| chars[end_punct + 1] == '?')
{
end_punct += 1;
}
if ch == '.' {
let preceding: String = chars[start..=end_punct].iter().collect();
let trimmed = preceding.trim_start();
let last_word = trimmed.split_whitespace().last().unwrap_or("");
let is_abbrev = abbrevs.iter().any(|a| last_word.eq_ignore_ascii_case(a));
let is_single_letter_abbrev = if end_punct >= 1 && chars[end_punct] == '.' {
let before_dot = end_punct.checked_sub(1).map(|j| chars[j]);
matches!(before_dot, Some(c) if c.is_ascii_uppercase())
&& (end_punct < 2 || chars[end_punct - 2] == '.')
} else {
false
};
let is_decimal = if end_punct >= 1 {
let before = chars[end_punct - 1];
let after_ws_pos = end_punct + 1;
let after_non_ws = if after_ws_pos < len {
let rest: String = chars[after_ws_pos..].iter().collect();
let first_non_ws = rest.trim_start().chars().next();
matches!(first_non_ws, Some(c) if c.is_ascii_digit())
&& before.is_ascii_digit()
} else {
false
};
before.is_ascii_digit() && after_non_ws
} else {
false
};
let word_before: String = chars[start..=end_punct]
.iter()
.collect::<String>()
.split_whitespace()
.last()
.unwrap_or("")
.to_string();
let is_url = word_before.contains("://") || word_before.starts_with("www.");
let is_email = word_before.contains('@') && word_before.contains('.');
if is_abbrev || is_single_letter_abbrev || is_decimal || is_url || is_email {
i = end_punct + 1;
continue;
}
}
let sentence: String = chars[start..=end_punct].iter().collect();
let sentence = if self.strip_whitespace {
sentence.trim().to_string()
} else {
sentence
};
if !sentence.is_empty() {
sentences.push(sentence);
}
start = end_punct + 1;
while start < len && chars[start].is_whitespace() {
start += 1;
}
i = start;
continue;
}
i += 1;
}
if start < len {
let remaining: String = chars[start..].iter().collect();
let remaining = if self.strip_whitespace {
remaining.trim().to_string()
} else {
remaining
};
if !remaining.is_empty() {
sentences.push(remaining);
}
}
sentences
}
fn split_sentences_simple(&self, text: &str) -> Vec<String> {
let re = Regex::new(r"([.!?])\s+").unwrap();
let mut sentences = Vec::new();
let mut last = 0;
for mat in re.find_iter(text) {
let end = mat.start() + 1; let sentence = &text[last..end];
let sentence = if self.strip_whitespace {
sentence.trim()
} else {
sentence
};
if !sentence.is_empty() {
sentences.push(sentence.to_string());
}
last = mat.end();
}
if last < text.len() {
let remaining = if self.strip_whitespace {
text[last..].trim()
} else {
&text[last..]
};
if !remaining.is_empty() {
sentences.push(remaining.to_string());
}
}
sentences
}
fn split_sentences_unicode(&self, text: &str) -> Vec<String> {
let re = Regex::new(r"([.!?\u{3002}\u{FF01}\u{FF1F}\u{2026}])\s*").unwrap();
let mut sentences = Vec::new();
let mut last = 0;
for mat in re.find_iter(text) {
let end = mat.start() + mat.as_str().trim_end().len();
let sentence = &text[last..end];
let sentence = if self.strip_whitespace {
sentence.trim()
} else {
sentence
};
if !sentence.is_empty() {
sentences.push(sentence.to_string());
}
last = mat.end();
}
if last < text.len() {
let remaining = if self.strip_whitespace {
text[last..].trim()
} else {
&text[last..]
};
if !remaining.is_empty() {
sentences.push(remaining.to_string());
}
}
sentences
}
fn split_sentences_custom(&self, text: &str, pattern: &str) -> Vec<String> {
let re = match Regex::new(pattern) {
Ok(r) => r,
Err(_) => return vec![text.to_string()],
};
let mut sentences = Vec::new();
let mut last = 0;
for mat in re.find_iter(text) {
let sentence = &text[last..mat.end()];
let sentence = if self.strip_whitespace {
sentence.trim()
} else {
sentence
};
if !sentence.is_empty() {
sentences.push(sentence.to_string());
}
last = mat.end();
}
if last < text.len() {
let remaining = if self.strip_whitespace {
text[last..].trim()
} else {
&text[last..]
};
if !remaining.is_empty() {
sentences.push(remaining.to_string());
}
}
sentences
}
pub fn split_text(&self, text: &str) -> Vec<String> {
if text.is_empty() {
return Vec::new();
}
if self.preserve_paragraphs {
return self.split_preserving_paragraphs(text);
}
let sentences = self.split_into_sentences(text);
self.merge_sentences_into_chunks(&sentences)
}
pub fn split_documents(&self, documents: &[Document]) -> Vec<Document> {
let texts: Vec<&str> = documents.iter().map(|d| d.page_content.as_str()).collect();
let metadatas: Vec<HashMap<String, Value>> =
documents.iter().map(|d| d.metadata.clone()).collect();
self.create_documents(
&texts.iter().map(|s| s.to_string()).collect::<Vec<_>>(),
Some(&metadatas),
)
}
pub fn create_documents(
&self,
texts: &[String],
metadatas: Option<&[HashMap<String, Value>]>,
) -> Vec<Document> {
let mut docs = Vec::new();
for (i, text) in texts.iter().enumerate() {
let metadata = metadatas
.and_then(|m| m.get(i))
.cloned()
.unwrap_or_default();
for chunk in self.split_text(text) {
docs.push(Document::new(chunk).with_metadata(metadata.clone()));
}
}
docs
}
fn merge_sentences_into_chunks(&self, sentences: &[String]) -> Vec<String> {
if sentences.is_empty() {
return Vec::new();
}
let mut chunks: Vec<String> = Vec::new();
let mut current_sentences: Vec<&str> = Vec::new();
let mut current_len: usize = 0;
for sentence in sentences {
let s_len = sentence.len();
let added = if current_sentences.is_empty() {
s_len
} else {
s_len + 1 };
if current_len + added > self.chunk_size && !current_sentences.is_empty() {
let chunk = current_sentences.join(" ");
let chunk = if self.strip_whitespace {
chunk.trim().to_string()
} else {
chunk
};
if !chunk.is_empty() {
chunks.push(chunk);
}
if self.chunk_overlap > 0 && current_sentences.len() > self.chunk_overlap {
let overlap_start = current_sentences.len() - self.chunk_overlap;
let overlap: Vec<&str> = current_sentences[overlap_start..].to_vec();
current_len = overlap.iter().map(|s| s.len()).sum::<usize>()
+ overlap.len().saturating_sub(1);
current_sentences = overlap;
} else if self.chunk_overlap == 0 {
current_sentences.clear();
current_len = 0;
}
}
current_sentences.push(sentence);
current_len = if current_sentences.len() == 1 {
s_len
} else {
current_len + s_len + 1
};
}
if !current_sentences.is_empty() {
let chunk = current_sentences.join(" ");
let chunk = if self.strip_whitespace {
chunk.trim().to_string()
} else {
chunk
};
if !chunk.is_empty() {
chunks.push(chunk);
}
}
if let Some(min_size) = self.min_chunk_size {
chunks = self.merge_small_chunks(chunks, min_size);
}
chunks
}
fn merge_small_chunks(&self, chunks: Vec<String>, min_size: usize) -> Vec<String> {
if chunks.is_empty() {
return chunks;
}
let mut merged: Vec<String> = Vec::new();
let mut accumulator = String::new();
for chunk in chunks {
if accumulator.is_empty() {
accumulator = chunk;
} else if accumulator.len() + 1 + chunk.len() <= self.chunk_size {
accumulator.push(' ');
accumulator.push_str(&chunk);
} else if accumulator.len() >= min_size {
merged.push(accumulator);
accumulator = chunk;
} else {
if accumulator.len() + 1 + chunk.len() <= self.chunk_size {
accumulator.push(' ');
accumulator.push_str(&chunk);
} else {
merged.push(accumulator);
accumulator = chunk;
}
}
}
if !accumulator.is_empty() {
if accumulator.len() < min_size && !merged.is_empty() {
let last = merged.last_mut().unwrap();
if last.len() + 1 + accumulator.len() <= self.chunk_size {
last.push(' ');
last.push_str(&accumulator);
} else {
merged.push(accumulator);
}
} else {
merged.push(accumulator);
}
}
merged
}
fn split_preserving_paragraphs(&self, text: &str) -> Vec<String> {
let paragraphs: Vec<&str> = text.split("\n\n").collect();
let mut all_sentences: Vec<String> = Vec::new();
for para in ¶graphs {
let trimmed = if self.strip_whitespace {
para.trim()
} else {
para
};
if trimmed.is_empty() {
continue;
}
if trimmed.len() <= self.chunk_size && !all_sentences.is_empty() {
all_sentences.push(trimmed.to_string());
} else if trimmed.len() <= self.chunk_size {
all_sentences.push(trimmed.to_string());
} else {
let para_sentences = self.split_into_sentences(trimmed);
all_sentences.extend(para_sentences);
}
}
self.merge_sentences_into_chunks(&all_sentences)
}
}
impl TextSplitter for SentenceTextSplitter {
fn split_text(&self, text: &str) -> Vec<String> {
SentenceTextSplitter::split_text(self, text)
}
fn chunk_size(&self) -> usize {
self.chunk_size
}
fn chunk_overlap(&self) -> usize {
self.chunk_overlap
}
}
pub struct SentenceTextSplitterBuilder {
chunk_size: usize,
chunk_overlap: usize,
min_chunk_size: Option<usize>,
separator_pattern: SentencePattern,
strip_whitespace: bool,
preserve_paragraphs: bool,
}
impl Default for SentenceTextSplitterBuilder {
fn default() -> Self {
Self {
chunk_size: 1000,
chunk_overlap: 0,
min_chunk_size: None,
separator_pattern: SentencePattern::Default,
strip_whitespace: true,
preserve_paragraphs: false,
}
}
}
impl SentenceTextSplitterBuilder {
pub fn chunk_size(mut self, size: usize) -> Self {
self.chunk_size = size;
self
}
pub fn chunk_overlap(mut self, overlap: usize) -> Self {
self.chunk_overlap = overlap;
self
}
pub fn min_chunk_size(mut self, size: usize) -> Self {
self.min_chunk_size = Some(size);
self
}
pub fn separator_pattern(mut self, pattern: SentencePattern) -> Self {
self.separator_pattern = pattern;
self
}
pub fn strip_whitespace(mut self, strip: bool) -> Self {
self.strip_whitespace = strip;
self
}
pub fn preserve_paragraphs(mut self, preserve: bool) -> Self {
self.preserve_paragraphs = preserve;
self
}
pub fn build(self) -> SentenceTextSplitter {
SentenceTextSplitter {
chunk_size: self.chunk_size,
chunk_overlap: self.chunk_overlap,
min_chunk_size: self.min_chunk_size,
separator_pattern: self.separator_pattern,
strip_whitespace: self.strip_whitespace,
preserve_paragraphs: self.preserve_paragraphs,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_sentence_splitting() {
let splitter = SentenceTextSplitter::builder().chunk_size(1000).build();
let sentences = splitter.split_into_sentences("Hello world. This is a test. How are you?");
assert_eq!(sentences.len(), 3, "Got {:?}", sentences);
assert_eq!(sentences[0], "Hello world.");
assert_eq!(sentences[1], "This is a test.");
assert_eq!(sentences[2], "How are you?");
}
#[test]
fn test_chunk_size_enforcement() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(40)
.chunk_overlap(0)
.build();
let text =
"First sentence here. Second sentence here. Third sentence here. Fourth sentence.";
let chunks = splitter.split_text(text);
assert!(
chunks.len() > 1,
"Expected multiple chunks, got {:?}",
chunks
);
for chunk in &chunks {
assert!(
chunk.len() <= 45, "Chunk exceeds size: {:?} (len {})",
chunk,
chunk.len()
);
}
}
#[test]
fn test_sentence_overlap() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(30)
.chunk_overlap(1)
.build();
let text = "Sentence one. Sentence two. Sentence three. Sentence four.";
let chunks = splitter.split_text(text);
assert!(
chunks.len() >= 2,
"Expected multiple chunks, got {:?}",
chunks
);
if chunks.len() >= 2 {
let first_sentences = splitter.split_into_sentences(&chunks[0]);
let second_sentences = splitter.split_into_sentences(&chunks[1]);
let last_of_first = first_sentences.last().unwrap();
let first_of_second = second_sentences.first().unwrap();
assert_eq!(
last_of_first, first_of_second,
"Expected 1-sentence overlap between chunks"
);
}
}
#[test]
fn test_abbreviation_mr_dr() {
let splitter = SentenceTextSplitter::builder().chunk_size(1000).build();
let text = "Mr. Smith went to Washington. Dr. Jones stayed home.";
let sentences = splitter.split_into_sentences(text);
assert_eq!(
sentences.len(),
2,
"Should not split on Mr. or Dr., got {:?}",
sentences
);
assert_eq!(sentences[0], "Mr. Smith went to Washington.");
assert_eq!(sentences[1], "Dr. Jones stayed home.");
}
#[test]
fn test_decimal_numbers_not_split() {
let splitter = SentenceTextSplitter::builder().chunk_size(1000).build();
let text = "The value is 3.14 approximately. Pi is important.";
let sentences = splitter.split_into_sentences(text);
assert_eq!(
sentences.len(),
2,
"Should not split on decimal point, got {:?}",
sentences
);
assert!(sentences[0].contains("3.14"));
}
#[test]
fn test_multiple_punctuation() {
let splitter = SentenceTextSplitter::builder().chunk_size(1000).build();
let text = "What a day!! It was incredible... Really.";
let sentences = splitter.split_into_sentences(text);
assert!(
sentences.len() >= 2,
"Should handle multiple punctuation, got {:?}",
sentences
);
}
#[test]
fn test_empty_text() {
let splitter = SentenceTextSplitter::new();
let chunks = splitter.split_text("");
assert!(chunks.is_empty());
let sentences = splitter.split_into_sentences("");
assert!(sentences.is_empty());
}
#[test]
fn test_single_sentence() {
let splitter = SentenceTextSplitter::builder().chunk_size(1000).build();
let text = "Just one sentence here.";
let chunks = splitter.split_text(text);
assert_eq!(chunks.len(), 1);
assert_eq!(chunks[0], "Just one sentence here.");
}
#[test]
fn test_paragraph_preservation() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(200)
.chunk_overlap(0)
.preserve_paragraphs(true)
.build();
let text = "First paragraph sentence one. Sentence two.\n\n\
Second paragraph sentence one. Sentence two.\n\n\
Third paragraph.";
let chunks = splitter.split_text(text);
assert!(!chunks.is_empty());
let has_full_para = chunks
.iter()
.any(|c| c.contains("First paragraph") && c.contains("Sentence two."));
assert!(
has_full_para,
"Expected paragraphs to be preserved, got {:?}",
chunks
);
}
#[test]
fn test_simple_pattern() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(1000)
.separator_pattern(SentencePattern::Simple)
.build();
let text = "Hello world. This is a test! Are you sure? Yes.";
let sentences = splitter.split_into_sentences(text);
assert_eq!(
sentences.len(),
4,
"Simple split on .!? got {:?}",
sentences
);
}
#[test]
fn test_custom_regex_pattern() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(1000)
.separator_pattern(SentencePattern::Custom(r"[;]\s*".to_string()))
.build();
let text = "part one; part two; part three";
let sentences = splitter.split_into_sentences(text);
assert_eq!(sentences.len(), 3, "Custom split on ; got {:?}", sentences);
}
#[test]
fn test_min_chunk_size_merging() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(100)
.chunk_overlap(0)
.min_chunk_size(30)
.build();
let text = "Hi. Ok. Sure. This is a longer sentence that has more content.";
let chunks = splitter.split_text(text);
for chunk in &chunks {
if chunks.len() > 1 {
assert!(
chunk.len() >= 10, "Chunk too small: {:?}",
chunk
);
}
}
assert!(!chunks.is_empty());
}
#[test]
fn test_split_documents_with_metadata() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(30)
.chunk_overlap(0)
.build();
let mut meta = HashMap::new();
meta.insert("source".to_string(), Value::String("doc.txt".to_string()));
let doc = Document::new("First sentence. Second sentence. Third sentence.")
.with_metadata(meta.clone());
let result = splitter.split_documents(&[doc]);
assert!(
result.len() >= 2,
"Expected multiple doc chunks, got {:?}",
result
);
for d in &result {
assert_eq!(
d.metadata.get("source"),
Some(&Value::String("doc.txt".to_string())),
"Metadata should be preserved"
);
}
}
#[test]
fn test_unicode_sentence_terminators() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(1000)
.separator_pattern(SentencePattern::Unicode)
.build();
let text = "First sentence\u{3002}Second sentence\u{3002}Third";
let sentences = splitter.split_into_sentences(text);
assert_eq!(
sentences.len(),
3,
"Unicode terminators should split, got {:?}",
sentences
);
}
#[test]
fn test_builder_pattern() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(500)
.chunk_overlap(2)
.min_chunk_size(50)
.separator_pattern(SentencePattern::Simple)
.strip_whitespace(false)
.preserve_paragraphs(true)
.build();
assert_eq!(splitter.chunk_size, 500);
assert_eq!(splitter.chunk_overlap, 2);
assert_eq!(splitter.min_chunk_size, Some(50));
assert!(!splitter.strip_whitespace);
assert!(splitter.preserve_paragraphs);
assert!(matches!(
splitter.separator_pattern,
SentencePattern::Simple
));
}
#[test]
fn test_long_single_sentence_exceeds_chunk_size() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(20)
.chunk_overlap(0)
.build();
let text =
"This is one very long sentence that clearly exceeds the chunk size limit by a lot.";
let chunks = splitter.split_text(text);
assert!(!chunks.is_empty(), "Should produce at least one chunk");
assert_eq!(
chunks.len(),
1,
"Single sentence should not be split mid-sentence"
);
assert_eq!(chunks[0], text);
}
#[test]
fn test_us_uk_abbreviation() {
let splitter = SentenceTextSplitter::builder().chunk_size(1000).build();
let text = "The U.S. is a country. The U.K. is also a country.";
let sentences = splitter.split_into_sentences(text);
assert_eq!(
sentences.len(),
2,
"Should not split on U.S. or U.K., got {:?}",
sentences
);
}
#[test]
fn test_create_documents_method() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(30)
.chunk_overlap(0)
.build();
let texts = vec![
"First text sentence one. Sentence two.".to_string(),
"Second text sentence one. Sentence two.".to_string(),
];
let mut meta1 = HashMap::new();
meta1.insert("idx".to_string(), Value::Number(0.into()));
let mut meta2 = HashMap::new();
meta2.insert("idx".to_string(), Value::Number(1.into()));
let metadatas = vec![meta1, meta2];
let docs = splitter.create_documents(&texts, Some(&metadatas));
assert!(docs.len() >= 2, "Should produce multiple documents");
let first_text_docs: Vec<_> = docs
.iter()
.filter(|d| d.metadata.get("idx") == Some(&Value::Number(0.into())))
.collect();
assert!(!first_text_docs.is_empty());
}
#[test]
fn test_text_splitter_trait() {
let splitter = SentenceTextSplitter::builder()
.chunk_size(50)
.chunk_overlap(0)
.build();
let trait_obj: &dyn TextSplitter = &splitter;
assert_eq!(trait_obj.chunk_size(), 50);
assert_eq!(trait_obj.chunk_overlap(), 0);
let chunks = trait_obj.split_text("Hello world. Goodbye world.");
assert!(!chunks.is_empty());
}
}