1use crate::core::error::{GraphRAGError, Result};
18use crate::embeddings::EmbeddingProvider;
19use serde::{Deserialize, Serialize};
20use std::sync::Arc;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CoherenceConfig {
25 pub min_coherence_threshold: f32,
27
28 pub max_sentences_per_chunk: usize,
30
31 pub min_sentences_per_chunk: usize,
33
34 pub coherence_window_size: usize,
36
37 pub adjacency_weight: f32,
39
40 pub adaptive_threshold: bool,
42
43 pub embedding_batch_size: usize,
45}
46
47impl Default for CoherenceConfig {
48 fn default() -> Self {
49 Self {
50 min_coherence_threshold: 0.65,
51 max_sentences_per_chunk: 20,
52 min_sentences_per_chunk: 2,
53 coherence_window_size: 3,
54 adjacency_weight: 0.7,
55 adaptive_threshold: true,
56 embedding_batch_size: 32,
57 }
58 }
59}
60
61#[derive(Debug, Clone)]
63pub struct ScoredChunk {
64 pub text: String,
66
67 pub start_pos: usize,
69
70 pub end_pos: usize,
72
73 pub coherence_score: f32,
75
76 pub sentence_count: usize,
78
79 pub avg_similarity: f32,
81}
82
83#[derive(Debug, Clone)]
85pub struct OptimalSplit {
86 pub split_positions: Vec<usize>,
88
89 pub chunks: Vec<ScoredChunk>,
91
92 pub overall_coherence: f32,
94
95 pub optimization_iterations: usize,
97}
98
99pub struct SemanticCoherenceScorer {
101 config: CoherenceConfig,
102 embedding_provider: Arc<dyn EmbeddingProvider>,
103}
104
105impl SemanticCoherenceScorer {
106 pub fn new(config: CoherenceConfig, embedding_provider: Arc<dyn EmbeddingProvider>) -> Self {
108 Self {
109 config,
110 embedding_provider,
111 }
112 }
113
114 pub async fn score_chunk_coherence(&self, text: &str) -> Result<f32> {
119 let sentences = self.split_sentences(text);
121
122 if sentences.len() < 2 {
123 return Ok(1.0);
125 }
126
127 let sentences: Vec<&str> = sentences
129 .iter()
130 .take(self.config.max_sentences_per_chunk)
131 .map(|s| s.as_str())
132 .collect();
133
134 let embeddings = self
136 .embedding_provider
137 .embed_batch(&sentences)
138 .await
139 .map_err(|e| GraphRAGError::Embedding {
140 message: e.to_string(),
141 })?;
142
143 if embeddings.len() != sentences.len() {
144 return Err(GraphRAGError::TextProcessing {
145 message: "Embedding count mismatch".to_string(),
146 });
147 }
148
149 let coherence = self.calculate_coherence(&embeddings);
151
152 Ok(coherence)
153 }
154
155 fn calculate_coherence(&self, embeddings: &[Vec<f32>]) -> f32 {
161 if embeddings.len() < 2 {
162 return 1.0;
163 }
164
165 let mut adjacent_similarities = Vec::new();
167 for i in 0..embeddings.len() - 1 {
168 let sim = self.cosine_similarity(&embeddings[i], &embeddings[i + 1]);
169 adjacent_similarities.push(sim);
170 }
171
172 let adjacent_avg =
173 adjacent_similarities.iter().sum::<f32>() / adjacent_similarities.len() as f32;
174
175 let window_avg = if self.config.coherence_window_size > 1 {
177 let mut window_similarities = Vec::new();
178 for i in 0..embeddings.len() {
179 let window_start = i.saturating_sub(self.config.coherence_window_size / 2);
180 let window_end =
181 (i + self.config.coherence_window_size / 2 + 1).min(embeddings.len());
182
183 for j in window_start..window_end {
184 if i != j {
185 let sim = self.cosine_similarity(&embeddings[i], &embeddings[j]);
186 window_similarities.push(sim);
187 }
188 }
189 }
190
191 if window_similarities.is_empty() {
192 adjacent_avg
193 } else {
194 window_similarities.iter().sum::<f32>() / window_similarities.len() as f32
195 }
196 } else {
197 adjacent_avg
198 };
199
200 let coherence = self.config.adjacency_weight * adjacent_avg
202 + (1.0 - self.config.adjacency_weight) * window_avg;
203
204 coherence.clamp(0.0, 1.0)
205 }
206
207 pub async fn find_optimal_split(
215 &self,
216 text: &str,
217 candidate_boundaries: &[usize],
218 ) -> Result<OptimalSplit> {
219 if candidate_boundaries.is_empty() {
220 let score = self.score_chunk_coherence(text).await?;
222 return Ok(OptimalSplit {
223 split_positions: vec![],
224 chunks: vec![ScoredChunk {
225 text: text.to_string(),
226 start_pos: 0,
227 end_pos: text.len(),
228 coherence_score: score,
229 sentence_count: self.split_sentences(text).len(),
230 avg_similarity: score,
231 }],
232 overall_coherence: score,
233 optimization_iterations: 1,
234 });
235 }
236
237 let mut current_splits: Vec<usize> = vec![];
239 let mut iterations = 0;
240 let max_iterations = 100;
241
242 loop {
243 iterations += 1;
244 if iterations > max_iterations {
245 break;
246 }
247
248 let current_chunks = self.create_chunks(text, ¤t_splits).await?;
250 let current_score = current_chunks
251 .iter()
252 .map(|c| c.coherence_score)
253 .sum::<f32>()
254 / current_chunks.len() as f32;
255
256 let mut best_new_split: Option<usize> = None;
258 let mut best_score = current_score;
259
260 for &boundary in candidate_boundaries {
261 if current_splits.contains(&boundary) {
262 continue;
263 }
264
265 let mut test_splits = current_splits.clone();
267 test_splits.push(boundary);
268 test_splits.sort_unstable();
269
270 let test_chunks = self.create_chunks(text, &test_splits).await?;
271 let test_score = test_chunks.iter().map(|c| c.coherence_score).sum::<f32>()
272 / test_chunks.len() as f32;
273
274 if test_score > best_score {
275 best_score = test_score;
276 best_new_split = Some(boundary);
277 }
278 }
279
280 if best_new_split.is_none() {
282 break;
283 }
284
285 current_splits.push(best_new_split.unwrap());
287 current_splits.sort_unstable();
288
289 if !self.validate_splits(text, ¤t_splits) {
291 current_splits.pop();
292 break;
293 }
294 }
295
296 let final_chunks = self.create_chunks(text, ¤t_splits).await?;
298 let overall_coherence =
299 final_chunks.iter().map(|c| c.coherence_score).sum::<f32>() / final_chunks.len() as f32;
300
301 Ok(OptimalSplit {
302 split_positions: current_splits,
303 chunks: final_chunks,
304 overall_coherence,
305 optimization_iterations: iterations,
306 })
307 }
308
309 async fn create_chunks(&self, text: &str, splits: &[usize]) -> Result<Vec<ScoredChunk>> {
311 let mut chunks = Vec::new();
312 let mut boundaries = vec![0];
313 boundaries.extend_from_slice(splits);
314 boundaries.push(text.len());
315
316 for i in 0..boundaries.len() - 1 {
317 let start = boundaries[i];
318 let end = boundaries[i + 1];
319 let chunk_text = &text[start..end];
320
321 let coherence = self.score_chunk_coherence(chunk_text).await?;
322 let sentences = self.split_sentences(chunk_text);
323
324 chunks.push(ScoredChunk {
325 text: chunk_text.to_string(),
326 start_pos: start,
327 end_pos: end,
328 coherence_score: coherence,
329 sentence_count: sentences.len(),
330 avg_similarity: coherence,
331 });
332 }
333
334 Ok(chunks)
335 }
336
337 fn validate_splits(&self, text: &str, splits: &[usize]) -> bool {
339 let mut boundaries = vec![0];
340 boundaries.extend_from_slice(splits);
341 boundaries.push(text.len());
342
343 for i in 0..boundaries.len() - 1 {
344 let start = boundaries[i];
345 let end = boundaries[i + 1];
346 let chunk_text = &text[start..end];
347 let sentences = self.split_sentences(chunk_text);
348
349 if sentences.len() < self.config.min_sentences_per_chunk {
350 return false;
351 }
352 }
353
354 true
355 }
356
357 pub fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
359 if a.len() != b.len() || a.is_empty() {
360 return 0.0;
361 }
362
363 let dot_product: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
364 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
365 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
366
367 if norm_a == 0.0 || norm_b == 0.0 {
368 return 0.0;
369 }
370
371 (dot_product / (norm_a * norm_b)).clamp(-1.0, 1.0)
372 }
373
374 fn split_sentences(&self, text: &str) -> Vec<String> {
379 let mut sentences = Vec::new();
380 let mut current_sentence = String::new();
381 let mut chars = text.chars().peekable();
382
383 while let Some(ch) = chars.next() {
384 current_sentence.push(ch);
385
386 if matches!(ch, '.' | '!' | '?') {
388 if let Some(&next_ch) = chars.peek() {
390 if next_ch.is_whitespace() || next_ch == '\n' {
391 let trimmed = current_sentence.trim();
392 if !trimmed.is_empty() && trimmed.len() > 3 {
393 sentences.push(trimmed.to_string());
394 current_sentence.clear();
395 }
396 }
397 } else {
398 let trimmed = current_sentence.trim();
400 if !trimmed.is_empty() {
401 sentences.push(trimmed.to_string());
402 current_sentence.clear();
403 }
404 }
405 }
406 }
407
408 let trimmed = current_sentence.trim();
410 if !trimmed.is_empty() && trimmed.len() > 3 {
411 sentences.push(trimmed.to_string());
412 }
413
414 sentences
415 }
416
417 pub fn calculate_adaptive_threshold(&self, text: &str) -> f32 {
419 if !self.config.adaptive_threshold {
420 return self.config.min_coherence_threshold;
421 }
422
423 let sentences = self.split_sentences(text);
424 let sentence_count = sentences.len();
425
426 let base_threshold = self.config.min_coherence_threshold;
428
429 let length_factor = (sentence_count as f32 / 50.0).min(1.0);
431 let adjusted = base_threshold - (length_factor * 0.05);
432
433 adjusted.clamp(0.5, 0.9)
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::*;
440 use crate::embeddings::EmbeddingProvider;
441 use async_trait::async_trait;
442 use std::sync::Arc;
443
444 struct MockEmbeddingProvider {
446 dimension: usize,
447 }
448
449 impl MockEmbeddingProvider {
450 fn new(dimension: usize) -> Self {
451 Self { dimension }
452 }
453 }
454
455 #[async_trait]
456 impl EmbeddingProvider for MockEmbeddingProvider {
457 async fn initialize(&mut self) -> Result<()> {
458 Ok(())
459 }
460
461 async fn embed(&self, text: &str) -> Result<Vec<f32>> {
462 let mut embedding = vec![0.0; self.dimension];
464 let hash = text.len() as f32;
465 for (i, val) in embedding.iter_mut().enumerate() {
466 *val = ((hash + i as f32) * 0.1).sin();
467 }
468 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
470 for val in &mut embedding {
471 *val /= norm;
472 }
473 Ok(embedding)
474 }
475
476 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
477 let mut results = Vec::new();
478 for text in texts {
479 results.push(self.embed(text).await?);
480 }
481 Ok(results)
482 }
483
484 fn dimensions(&self) -> usize {
485 self.dimension
486 }
487
488 fn is_available(&self) -> bool {
489 true
490 }
491
492 fn provider_name(&self) -> &str {
493 "MockProvider"
494 }
495 }
496
497 #[tokio::test]
498 async fn test_cosine_similarity() {
499 let config = CoherenceConfig::default();
500 let provider = Arc::new(MockEmbeddingProvider::new(384));
501 let scorer = SemanticCoherenceScorer::new(config, provider);
502
503 let v1 = vec![1.0, 0.0, 0.0];
505 let v2 = vec![1.0, 0.0, 0.0];
506 let sim = scorer.cosine_similarity(&v1, &v2);
507 assert!((sim - 1.0).abs() < 0.001);
508
509 let v3 = vec![1.0, 0.0, 0.0];
511 let v4 = vec![0.0, 1.0, 0.0];
512 let sim = scorer.cosine_similarity(&v3, &v4);
513 assert!(sim.abs() < 0.001);
514
515 let v5 = vec![1.0, 0.0, 0.0];
517 let v6 = vec![-1.0, 0.0, 0.0];
518 let sim = scorer.cosine_similarity(&v5, &v6);
519 assert!((sim - (-1.0)).abs() < 0.001);
520 }
521
522 #[tokio::test]
523 async fn test_sentence_splitting() {
524 let config = CoherenceConfig::default();
525 let provider = Arc::new(MockEmbeddingProvider::new(384));
526 let scorer = SemanticCoherenceScorer::new(config, provider);
527
528 let text = "This is sentence one. This is sentence two! Is this sentence three?";
529 let sentences = scorer.split_sentences(text);
530
531 assert_eq!(sentences.len(), 3);
532 assert!(sentences[0].contains("sentence one"));
533 assert!(sentences[1].contains("sentence two"));
534 assert!(sentences[2].contains("sentence three"));
535 }
536
537 #[tokio::test]
538 async fn test_score_chunk_coherence() {
539 let config = CoherenceConfig::default();
540 let provider = Arc::new(MockEmbeddingProvider::new(384));
541 let scorer = SemanticCoherenceScorer::new(config, provider);
542
543 let text = "This is a test. This is another test. Testing continues here.";
544 let score = scorer.score_chunk_coherence(text).await.unwrap();
545
546 assert!(score >= 0.0 && score <= 1.0);
548 }
549
550 #[tokio::test]
551 async fn test_single_sentence_coherence() {
552 let config = CoherenceConfig::default();
553 let provider = Arc::new(MockEmbeddingProvider::new(384));
554 let scorer = SemanticCoherenceScorer::new(config, provider);
555
556 let text = "This is a single sentence.";
557 let score = scorer.score_chunk_coherence(text).await.unwrap();
558
559 assert_eq!(score, 1.0);
561 }
562
563 #[tokio::test]
564 async fn test_find_optimal_split_no_boundaries() {
565 let config = CoherenceConfig::default();
566 let provider = Arc::new(MockEmbeddingProvider::new(384));
567 let scorer = SemanticCoherenceScorer::new(config, provider);
568
569 let text = "First sentence. Second sentence. Third sentence.";
570 let result = scorer.find_optimal_split(text, &[]).await.unwrap();
571
572 assert_eq!(result.chunks.len(), 1);
574 assert_eq!(result.split_positions.len(), 0);
575 }
576
577 #[tokio::test]
578 async fn test_create_chunks() {
579 let config = CoherenceConfig::default();
580 let provider = Arc::new(MockEmbeddingProvider::new(384));
581 let scorer = SemanticCoherenceScorer::new(config, provider);
582
583 let text = "First part. Second part. Third part.";
584 let splits = vec![12, 25]; let chunks = scorer.create_chunks(text, &splits).await.unwrap();
587
588 assert_eq!(chunks.len(), 3);
589 assert!(chunks[0].text.contains("First"));
590 assert!(chunks[1].text.contains("Second"));
591 assert!(chunks[2].text.contains("Third"));
592 }
593
594 #[tokio::test]
595 async fn test_validate_splits() {
596 let config = CoherenceConfig {
597 min_sentences_per_chunk: 2,
598 ..Default::default()
599 };
600 let provider = Arc::new(MockEmbeddingProvider::new(384));
601 let scorer = SemanticCoherenceScorer::new(config, provider);
602
603 let text = "Sentence one. Sentence two. Sentence three. Sentence four. Sentence five.";
604
605 let splits = vec![26]; assert!(scorer.validate_splits(text, &splits));
608
609 let splits = vec![14]; assert!(!scorer.validate_splits(text, &splits));
612 }
613
614 #[tokio::test]
615 async fn test_adaptive_threshold() {
616 let config = CoherenceConfig {
617 adaptive_threshold: true,
618 ..Default::default()
619 };
620 let provider = Arc::new(MockEmbeddingProvider::new(384));
621 let scorer = SemanticCoherenceScorer::new(config, provider);
622
623 let short_text = "One. Two. Three.";
625 let threshold_short = scorer.calculate_adaptive_threshold(short_text);
626
627 let long_text = (0..100)
629 .map(|i| format!("Sentence {}.", i))
630 .collect::<Vec<_>>()
631 .join(" ");
632 let threshold_long = scorer.calculate_adaptive_threshold(&long_text);
633
634 assert!(threshold_long <= threshold_short);
636 assert!(threshold_short >= 0.5 && threshold_short <= 0.9);
637 assert!(threshold_long >= 0.5 && threshold_long <= 0.9);
638 }
639
640 #[tokio::test]
641 async fn test_coherence_calculation() {
642 let config = CoherenceConfig::default();
643 let provider = Arc::new(MockEmbeddingProvider::new(384));
644 let scorer = SemanticCoherenceScorer::new(config, provider);
645
646 let emb1 = vec![1.0, 0.1, 0.1];
648 let emb2 = vec![0.9, 0.15, 0.15];
649 let emb3 = vec![0.95, 0.12, 0.12];
650 let embeddings = vec![emb1, emb2, emb3];
651
652 let coherence = scorer.calculate_coherence(&embeddings);
653 assert!(coherence > 0.5); let emb1 = vec![1.0, 0.0, 0.0];
657 let emb2 = vec![0.0, 1.0, 0.0];
658 let emb3 = vec![0.0, 0.0, 1.0];
659 let embeddings = vec![emb1, emb2, emb3];
660
661 let coherence = scorer.calculate_coherence(&embeddings);
662 assert!(coherence < 0.5); }
664}