use crate::character::validate_chunk_config;
use crate::chunk::{chunks_from_spans, TextChunk, TextSpan};
use crate::error::ChunkError;
use crate::sizing::{CharSizer, ChunkConfig, ChunkSizer, FunctionSizer};
use std::sync::Arc;
const DEFAULT_DELIMITERS: &[&str] = &[". ", "! ", "? ", "\n"];
#[derive(Clone)]
pub struct SemanticChunker<S = CharSizer> {
pub(crate) config: ChunkConfig<S>,
pub(crate) window_size: usize,
pub(crate) skip_window: usize,
pub(crate) reconnect_similarity_threshold: f32,
pub(crate) max_aside_length: usize,
pub(crate) delimiters: Vec<String>,
pub(crate) min_characters_per_sentence: usize,
pub(crate) strip_whitespace: bool,
length_fn: crate::LengthFn,
pub(crate) embedder: Option<crate::EmbedderHandle>,
}
impl SemanticChunker<CharSizer> {
pub fn new(chunk_size: usize, chunk_overlap: usize) -> Self {
Self {
config: ChunkConfig::new(chunk_size, chunk_overlap, CharSizer),
window_size: 3,
skip_window: 0,
reconnect_similarity_threshold: 0.75,
max_aside_length: 512,
delimiters: DEFAULT_DELIMITERS.iter().map(|s| s.to_string()).collect(),
min_characters_per_sentence: 12,
strip_whitespace: true,
length_fn: Arc::new(crate::char_len),
embedder: None,
}
}
pub fn builder() -> SemanticChunkerBuilder<CharSizer> {
SemanticChunkerBuilder::default()
}
}
impl<S> SemanticChunker<S>
where
S: ChunkSizer,
{
pub fn split_text(&self, text: &str) -> Vec<String> {
self.split_chunks(text)
.into_iter()
.map(|chunk| chunk.text.to_string())
.collect()
}
pub fn split_chunks<'a>(&self, text: &'a str) -> Vec<TextChunk<'a>> {
if text.is_empty() {
return Vec::new();
}
let len_fn = self.length_fn.as_ref();
let mut sentences = split_into_sentence_spans(text, &self.delimiters);
sentences = merge_short_sentence_spans(text, sentences, self.min_characters_per_sentence);
if sentences.is_empty() {
return Vec::new();
}
if self.window_size == 0 || sentences.len() <= self.window_size {
let spans = greedy_sentence_chunk_spans(
text,
&sentences,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
len_fn,
);
return chunks_from_spans(text, spans, len_fn);
}
let Some(embedder) = &self.embedder else {
let spans = greedy_sentence_chunk_spans(
text,
&sentences,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
len_fn,
);
return chunks_from_spans(text, spans, len_fn);
};
let boundaries =
semantic_boundaries_batch(text, &sentences, self.window_size, embedder.as_ref())
.expect(
"semantic embedding failed; use try_split_chunks to handle embedding errors",
);
let mut chunks =
chunk_spans_by_boundaries(text, &sentences, &boundaries, self.strip_whitespace);
if chunks.is_empty() {
if let Some(span) =
join_span_range(text, &sentences, 0, sentences.len(), self.strip_whitespace)
{
chunks = vec![span];
}
}
if self.skip_window > 0 {
chunks = reconnect_skip_window_spans_batch(
text,
chunks,
self.skip_window,
self.reconnect_similarity_threshold,
self.max_aside_length,
embedder.as_ref(),
)
.expect("semantic embedding failed; use try_split_chunks to handle embedding errors");
}
let mut final_chunks: Vec<TextSpan> = Vec::new();
for chunk in chunks {
if len_fn(chunk.text(text)) <= self.config.chunk_size {
final_chunks.push(chunk);
continue;
}
let chunk_sentences = split_into_sentence_spans(chunk.text(text), &self.delimiters)
.into_iter()
.map(|span| TextSpan::new(span.start + chunk.start, span.end + chunk.start))
.collect::<Vec<_>>();
let sub = greedy_sentence_chunk_spans(
text,
&chunk_sentences,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
len_fn,
);
final_chunks.extend(sub);
}
chunks_from_spans(text, final_chunks, len_fn)
}
pub fn try_split_text(&self, text: &str) -> Result<Vec<String>, ChunkError> {
Ok(self
.try_split_chunks(text)?
.into_iter()
.map(|chunk| chunk.text.to_string())
.collect())
}
pub fn try_split_chunks<'a>(&self, text: &'a str) -> Result<Vec<TextChunk<'a>>, ChunkError> {
match &self.embedder {
Some(_) => self.split_chunks_with_batch_embedder(text),
None => Ok(self.split_chunks(text)),
}
}
pub fn try_chunks<'a>(
&self,
text: &'a str,
) -> Result<std::vec::IntoIter<TextChunk<'a>>, ChunkError> {
Ok(self.try_split_chunks(text)?.into_iter())
}
fn split_chunks_with_batch_embedder<'a>(
&self,
text: &'a str,
) -> Result<Vec<TextChunk<'a>>, ChunkError> {
if text.is_empty() {
return Ok(Vec::new());
}
let len_fn = self.length_fn.as_ref();
let mut sentences = split_into_sentence_spans(text, &self.delimiters);
sentences = merge_short_sentence_spans(text, sentences, self.min_characters_per_sentence);
if sentences.is_empty() {
return Ok(Vec::new());
}
if self.window_size == 0 || sentences.len() <= self.window_size {
let spans = greedy_sentence_chunk_spans(
text,
&sentences,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
len_fn,
);
return Ok(chunks_from_spans(text, spans, len_fn));
}
let boundaries = semantic_boundaries_batch(
text,
&sentences,
self.window_size,
self.embedder
.as_ref()
.expect("checked before batch semantic chunking")
.as_ref(),
)?;
let mut chunks =
chunk_spans_by_boundaries(text, &sentences, &boundaries, self.strip_whitespace);
if chunks.is_empty() {
if let Some(span) =
join_span_range(text, &sentences, 0, sentences.len(), self.strip_whitespace)
{
chunks = vec![span];
}
}
if self.skip_window > 0 {
chunks = reconnect_skip_window_spans_batch(
text,
chunks,
self.skip_window,
self.reconnect_similarity_threshold,
self.max_aside_length,
self.embedder
.as_ref()
.expect("checked before batch semantic chunking")
.as_ref(),
)?;
}
let mut final_chunks = Vec::new();
for chunk in chunks {
if len_fn(chunk.text(text)) <= self.config.chunk_size {
final_chunks.push(chunk);
continue;
}
let chunk_sentences = split_into_sentence_spans(chunk.text(text), &self.delimiters)
.into_iter()
.map(|span| TextSpan::new(span.start + chunk.start, span.end + chunk.start))
.collect::<Vec<_>>();
let sub = greedy_sentence_chunk_spans(
text,
&chunk_sentences,
self.config.chunk_size,
self.config.chunk_overlap,
self.strip_whitespace,
len_fn,
);
final_chunks.extend(sub);
}
Ok(chunks_from_spans(text, final_chunks, len_fn))
}
}
fn semantic_boundaries_batch(
input: &str,
sentences: &[TextSpan],
window_size: usize,
embedder: &dyn crate::Embedder,
) -> Result<Vec<usize>, ChunkError> {
if sentences.len() <= window_size {
return Ok(Vec::new());
}
let windows: Vec<&str> = (0..=sentences.len() - window_size)
.map(|i| &input[sentences[i].start..sentences[i + window_size - 1].end])
.collect();
if windows.len() < 2 {
return Ok(Vec::new());
}
let embeddings = embedder.embed_batch(&windows)?;
if embeddings.len() != windows.len() {
return Err(ChunkError::embedding_failure(format!(
"embedder returned {} embeddings for {} inputs",
embeddings.len(),
windows.len()
)));
}
validate_embedding_shapes(&embeddings)?;
let sims: Vec<f32> = embeddings
.windows(2)
.map(|pair| cosine_similarity(&pair[0], &pair[1]))
.collect();
let smoothed = savgol_smooth(&sims);
let minima = local_minima(&smoothed);
let mut boundaries: Vec<usize> = minima
.into_iter()
.map(|i| i + window_size)
.filter(|&b| b > 0 && b < sentences.len())
.collect();
boundaries.sort_unstable();
boundaries.dedup();
Ok(boundaries)
}
fn validate_embedding_shapes(embeddings: &[Vec<f32>]) -> Result<(), ChunkError> {
let Some(first) = embeddings.first() else {
return Ok(());
};
if first.is_empty() {
return Err(ChunkError::embedding_failure(
"embedder returned empty embedding vectors",
));
}
let dimension = first.len();
for (idx, embedding) in embeddings.iter().enumerate() {
if embedding.is_empty() {
return Err(ChunkError::embedding_failure(format!(
"embedder returned empty embedding vector at index {idx}"
)));
}
if embedding.len() != dimension {
return Err(ChunkError::embedding_failure(format!(
"embedder returned inconsistent embedding dimension at index {idx}: expected {dimension}, got {}",
embedding.len()
)));
}
if embedding.iter().any(|value| !value.is_finite()) {
return Err(ChunkError::embedding_failure(format!(
"embedder returned non-finite embedding value at index {idx}"
)));
}
}
Ok(())
}
#[derive(Clone)]
pub struct SemanticChunkerBuilder<S = CharSizer> {
inner: SemanticChunker<S>,
}
impl Default for SemanticChunkerBuilder<CharSizer> {
fn default() -> Self {
Self {
inner: SemanticChunker::new(1000, 200),
}
}
}
impl<S> SemanticChunkerBuilder<S>
where
S: ChunkSizer,
{
pub fn chunk_size(mut self, chunk_size: usize) -> Self {
self.inner.config.chunk_size = chunk_size;
self
}
pub fn chunk_overlap(mut self, chunk_overlap: usize) -> Self {
self.inner.config.chunk_overlap = chunk_overlap;
self
}
pub fn window_size(mut self, window_size: usize) -> Self {
self.inner.window_size = window_size;
self
}
pub fn skip_window(mut self, skip_window: usize) -> Self {
self.inner.skip_window = skip_window;
self
}
pub fn reconnect_similarity_threshold(mut self, threshold: f32) -> Self {
self.inner.reconnect_similarity_threshold = threshold;
self
}
pub fn max_aside_length(mut self, max_aside_length: usize) -> Self {
self.inner.max_aside_length = max_aside_length;
self
}
pub fn delimiters(mut self, delimiters: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.inner.delimiters = delimiters.into_iter().map(Into::into).collect();
self
}
pub fn min_characters_per_sentence(mut self, min_characters_per_sentence: usize) -> Self {
self.inner.min_characters_per_sentence = min_characters_per_sentence;
self
}
pub fn strip_whitespace(mut self, strip_whitespace: bool) -> Self {
self.inner.strip_whitespace = strip_whitespace;
self
}
pub fn sizer<T>(self, sizer: T) -> SemanticChunkerBuilder<T>
where
T: ChunkSizer,
{
let inner = self.inner;
let length_sizer = sizer.clone();
SemanticChunkerBuilder {
inner: SemanticChunker {
config: ChunkConfig::new(
inner.config.chunk_size,
inner.config.chunk_overlap,
sizer,
),
window_size: inner.window_size,
skip_window: inner.skip_window,
reconnect_similarity_threshold: inner.reconnect_similarity_threshold,
max_aside_length: inner.max_aside_length,
delimiters: inner.delimiters,
min_characters_per_sentence: inner.min_characters_per_sentence,
strip_whitespace: inner.strip_whitespace,
length_fn: Arc::new(move |value: &str| length_sizer.size(value)),
embedder: inner.embedder,
},
}
}
pub fn length_fn(self, length_fn: crate::LengthFn) -> SemanticChunkerBuilder<FunctionSizer> {
self.sizer(FunctionSizer::new(length_fn))
}
pub fn embedding_fn(mut self, embedding_fn: crate::EmbeddingFn) -> Self {
let adapter: crate::EmbedderHandle = std::sync::Arc::new(
move |inputs: &[&str]| -> Result<Vec<Vec<f32>>, ChunkError> {
Ok(inputs.iter().map(|input| embedding_fn(input)).collect())
},
);
self.inner.embedder = Some(adapter);
self
}
pub fn embedder(mut self, embedder: crate::EmbedderHandle) -> Self {
self.inner.embedder = Some(embedder);
self
}
pub fn build(self) -> Result<SemanticChunker<S>, ChunkError> {
validate_chunk_config(
self.inner.config.chunk_size,
self.inner.config.chunk_overlap,
)?;
if self.inner.delimiters.is_empty() {
return Err(ChunkError::invalid_configuration(
"semantic chunker requires at least one delimiter",
));
}
if !(0.0..=1.0).contains(&self.inner.reconnect_similarity_threshold) {
return Err(ChunkError::invalid_configuration(
"reconnect_similarity_threshold must be between 0.0 and 1.0",
));
}
Ok(self.inner)
}
}
fn chunk_spans_by_boundaries(
input: &str,
sentences: &[TextSpan],
boundaries: &[usize],
strip_whitespace: bool,
) -> Vec<TextSpan> {
if sentences.is_empty() {
return Vec::new();
}
if boundaries.is_empty() {
return join_span_range(input, sentences, 0, sentences.len(), strip_whitespace)
.into_iter()
.collect();
}
let mut chunks = Vec::new();
let mut start = 0usize;
for &boundary in boundaries {
if boundary <= start || boundary > sentences.len() {
continue;
}
if let Some(chunk) = join_span_range(input, sentences, start, boundary, strip_whitespace) {
chunks.push(chunk);
}
start = boundary;
}
if start < sentences.len() {
if let Some(tail) =
join_span_range(input, sentences, start, sentences.len(), strip_whitespace)
{
chunks.push(tail);
}
}
chunks
}
fn greedy_sentence_chunk_spans(
input: &str,
sentences: &[TextSpan],
chunk_size: usize,
chunk_overlap: usize,
strip_whitespace: bool,
length_fn: &dyn Fn(&str) -> usize,
) -> Vec<TextSpan> {
if sentences.is_empty() {
return Vec::new();
}
let mut chunks = Vec::new();
let mut current = Vec::new();
let mut current_len = 0usize;
for (i, sentence) in sentences.iter().enumerate() {
let sentence_len = length_fn(sentence.text(input));
if current.is_empty() {
current.push(i);
current_len = sentence_len;
continue;
}
if current_len + sentence_len > chunk_size {
if let Some(chunk) = join_indices_to_span(input, sentences, ¤t, strip_whitespace)
{
chunks.push(chunk);
}
current.clear();
current_len = 0;
if chunk_overlap > 0 {
let mut overlap_len = 0usize;
let mut overlap_start = i;
while overlap_start > 0 {
let candidate = overlap_start - 1;
let candidate_len = length_fn(sentences[candidate].text(input));
if overlap_len + candidate_len > chunk_overlap {
break;
}
overlap_len += candidate_len;
overlap_start = candidate;
}
current.extend(overlap_start..i);
current_len = overlap_len;
}
current.push(i);
current_len += sentence_len;
} else {
current.push(i);
current_len += sentence_len;
}
}
if !current.is_empty() {
if let Some(chunk) = join_indices_to_span(input, sentences, ¤t, strip_whitespace) {
chunks.push(chunk);
}
}
chunks
}
fn reconnect_skip_window_spans_batch(
input: &str,
chunks: Vec<TextSpan>,
skip_window: usize,
threshold: f32,
max_aside_length: usize,
embedder: &dyn crate::Embedder,
) -> Result<Vec<TextSpan>, ChunkError> {
if chunks.len() < 3 || skip_window == 0 {
return Ok(chunks);
}
let mut out = Vec::new();
let mut i = 0usize;
while i < chunks.len() {
let mut best_end = None;
let max_gap = skip_window.min(chunks.len().saturating_sub(i + 2));
for gap in 1..=max_gap {
let j = i + gap + 1;
let aside_len = input[chunks[i + 1].start..chunks[j].start].chars().count();
if aside_len > max_aside_length {
continue;
}
let inputs = [chunks[i].text(input), chunks[j].text(input)];
let embeddings = embedder.embed_batch(&inputs)?;
if embeddings.len() != inputs.len() {
return Err(ChunkError::embedding_failure(format!(
"embedder returned {} embeddings for {} reconnect inputs",
embeddings.len(),
inputs.len()
)));
}
validate_embedding_shapes(&embeddings)?;
if cosine_similarity(&embeddings[0], &embeddings[1]) >= threshold {
best_end = Some(j);
}
}
if let Some(end) = best_end {
out.push(TextSpan::new(chunks[i].start, chunks[end].end));
i = end + 1;
} else {
out.push(chunks[i]);
i += 1;
}
}
Ok(out)
}
fn split_into_sentence_spans(text: &str, delimiters: &[String]) -> Vec<TextSpan> {
let mut sentences = Vec::new();
let mut start = 0usize;
while start < text.len() {
let remaining = &text[start..];
let mut earliest_pos = None;
let mut earliest_delim_len = 0usize;
for delim in delimiters {
if let Some(pos) = remaining.find(delim.as_str()) {
if earliest_pos.is_none_or(|current| pos < current) {
earliest_pos = Some(pos);
earliest_delim_len = delim.len();
}
}
}
match earliest_pos {
Some(pos) => {
let end = start + pos + earliest_delim_len;
if start < end {
sentences.push(TextSpan::new(start, end));
}
start = end;
}
None => {
sentences.push(TextSpan::new(start, text.len()));
break;
}
}
}
sentences
}
fn merge_short_sentence_spans(
input: &str,
sentences: Vec<TextSpan>,
min_chars: usize,
) -> Vec<TextSpan> {
if sentences.is_empty() {
return sentences;
}
let mut result: Vec<TextSpan> = Vec::new();
let mut buffer_start = None;
let mut buffer_end = 0usize;
for sentence in sentences {
let start = buffer_start.unwrap_or(sentence.start);
buffer_start = Some(start);
buffer_end = sentence.end;
let buffer = TextSpan::new(start, buffer_end);
if buffer.text(input).chars().count() >= min_chars {
result.push(buffer);
buffer_start = None;
}
}
if let Some(start) = buffer_start {
if let Some(last) = result.last_mut() {
last.end = buffer_end;
} else {
result.push(TextSpan::new(start, buffer_end));
}
}
result
}
fn join_span_range(
input: &str,
spans: &[TextSpan],
start: usize,
end: usize,
strip_whitespace: bool,
) -> Option<TextSpan> {
if start >= end {
return None;
}
let span = TextSpan::new(spans[start].start, spans[end - 1].end);
if strip_whitespace {
span.trim(input)
} else {
Some(span)
}
}
fn join_indices_to_span(
input: &str,
spans: &[TextSpan],
indices: &[usize],
strip_whitespace: bool,
) -> Option<TextSpan> {
let start = spans[*indices.first()?].start;
let end = spans[*indices.last()?].end;
let span = TextSpan::new(start, end);
if strip_whitespace {
span.trim(input)
} else {
Some(span)
}
}
#[cfg(test)]
fn reconnect_skip_windows(
chunks: Vec<String>,
skip_window: usize,
threshold: f32,
max_aside_length: usize,
embedding_fn: &dyn Fn(&str) -> Vec<f32>,
) -> Vec<String> {
if chunks.len() < 3 || skip_window == 0 {
return chunks;
}
let mut out: Vec<String> = Vec::new();
let mut i = 0usize;
while i < chunks.len() {
let mut best_end: Option<usize> = None;
let max_gap = skip_window.min(chunks.len().saturating_sub(i + 2));
for gap in 1..=max_gap {
let j = i + gap + 1;
let aside_len: usize = chunks[i + 1..j].iter().map(|c| c.chars().count()).sum();
if aside_len > max_aside_length {
continue;
}
let left = embedding_fn(&chunks[i]);
let right = embedding_fn(&chunks[j]);
let sim = cosine_similarity(&left, &right);
if sim >= threshold {
best_end = Some(j);
}
}
if let Some(end) = best_end {
out.push(chunks[i..=end].concat());
i = end + 1;
} else {
out.push(chunks[i].clone());
i += 1;
}
}
out
}
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.is_empty() || b.is_empty() || a.len() != b.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut na = 0.0f32;
let mut nb = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
na += a[i] * a[i];
nb += b[i] * b[i];
}
if na == 0.0 || nb == 0.0 {
0.0
} else {
dot / (na.sqrt() * nb.sqrt())
}
}
fn savgol_smooth(values: &[f32]) -> Vec<f32> {
if values.len() < 5 {
return values.to_vec();
}
let coeff: [f32; 5] = [
-3.0 / 35.0,
12.0 / 35.0,
17.0 / 35.0,
12.0 / 35.0,
-3.0 / 35.0,
];
let mut out = values.to_vec();
for i in 2..values.len() - 2 {
let mut v = 0.0f32;
for j in 0..5 {
v += coeff[j] * values[i + j - 2];
}
out[i] = v;
}
out
}
fn local_minima(values: &[f32]) -> Vec<usize> {
if values.len() < 3 {
return Vec::new();
}
let mut mins: Vec<usize> = Vec::new();
for i in 1..values.len() - 1 {
if values[i] < values[i - 1] && values[i] <= values[i + 1] {
mins.push(i);
}
}
mins
}
#[cfg(test)]
mod tests {
use super::*;
fn topic_embedding(text: &str) -> Vec<f32> {
let lower = text.to_lowercase();
let db = ["database", "sql", "table", "vectorizer"]
.iter()
.map(|k| lower.matches(k).count() as f32)
.sum::<f32>();
let weather = ["weather", "rain", "forecast", "temperature"]
.iter()
.map(|k| lower.matches(k).count() as f32)
.sum::<f32>();
vec![db, weather]
}
#[test]
fn test_semantic_chunker_topic_boundary_split() {
let chunker = SemanticChunker::builder()
.chunk_size(10_000)
.chunk_overlap(0)
.window_size(2)
.min_characters_per_sentence(1)
.embedding_fn(std::sync::Arc::new(topic_embedding))
.build()
.unwrap();
let text = "SQL tables store rows. Vectorizer jobs build embeddings. Queries retrieve context. Weather forecasts predict rain. Temperature drops overnight. Storm alerts were issued.";
let chunks = chunker.split_text(text);
assert_eq!(
chunks.len(),
2,
"Expected semantic split into 2 chunks: {:?}",
chunks
);
assert!(chunks[0].contains("SQL tables"));
assert!(chunks[1].contains("Weather forecasts"));
}
#[test]
fn test_semantic_chunker_enforces_chunk_size() {
let chunker = SemanticChunker::builder()
.chunk_size(60)
.chunk_overlap(0)
.window_size(2)
.min_characters_per_sentence(1)
.embedding_fn(std::sync::Arc::new(topic_embedding))
.build()
.unwrap();
let text = "SQL tables store rows for applications. Vectorizer jobs build embeddings for semantic search. Queries retrieve context from matching chunks.";
let chunks = chunker.split_text(text);
assert!(chunks.len() >= 2);
for chunk in chunks {
assert!(chunk.chars().count() <= 60, "Chunk too large: {:?}", chunk);
}
}
#[test]
fn test_semantic_chunker_without_embedding_fn_falls_back() {
let chunker = SemanticChunker::builder()
.chunk_size(40)
.chunk_overlap(0)
.window_size(3)
.skip_window(0)
.min_characters_per_sentence(1)
.build()
.unwrap();
let text = "Sentence one has enough text. Sentence two has enough text. Sentence three has enough text.";
let chunks = chunker.split_text(text);
assert!(chunks.len() >= 2, "Expected fallback size-based split");
}
#[test]
fn test_reconnect_skip_windows_merges_tangential_aside() {
let chunks = vec![
"Database schemas and SQL indexes.".to_string(),
"Rain and storm weather updates.".to_string(),
"Vectorizer query planning and SQL tuning.".to_string(),
];
let merged = reconnect_skip_windows(chunks, 1, 0.5, 200, &topic_embedding);
assert_eq!(
merged.len(),
1,
"Expected aside reconnection merge: {:?}",
merged
);
assert!(merged[0].contains("weather"));
assert!(merged[0].contains("Vectorizer"));
}
}