use std::sync::Arc;
use crate::character::validate_chunk_config;
use crate::chunk::{measured_spans, TextChunk, TextChunkIter, TextSpan};
use crate::error::ChunkError;
use crate::merge::merge_spans;
use crate::sizing::{CharSizer, ChunkConfig, ChunkSizer, FunctionSizer};
#[derive(Clone)]
pub struct SemchunkSplitter<S = CharSizer> {
pub(crate) config: ChunkConfig<S>,
pub(crate) memoize: bool,
pub(crate) strict_mode: bool,
pub(crate) strip_whitespace: bool,
length_fn: crate::LengthFn,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum SplitLevel {
Newlines, Tabs, SentenceTerminators, ClauseSeparators, BracketBoundaries, QuoteBoundaries, SentenceInterrupters, SymbolBoundaries, Whitespace, WordJoiners, Characters, }
const SPLIT_LEVELS: &[SplitLevel] = &[
SplitLevel::Newlines,
SplitLevel::Tabs,
SplitLevel::SentenceTerminators,
SplitLevel::ClauseSeparators,
SplitLevel::BracketBoundaries,
SplitLevel::QuoteBoundaries,
SplitLevel::SentenceInterrupters,
SplitLevel::SymbolBoundaries,
SplitLevel::Whitespace,
SplitLevel::WordJoiners,
SplitLevel::Characters,
];
const SENTENCE_TERMINATORS: &[&str] = &[
"... ", "...", "… ", "…", ". ", "? ", "! ", "* ", ".", "?", "!", "*",
];
const STRICT_SENTENCE_TERMINATORS: &[&str] = &[
"...\n", "... ", "...", "…\n", "… ", "…", ".\n", "?\n", "!\n", "*\n", ". ", "? ", "! ", "* ",
".", "?", "!", "*",
];
const CLAUSE_SEPARATORS: &[&str] = &[", ", "; ", ",", ";"];
const STRICT_CLAUSE_SEPARATORS: &[&str] = &[",\n", ";\n", ", ", "; ", ",", ";"];
const BRACKET_BOUNDARIES: &[&str] = &[
") ", "] ", "} ", "( ", "[ ", "{ ", ")", "]", "}", "(", "[", "{",
];
const STRICT_BRACKET_BOUNDARIES: &[&str] = &[
")\n", "]\n", "}\n", "( ", "[ ", "{ ", ") ", "] ", "} ", ")", "]", "}", "(", "[", "{",
];
const QUOTE_BOUNDARIES: &[&str] = &["\" ", "' ", "\"", "'"];
const STRICT_QUOTE_BOUNDARIES: &[&str] = &["\"\n", "'\n", "\" ", "' ", "\"", "'"];
const SENTENCE_INTERRUPTERS: &[&str] = &[" -- ", "-- ", ": ", ":", "--"];
const STRICT_SENTENCE_INTERRUPTERS: &[&str] = &[
"...\n", "... ", "...", " -- ", "-- ", ": ", ":\n", ":", "--",
];
const SYMBOL_BOUNDARIES: &[&str] = &[
"@ ", "# ", "$ ", "% ", "= ", "+ ", "| ", "@", "#", "$", "%", "=", "+", "|",
];
const STRICT_SYMBOL_BOUNDARIES: &[&str] = &[
"@\n", "#\n", "$\n", "%\n", "=\n", "+\n", "|\n", "@ ", "# ", "$ ", "% ", "= ", "+ ", "| ", "@",
"#", "$", "%", "=", "+", "|",
];
const WORD_JOINERS: &[&str] = &[" / ", " \\ ", " & ", " - ", "/", "\\", "&", "-"];
const STRICT_WORD_JOINERS: &[&str] = &[
" /\n", " \\\n", " &\n", " -\n", " / ", " \\ ", " & ", " - ", "/", "\\", "&", "-",
];
impl SemchunkSplitter<CharSizer> {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
config: ChunkConfig::new(chunk_size, chunk_overlap, CharSizer),
memoize: true,
strict_mode: false,
strip_whitespace: true,
length_fn: Arc::new(crate::char_len),
}
}
pub fn builder() -> SemchunkSplitterBuilder<CharSizer> {
SemchunkSplitterBuilder::default()
}
}
impl<S> SemchunkSplitter<S>
where
S: ChunkSizer,
{
pub fn split_text(&self, text: &str) -> Vec<String> {
self.chunks(text)
.map(|chunk| chunk.text.to_string())
.collect()
}
pub fn chunks<'a>(&'a self, text: &'a str) -> impl Iterator<Item = TextChunk<'a>> + 'a {
let len_fn = self.length_fn.as_ref();
TextChunkIter::new(
text,
measured_spans(text, self.chunk_spans(text, len_fn).into_iter(), len_fn),
)
}
pub fn split_chunks<'a>(&'a self, text: &'a str) -> Vec<TextChunk<'a>> {
self.chunks(text).collect()
}
fn chunk_spans(&self, text: &str, len_fn: &dyn Fn(&str) -> usize) -> Vec<TextSpan> {
let initial = if self.strip_whitespace {
match TextSpan::new(0, text.len()).trim(text) {
Some(span) => span,
None => return Vec::new(),
}
} else {
TextSpan::new(0, text.len())
};
let work_text = initial.text(text);
if work_text.is_empty() {
return Vec::new();
}
if len_fn(work_text) <= self.config.chunk_size {
return vec![initial];
}
self.split_recursive(text, work_text, initial.start, 0, len_fn)
}
fn split_recursive(
&self,
input: &str,
text: &str,
base_offset: usize,
level_idx: usize,
length_fn: &dyn Fn(&str) -> usize,
) -> Vec<TextSpan> {
if text.is_empty() {
return Vec::new();
}
if length_fn(text) <= self.config.chunk_size {
return vec![TextSpan::new(base_offset, base_offset + text.len())];
}
if level_idx >= SPLIT_LEVELS.len() {
return vec![TextSpan::new(base_offset, base_offset + text.len())];
}
let level = SPLIT_LEVELS[level_idx];
if level == SplitLevel::Characters {
let chars: Vec<TextSpan> = text
.char_indices()
.map(|(start, ch)| {
TextSpan::new(base_offset + start, base_offset + start + ch.len_utf8())
})
.collect();
return merge_spans(
input,
&chars,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
length_fn,
);
}
let delimiter = match self.find_delimiter(text, level) {
Some(d) => d,
None => {
return self.split_recursive(input, text, base_offset, level_idx + 1, length_fn);
}
};
let is_whitespace_delim = matches!(
level,
SplitLevel::Newlines | SplitLevel::Tabs | SplitLevel::Whitespace
);
let splits = split_semchunk_spans(text, base_offset, &delimiter, is_whitespace_delim);
if splits.is_empty() {
return self.split_recursive(input, text, base_offset, level_idx + 1, length_fn);
}
if splits.len() == 1 {
return self.split_recursive(input, text, base_offset, level_idx + 1, length_fn);
}
let merged = merge_spans(
input,
&splits,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
length_fn,
);
let mut result: Vec<TextSpan> = Vec::new();
for chunk in merged {
if length_fn(chunk.text(input)) > self.config.chunk_size {
let sub_chunks = self.split_recursive(
input,
chunk.text(input),
chunk.start,
level_idx + 1,
length_fn,
);
result.extend(sub_chunks);
} else {
result.push(chunk);
}
}
result
}
fn find_delimiter(&self, text: &str, level: SplitLevel) -> Option<String> {
match level {
SplitLevel::Newlines => find_longest_sequence(text, '\n'),
SplitLevel::Tabs => find_longest_sequence(text, '\t'),
SplitLevel::Whitespace => find_longest_space_sequence(text),
SplitLevel::SentenceTerminators => find_best_delimiter(
text,
if self.strict_mode {
STRICT_SENTENCE_TERMINATORS
} else {
SENTENCE_TERMINATORS
},
),
SplitLevel::ClauseSeparators => find_best_delimiter(
text,
if self.strict_mode {
STRICT_CLAUSE_SEPARATORS
} else {
CLAUSE_SEPARATORS
},
),
SplitLevel::BracketBoundaries => find_best_delimiter(
text,
if self.strict_mode {
STRICT_BRACKET_BOUNDARIES
} else {
BRACKET_BOUNDARIES
},
),
SplitLevel::QuoteBoundaries => find_best_delimiter(
text,
if self.strict_mode {
STRICT_QUOTE_BOUNDARIES
} else {
QUOTE_BOUNDARIES
},
),
SplitLevel::SentenceInterrupters => find_best_delimiter(
text,
if self.strict_mode {
STRICT_SENTENCE_INTERRUPTERS
} else {
SENTENCE_INTERRUPTERS
},
),
SplitLevel::SymbolBoundaries => find_best_delimiter(
text,
if self.strict_mode {
STRICT_SYMBOL_BOUNDARIES
} else {
SYMBOL_BOUNDARIES
},
),
SplitLevel::WordJoiners => find_best_delimiter(
text,
if self.strict_mode {
STRICT_WORD_JOINERS
} else {
WORD_JOINERS
},
),
SplitLevel::Characters => {
Some(String::new())
}
}
}
}
#[derive(Clone)]
pub struct SemchunkSplitterBuilder<S = CharSizer> {
inner: SemchunkSplitter<S>,
}
impl Default for SemchunkSplitterBuilder<CharSizer> {
fn default() -> Self {
Self {
inner: SemchunkSplitter::new(1000, 200),
}
}
}
impl<S> SemchunkSplitterBuilder<S>
where
S: ChunkSizer,
{
pub fn chunk_size(mut self, chunk_size: usize) -> Self {
self.inner.config.chunk_size = chunk_size;
self
}
pub fn chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.inner.config.chunk_overlap = chunk_overlap;
self
}
pub fn memoize(mut self, memoize: bool) -> Self {
self.inner.memoize = memoize;
self
}
pub fn strict_mode(mut self, strict_mode: bool) -> Self {
self.inner.strict_mode = strict_mode;
self
}
pub fn strip_whitespace(mut self, strip_whitespace: bool) -> Self {
self.inner.strip_whitespace = strip_whitespace;
self
}
pub fn sizer<T>(self, sizer: T) -> SemchunkSplitterBuilder<T>
where
T: ChunkSizer,
{
let inner = self.inner;
let length_sizer = sizer.clone();
SemchunkSplitterBuilder {
inner: SemchunkSplitter {
config: ChunkConfig::new(
inner.config.chunk_size,
inner.config.chunk_overlap,
sizer,
),
memoize: inner.memoize,
strict_mode: inner.strict_mode,
strip_whitespace: inner.strip_whitespace,
length_fn: Arc::new(move |value: &str| length_sizer.size(value)),
},
}
}
pub fn length_fn(self, length_fn: crate::LengthFn) -> SemchunkSplitterBuilder<FunctionSizer> {
self.sizer(FunctionSizer::new(length_fn))
}
pub fn build(self) -> Result<SemchunkSplitter<S>, ChunkError> {
validate_chunk_config(
self.inner.config.chunk_size,
self.inner.config.chunk_overlap,
)?;
Ok(self.inner)
}
}
fn split_semchunk_spans(
text: &str,
base_offset: usize,
delimiter: &str,
is_whitespace_delim: bool,
) -> Vec<TextSpan> {
let mut spans = Vec::new();
let mut start = 0usize;
while let Some(relative) = text[start..].find(delimiter) {
let delim_start = start + relative;
let delim_end = delim_start + delimiter.len();
let end = if is_whitespace_delim {
delim_start
} else {
delim_end
};
if start < end {
spans.push(TextSpan::new(base_offset + start, base_offset + end));
}
start = delim_end;
}
if start < text.len() {
spans.push(TextSpan::new(base_offset + start, base_offset + text.len()));
}
spans
}
fn find_longest_sequence(text: &str, ch: char) -> Option<String> {
let mut max_len: usize = 0;
let mut current_len: usize = 0;
for c in text.chars() {
if c == ch {
current_len += 1;
if current_len > max_len {
max_len = current_len;
}
} else {
current_len = 0;
}
}
if max_len > 0 {
Some(std::iter::repeat_n(ch, max_len).collect())
} else {
None
}
}
fn find_longest_space_sequence(text: &str) -> Option<String> {
let mut max_len: usize = 0;
let mut current_len: usize = 0;
for c in text.chars() {
if c == ' ' {
current_len += 1;
if current_len > max_len {
max_len = current_len;
}
} else {
current_len = 0;
}
}
if max_len > 0 {
Some(" ".repeat(max_len))
} else {
None
}
}
fn find_best_delimiter(text: &str, delimiters: &[&str]) -> Option<String> {
let mut best: Option<(&str, usize, usize)> = None;
for (order, delim) in delimiters.iter().enumerate() {
if let Some(pos) = text.find(delim) {
match best {
None => best = Some((delim, pos, order)),
Some((current, current_pos, current_order)) => {
let better_len = delim.len() > current.len();
let same_len = delim.len() == current.len();
let better_pos = same_len && pos < current_pos;
let better_order = same_len && pos == current_pos && order < current_order;
if better_len || better_pos || better_order {
best = Some((delim, pos, order));
}
}
}
}
}
best.map(|(delim, _, _)| delim.to_string())
}
#[cfg(test)]
fn reattach_delimiter(splits: &[&str], delimiter: &str) -> Vec<String> {
let mut result: Vec<String> = Vec::new();
for (i, split) in splits.iter().enumerate() {
if split.is_empty() && i > 0 {
continue;
}
if i < splits.len() - 1 {
result.push(format!("{}{}", split, delimiter));
} else if !split.is_empty() {
result.push(split.to_string());
}
}
result
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_semchunk_basic() {
let splitter = SemchunkSplitter::new(15, 0);
let result = splitter.split_text("Hello world. How are you?");
assert!(result.len() >= 2);
for chunk in &result {
assert!(
chunk.chars().count() <= 15,
"Chunk too long: {:?} ({})",
chunk,
chunk.chars().count()
);
}
}
#[test]
fn test_semchunk_paragraph_boundaries() {
let splitter = SemchunkSplitter::new(20, 0);
let result = splitter.split_text("Para one.\n\nPara two.\n\nPara three.");
assert!(result.len() >= 2);
for chunk in &result {
assert!(chunk.chars().count() <= 20, "Chunk too long: {:?}", chunk);
}
}
#[test]
fn test_semchunk_sentence_boundaries() {
let splitter = SemchunkSplitter::new(30, 0);
let text = "This is sentence one. This is sentence two. And sentence three.";
let result = splitter.split_text(text);
assert!(result.len() >= 2);
for chunk in &result {
assert!(chunk.chars().count() <= 30, "Chunk too long: {:?}", chunk);
}
}
#[test]
fn test_semchunk_punctuation_reattachment() {
let splitter = SemchunkSplitter::new(25, 0);
let result = splitter.split_text("Hello world. Goodbye world.");
for chunk in &result {
assert!(
!chunk.starts_with(". "),
"Delimiter should be reattached: {:?}",
chunk
);
}
}
#[test]
fn test_semchunk_longest_sequence_preference() {
let splitter = SemchunkSplitter::new(10, 0);
let text = "AAAAAAA\n\n\nBBBBBBB\nCCCCCCC";
let result = splitter.split_text(text);
assert!(result.len() >= 2);
assert_eq!(result[0], "AAAAAAA");
}
#[test]
fn test_semchunk_fallback_to_characters() {
let splitter = SemchunkSplitter::new(5, 0);
let text = "abcdefghij";
let result = splitter.split_text(text);
assert!(result.len() >= 2);
for chunk in &result {
assert!(chunk.chars().count() <= 5, "Chunk too long: {:?}", chunk);
}
}
#[test]
fn test_semchunk_empty_text() {
let splitter = SemchunkSplitter::new(100, 0);
let result = splitter.split_text("");
assert!(result.is_empty());
}
#[test]
fn test_semchunk_whitespace_only() {
let splitter = SemchunkSplitter::new(100, 0);
let result = splitter.split_text(" \n\n ");
assert!(result.is_empty());
}
#[test]
fn test_semchunk_fits_in_one_chunk() {
let splitter = SemchunkSplitter::new(100, 0);
let result = splitter.split_text("Short text");
assert_eq!(result, vec!["Short text"]);
}
#[test]
fn test_find_longest_sequence() {
assert_eq!(
find_longest_sequence("a\n\n\nb\nc", '\n'),
Some("\n\n\n".to_string())
);
assert_eq!(find_longest_sequence("abc", '\n'), None);
assert_eq!(find_longest_sequence("\n", '\n'), Some("\n".to_string()));
}
#[test]
fn test_reattach_delimiter() {
let splits = vec!["Hello", "World", "Foo"];
let result = reattach_delimiter(&splits, ". ");
assert_eq!(result, vec!["Hello. ", "World. ", "Foo"]);
}
#[test]
fn test_semchunk_overlap() {
let splitter = SemchunkSplitter::new(60, 35);
let text = "Schemas define structure. Vectorizers create embeddings. Workers process pending rows. Queries retrieve semantic context.";
let result = splitter.split_text(text);
assert!(
result.len() >= 2,
"Expected multiple chunks, got {:?}",
result
);
assert!(
result[0].contains("Schemas define structure.")
&& result[0].contains("Vectorizers create embeddings."),
"Expected first chunk to preserve full sentence boundaries: {:?}",
result
);
}
#[test]
fn test_find_best_delimiter_prefers_longest_delimiter() {
let text = "Alpha... Beta: Gamma";
let delim = find_best_delimiter(text, &[". ", "...", ":"]);
assert_eq!(delim, Some("...".to_string()));
}
#[test]
fn test_semchunk_bracket_and_quote_boundaries() {
let splitter = SemchunkSplitter::new(18, 0);
let text = "Alpha (beta) [gamma] {delta} \"epsilon\" 'zeta'";
let chunks = splitter.split_text(text);
assert!(
chunks.len() >= 2,
"Expected multiple chunks, got {:?}",
chunks
);
assert!(
chunks.iter().all(|c| c.chars().count() <= 18),
"Chunk exceeded size: {:?}",
chunks
);
}
#[test]
fn test_semchunk_symbol_boundaries() {
let splitter = SemchunkSplitter::new(12, 0);
let text = "alpha=value+delta|gamma";
let chunks = splitter.split_text(text);
assert!(
chunks.len() >= 2,
"Expected multiple chunks, got {:?}",
chunks
);
assert!(
chunks.iter().all(|c| c.chars().count() <= 12),
"Chunk exceeded size: {:?}",
chunks
);
}
#[test]
fn test_semchunk_without_memoization() {
let splitter = SemchunkSplitter {
memoize: false,
..SemchunkSplitter::new(20, 5)
};
let text = "This is sentence one. This is sentence two. This is sentence three.";
let chunks = splitter.split_text(text);
assert!(chunks.len() >= 2);
for chunk in chunks {
assert!(
chunk.chars().count() <= 20,
"Chunk exceeded size: {:?}",
chunk
);
}
}
#[test]
fn test_semchunk_strict_mode_prefers_newline_aware_delimiters() {
let default = SemchunkSplitter {
strict_mode: false,
..SemchunkSplitter::new(25, 0)
};
let strict = SemchunkSplitter {
strict_mode: true,
..SemchunkSplitter::new(25, 0)
};
let text = "Alpha.\nBeta. Gamma.";
let default_chunks = default.split_text(text);
let strict_chunks = strict.split_text(text);
assert!(strict_chunks.len() >= default_chunks.len());
assert!(
strict_chunks[0].contains("Alpha."),
"Strict mode should preserve first sentence boundary: {:?}",
strict_chunks
);
}
}