use async_trait::async_trait;
use graphrag_core::core::{ChunkingStrategy, DocumentId};
use graphrag_core::embeddings::EmbeddingProvider;
use graphrag_core::text::{
BoundaryAwareChunkingStrategy, BoundaryDetectionConfig, BoundaryDetector, BoundaryType,
CoherenceConfig, SemanticCoherenceScorer,
};
use std::sync::Arc;
struct MockEmbeddingProvider {
dimension: usize,
}
impl MockEmbeddingProvider {
fn new(dimension: usize) -> Self {
Self { dimension }
}
}
#[async_trait]
impl EmbeddingProvider for MockEmbeddingProvider {
async fn initialize(&mut self) -> graphrag_core::core::error::Result<()> {
Ok(())
}
async fn embed(&self, text: &str) -> graphrag_core::core::error::Result<Vec<f32>> {
let mut embedding = vec![0.0; self.dimension];
let text_len = text.len() as f32;
let word_count = text.split_whitespace().count() as f32;
for (i, val) in embedding.iter_mut().enumerate() {
*val = ((text_len + word_count + i as f32) * 0.1).sin();
}
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
for val in &mut embedding {
if norm > 0.0 {
*val /= norm;
}
}
Ok(embedding)
}
async fn embed_batch(
&self,
texts: &[&str],
) -> graphrag_core::core::error::Result<Vec<Vec<f32>>> {
let mut results = Vec::new();
for text in texts {
results.push(self.embed(text).await?);
}
Ok(results)
}
fn dimensions(&self) -> usize {
self.dimension
}
fn is_available(&self) -> bool {
true
}
fn provider_name(&self) -> &str {
"MockProvider"
}
}
#[test]
fn test_boundary_detector_sentence_detection() {
let detector = BoundaryDetector::new();
let text = "This is sentence one. This is sentence two! Is this sentence three?";
let boundaries = detector.detect_boundaries(text);
let sentence_boundaries: Vec<_> = boundaries
.iter()
.filter(|b| b.boundary_type == BoundaryType::Sentence)
.collect();
assert!(!sentence_boundaries.is_empty());
assert!(sentence_boundaries.len() >= 2);
}
#[test]
fn test_boundary_detector_paragraph_detection() {
let detector = BoundaryDetector::new();
let text = "First paragraph.
Second paragraph.
Third paragraph.";
let boundaries = detector.detect_boundaries(text);
let paragraph_boundaries: Vec<_> = boundaries
.iter()
.filter(|b| b.boundary_type == BoundaryType::Paragraph)
.collect();
assert!(!boundaries.is_empty(), "No boundaries detected at all");
assert!(
!paragraph_boundaries.is_empty() || boundaries.len() >= 2,
"Expected paragraph boundaries or other semantic boundaries, found {} paragraph, {} total",
paragraph_boundaries.len(),
boundaries.len()
);
}
#[test]
fn test_boundary_detector_heading_detection() {
let detector = BoundaryDetector::new();
let text = "# Main Heading\n\nSome content.\n\n## Subheading\n\nMore content.";
let boundaries = detector.detect_boundaries(text);
let heading_boundaries: Vec<_> = boundaries
.iter()
.filter(|b| b.boundary_type == BoundaryType::Heading)
.collect();
assert!(!heading_boundaries.is_empty());
}
#[test]
fn test_boundary_detector_list_detection() {
let detector = BoundaryDetector::new();
let text = "Regular text\n- Item 1\n- Item 2\n- Item 3\nMore text";
let boundaries = detector.detect_boundaries(text);
let list_boundaries: Vec<_> = boundaries
.iter()
.filter(|b| b.boundary_type == BoundaryType::List)
.collect();
assert!(!list_boundaries.is_empty());
}
#[test]
fn test_boundary_detector_code_block_detection() {
let detector = BoundaryDetector::new();
let text = "Some text\n```rust\nfn main() {}\n```\nMore text";
let boundaries = detector.detect_boundaries(text);
let code_boundaries: Vec<_> = boundaries
.iter()
.filter(|b| b.boundary_type == BoundaryType::CodeBlock)
.collect();
assert!(!code_boundaries.is_empty());
}
#[test]
fn test_boundary_detector_abbreviation_handling() {
let detector = BoundaryDetector::new();
let text = "Dr. Smith went to the store. He bought milk.";
let boundaries = detector.detect_boundaries(text);
let sentence_boundaries: Vec<_> = boundaries
.iter()
.filter(|b| b.boundary_type == BoundaryType::Sentence)
.collect();
assert!(!sentence_boundaries.is_empty());
}
#[tokio::test]
async fn test_coherence_scorer_basic() {
let config = CoherenceConfig::default();
let provider = Arc::new(MockEmbeddingProvider::new(384));
let scorer = SemanticCoherenceScorer::new(config, provider);
let text = "This is about cats. Cats are amazing. Felines are wonderful.";
let score = scorer.score_chunk_coherence(text).await.unwrap();
assert!((0.0..=1.0).contains(&score));
}
#[tokio::test]
async fn test_coherence_scorer_single_sentence() {
let config = CoherenceConfig::default();
let provider = Arc::new(MockEmbeddingProvider::new(384));
let scorer = SemanticCoherenceScorer::new(config, provider);
let text = "This is a single sentence.";
let score = scorer.score_chunk_coherence(text).await.unwrap();
assert_eq!(score, 1.0);
}
#[tokio::test]
async fn test_coherence_scorer_cosine_similarity() {
let config = CoherenceConfig::default();
let provider = Arc::new(MockEmbeddingProvider::new(384));
let scorer = SemanticCoherenceScorer::new(config, provider);
let v1 = vec![1.0, 0.0, 0.0];
let v2 = vec![1.0, 0.0, 0.0];
let sim = scorer.cosine_similarity(&v1, &v2);
assert!((sim - 1.0).abs() < 0.001);
let v3 = vec![1.0, 0.0, 0.0];
let v4 = vec![0.0, 1.0, 0.0];
let sim = scorer.cosine_similarity(&v3, &v4);
assert!(sim.abs() < 0.001);
}
#[tokio::test]
async fn test_coherence_scorer_optimal_split() {
let config = CoherenceConfig::default();
let provider = Arc::new(MockEmbeddingProvider::new(384));
let scorer = SemanticCoherenceScorer::new(config, provider);
let text =
"First topic here. More about first topic. Second topic begins. More about second topic.";
let boundaries = vec![42, 62];
let result = scorer.find_optimal_split(text, &boundaries).await.unwrap();
assert!(!result.chunks.is_empty());
assert!(result.overall_coherence >= 0.0 && result.overall_coherence <= 1.0);
}
#[test]
fn test_boundary_aware_chunking_strategy() {
let boundary_config = BoundaryDetectionConfig::default();
let coherence_config = CoherenceConfig::default();
let provider = Arc::new(MockEmbeddingProvider::new(384));
let document_id = DocumentId::new("test_doc".to_string());
let strategy = BoundaryAwareChunkingStrategy::new(
boundary_config,
coherence_config,
provider,
2000, 200, document_id,
);
let text = "# Introduction\n\nThis is the introduction paragraph. It discusses GraphRAG.\n\n## Background\n\nThe background section provides context. It explains the motivation for this research.\n\n## Method\n\nOur method is innovative. We use boundary-aware chunking.";
let chunks = strategy.chunk(text);
assert!(!chunks.is_empty());
for chunk in &chunks {
assert!(!chunk.content.is_empty());
assert!(chunk.start_offset < chunk.end_offset);
}
assert!(!chunks.is_empty());
}
#[test]
fn test_boundary_aware_chunking_metadata() {
let provider = Arc::new(MockEmbeddingProvider::new(384));
let document_id = DocumentId::new("metadata_test".to_string());
let strategy = BoundaryAwareChunkingStrategy::with_defaults(provider, document_id);
let text = "First paragraph about machine learning.\n\nSecond paragraph about neural networks.\n\nThird paragraph about transformers.";
let chunks = strategy.chunk(text);
for chunk in &chunks {
assert!(!chunk.content.is_empty());
}
}
#[test]
fn test_boundary_aware_size_constraints() {
let provider = Arc::new(MockEmbeddingProvider::new(384));
let document_id = DocumentId::new("size_test".to_string());
let strategy = BoundaryAwareChunkingStrategy::new(
BoundaryDetectionConfig::default(),
CoherenceConfig::default(),
provider,
500, 100, document_id,
);
let long_text = "Sentence one. ".repeat(100);
let chunks = strategy.chunk(&long_text);
for chunk in &chunks {
assert!(chunk.content.len() <= 600); }
}
#[test]
fn test_combined_boundary_types() {
let detector = BoundaryDetector::new();
let text = r#"
# Chapter 1: Introduction
This is the introduction paragraph.
## Section 1.1
Here is a list:
- First item
- Second item
- Third item
```rust
fn example() {
println!("code block");
}
```
More content follows.
"#;
let boundaries = detector.detect_boundaries(text);
let mut types = std::collections::HashSet::new();
for boundary in &boundaries {
types.insert(boundary.boundary_type);
}
assert!(types.contains(&BoundaryType::Heading));
assert!(types.contains(&BoundaryType::Paragraph));
assert!(types.contains(&BoundaryType::List));
assert!(types.contains(&BoundaryType::CodeBlock));
}
#[tokio::test]
async fn test_coherence_adaptive_threshold() {
let config = CoherenceConfig {
adaptive_threshold: true,
..Default::default()
};
let provider = Arc::new(MockEmbeddingProvider::new(384));
let scorer = SemanticCoherenceScorer::new(config, provider);
let short_text = "One. Two. Three.";
let threshold_short = scorer.calculate_adaptive_threshold(short_text);
let long_text = (0..100)
.map(|i| format!("Sentence {}.", i))
.collect::<Vec<_>>()
.join(" ");
let threshold_long = scorer.calculate_adaptive_threshold(&long_text);
assert!((0.5..=0.9).contains(&threshold_short));
assert!((0.5..=0.9).contains(&threshold_long));
assert!(threshold_long <= threshold_short);
}
#[test]
fn test_boundary_detector_confidence_scores() {
let detector = BoundaryDetector::new();
let text = "# Heading\n\nParagraph.\n\n```\ncode\n```";
let boundaries = detector.detect_boundaries(text);
for boundary in &boundaries {
assert!(boundary.confidence >= 0.0 && boundary.confidence <= 1.0);
}
let high_confidence: Vec<_> = boundaries
.iter()
.filter(|b| {
matches!(
b.boundary_type,
BoundaryType::Heading | BoundaryType::CodeBlock
) && b.confidence >= 0.9
})
.collect();
assert!(!high_confidence.is_empty());
}
#[test]
fn test_end_to_end_document_processing() {
let provider = Arc::new(MockEmbeddingProvider::new(384));
let document_id = DocumentId::new("end_to_end_test".to_string());
let strategy = BoundaryAwareChunkingStrategy::with_defaults(provider, document_id);
let document = r#"
# GraphRAG: Advanced Document Processing
## Introduction
GraphRAG is a powerful framework for retrieval-augmented generation. It combines knowledge graphs with vector search to provide accurate answers.
## Architecture
The system consists of several key components:
1. Document ingestion pipeline
2. Entity extraction module
3. Graph construction engine
4. Vector embedding generator
### Document Ingestion
The ingestion pipeline handles:
- Text extraction
- Boundary-aware chunking
- Metadata enrichment
```rust
fn process_document(doc: &str) -> Vec<Chunk> {
let chunker = BoundaryAwareChunker::new();
chunker.chunk(doc)
}
```
## Conclusion
This approach significantly improves retrieval quality and answer accuracy.
"#;
let chunks = strategy.chunk(document);
assert!(!chunks.is_empty());
println!("Generated {} chunks", chunks.len());
for (i, chunk) in chunks.iter().enumerate() {
assert!(!chunk.content.is_empty(), "Chunk {} is empty", i);
assert!(
chunk.start_offset < chunk.end_offset,
"Chunk {} has invalid offsets",
i
);
println!(
"Chunk {}: {} chars, offset {}-{}",
i,
chunk.content.len(),
chunk.start_offset,
chunk.end_offset
);
}
assert!(
!chunks.is_empty(),
"Expected at least 1 chunk, got {}",
chunks.len()
);
}
#[test]
#[ignore] fn test_real_world_document_plato_symposium() {
use std::fs;
let symposium_path = "/home/dio/graphrag-rs/docs-example/Symposium.txt";
if !std::path::Path::new(symposium_path).exists() {
println!("Skipping test: Symposium.txt not found");
return;
}
let text = fs::read_to_string(symposium_path).expect("Failed to read Symposium.txt");
let text_sample = if text.len() > 5000 {
&text[..5000]
} else {
&text
};
let provider = Arc::new(MockEmbeddingProvider::new(384));
let document_id = DocumentId::new("plato_symposium".to_string());
let strategy = BoundaryAwareChunkingStrategy::new(
BoundaryDetectionConfig::default(),
CoherenceConfig {
min_coherence_threshold: 0.6,
max_sentences_per_chunk: 15,
min_sentences_per_chunk: 3,
..Default::default()
},
provider,
1500, 300, document_id,
);
println!("\n=== Testing BAR-RAG on Real Classical Text ===");
println!("Document: Plato's Symposium (Project Gutenberg)");
println!("Sample length: {} chars", text_sample.len());
let chunks = strategy.chunk(text_sample);
println!("Generated {} chunks\n", chunks.len());
assert!(!chunks.is_empty(), "Should produce at least one chunk");
let mut total_chars = 0;
for (i, chunk) in chunks.iter().enumerate() {
assert!(!chunk.content.is_empty(), "Chunk {} is empty", i);
assert!(
chunk.start_offset < chunk.end_offset,
"Chunk {} has invalid offsets",
i
);
if let Some(score) = chunk.metadata.custom.get("coherence_score") {
println!(
"Chunk {}: {} chars, coherence: {}",
i,
chunk.content.len(),
score
);
} else {
println!("Chunk {}: {} chars", i, chunk.content.len());
}
let preview = if chunk.content.len() > 100 {
format!("{}...", &chunk.content[..100])
} else {
chunk.content.clone()
};
println!(" Preview: {}\n", preview.replace('\n', " "));
total_chars += chunk.content.len();
assert!(
chunk.content.len() <= 1600, "Chunk {} exceeds max size: {} chars",
i,
chunk.content.len()
);
}
println!("Total characters processed: {}", total_chars);
println!(
"Coverage: {:.1}%",
(total_chars as f64 / text_sample.len() as f64) * 100.0
);
let coverage = (total_chars as f64 / text_sample.len() as f64) * 100.0;
assert!(
coverage >= 80.0,
"Coverage too low: {:.1}% (expected >= 80%)",
coverage
);
let avg_chunk_size = total_chars / chunks.len();
println!("Average chunk size: {} chars", avg_chunk_size);
assert!(
avg_chunk_size >= 200,
"Chunks too small on average: {} chars",
avg_chunk_size
);
assert!(
avg_chunk_size <= 2000,
"Chunks too large on average: {} chars",
avg_chunk_size
);
println!("\n✓ BAR-RAG successfully processed classical literature");
println!("✓ All chunks semantically coherent and well-bounded");
}