Skip to main content

graphrag_core/text/
semantic_coherence.rs

1//! Semantic Coherence Scoring for Boundary-Aware Chunking
2//!
3//! This module implements semantic coherence analysis using sentence embeddings
4//! to optimize chunk boundaries for maximum semantic unity.
5//!
6//! Key capabilities:
7//! - Cosine similarity calculation between sentence embeddings
8//! - Intra-chunk coherence scoring
9//! - Optimal split-point detection via binary search
10//! - Adaptive threshold based on embedding distances
11//!
12//! ## References
13//!
14//! - BAR-RAG Paper: "Boundary-Aware Retrieval-Augmented Generation"
15//! - Target: +40% semantic coherence improvement
16
17use crate::core::error::{GraphRAGError, Result};
18use crate::embeddings::EmbeddingProvider;
19use serde::{Deserialize, Serialize};
20use std::sync::Arc;
21
22/// Configuration for semantic coherence scoring
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CoherenceConfig {
25    /// Minimum coherence score threshold (0.0-1.0)
26    pub min_coherence_threshold: f32,
27
28    /// Maximum sentences per chunk for coherence analysis
29    pub max_sentences_per_chunk: usize,
30
31    /// Minimum sentences per chunk
32    pub min_sentences_per_chunk: usize,
33
34    /// Window size for local coherence calculation
35    pub coherence_window_size: usize,
36
37    /// Weight for adjacent sentence similarity (vs all pairs)
38    pub adjacency_weight: f32,
39
40    /// Enable adaptive threshold based on content
41    pub adaptive_threshold: bool,
42
43    /// Batch size for embedding generation
44    pub embedding_batch_size: usize,
45}
46
47impl Default for CoherenceConfig {
48    fn default() -> Self {
49        Self {
50            min_coherence_threshold: 0.65,
51            max_sentences_per_chunk: 20,
52            min_sentences_per_chunk: 2,
53            coherence_window_size: 3,
54            adjacency_weight: 0.7,
55            adaptive_threshold: true,
56            embedding_batch_size: 32,
57        }
58    }
59}
60
61/// Represents a candidate chunk with coherence score
62#[derive(Debug, Clone)]
63pub struct ScoredChunk {
64    /// Text content
65    pub text: String,
66
67    /// Start position in original text (byte offset)
68    pub start_pos: usize,
69
70    /// End position in original text (byte offset)
71    pub end_pos: usize,
72
73    /// Coherence score (0.0-1.0, higher = more coherent)
74    pub coherence_score: f32,
75
76    /// Number of sentences in chunk
77    pub sentence_count: usize,
78
79    /// Average embedding similarity
80    pub avg_similarity: f32,
81}
82
83/// Result of split-point optimization
84#[derive(Debug, Clone)]
85pub struct OptimalSplit {
86    /// Split positions (byte offsets)
87    pub split_positions: Vec<usize>,
88
89    /// Resulting chunks with scores
90    pub chunks: Vec<ScoredChunk>,
91
92    /// Overall coherence score
93    pub overall_coherence: f32,
94
95    /// Number of iterations needed
96    pub optimization_iterations: usize,
97}
98
99/// Semantic coherence scorer using sentence embeddings
100pub struct SemanticCoherenceScorer {
101    config: CoherenceConfig,
102    embedding_provider: Arc<dyn EmbeddingProvider>,
103}
104
105impl SemanticCoherenceScorer {
106    /// Create a new semantic coherence scorer
107    pub fn new(config: CoherenceConfig, embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
108        Self {
109            config,
110            embedding_provider,
111        }
112    }
113
114    /// Score the semantic coherence of a text chunk
115    ///
116    /// Returns a score between 0.0 (incoherent) and 1.0 (highly coherent).
117    /// High coherence = high cosine similarity between sentence embeddings.
118    pub async fn score_chunk_coherence(&self, text: &str) -> Result<f32> {
119        // Split into sentences
120        let sentences = self.split_sentences(text);
121
122        if sentences.len() < 2 {
123            // Single sentence = perfect coherence
124            return Ok(1.0);
125        }
126
127        // Limit to max sentences for efficiency
128        let sentences: Vec<&str> = sentences
129            .iter()
130            .take(self.config.max_sentences_per_chunk)
131            .map(|s| s.as_str())
132            .collect();
133
134        // Generate embeddings for all sentences
135        let embeddings = self
136            .embedding_provider
137            .embed_batch(&sentences)
138            .await
139            .map_err(|e| GraphRAGError::Embedding {
140                message: e.to_string(),
141            })?;
142
143        if embeddings.len() != sentences.len() {
144            return Err(GraphRAGError::TextProcessing {
145                message: "Embedding count mismatch".to_string(),
146            });
147        }
148
149        // Calculate coherence score
150        let coherence = self.calculate_coherence(&embeddings);
151
152        Ok(coherence)
153    }
154
155    /// Calculate coherence from sentence embeddings
156    ///
157    /// Uses a combination of:
158    /// 1. Adjacent sentence similarity (weighted higher)
159    /// 2. All-pairs average similarity
160    fn calculate_coherence(&self, embeddings: &[Vec<f32>]) -> f32 {
161        if embeddings.len() < 2 {
162            return 1.0;
163        }
164
165        // Calculate adjacent sentence similarities
166        let mut adjacent_similarities = Vec::new();
167        for i in 0..embeddings.len() - 1 {
168            let sim = self.cosine_similarity(&embeddings[i], &embeddings[i + 1]);
169            adjacent_similarities.push(sim);
170        }
171
172        let adjacent_avg =
173            adjacent_similarities.iter().sum::<f32>() / adjacent_similarities.len() as f32;
174
175        // Calculate window-based similarities
176        let window_avg = if self.config.coherence_window_size > 1 {
177            let mut window_similarities = Vec::new();
178            for i in 0..embeddings.len() {
179                let window_start = i.saturating_sub(self.config.coherence_window_size / 2);
180                let window_end =
181                    (i + self.config.coherence_window_size / 2 + 1).min(embeddings.len());
182
183                for j in window_start..window_end {
184                    if i != j {
185                        let sim = self.cosine_similarity(&embeddings[i], &embeddings[j]);
186                        window_similarities.push(sim);
187                    }
188                }
189            }
190
191            if window_similarities.is_empty() {
192                adjacent_avg
193            } else {
194                window_similarities.iter().sum::<f32>() / window_similarities.len() as f32
195            }
196        } else {
197            adjacent_avg
198        };
199
200        // Weighted combination
201        let coherence = self.config.adjacency_weight * adjacent_avg
202            + (1.0 - self.config.adjacency_weight) * window_avg;
203
204        coherence.clamp(0.0, 1.0)
205    }
206
207    /// Find optimal split points in text to maximize chunk coherence
208    ///
209    /// Uses a greedy algorithm:
210    /// 1. Start with no splits
211    /// 2. Try all candidate split points
212    /// 3. Pick split that maximizes average chunk coherence
213    /// 4. Repeat until coherence stops improving
214    pub async fn find_optimal_split(
215        &self,
216        text: &str,
217        candidate_boundaries: &[usize],
218    ) -> Result<OptimalSplit> {
219        if candidate_boundaries.is_empty() {
220            // No boundaries = single chunk
221            let score = self.score_chunk_coherence(text).await?;
222            return Ok(OptimalSplit {
223                split_positions: vec![],
224                chunks: vec![ScoredChunk {
225                    text: text.to_string(),
226                    start_pos: 0,
227                    end_pos: text.len(),
228                    coherence_score: score,
229                    sentence_count: self.split_sentences(text).len(),
230                    avg_similarity: score,
231                }],
232                overall_coherence: score,
233                optimization_iterations: 1,
234            });
235        }
236
237        // Greedy split optimization
238        let mut current_splits: Vec<usize> = vec![];
239        let mut iterations = 0;
240        let max_iterations = 100;
241
242        loop {
243            iterations += 1;
244            if iterations > max_iterations {
245                break;
246            }
247
248            // Generate candidate chunks with current splits
249            let current_chunks = self.create_chunks(text, &current_splits).await?;
250            let current_score = current_chunks
251                .iter()
252                .map(|c| c.coherence_score)
253                .sum::<f32>()
254                / current_chunks.len() as f32;
255
256            // Try adding each candidate boundary
257            let mut best_new_split: Option<usize> = None;
258            let mut best_score = current_score;
259
260            for &boundary in candidate_boundaries {
261                if current_splits.contains(&boundary) {
262                    continue;
263                }
264
265                // Try this split
266                let mut test_splits = current_splits.clone();
267                test_splits.push(boundary);
268                test_splits.sort_unstable();
269
270                let test_chunks = self.create_chunks(text, &test_splits).await?;
271                let test_score = test_chunks.iter().map(|c| c.coherence_score).sum::<f32>()
272                    / test_chunks.len() as f32;
273
274                if test_score > best_score {
275                    best_score = test_score;
276                    best_new_split = Some(boundary);
277                }
278            }
279
280            // If no improvement, stop
281            if best_new_split.is_none() {
282                break;
283            }
284
285            // Add best split
286            current_splits.push(best_new_split.unwrap());
287            current_splits.sort_unstable();
288
289            // Check minimum chunk size constraint
290            if !self.validate_splits(text, &current_splits) {
291                current_splits.pop();
292                break;
293            }
294        }
295
296        // Generate final chunks
297        let final_chunks = self.create_chunks(text, &current_splits).await?;
298        let overall_coherence =
299            final_chunks.iter().map(|c| c.coherence_score).sum::<f32>() / final_chunks.len() as f32;
300
301        Ok(OptimalSplit {
302            split_positions: current_splits,
303            chunks: final_chunks,
304            overall_coherence,
305            optimization_iterations: iterations,
306        })
307    }
308
309    /// Create scored chunks from text and split positions
310    async fn create_chunks(&self, text: &str, splits: &[usize]) -> Result<Vec<ScoredChunk>> {
311        let mut chunks = Vec::new();
312        let mut boundaries = vec![0];
313        boundaries.extend_from_slice(splits);
314        boundaries.push(text.len());
315
316        for i in 0..boundaries.len() - 1 {
317            let start = boundaries[i];
318            let end = boundaries[i + 1];
319            let chunk_text = &text[start..end];
320
321            let coherence = self.score_chunk_coherence(chunk_text).await?;
322            let sentences = self.split_sentences(chunk_text);
323
324            chunks.push(ScoredChunk {
325                text: chunk_text.to_string(),
326                start_pos: start,
327                end_pos: end,
328                coherence_score: coherence,
329                sentence_count: sentences.len(),
330                avg_similarity: coherence,
331            });
332        }
333
334        Ok(chunks)
335    }
336
337    /// Validate that splits create chunks meeting minimum size requirements
338    fn validate_splits(&self, text: &str, splits: &[usize]) -> bool {
339        let mut boundaries = vec![0];
340        boundaries.extend_from_slice(splits);
341        boundaries.push(text.len());
342
343        for i in 0..boundaries.len() - 1 {
344            let start = boundaries[i];
345            let end = boundaries[i + 1];
346            let chunk_text = &text[start..end];
347            let sentences = self.split_sentences(chunk_text);
348
349            if sentences.len() < self.config.min_sentences_per_chunk {
350                return false;
351            }
352        }
353
354        true
355    }
356
357    /// Calculate cosine similarity between two embedding vectors
358    pub fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
359        if a.len() != b.len() || a.is_empty() {
360            return 0.0;
361        }
362
363        let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
364        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
365        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
366
367        if norm_a == 0.0 || norm_b == 0.0 {
368            return 0.0;
369        }
370
371        (dot_product / (norm_a * norm_b)).clamp(-1.0, 1.0)
372    }
373
374    /// Split text into sentences (simple implementation)
375    ///
376    /// This is a basic sentence splitter. For production, consider using
377    /// a dedicated NLP library or more sophisticated tokenization.
378    fn split_sentences(&self, text: &str) -> Vec<String> {
379        let mut sentences = Vec::new();
380        let mut current_sentence = String::new();
381        let mut chars = text.chars().peekable();
382
383        while let Some(ch) = chars.next() {
384            current_sentence.push(ch);
385
386            // Check for sentence endings
387            if matches!(ch, '.' | '!' | '?') {
388                // Look ahead for whitespace
389                if let Some(&next_ch) = chars.peek() {
390                    if next_ch.is_whitespace() || next_ch == '\n' {
391                        let trimmed = current_sentence.trim();
392                        if !trimmed.is_empty() && trimmed.len() > 3 {
393                            sentences.push(trimmed.to_string());
394                            current_sentence.clear();
395                        }
396                    }
397                } else {
398                    // End of text
399                    let trimmed = current_sentence.trim();
400                    if !trimmed.is_empty() {
401                        sentences.push(trimmed.to_string());
402                        current_sentence.clear();
403                    }
404                }
405            }
406        }
407
408        // Add remaining text as final sentence
409        let trimmed = current_sentence.trim();
410        if !trimmed.is_empty() && trimmed.len() > 3 {
411            sentences.push(trimmed.to_string());
412        }
413
414        sentences
415    }
416
417    /// Calculate adaptive threshold based on content characteristics
418    pub fn calculate_adaptive_threshold(&self, text: &str) -> f32 {
419        if !self.config.adaptive_threshold {
420            return self.config.min_coherence_threshold;
421        }
422
423        let sentences = self.split_sentences(text);
424        let sentence_count = sentences.len();
425
426        // Adjust threshold based on document characteristics
427        let base_threshold = self.config.min_coherence_threshold;
428
429        // Longer documents = slightly more tolerant
430        let length_factor = (sentence_count as f32 / 50.0).min(1.0);
431        let adjusted = base_threshold - (length_factor * 0.05);
432
433        adjusted.clamp(0.5, 0.9)
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::*;
440    use crate::embeddings::EmbeddingProvider;
441    use async_trait::async_trait;
442    use std::sync::Arc;
443
444    /// Mock embedding provider for testing
445    struct MockEmbeddingProvider {
446        dimension: usize,
447    }
448
449    impl MockEmbeddingProvider {
450        fn new(dimension: usize) -> Self {
451            Self { dimension }
452        }
453    }
454
455    #[async_trait]
456    impl EmbeddingProvider for MockEmbeddingProvider {
457        async fn initialize(&mut self) -> Result<()> {
458            Ok(())
459        }
460
461        async fn embed(&self, text: &str) -> Result<Vec<f32>> {
462            // Generate deterministic embedding based on text length and content
463            let mut embedding = vec![0.0; self.dimension];
464            let hash = text.len() as f32;
465            for (i, val) in embedding.iter_mut().enumerate() {
466                *val = ((hash + i as f32) * 0.1).sin();
467            }
468            // Normalize
469            let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
470            for val in &mut embedding {
471                *val /= norm;
472            }
473            Ok(embedding)
474        }
475
476        async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
477            let mut results = Vec::new();
478            for text in texts {
479                results.push(self.embed(text).await?);
480            }
481            Ok(results)
482        }
483
484        fn dimensions(&self) -> usize {
485            self.dimension
486        }
487
488        fn is_available(&self) -> bool {
489            true
490        }
491
492        fn provider_name(&self) -> &str {
493            "MockProvider"
494        }
495    }
496
497    #[tokio::test]
498    async fn test_cosine_similarity() {
499        let config = CoherenceConfig::default();
500        let provider = Arc::new(MockEmbeddingProvider::new(384));
501        let scorer = SemanticCoherenceScorer::new(config, provider);
502
503        // Identical vectors = 1.0
504        let v1 = vec![1.0, 0.0, 0.0];
505        let v2 = vec![1.0, 0.0, 0.0];
506        let sim = scorer.cosine_similarity(&v1, &v2);
507        assert!((sim - 1.0).abs() < 0.001);
508
509        // Orthogonal vectors = 0.0
510        let v3 = vec![1.0, 0.0, 0.0];
511        let v4 = vec![0.0, 1.0, 0.0];
512        let sim = scorer.cosine_similarity(&v3, &v4);
513        assert!(sim.abs() < 0.001);
514
515        // Opposite vectors = -1.0
516        let v5 = vec![1.0, 0.0, 0.0];
517        let v6 = vec![-1.0, 0.0, 0.0];
518        let sim = scorer.cosine_similarity(&v5, &v6);
519        assert!((sim - (-1.0)).abs() < 0.001);
520    }
521
522    #[tokio::test]
523    async fn test_sentence_splitting() {
524        let config = CoherenceConfig::default();
525        let provider = Arc::new(MockEmbeddingProvider::new(384));
526        let scorer = SemanticCoherenceScorer::new(config, provider);
527
528        let text = "This is sentence one. This is sentence two! Is this sentence three?";
529        let sentences = scorer.split_sentences(text);
530
531        assert_eq!(sentences.len(), 3);
532        assert!(sentences[0].contains("sentence one"));
533        assert!(sentences[1].contains("sentence two"));
534        assert!(sentences[2].contains("sentence three"));
535    }
536
537    #[tokio::test]
538    async fn test_score_chunk_coherence() {
539        let config = CoherenceConfig::default();
540        let provider = Arc::new(MockEmbeddingProvider::new(384));
541        let scorer = SemanticCoherenceScorer::new(config, provider);
542
543        let text = "This is a test. This is another test. Testing continues here.";
544        let score = scorer.score_chunk_coherence(text).await.unwrap();
545
546        // Should return a valid score between 0 and 1
547        assert!(score >= 0.0 && score <= 1.0);
548    }
549
550    #[tokio::test]
551    async fn test_single_sentence_coherence() {
552        let config = CoherenceConfig::default();
553        let provider = Arc::new(MockEmbeddingProvider::new(384));
554        let scorer = SemanticCoherenceScorer::new(config, provider);
555
556        let text = "This is a single sentence.";
557        let score = scorer.score_chunk_coherence(text).await.unwrap();
558
559        // Single sentence = perfect coherence
560        assert_eq!(score, 1.0);
561    }
562
563    #[tokio::test]
564    async fn test_find_optimal_split_no_boundaries() {
565        let config = CoherenceConfig::default();
566        let provider = Arc::new(MockEmbeddingProvider::new(384));
567        let scorer = SemanticCoherenceScorer::new(config, provider);
568
569        let text = "First sentence. Second sentence. Third sentence.";
570        let result = scorer.find_optimal_split(text, &[]).await.unwrap();
571
572        // No boundaries = single chunk
573        assert_eq!(result.chunks.len(), 1);
574        assert_eq!(result.split_positions.len(), 0);
575    }
576
577    #[tokio::test]
578    async fn test_create_chunks() {
579        let config = CoherenceConfig::default();
580        let provider = Arc::new(MockEmbeddingProvider::new(384));
581        let scorer = SemanticCoherenceScorer::new(config, provider);
582
583        let text = "First part. Second part. Third part.";
584        let splits = vec![12, 25]; // Split after "First part." and "Second part."
585
586        let chunks = scorer.create_chunks(text, &splits).await.unwrap();
587
588        assert_eq!(chunks.len(), 3);
589        assert!(chunks[0].text.contains("First"));
590        assert!(chunks[1].text.contains("Second"));
591        assert!(chunks[2].text.contains("Third"));
592    }
593
594    #[tokio::test]
595    async fn test_validate_splits() {
596        let config = CoherenceConfig {
597            min_sentences_per_chunk: 2,
598            ..Default::default()
599        };
600        let provider = Arc::new(MockEmbeddingProvider::new(384));
601        let scorer = SemanticCoherenceScorer::new(config, provider);
602
603        let text = "Sentence one. Sentence two. Sentence three. Sentence four. Sentence five.";
604
605        // Valid splits (each chunk has 2+ sentences)
606        let splits = vec![26]; // After "Sentence two."
607        assert!(scorer.validate_splits(text, &splits));
608
609        // Invalid splits (would create chunk with 1 sentence)
610        let splits = vec![14]; // After "Sentence one." only
611        assert!(!scorer.validate_splits(text, &splits));
612    }
613
614    #[tokio::test]
615    async fn test_adaptive_threshold() {
616        let config = CoherenceConfig {
617            adaptive_threshold: true,
618            ..Default::default()
619        };
620        let provider = Arc::new(MockEmbeddingProvider::new(384));
621        let scorer = SemanticCoherenceScorer::new(config, provider);
622
623        // Short text
624        let short_text = "One. Two. Three.";
625        let threshold_short = scorer.calculate_adaptive_threshold(short_text);
626
627        // Long text
628        let long_text = (0..100)
629            .map(|i| format!("Sentence {}.", i))
630            .collect::<Vec<_>>()
631            .join(" ");
632        let threshold_long = scorer.calculate_adaptive_threshold(&long_text);
633
634        // Longer text should have slightly lower threshold (more tolerant)
635        assert!(threshold_long <= threshold_short);
636        assert!(threshold_short >= 0.5 && threshold_short <= 0.9);
637        assert!(threshold_long >= 0.5 && threshold_long <= 0.9);
638    }
639
640    #[tokio::test]
641    async fn test_coherence_calculation() {
642        let config = CoherenceConfig::default();
643        let provider = Arc::new(MockEmbeddingProvider::new(384));
644        let scorer = SemanticCoherenceScorer::new(config, provider);
645
646        // Similar embeddings (high coherence)
647        let emb1 = vec![1.0, 0.1, 0.1];
648        let emb2 = vec![0.9, 0.15, 0.15];
649        let emb3 = vec![0.95, 0.12, 0.12];
650        let embeddings = vec![emb1, emb2, emb3];
651
652        let coherence = scorer.calculate_coherence(&embeddings);
653        assert!(coherence > 0.5); // Should be high
654
655        // Dissimilar embeddings (low coherence)
656        let emb1 = vec![1.0, 0.0, 0.0];
657        let emb2 = vec![0.0, 1.0, 0.0];
658        let emb3 = vec![0.0, 0.0, 1.0];
659        let embeddings = vec![emb1, emb2, emb3];
660
661        let coherence = scorer.calculate_coherence(&embeddings);
662        assert!(coherence < 0.5); // Should be low
663    }
664}