use crate::chunker::Chunk;
use crate::embedding::{l2_normalize, Embedder};
use crate::error::RagError;
pub struct SemanticChunker<'a, E: Embedder> {
embedder: &'a E,
similarity_threshold: f32,
min_sentences_per_chunk: usize,
}
impl<'a, E: Embedder> SemanticChunker<'a, E> {
pub fn new(embedder: &'a E, similarity_threshold: f32) -> Self {
Self {
embedder,
similarity_threshold,
min_sentences_per_chunk: 1,
}
}
#[must_use]
pub fn with_min_sentences(mut self, min: usize) -> Self {
self.min_sentences_per_chunk = min.max(1);
self
}
pub fn chunk(&self, text: &str, doc_id: usize) -> Result<Vec<Chunk>, RagError> {
if text.trim().is_empty() {
return Ok(Vec::new());
}
let sentences = split_sentences(text);
if sentences.is_empty() {
return Ok(vec![Chunk::new(text.trim().to_string(), doc_id, 0, 0)]);
}
if sentences.len() == 1 {
let (start, s) = sentences[0];
return Ok(vec![Chunk::new(s.to_string(), doc_id, 0, start)]);
}
let mut chunks: Vec<Chunk> = Vec::new();
let mut current_text = String::new();
let mut current_start = sentences[0].0;
let mut current_mean: Vec<f32> = self.embed_unit(sentences[0].1)?;
let mut current_count: usize = 1;
current_text.push_str(sentences[0].1);
for (offset, sentence) in &sentences[1..] {
let emb = self.embed_unit(sentence)?;
let similarity = cosine_unit(¤t_mean, &emb);
let must_extend = current_count < self.min_sentences_per_chunk;
if similarity >= self.similarity_threshold || must_extend {
if !current_text.is_empty() {
current_text.push(' ');
}
current_text.push_str(sentence);
current_count += 1;
for (m, e) in current_mean.iter_mut().zip(emb.iter()) {
*m += (*e - *m) / current_count as f32;
}
} else {
chunks.push(Chunk::new(
std::mem::take(&mut current_text),
doc_id,
chunks.len(),
current_start,
));
current_start = *offset;
current_mean = emb;
current_count = 1;
current_text.push_str(sentence);
}
}
if !current_text.is_empty() {
chunks.push(Chunk::new(
current_text,
doc_id,
chunks.len(),
current_start,
));
}
Ok(chunks)
}
fn embed_unit(&self, text: &str) -> Result<Vec<f32>, RagError> {
let mut v = self.embedder.embed(text)?;
if v.iter().any(|x| !x.is_finite()) {
return Err(RagError::NonFinite);
}
l2_normalize(&mut v);
Ok(v)
}
}
fn split_sentences(text: &str) -> Vec<(usize, &str)> {
let bytes = text.as_bytes();
let mut start = 0usize;
let mut i = 0usize;
let mut out = Vec::new();
while i < bytes.len() {
let b = bytes[i];
if b == b'.' || b == b'!' || b == b'?' {
let mut j = i + 1;
while j < bytes.len() && matches!(bytes[j], b'.' | b'!' | b'?') {
j += 1;
}
while j < bytes.len() && matches!(bytes[j], b' ' | b'\t' | b'\n' | b'\r') {
j += 1;
}
let sentence = text[start..j].trim();
if !sentence.is_empty() {
let anchor = start
+ text[start..]
.char_indices()
.find(|(_, c)| !c.is_whitespace())
.map(|(b, _)| b)
.unwrap_or(0);
out.push((anchor, sentence));
}
start = j;
i = j;
} else {
i += 1;
}
}
let tail = text[start..].trim();
if !tail.is_empty() {
let anchor = start
+ text[start..]
.char_indices()
.find(|(_, c)| !c.is_whitespace())
.map(|(b, _)| b)
.unwrap_or(0);
out.push((anchor, tail));
}
out
}
fn cosine_unit(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if na < 1e-10 || nb < 1e-10 {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
(dot / (na * nb)).clamp(-1.0, 1.0)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::embedding::IdentityEmbedder;
#[test]
fn empty_input_yields_no_chunks() {
let emb = IdentityEmbedder::new(16).expect("valid dim");
let chunker = SemanticChunker::new(&emb, 0.5);
let chunks = chunker.chunk("", 0).expect("chunk");
assert!(chunks.is_empty());
}
#[test]
fn single_sentence_fallback() {
let emb = IdentityEmbedder::new(16).expect("valid dim");
let chunker = SemanticChunker::new(&emb, 0.5);
let chunks = chunker.chunk("Only one sentence here", 0).expect("chunk");
assert_eq!(chunks.len(), 1);
}
#[test]
fn threshold_zero_merges_everything() {
let emb = IdentityEmbedder::new(16).expect("valid dim");
let chunker = SemanticChunker::new(&emb, -2.0);
let text = "Alpha. Beta. Gamma.";
let chunks = chunker.chunk(text, 0).expect("chunk");
assert_eq!(chunks.len(), 1);
}
}