use crate::core::Result;
use crate::vector::EmbeddingGenerator;
#[derive(Debug, Clone)]
pub struct SemanticChunk {
pub content: String,
pub start_sentence: usize,
pub end_sentence: usize,
pub sentence_count: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum BreakpointStrategy {
Percentile,
StandardDeviation,
Absolute,
}
#[derive(Debug, Clone)]
pub struct SemanticChunkerConfig {
pub breakpoint_strategy: BreakpointStrategy,
pub threshold_amount: f32,
pub min_chunk_size: usize,
pub max_chunk_size: usize,
pub buffer_size: usize,
}
impl Default for SemanticChunkerConfig {
fn default() -> Self {
Self {
breakpoint_strategy: BreakpointStrategy::Percentile,
threshold_amount: 95.0,
min_chunk_size: 1,
max_chunk_size: 0, buffer_size: 1,
}
}
}
pub struct SemanticChunker {
config: SemanticChunkerConfig,
embedding_generator: EmbeddingGenerator,
}
impl SemanticChunker {
pub fn new(config: SemanticChunkerConfig, embedding_generator: EmbeddingGenerator) -> Self {
Self {
config,
embedding_generator,
}
}
pub fn chunk(&mut self, text: &str) -> Result<Vec<SemanticChunk>> {
let sentences = self.split_sentences(text);
if sentences.is_empty() {
return Ok(Vec::new());
}
if sentences.len() == 1 {
return Ok(vec![SemanticChunk {
content: text.to_string(),
start_sentence: 0,
end_sentence: 1,
sentence_count: 1,
}]);
}
let embeddings = self.embed_sentences(&sentences)?;
let similarity_diffs = self.calculate_similarity_differences(&embeddings);
let breakpoints = self.determine_breakpoints(&similarity_diffs)?;
let chunks = self.create_chunks(&sentences, &breakpoints);
Ok(chunks)
}
fn split_sentences(&self, text: &str) -> Vec<String> {
let mut sentences = Vec::new();
let mut current_sentence = String::new();
for line in text.lines() {
let line = line.trim();
if line.is_empty() {
if !current_sentence.is_empty() {
sentences.push(current_sentence.clone());
current_sentence.clear();
}
continue;
}
for part in line.split_inclusive(&['.', '!', '?']) {
let part = part.trim();
if part.is_empty() {
continue;
}
current_sentence.push_str(part);
current_sentence.push(' ');
if part.ends_with('.') || part.ends_with('!') || part.ends_with('?') {
sentences.push(current_sentence.trim().to_string());
current_sentence.clear();
}
}
}
if !current_sentence.trim().is_empty() {
sentences.push(current_sentence.trim().to_string());
}
sentences
}
fn embed_sentences(&mut self, sentences: &[String]) -> Result<Vec<Vec<f32>>> {
let mut embeddings = Vec::new();
for sentence in sentences {
let embedding = self.embedding_generator.generate_embedding(sentence);
embeddings.push(embedding);
}
Ok(embeddings)
}
fn calculate_similarity_differences(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
let mut diffs = Vec::new();
for i in 0..embeddings.len().saturating_sub(self.config.buffer_size) {
let sim =
self.cosine_similarity(&embeddings[i], &embeddings[i + self.config.buffer_size]);
let distance = 1.0 - sim;
diffs.push(distance);
}
diffs
}
fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if mag_a == 0.0 || mag_b == 0.0 {
return 0.0;
}
dot / (mag_a * mag_b)
}
fn determine_breakpoints(&self, diffs: &[f32]) -> Result<Vec<usize>> {
if diffs.is_empty() {
return Ok(Vec::new());
}
let threshold = match self.config.breakpoint_strategy {
BreakpointStrategy::Percentile => self.calculate_percentile_threshold(diffs),
BreakpointStrategy::StandardDeviation => self.calculate_std_threshold(diffs),
BreakpointStrategy::Absolute => self.config.threshold_amount,
};
let mut breakpoints = Vec::new();
for (i, &diff) in diffs.iter().enumerate() {
if diff > threshold {
breakpoints.push(i + 1);
}
}
Ok(breakpoints)
}
fn calculate_percentile_threshold(&self, diffs: &[f32]) -> f32 {
let mut sorted = diffs.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let percentile = self.config.threshold_amount / 100.0;
let index = ((sorted.len() as f32 * percentile) as usize).min(sorted.len() - 1);
sorted[index]
}
fn calculate_std_threshold(&self, diffs: &[f32]) -> f32 {
let mean: f32 = diffs.iter().sum::<f32>() / diffs.len() as f32;
let variance: f32 =
diffs.iter().map(|&x| (x - mean).powi(2)).sum::<f32>() / diffs.len() as f32;
let std_dev = variance.sqrt();
mean + (self.config.threshold_amount * std_dev)
}
fn create_chunks(&self, sentences: &[String], breakpoints: &[usize]) -> Vec<SemanticChunk> {
let mut chunks = Vec::new();
let mut start_idx = 0;
let mut all_breakpoints = breakpoints.to_vec();
all_breakpoints.push(sentences.len());
for &end_idx in &all_breakpoints {
if end_idx <= start_idx {
continue;
}
let sentence_count = end_idx - start_idx;
if sentence_count < self.config.min_chunk_size {
continue;
}
if self.config.max_chunk_size > 0 && sentence_count > self.config.max_chunk_size {
let mut sub_start = start_idx;
while sub_start < end_idx {
let sub_end = (sub_start + self.config.max_chunk_size).min(end_idx);
let content = sentences[sub_start..sub_end].join(" ");
chunks.push(SemanticChunk {
content,
start_sentence: sub_start,
end_sentence: sub_end,
sentence_count: sub_end - sub_start,
});
sub_start = sub_end;
}
} else {
let content = sentences[start_idx..end_idx].join(" ");
chunks.push(SemanticChunk {
content,
start_sentence: start_idx,
end_sentence: end_idx,
sentence_count,
});
}
start_idx = end_idx;
}
chunks
}
pub fn config(&self) -> &SemanticChunkerConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_sentence_splitting() {
let config = SemanticChunkerConfig::default();
let embedding_gen = EmbeddingGenerator::new(384); let chunker = SemanticChunker::new(config, embedding_gen);
let text = "This is sentence one. This is sentence two! Is this sentence three?";
let sentences = chunker.split_sentences(text);
assert_eq!(sentences.len(), 3);
assert!(sentences[0].contains("sentence one"));
assert!(sentences[1].contains("sentence two"));
assert!(sentences[2].contains("sentence three"));
}
#[test]
fn test_cosine_similarity() {
let config = SemanticChunkerConfig::default();
let embedding_gen = EmbeddingGenerator::new(384);
let chunker = SemanticChunker::new(config, embedding_gen);
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let sim = chunker.cosine_similarity(&a, &b);
assert!((sim - 1.0).abs() < 0.001);
let a = vec![1.0, 0.0];
let b = vec![0.0, 1.0];
let sim = chunker.cosine_similarity(&a, &b);
assert!(sim.abs() < 0.001);
let a = vec![1.0, 0.0];
let b = vec![-1.0, 0.0];
let sim = chunker.cosine_similarity(&a, &b);
assert!((sim + 1.0).abs() < 0.001);
}
#[test]
fn test_percentile_threshold() {
let config = SemanticChunkerConfig {
breakpoint_strategy: BreakpointStrategy::Percentile,
threshold_amount: 95.0,
..Default::default()
};
let embedding_gen = EmbeddingGenerator::new(384);
let chunker = SemanticChunker::new(config, embedding_gen);
let diffs = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
let threshold = chunker.calculate_percentile_threshold(&diffs);
assert!(threshold >= 0.9);
}
#[test]
fn test_std_threshold() {
let config = SemanticChunkerConfig {
breakpoint_strategy: BreakpointStrategy::StandardDeviation,
threshold_amount: 3.0,
..Default::default()
};
let embedding_gen = EmbeddingGenerator::new(384);
let chunker = SemanticChunker::new(config, embedding_gen);
let diffs = vec![0.5, 0.5, 0.5, 0.5, 0.5]; let threshold = chunker.calculate_std_threshold(&diffs);
assert!((threshold - 0.5).abs() < 0.001); }
#[test]
fn test_semantic_chunking_basic() {
let config = SemanticChunkerConfig {
breakpoint_strategy: BreakpointStrategy::Percentile,
threshold_amount: 50.0, min_chunk_size: 1,
max_chunk_size: 0,
buffer_size: 1,
};
let embedding_gen = EmbeddingGenerator::new(384);
let mut chunker = SemanticChunker::new(config, embedding_gen);
let text = "Alice loves programming. Bob also codes daily. \
The weather is sunny. Rain is expected tomorrow.";
let chunks = chunker.chunk(text).unwrap();
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(!chunk.content.is_empty());
assert!(chunk.sentence_count > 0);
}
}
}