Skip to main content

graphrag_core/text/
semantic_chunking.rs

1//! Semantic Chunking for RAG
2//!
3//! This module implements semantic chunking that splits text based on
4//! semantic similarity rather than fixed character/token counts.
5//!
6//! Key innovation: Uses sentence embeddings and cosine similarity to
7//! determine natural breakpoints, creating semantically cohesive chunks.
8//!
9//! Reference: LangChain SemanticChunker, Greg Kamradt's 5 Levels of Text Splitting
10
11use crate::core::Result;
12use crate::vector::EmbeddingGenerator;
13
14/// Chunk of semantically similar sentences
15#[derive(Debug, Clone)]
16pub struct SemanticChunk {
17    /// The text content of the chunk
18    pub content: String,
19
20    /// Start sentence index
21    pub start_sentence: usize,
22
23    /// End sentence index (exclusive)
24    pub end_sentence: usize,
25
26    /// Number of sentences in this chunk
27    pub sentence_count: usize,
28}
29
30/// Strategy for determining chunk breakpoints
31#[derive(Debug, Clone, Copy, PartialEq)]
32pub enum BreakpointStrategy {
33    /// Use percentile of similarity differences (e.g., 95th percentile)
34    Percentile,
35
36    /// Use standard deviation of similarity differences (e.g., 3σ)
37    StandardDeviation,
38
39    /// Use absolute threshold (e.g., similarity < 0.5)
40    Absolute,
41}
42
43/// Configuration for semantic chunking
44#[derive(Debug, Clone)]
45pub struct SemanticChunkerConfig {
46    /// Strategy for determining breakpoints
47    pub breakpoint_strategy: BreakpointStrategy,
48
49    /// Threshold amount:
50    /// - Percentile: 0-100 (default: 95.0)
51    /// - StandardDeviation: number of std devs (default: 3.0)
52    /// - Absolute: similarity threshold (default: 0.5)
53    pub threshold_amount: f32,
54
55    /// Minimum chunk size in sentences
56    pub min_chunk_size: usize,
57
58    /// Maximum chunk size in sentences (0 = unlimited)
59    pub max_chunk_size: usize,
60
61    /// Buffer size for comparing sentences (default: 1 = compare consecutive)
62    pub buffer_size: usize,
63}
64
65impl Default for SemanticChunkerConfig {
66    fn default() -> Self {
67        Self {
68            breakpoint_strategy: BreakpointStrategy::Percentile,
69            threshold_amount: 95.0,
70            min_chunk_size: 1,
71            max_chunk_size: 0, // unlimited
72            buffer_size: 1,
73        }
74    }
75}
76
77/// Semantic text chunker that splits based on embedding similarity
78pub struct SemanticChunker {
79    config: SemanticChunkerConfig,
80    embedding_generator: EmbeddingGenerator,
81}
82
83impl SemanticChunker {
84    /// Create a new semantic chunker
85    pub fn new(config: SemanticChunkerConfig, embedding_generator: EmbeddingGenerator) -> Self {
86        Self {
87            config,
88            embedding_generator,
89        }
90    }
91
92    /// Split text into semantic chunks
93    pub fn chunk(&mut self, text: &str) -> Result<Vec<SemanticChunk>> {
94        // 1. Split into sentences
95        let sentences = self.split_sentences(text);
96
97        if sentences.is_empty() {
98            return Ok(Vec::new());
99        }
100
101        if sentences.len() == 1 {
102            return Ok(vec![SemanticChunk {
103                content: text.to_string(),
104                start_sentence: 0,
105                end_sentence: 1,
106                sentence_count: 1,
107            }]);
108        }
109
110        // 2. Generate embeddings for each sentence
111        let embeddings = self.embed_sentences(&sentences)?;
112
113        // 3. Calculate similarity differences between consecutive sentences
114        let similarity_diffs = self.calculate_similarity_differences(&embeddings);
115
116        // 4. Determine breakpoints based on strategy
117        let breakpoints = self.determine_breakpoints(&similarity_diffs)?;
118
119        // 5. Create chunks from sentences using breakpoints
120        let chunks = self.create_chunks(&sentences, &breakpoints);
121
122        Ok(chunks)
123    }
124
125    /// Split text into sentences using simple sentence tokenization
126    fn split_sentences(&self, text: &str) -> Vec<String> {
127        let mut sentences = Vec::new();
128        let mut current_sentence = String::new();
129
130        for line in text.lines() {
131            let line = line.trim();
132            if line.is_empty() {
133                if !current_sentence.is_empty() {
134                    sentences.push(current_sentence.clone());
135                    current_sentence.clear();
136                }
137                continue;
138            }
139
140            // Split on sentence boundaries: . ! ?
141            for part in line.split_inclusive(&['.', '!', '?']) {
142                let part = part.trim();
143                if part.is_empty() {
144                    continue;
145                }
146
147                current_sentence.push_str(part);
148                current_sentence.push(' ');
149
150                // Check if this looks like end of sentence
151                if part.ends_with('.') || part.ends_with('!') || part.ends_with('?') {
152                    sentences.push(current_sentence.trim().to_string());
153                    current_sentence.clear();
154                }
155            }
156        }
157
158        // Add any remaining text
159        if !current_sentence.trim().is_empty() {
160            sentences.push(current_sentence.trim().to_string());
161        }
162
163        sentences
164    }
165
166    /// Generate embeddings for all sentences
167    fn embed_sentences(&mut self, sentences: &[String]) -> Result<Vec<Vec<f32>>> {
168        let mut embeddings = Vec::new();
169
170        for sentence in sentences {
171            let embedding = self.embedding_generator.generate_embedding(sentence);
172            embeddings.push(embedding);
173        }
174
175        Ok(embeddings)
176    }
177
178    /// Calculate cosine similarity differences between consecutive sentences
179    fn calculate_similarity_differences(&self, embeddings: &[Vec<f32>]) -> Vec<f32> {
180        let mut diffs = Vec::new();
181
182        for i in 0..embeddings.len().saturating_sub(self.config.buffer_size) {
183            let sim = self.cosine_similarity(&embeddings[i], &embeddings[i + self.config.buffer_size]);
184
185            // Convert similarity to difference (distance)
186            // Higher distance = more dissimilar = potential breakpoint
187            let distance = 1.0 - sim;
188            diffs.push(distance);
189        }
190
191        diffs
192    }
193
194    /// Calculate cosine similarity between two vectors
195    fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
196        if a.len() != b.len() {
197            return 0.0;
198        }
199
200        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
201        let mag_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
202        let mag_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
203
204        if mag_a == 0.0 || mag_b == 0.0 {
205            return 0.0;
206        }
207
208        dot / (mag_a * mag_b)
209    }
210
211    /// Determine chunk breakpoints based on similarity differences
212    fn determine_breakpoints(&self, diffs: &[f32]) -> Result<Vec<usize>> {
213        if diffs.is_empty() {
214            return Ok(Vec::new());
215        }
216
217        let threshold = match self.config.breakpoint_strategy {
218            BreakpointStrategy::Percentile => self.calculate_percentile_threshold(diffs),
219            BreakpointStrategy::StandardDeviation => self.calculate_std_threshold(diffs),
220            BreakpointStrategy::Absolute => self.config.threshold_amount,
221        };
222
223        // Find indices where difference exceeds threshold
224        let mut breakpoints = Vec::new();
225        for (i, &diff) in diffs.iter().enumerate() {
226            if diff > threshold {
227                // +1 because diff[i] is between sentence[i] and sentence[i+1]
228                breakpoints.push(i + 1);
229            }
230        }
231
232        Ok(breakpoints)
233    }
234
235    /// Calculate threshold based on percentile
236    fn calculate_percentile_threshold(&self, diffs: &[f32]) -> f32 {
237        let mut sorted = diffs.to_vec();
238        sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
239
240        let percentile = self.config.threshold_amount / 100.0;
241        let index = ((sorted.len() as f32 * percentile) as usize).min(sorted.len() - 1);
242
243        sorted[index]
244    }
245
246    /// Calculate threshold based on standard deviation
247    fn calculate_std_threshold(&self, diffs: &[f32]) -> f32 {
248        let mean: f32 = diffs.iter().sum::<f32>() / diffs.len() as f32;
249
250        let variance: f32 = diffs.iter()
251            .map(|&x| (x - mean).powi(2))
252            .sum::<f32>() / diffs.len() as f32;
253
254        let std_dev = variance.sqrt();
255
256        mean + (self.config.threshold_amount * std_dev)
257    }
258
259    /// Create chunks from sentences using breakpoints
260    fn create_chunks(&self, sentences: &[String], breakpoints: &[usize]) -> Vec<SemanticChunk> {
261        let mut chunks = Vec::new();
262        let mut start_idx = 0;
263
264        let mut all_breakpoints = breakpoints.to_vec();
265        all_breakpoints.push(sentences.len()); // Add final breakpoint
266
267        for &end_idx in &all_breakpoints {
268            if end_idx <= start_idx {
269                continue;
270            }
271
272            let sentence_count = end_idx - start_idx;
273
274            // Check size constraints
275            if sentence_count < self.config.min_chunk_size {
276                continue;
277            }
278
279            if self.config.max_chunk_size > 0 && sentence_count > self.config.max_chunk_size {
280                // Split large chunk into smaller ones
281                let mut sub_start = start_idx;
282                while sub_start < end_idx {
283                    let sub_end = (sub_start + self.config.max_chunk_size).min(end_idx);
284                    let content = sentences[sub_start..sub_end].join(" ");
285
286                    chunks.push(SemanticChunk {
287                        content,
288                        start_sentence: sub_start,
289                        end_sentence: sub_end,
290                        sentence_count: sub_end - sub_start,
291                    });
292
293                    sub_start = sub_end;
294                }
295            } else {
296                let content = sentences[start_idx..end_idx].join(" ");
297
298                chunks.push(SemanticChunk {
299                    content,
300                    start_sentence: start_idx,
301                    end_sentence: end_idx,
302                    sentence_count,
303                });
304            }
305
306            start_idx = end_idx;
307        }
308
309        chunks
310    }
311
312    /// Get configuration
313    pub fn config(&self) -> &SemanticChunkerConfig {
314        &self.config
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    #[test]
323    fn test_sentence_splitting() {
324        let config = SemanticChunkerConfig::default();
325        let embedding_gen = EmbeddingGenerator::new(384); // Use simple hash-based for testing
326        let chunker = SemanticChunker::new(config, embedding_gen);
327
328        let text = "This is sentence one. This is sentence two! Is this sentence three?";
329        let sentences = chunker.split_sentences(text);
330
331        assert_eq!(sentences.len(), 3);
332        assert!(sentences[0].contains("sentence one"));
333        assert!(sentences[1].contains("sentence two"));
334        assert!(sentences[2].contains("sentence three"));
335    }
336
337    #[test]
338    fn test_cosine_similarity() {
339        let config = SemanticChunkerConfig::default();
340        let embedding_gen = EmbeddingGenerator::new(384);
341        let chunker = SemanticChunker::new(config, embedding_gen);
342
343        // Identical vectors
344        let a = vec![1.0, 0.0, 0.0];
345        let b = vec![1.0, 0.0, 0.0];
346        let sim = chunker.cosine_similarity(&a, &b);
347        assert!((sim - 1.0).abs() < 0.001);
348
349        // Orthogonal vectors
350        let a = vec![1.0, 0.0];
351        let b = vec![0.0, 1.0];
352        let sim = chunker.cosine_similarity(&a, &b);
353        assert!(sim.abs() < 0.001);
354
355        // Opposite vectors
356        let a = vec![1.0, 0.0];
357        let b = vec![-1.0, 0.0];
358        let sim = chunker.cosine_similarity(&a, &b);
359        assert!((sim + 1.0).abs() < 0.001);
360    }
361
362    #[test]
363    fn test_percentile_threshold() {
364        let config = SemanticChunkerConfig {
365            breakpoint_strategy: BreakpointStrategy::Percentile,
366            threshold_amount: 95.0,
367            ..Default::default()
368        };
369        let embedding_gen = EmbeddingGenerator::new(384);
370        let chunker = SemanticChunker::new(config, embedding_gen);
371
372        let diffs = vec![0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0];
373        let threshold = chunker.calculate_percentile_threshold(&diffs);
374
375        // 95th percentile of 10 values should be around 0.95
376        assert!(threshold >= 0.9);
377    }
378
379    #[test]
380    fn test_std_threshold() {
381        let config = SemanticChunkerConfig {
382            breakpoint_strategy: BreakpointStrategy::StandardDeviation,
383            threshold_amount: 3.0,
384            ..Default::default()
385        };
386        let embedding_gen = EmbeddingGenerator::new(384);
387        let chunker = SemanticChunker::new(config, embedding_gen);
388
389        let diffs = vec![0.5, 0.5, 0.5, 0.5, 0.5]; // All same = zero std dev
390        let threshold = chunker.calculate_std_threshold(&diffs);
391
392        assert!((threshold - 0.5).abs() < 0.001); // Should be mean when std=0
393    }
394
395    #[test]
396    fn test_semantic_chunking_basic() {
397        let config = SemanticChunkerConfig {
398            breakpoint_strategy: BreakpointStrategy::Percentile,
399            threshold_amount: 50.0, // Lower threshold for testing
400            min_chunk_size: 1,
401            max_chunk_size: 0,
402            buffer_size: 1,
403        };
404
405        let embedding_gen = EmbeddingGenerator::new(384);
406        let mut chunker = SemanticChunker::new(config, embedding_gen);
407
408        let text = "Alice loves programming. Bob also codes daily. \
409                    The weather is sunny. Rain is expected tomorrow.";
410
411        let chunks = chunker.chunk(text).unwrap();
412
413        // Should create at least 1 chunk
414        assert!(!chunks.is_empty());
415
416        // Each chunk should have content
417        for chunk in &chunks {
418            assert!(!chunk.content.is_empty());
419            assert!(chunk.sentence_count > 0);
420        }
421    }
422}