Skip to main content

graphrag_core/text/
chunking_strategies.rs

1//! Trait-based chunking strategy implementations
2//!
3//! This module provides concrete implementations of the ChunkingStrategy trait
4//! that wrap existing chunking logic while maintaining a clean, minimal interface.
5
6use crate::{
7    core::{ChunkId, ChunkingStrategy, DocumentId, TextChunk},
8    text::{HierarchicalChunker, SemanticChunker},
9};
10
11use std::sync::atomic::{AtomicU64, Ordering};
12
13/// Global counter for generating unique chunk IDs
14static CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16/// Hierarchical chunking strategy wrapper
17///
18/// Wraps the existing HierarchicalChunker to implement ChunkingStrategy trait.
19/// This strategy respects semantic boundaries (paragraphs, sentences, words).
20pub struct HierarchicalChunkingStrategy {
21    inner: HierarchicalChunker,
22    chunk_size: usize,
23    overlap: usize,
24    document_id: DocumentId,
25}
26
27impl HierarchicalChunkingStrategy {
28    /// Create a new hierarchical chunking strategy
29    pub fn new(chunk_size: usize, overlap: usize, document_id: DocumentId) -> Self {
30        Self {
31            inner: HierarchicalChunker::new().with_min_size(50),
32            chunk_size,
33            overlap,
34            document_id,
35        }
36    }
37
38    /// Set minimum chunk size
39    pub fn with_min_size(mut self, min_size: usize) -> Self {
40        self.inner = self.inner.with_min_size(min_size);
41        self
42    }
43}
44
45impl ChunkingStrategy for HierarchicalChunkingStrategy {
46    fn chunk(&self, text: &str) -> Vec<TextChunk> {
47        let chunks_text = self.inner.chunk_text(text, self.chunk_size, self.overlap);
48        let mut chunks = Vec::new();
49        let mut current_pos = 0;
50
51        for chunk_content in chunks_text {
52            if !chunk_content.trim().is_empty() {
53                let chunk_id = ChunkId::new(format!(
54                    "{}_{}",
55                    self.document_id,
56                    CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
57                ));
58                let chunk_start = current_pos;
59                let chunk_end = chunk_start + chunk_content.len();
60
61                let chunk = TextChunk::new(
62                    chunk_id,
63                    self.document_id.clone(),
64                    chunk_content.clone(),
65                    chunk_start,
66                    chunk_end,
67                );
68                chunks.push(chunk);
69                current_pos = chunk_end;
70            } else {
71                current_pos += chunk_content.len();
72            }
73        }
74
75        chunks
76    }
77}
78
79/// Semantic chunking strategy wrapper
80///
81/// Wraps the existing SemanticChunker to implement ChunkingStrategy trait.
82/// This strategy uses embedding similarity to determine natural breakpoints.
83pub struct SemanticChunkingStrategy {
84    _inner: SemanticChunker,
85    document_id: DocumentId,
86}
87
88impl SemanticChunkingStrategy {
89    /// Create a new semantic chunking strategy
90    pub fn new(chunker: SemanticChunker, document_id: DocumentId) -> Self {
91        Self {
92            _inner: chunker,
93            document_id,
94        }
95    }
96}
97
98impl ChunkingStrategy for SemanticChunkingStrategy {
99    fn chunk(&self, text: &str) -> Vec<TextChunk> {
100        // Note: This is a simplified implementation
101        // In a real scenario, you would need to handle the async nature of semantic chunking
102        // or use a synchronous embedding generator
103
104        // For now, fall back to a simple sentence-based approach
105        let sentences: Vec<&str> = text
106            .split(&['.', '!', '?'][..])
107            .filter(|s| !s.trim().is_empty())
108            .collect();
109
110        let mut chunks = Vec::new();
111        let mut current_pos = 0;
112
113        // Group sentences into chunks of reasonable size
114        let chunk_size = 5; // sentences per chunk
115        for chunk_sentences in sentences.chunks(chunk_size) {
116            let chunk_content = chunk_sentences.join(". ") + ".";
117            let chunk_id = ChunkId::new(format!(
118                "{}_{}",
119                self.document_id,
120                CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
121            ));
122            let chunk_start = current_pos;
123            let chunk_end = chunk_start + chunk_content.len();
124
125            let chunk = TextChunk::new(
126                chunk_id,
127                self.document_id.clone(),
128                chunk_content,
129                chunk_start,
130                chunk_end,
131            );
132            chunks.push(chunk);
133            current_pos = chunk_end;
134        }
135
136        chunks
137    }
138}
139
140/// Rust code chunking strategy using tree-sitter
141///
142/// Parses Rust code using tree-sitter and creates chunks at function/method boundaries.
143/// This ensures that code chunks are syntactically complete and meaningful.
144#[cfg(feature = "code-chunking")]
145pub struct RustCodeChunkingStrategy {
146    min_chunk_size: usize,
147    document_id: DocumentId,
148}
149
150#[cfg(feature = "code-chunking")]
151impl RustCodeChunkingStrategy {
152    /// Create a new Rust code chunking strategy
153    pub fn new(min_chunk_size: usize, document_id: DocumentId) -> Self {
154        Self {
155            min_chunk_size,
156            document_id,
157        }
158    }
159}
160
161#[cfg(feature = "code-chunking")]
162impl ChunkingStrategy for RustCodeChunkingStrategy {
163    fn chunk(&self, text: &str) -> Vec<TextChunk> {
164        use tree_sitter::Parser;
165
166        let mut parser = Parser::new();
167        let language = tree_sitter_rust::language();
168        parser
169            .set_language(&language)
170            .expect("Error loading Rust grammar");
171
172        let tree = parser.parse(text, None).expect("Error parsing Rust code");
173        let root_node = tree.root_node();
174
175        let mut chunks = Vec::new();
176
177        // Extract top-level items: functions, impl blocks, structs, enums, mods
178        self.extract_chunks(&root_node, text, &mut chunks);
179
180        // If no chunks found (e.g., just expressions), create a single chunk
181        if chunks.is_empty() && !text.trim().is_empty() {
182            let chunk_id = ChunkId::new(format!(
183                "{}_{}",
184                self.document_id,
185                CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
186            ));
187            let chunk = TextChunk::new(
188                chunk_id,
189                self.document_id.clone(),
190                text.to_string(),
191                0,
192                text.len(),
193            );
194            chunks.push(chunk);
195        }
196
197        chunks
198    }
199}
200
201#[cfg(feature = "code-chunking")]
202impl RustCodeChunkingStrategy {
203    /// Extract code chunks from AST nodes
204    fn extract_chunks(&self, node: &tree_sitter::Node, source: &str, chunks: &mut Vec<TextChunk>) {
205        match node.kind() {
206            // Top-level items that should become chunks
207            "function_item" | "impl_item" | "struct_item" | "enum_item" | "mod_item"
208            | "trait_item" => {
209                let start_byte = node.start_byte();
210                let end_byte = node.end_byte();
211
212                // Convert byte indices to char indices
213                let start_pos = source.len() - source[start_byte..].len();
214                let end_pos = source.len() - source[end_byte..].len();
215
216                let chunk_content = &source[start_pos..end_pos];
217
218                if chunk_content.len() >= self.min_chunk_size {
219                    let chunk_id = ChunkId::new(format!(
220                        "{}_{}",
221                        self.document_id,
222                        CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
223                    ));
224
225                    let chunk = TextChunk::new(
226                        chunk_id,
227                        self.document_id.clone(),
228                        chunk_content.to_string(),
229                        start_pos,
230                        end_pos,
231                    );
232                    chunks.push(chunk);
233                }
234            },
235
236            // Source file (root) - process children
237            "source_file" => {
238                let mut child = node.child(0);
239                while let Some(current) = child {
240                    self.extract_chunks(&current, source, chunks);
241                    child = current.next_sibling();
242                }
243            },
244
245            // Other nodes - recurse into children
246            _ => {
247                let mut child = node.child(0);
248                while let Some(current) = child {
249                    self.extract_chunks(&current, source, chunks);
250                    child = current.next_sibling();
251                }
252            },
253        }
254    }
255}
256
257/// Boundary-Aware Chunking Strategy (BAR-RAG)
258///
259/// This strategy implements the BAR-RAG (Boundary-Aware Retrieval-Augmented Generation)
260/// approach by:
261/// 1. Detecting semantic boundaries in text (sentences, paragraphs, headings, etc.)
262/// 2. Scoring chunk coherence using sentence embeddings
263/// 3. Finding optimal split points that maximize semantic unity
264///
265/// **Performance Target**: +40% semantic coherence, -60% entity fragmentation
266pub struct BoundaryAwareChunkingStrategy {
267    boundary_detector: crate::text::BoundaryDetector,
268    coherence_scorer: std::sync::Arc<crate::text::SemanticCoherenceScorer>,
269    max_chunk_chars: usize,
270    min_chunk_chars: usize,
271    document_id: DocumentId,
272}
273
274impl BoundaryAwareChunkingStrategy {
275    /// Create a new boundary-aware chunking strategy
276    ///
277    /// # Arguments
278    /// * `boundary_config` - Configuration for boundary detection
279    /// * `coherence_config` - Configuration for coherence scoring
280    /// * `embedding_provider` - Provider for generating sentence embeddings
281    /// * `max_chunk_chars` - Maximum characters per chunk
282    /// * `min_chunk_chars` - Minimum characters per chunk
283    /// * `document_id` - Document identifier for chunk IDs
284    pub fn new(
285        boundary_config: crate::text::BoundaryDetectionConfig,
286        coherence_config: crate::text::CoherenceConfig,
287        embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
288        max_chunk_chars: usize,
289        min_chunk_chars: usize,
290        document_id: DocumentId,
291    ) -> Self {
292        Self {
293            boundary_detector: crate::text::BoundaryDetector::with_config(boundary_config),
294            coherence_scorer: std::sync::Arc::new(crate::text::SemanticCoherenceScorer::new(
295                coherence_config,
296                embedding_provider,
297            )),
298            max_chunk_chars,
299            min_chunk_chars,
300            document_id,
301        }
302    }
303
304    /// Create with default configuration
305    pub fn with_defaults(
306        embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
307        document_id: DocumentId,
308    ) -> Self {
309        Self::new(
310            crate::text::BoundaryDetectionConfig::default(),
311            crate::text::CoherenceConfig::default(),
312            embedding_provider,
313            2000, // max chars
314            200,  // min chars
315            document_id,
316        )
317    }
318
319    /// Chunk text asynchronously (helper for async contexts)
320    async fn chunk_async(&self, text: &str) -> Vec<TextChunk> {
321        // 1. Detect all semantic boundaries
322        let boundaries = self.boundary_detector.detect_boundaries(text);
323
324        // Extract boundary positions suitable for splitting
325        let boundary_positions: Vec<usize> = boundaries
326            .iter()
327            .filter(|b| {
328                // Filter boundaries that are good split points
329                matches!(
330                    b.boundary_type,
331                    crate::text::BoundaryType::Paragraph
332                        | crate::text::BoundaryType::Heading
333                        | crate::text::BoundaryType::CodeBlock
334                )
335            })
336            .map(|b| b.position)
337            .collect();
338
339        // 2. Find optimal splits using coherence scoring
340        let optimal_result = self
341            .coherence_scorer
342            .find_optimal_split(text, &boundary_positions)
343            .await;
344
345        let chunks = match optimal_result {
346            Ok(result) => {
347                // Use optimally scored chunks
348                self.create_text_chunks_from_scored(&result.chunks)
349            },
350            Err(_) => {
351                // Fallback: use boundary positions directly
352                self.create_text_chunks_from_boundaries(text, &boundary_positions)
353            },
354        };
355
356        // 3. Enforce size constraints
357        self.enforce_size_constraints(chunks)
358    }
359
360    /// Create TextChunk objects from scored chunks
361    fn create_text_chunks_from_scored(
362        &self,
363        scored_chunks: &[crate::text::ScoredChunk],
364    ) -> Vec<TextChunk> {
365        scored_chunks
366            .iter()
367            .enumerate()
368            .map(|(i, sc)| {
369                let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
370                let mut chunk = TextChunk::new(
371                    chunk_id,
372                    self.document_id.clone(),
373                    sc.text.clone(),
374                    sc.start_pos,
375                    sc.end_pos,
376                );
377
378                // Add coherence score to metadata
379                chunk.metadata.custom.insert(
380                    "coherence_score".to_string(),
381                    sc.coherence_score.to_string(),
382                );
383                chunk
384                    .metadata
385                    .custom
386                    .insert("sentence_count".to_string(), sc.sentence_count.to_string());
387
388                chunk
389            })
390            .collect()
391    }
392
393    /// Create TextChunk objects from boundary positions (fallback)
394    fn create_text_chunks_from_boundaries(
395        &self,
396        text: &str,
397        boundaries: &[usize],
398    ) -> Vec<TextChunk> {
399        let mut chunks = Vec::new();
400        let mut prev_pos = 0;
401
402        for (i, &pos) in boundaries.iter().enumerate() {
403            if pos > prev_pos {
404                let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
405                let chunk = TextChunk::new(
406                    chunk_id,
407                    self.document_id.clone(),
408                    text[prev_pos..pos].to_string(),
409                    prev_pos,
410                    pos,
411                );
412                chunks.push(chunk);
413                prev_pos = pos;
414            }
415        }
416
417        // Add final chunk
418        if prev_pos < text.len() {
419            let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, chunks.len()));
420            let chunk = TextChunk::new(
421                chunk_id,
422                self.document_id.clone(),
423                text[prev_pos..].to_string(),
424                prev_pos,
425                text.len(),
426            );
427            chunks.push(chunk);
428        }
429
430        chunks
431    }
432
433    /// Enforce size constraints on chunks
434    fn enforce_size_constraints(&self, mut chunks: Vec<TextChunk>) -> Vec<TextChunk> {
435        let mut result = Vec::new();
436
437        for chunk in chunks.drain(..) {
438            let chunk_len = chunk.content.len();
439
440            if chunk_len > self.max_chunk_chars {
441                // Split large chunks at sentence boundaries
442                result.extend(self.split_large_chunk(chunk));
443            } else if chunk_len < self.min_chunk_chars && !result.is_empty() {
444                // Merge small chunks with previous
445                if let Some(mut prev_chunk) = result.pop() {
446                    prev_chunk.content.push(' ');
447                    prev_chunk.content.push_str(&chunk.content);
448                    prev_chunk.end_offset = chunk.end_offset;
449                    result.push(prev_chunk);
450                } else {
451                    result.push(chunk);
452                }
453            } else {
454                result.push(chunk);
455            }
456        }
457
458        result
459    }
460
461    /// Split a large chunk at sentence boundaries
462    fn split_large_chunk(&self, chunk: TextChunk) -> Vec<TextChunk> {
463        // Simple split at sentence boundaries
464        let sentences: Vec<&str> = chunk
465            .content
466            .split(&['.', '!', '?'][..])
467            .filter(|s| !s.trim().is_empty())
468            .collect();
469
470        let mut sub_chunks = Vec::new();
471        let mut current_text = String::new();
472        let mut current_start = chunk.start_offset;
473
474        for sentence in sentences {
475            if current_text.len() + sentence.len() > self.max_chunk_chars
476                && !current_text.is_empty()
477            {
478                // Create chunk
479                let chunk_id = ChunkId::new(format!(
480                    "{}_{}",
481                    self.document_id,
482                    CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
483                ));
484                let end = current_start + current_text.len();
485                sub_chunks.push(TextChunk::new(
486                    chunk_id,
487                    self.document_id.clone(),
488                    current_text.clone(),
489                    current_start,
490                    end,
491                ));
492
493                current_start = end;
494                current_text.clear();
495            }
496
497            current_text.push_str(sentence);
498            current_text.push('.');
499        }
500
501        // Add remaining text
502        if !current_text.is_empty() {
503            let chunk_id = ChunkId::new(format!(
504                "{}_{}",
505                self.document_id,
506                CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
507            ));
508            sub_chunks.push(TextChunk::new(
509                chunk_id,
510                self.document_id.clone(),
511                current_text,
512                current_start,
513                chunk.end_offset,
514            ));
515        }
516
517        sub_chunks
518    }
519}
520
521impl ChunkingStrategy for BoundaryAwareChunkingStrategy {
522    fn chunk(&self, text: &str) -> Vec<TextChunk> {
523        // Create a Tokio runtime to execute async code synchronously
524        // This allows us to use async coherence scoring in a sync context
525        let runtime = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
526
527        runtime.block_on(self.chunk_async(text))
528    }
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_hierarchical_chunking_strategy() {
537        let document_id = DocumentId::new("test_doc".to_string());
538        let strategy = HierarchicalChunkingStrategy::new(100, 20, document_id);
539
540        let text = "This is paragraph one.\n\nThis is paragraph two with more content to test chunking behavior.";
541        let chunks = strategy.chunk(text);
542
543        assert!(!chunks.is_empty());
544        for chunk in &chunks {
545            assert!(!chunk.content.is_empty());
546            assert!(chunk.start_offset < chunk.end_offset);
547        }
548    }
549
550    #[test]
551    fn test_semantic_chunking_strategy() {
552        let document_id = DocumentId::new("test_doc".to_string());
553        // Note: In a real test, you would create a proper SemanticChunker
554        // For now, we'll use a mock approach
555        let config = crate::text::semantic_chunking::SemanticChunkerConfig::default();
556        // We can't easily create a mock embedding generator here, so skip the test
557        // let embedding_gen = crate::vector::EmbeddingGenerator::mock();
558        // let chunker = SemanticChunker::new(config, embedding_gen);
559        // let strategy = SemanticChunkingStrategy::new(chunker, document_id);
560        //
561        // let text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence. Sixth sentence.";
562        // let chunks = strategy.chunk(text);
563        //
564        // assert!(!chunks.is_empty());
565        // for chunk in &chunks {
566        //     assert!(!chunk.content.is_empty());
567        // }
568    }
569
570    #[test]
571    #[cfg(feature = "code-chunking")]
572    fn test_rust_code_chunking_strategy() {
573        let document_id = DocumentId::new("rust_code".to_string());
574        let strategy = RustCodeChunkingStrategy::new(10, document_id);
575
576        let rust_code = r#"
577fn main() {
578    println!("Hello, world!");
579}
580
581struct Point {
582    x: f64,
583    y: f64,
584}
585
586impl Point {
587    fn new(x: f64, y: f64) -> Self {
588        Point { x, y }
589    }
590}
591"#;
592
593        let chunks = strategy.chunk(rust_code);
594
595        assert!(!chunks.is_empty());
596        // Should find at least main function and struct/impl blocks
597        assert!(chunks.len() >= 2);
598
599        for chunk in &chunks {
600            assert!(!chunk.content.is_empty());
601            assert!(chunk.start_offset < chunk.end_offset);
602        }
603    }
604}