1use crate::core::Result;
12use crate::vector::EmbeddingGenerator;
13
14#[derive(Debug, Clone)]
16pub struct SemanticChunk {
17 pub content: String,
19
20 pub start_sentence: usize,
22
23 pub end_sentence: usize,
25
26 pub sentence_count: usize,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum BreakpointStrategy {
33 Percentile,
35
36 StandardDeviation,
38
39 Absolute,
41}
42
43#[derive(Debug, Clone)]
45pub struct SemanticChunkerConfig {
46 pub breakpoint_strategy: BreakpointStrategy,
48
49 pub threshold_amount: f32,
54
55 pub min_chunk_size: usize,
57
58 pub max_chunk_size: usize,
60
61 pub buffer_size: usize,
63}
64
65impl Default for SemanticChunkerConfig {
66 fn default() -> Self {
67 Self {
68 breakpoint_strategy: BreakpointStrategy::Percentile,
69 threshold_amount: 95.0,
70 min_chunk_size: 1,
71 max_chunk_size: 0, buffer_size: 1,
73 }
74 }
75}
76
77pub struct SemanticChunker {
79 config: SemanticChunkerConfig,
80 embedding_generator: EmbeddingGenerator,
81}
82
83impl SemanticChunker {
84 pub fn new(config: SemanticChunkerConfig, embedding_generator: EmbeddingGenerator) -> Self {
86 Self {
87 config,
88 embedding_generator,
89 }
90 }
91
92 pub fn chunk(&mut self, text: &str) -> Result<Vec<SemanticChunk>> {
94 let sentences = self.split_sentences(text);
96
97 if sentences.is_empty() {
98 return Ok(Vec::new());
99 }
100
101 if sentences.len() == 1 {
102 return Ok(vec![SemanticChunk {
103 content: text.to_string(),
104 start_sentence: 0,
105 end_sentence: 1,
106 sentence_count: 1,
107 }]);
108 }
109
110 let embeddings = self.embed_sentences(&sentences)?;
112
113 let similarity_diffs = self.calculate_similarity_differences(&embeddings);
115
116 let breakpoints = self.determine_breakpoints(&similarity_diffs)?;
118
119 let chunks = self.create_chunks(&sentences, &breakpoints);
121
122 Ok(chunks)
123 }
124
125 fn split_sentences(&self, text: &str) -> Vec<String> {
127 let mut sentences = Vec::new();
128 let mut current_sentence = String::new();
129
130 for line in text.lines() {
131 let line = line.trim();
132 if line.is_empty() {
133 if !current_sentence.is_empty() {
134 sentences.push(current_sentence.clone());
135 current_sentence.clear();
136 }
137 continue;
138 }
139
140 for part in line.split_inclusive(&['.', '!', '?']) {
142 let part = part.trim();
143 if part.is_empty() {
144 continue;
145 }
146
147 current_sentence.push_str(part);
148 current_sentence.push(' ');
149
150 if part.ends_with('.') || part.ends_with('!') || part.ends_with('?') {
152 sentences.push(current_sentence.trim().to_string());
153 current_sentence.clear();
154 }
155 }
156 }
157
158 if !current_sentence.trim().is_empty() {
160 sentences.push(current_sentence.trim().to_string());
161 }
162
163 sentences
164 }
165
166 fn embed_sentences(&mut self, sentences: &[String]) -> Result<Vec<Vec<f32>>> {
168 let mut embeddings = Vec::new();
169
170 for sentence in sentences {
171 let embedding = self.embedding_generator.generate_embedding(sentence);
172 embeddings.push(embedding);
173 }
174
175 Ok(embeddings)
176 }
177
178 fn calculate_similarity_differences(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
180 let mut diffs = Vec::new();
181
182 for i in 0..embeddings.len().saturating_sub(self.config.buffer_size) {
183 let sim = self.cosine_similarity(&embeddings[i], &embeddings[i + self.config.buffer_size]);
184
185 let distance = 1.0 - sim;
188 diffs.push(distance);
189 }
190
191 diffs
192 }
193
194 fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
196 if a.len() != b.len() {
197 return 0.0;
198 }
199
200 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
201 let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
202 let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
203
204 if mag_a == 0.0 || mag_b == 0.0 {
205 return 0.0;
206 }
207
208 dot / (mag_a * mag_b)
209 }
210
211 fn determine_breakpoints(&self, diffs: &[f32]) -> Result<Vec<usize>> {
213 if diffs.is_empty() {
214 return Ok(Vec::new());
215 }
216
217 let threshold = match self.config.breakpoint_strategy {
218 BreakpointStrategy::Percentile => self.calculate_percentile_threshold(diffs),
219 BreakpointStrategy::StandardDeviation => self.calculate_std_threshold(diffs),
220 BreakpointStrategy::Absolute => self.config.threshold_amount,
221 };
222
223 let mut breakpoints = Vec::new();
225 for (i, &diff) in diffs.iter().enumerate() {
226 if diff > threshold {
227 breakpoints.push(i + 1);
229 }
230 }
231
232 Ok(breakpoints)
233 }
234
235 fn calculate_percentile_threshold(&self, diffs: &[f32]) -> f32 {
237 let mut sorted = diffs.to_vec();
238 sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
239
240 let percentile = self.config.threshold_amount / 100.0;
241 let index = ((sorted.len() as f32 * percentile) as usize).min(sorted.len() - 1);
242
243 sorted[index]
244 }
245
246 fn calculate_std_threshold(&self, diffs: &[f32]) -> f32 {
248 let mean: f32 = diffs.iter().sum::<f32>() / diffs.len() as f32;
249
250 let variance: f32 = diffs.iter()
251 .map(|&x| (x - mean).powi(2))
252 .sum::<f32>() / diffs.len() as f32;
253
254 let std_dev = variance.sqrt();
255
256 mean + (self.config.threshold_amount * std_dev)
257 }
258
259 fn create_chunks(&self, sentences: &[String], breakpoints: &[usize]) -> Vec<SemanticChunk> {
261 let mut chunks = Vec::new();
262 let mut start_idx = 0;
263
264 let mut all_breakpoints = breakpoints.to_vec();
265 all_breakpoints.push(sentences.len()); for &end_idx in &all_breakpoints {
268 if end_idx <= start_idx {
269 continue;
270 }
271
272 let sentence_count = end_idx - start_idx;
273
274 if sentence_count < self.config.min_chunk_size {
276 continue;
277 }
278
279 if self.config.max_chunk_size > 0 && sentence_count > self.config.max_chunk_size {
280 let mut sub_start = start_idx;
282 while sub_start < end_idx {
283 let sub_end = (sub_start + self.config.max_chunk_size).min(end_idx);
284 let content = sentences[sub_start..sub_end].join(" ");
285
286 chunks.push(SemanticChunk {
287 content,
288 start_sentence: sub_start,
289 end_sentence: sub_end,
290 sentence_count: sub_end - sub_start,
291 });
292
293 sub_start = sub_end;
294 }
295 } else {
296 let content = sentences[start_idx..end_idx].join(" ");
297
298 chunks.push(SemanticChunk {
299 content,
300 start_sentence: start_idx,
301 end_sentence: end_idx,
302 sentence_count,
303 });
304 }
305
306 start_idx = end_idx;
307 }
308
309 chunks
310 }
311
312 pub fn config(&self) -> &SemanticChunkerConfig {
314 &self.config
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[test]
323 fn test_sentence_splitting() {
324 let config = SemanticChunkerConfig::default();
325 let embedding_gen = EmbeddingGenerator::new(384); let chunker = SemanticChunker::new(config, embedding_gen);
327
328 let text = "This is sentence one. This is sentence two! Is this sentence three?";
329 let sentences = chunker.split_sentences(text);
330
331 assert_eq!(sentences.len(), 3);
332 assert!(sentences[0].contains("sentence one"));
333 assert!(sentences[1].contains("sentence two"));
334 assert!(sentences[2].contains("sentence three"));
335 }
336
337 #[test]
338 fn test_cosine_similarity() {
339 let config = SemanticChunkerConfig::default();
340 let embedding_gen = EmbeddingGenerator::new(384);
341 let chunker = SemanticChunker::new(config, embedding_gen);
342
343 let a = vec![1.0, 0.0, 0.0];
345 let b = vec![1.0, 0.0, 0.0];
346 let sim = chunker.cosine_similarity(&a, &b);
347 assert!((sim - 1.0).abs() < 0.001);
348
349 let a = vec![1.0, 0.0];
351 let b = vec![0.0, 1.0];
352 let sim = chunker.cosine_similarity(&a, &b);
353 assert!(sim.abs() < 0.001);
354
355 let a = vec![1.0, 0.0];
357 let b = vec![-1.0, 0.0];
358 let sim = chunker.cosine_similarity(&a, &b);
359 assert!((sim + 1.0).abs() < 0.001);
360 }
361
362 #[test]
363 fn test_percentile_threshold() {
364 let config = SemanticChunkerConfig {
365 breakpoint_strategy: BreakpointStrategy::Percentile,
366 threshold_amount: 95.0,
367 ..Default::default()
368 };
369 let embedding_gen = EmbeddingGenerator::new(384);
370 let chunker = SemanticChunker::new(config, embedding_gen);
371
372 let diffs = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
373 let threshold = chunker.calculate_percentile_threshold(&diffs);
374
375 assert!(threshold >= 0.9);
377 }
378
379 #[test]
380 fn test_std_threshold() {
381 let config = SemanticChunkerConfig {
382 breakpoint_strategy: BreakpointStrategy::StandardDeviation,
383 threshold_amount: 3.0,
384 ..Default::default()
385 };
386 let embedding_gen = EmbeddingGenerator::new(384);
387 let chunker = SemanticChunker::new(config, embedding_gen);
388
389 let diffs = vec![0.5, 0.5, 0.5, 0.5, 0.5]; let threshold = chunker.calculate_std_threshold(&diffs);
391
392 assert!((threshold - 0.5).abs() < 0.001); }
394
395 #[test]
396 fn test_semantic_chunking_basic() {
397 let config = SemanticChunkerConfig {
398 breakpoint_strategy: BreakpointStrategy::Percentile,
399 threshold_amount: 50.0, min_chunk_size: 1,
401 max_chunk_size: 0,
402 buffer_size: 1,
403 };
404
405 let embedding_gen = EmbeddingGenerator::new(384);
406 let mut chunker = SemanticChunker::new(config, embedding_gen);
407
408 let text = "Alice loves programming. Bob also codes daily. \
409 The weather is sunny. Rain is expected tomorrow.";
410
411 let chunks = chunker.chunk(text).unwrap();
412
413 assert!(!chunks.is_empty());
415
416 for chunk in &chunks {
418 assert!(!chunk.content.is_empty());
419 assert!(chunk.sentence_count > 0);
420 }
421 }
422}