Skip to main content

graphrag_core/text/
chunking_strategies.rs

1//! Trait-based chunking strategy implementations
2//!
3//! This module provides concrete implementations of the ChunkingStrategy trait
4//! that wrap existing chunking logic while maintaining a clean, minimal interface.
5
6use crate::{
7    core::{ChunkId, ChunkingStrategy, DocumentId, TextChunk},
8    text::{HierarchicalChunker, SemanticChunker},
9};
10
11use std::sync::atomic::{AtomicU64, Ordering};
12
13/// Global counter for generating unique chunk IDs
14static CHUNK_COUNTER: AtomicU64 = AtomicU64::new(0);
15
16/// Hierarchical chunking strategy wrapper
17///
18/// Wraps the existing HierarchicalChunker to implement ChunkingStrategy trait.
19/// This strategy respects semantic boundaries (paragraphs, sentences, words).
20pub struct HierarchicalChunkingStrategy {
21    inner: HierarchicalChunker,
22    chunk_size: usize,
23    overlap: usize,
24    document_id: DocumentId,
25}
26
27impl HierarchicalChunkingStrategy {
28    /// Create a new hierarchical chunking strategy
29    pub fn new(chunk_size: usize, overlap: usize, document_id: DocumentId) -> Self {
30        Self {
31            inner: HierarchicalChunker::new().with_min_size(50),
32            chunk_size,
33            overlap,
34            document_id,
35        }
36    }
37
38    /// Set minimum chunk size
39    pub fn with_min_size(mut self, min_size: usize) -> Self {
40        self.inner = self.inner.with_min_size(min_size);
41        self
42    }
43}
44
45impl ChunkingStrategy for HierarchicalChunkingStrategy {
46    fn chunk(&self, text: &str) -> Vec<TextChunk> {
47        let chunks_text = self.inner.chunk_text(text, self.chunk_size, self.overlap);
48        let mut chunks = Vec::new();
49        let mut current_pos = 0;
50
51        for chunk_content in chunks_text {
52            if !chunk_content.trim().is_empty() {
53                let chunk_id = ChunkId::new(format!(
54                    "{}_{}",
55                    self.document_id,
56                    CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
57                ));
58                let chunk_start = current_pos;
59                let chunk_end = chunk_start + chunk_content.len();
60
61                let chunk = TextChunk::new(
62                    chunk_id,
63                    self.document_id.clone(),
64                    chunk_content.clone(),
65                    chunk_start,
66                    chunk_end,
67                );
68                chunks.push(chunk);
69                current_pos = chunk_end;
70            } else {
71                current_pos += chunk_content.len();
72            }
73        }
74
75        chunks
76    }
77}
78
79/// Semantic chunking strategy wrapper
80///
81/// Wraps the existing SemanticChunker to implement ChunkingStrategy trait.
82/// This strategy uses embedding similarity to determine natural breakpoints.
83pub struct SemanticChunkingStrategy {
84    _inner: SemanticChunker,
85    document_id: DocumentId,
86}
87
88impl SemanticChunkingStrategy {
89    /// Create a new semantic chunking strategy
90    pub fn new(chunker: SemanticChunker, document_id: DocumentId) -> Self {
91        Self {
92            _inner: chunker,
93            document_id,
94        }
95    }
96}
97
98impl ChunkingStrategy for SemanticChunkingStrategy {
99    fn chunk(&self, text: &str) -> Vec<TextChunk> {
100        // Note: This is a simplified implementation
101        // In a real scenario, you would need to handle the async nature of semantic chunking
102        // or use a synchronous embedding generator
103
104        // For now, fall back to a simple sentence-based approach
105        let sentences: Vec<&str> = text
106            .split(&['.', '!', '?'][..])
107            .filter(|s| !s.trim().is_empty())
108            .collect();
109
110        let mut chunks = Vec::new();
111        let mut current_pos = 0;
112
113        // Group sentences into chunks of reasonable size
114        let chunk_size = 5; // sentences per chunk
115        for chunk_sentences in sentences.chunks(chunk_size) {
116            let chunk_content = chunk_sentences.join(". ") + ".";
117            let chunk_id = ChunkId::new(format!(
118                "{}_{}",
119                self.document_id,
120                CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
121            ));
122            let chunk_start = current_pos;
123            let chunk_end = chunk_start + chunk_content.len();
124
125            let chunk = TextChunk::new(
126                chunk_id,
127                self.document_id.clone(),
128                chunk_content,
129                chunk_start,
130                chunk_end,
131            );
132            chunks.push(chunk);
133            current_pos = chunk_end;
134        }
135
136        chunks
137    }
138}
139
140/// Rust code chunking strategy using tree-sitter
141///
142/// Parses Rust code using tree-sitter and creates chunks at function/method boundaries.
143/// This ensures that code chunks are syntactically complete and meaningful.
144#[cfg(feature = "code-chunking")]
145pub struct RustCodeChunkingStrategy {
146    min_chunk_size: usize,
147    document_id: DocumentId,
148}
149
150#[cfg(feature = "code-chunking")]
151impl RustCodeChunkingStrategy {
152    /// Create a new Rust code chunking strategy
153    pub fn new(min_chunk_size: usize, document_id: DocumentId) -> Self {
154        Self {
155            min_chunk_size,
156            document_id,
157        }
158    }
159}
160
161#[cfg(feature = "code-chunking")]
162impl ChunkingStrategy for RustCodeChunkingStrategy {
163    fn chunk(&self, text: &str) -> Vec<TextChunk> {
164        use tree_sitter::Parser;
165
166        let mut parser = Parser::new();
167        let language = tree_sitter_rust::language();
168        parser
169            .set_language(&language)
170            .expect("Error loading Rust grammar");
171
172        let tree = parser.parse(text, None).expect("Error parsing Rust code");
173        let root_node = tree.root_node();
174
175        let mut chunks = Vec::new();
176
177        // Extract top-level items: functions, impl blocks, structs, enums, mods
178        self.extract_chunks(&root_node, text, &mut chunks);
179
180        // If no chunks found (e.g., just expressions), create a single chunk
181        if chunks.is_empty() && !text.trim().is_empty() {
182            let chunk_id = ChunkId::new(format!(
183                "{}_{}",
184                self.document_id,
185                CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
186            ));
187            let chunk = TextChunk::new(
188                chunk_id,
189                self.document_id.clone(),
190                text.to_string(),
191                0,
192                text.len(),
193            );
194            chunks.push(chunk);
195        }
196
197        chunks
198    }
199}
200
201#[cfg(feature = "code-chunking")]
202impl RustCodeChunkingStrategy {
203    /// Extract code chunks from AST nodes
204    fn extract_chunks(&self, node: &tree_sitter::Node, source: &str, chunks: &mut Vec<TextChunk>) {
205        match node.kind() {
206            // Top-level items that should become chunks
207            "function_item" | "impl_item" | "struct_item" | "enum_item" | "mod_item"
208            | "trait_item" => {
209                let start_byte = node.start_byte();
210                let end_byte = node.end_byte();
211
212                // Convert byte indices to char indices
213                let start_pos = source.len() - source[start_byte..].len();
214                let end_pos = source.len() - source[end_byte..].len();
215
216                let chunk_content = &source[start_pos..end_pos];
217
218                if chunk_content.len() >= self.min_chunk_size {
219                    let chunk_id = ChunkId::new(format!(
220                        "{}_{}",
221                        self.document_id,
222                        CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
223                    ));
224
225                    let chunk = TextChunk::new(
226                        chunk_id,
227                        self.document_id.clone(),
228                        chunk_content.to_string(),
229                        start_pos,
230                        end_pos,
231                    );
232                    chunks.push(chunk);
233                }
234            },
235
236            // Source file (root) - process children
237            "source_file" => {
238                let mut child = node.child(0);
239                while let Some(current) = child {
240                    self.extract_chunks(&current, source, chunks);
241                    child = current.next_sibling();
242                }
243            },
244
245            // Other nodes - recurse into children
246            _ => {
247                let mut child = node.child(0);
248                while let Some(current) = child {
249                    self.extract_chunks(&current, source, chunks);
250                    child = current.next_sibling();
251                }
252            },
253        }
254    }
255}
256
257/// Boundary-Aware Chunking Strategy (BAR-RAG)
258///
259/// This strategy implements the BAR-RAG (Boundary-Aware Retrieval-Augmented Generation)
260/// approach by:
261/// 1. Detecting semantic boundaries in text (sentences, paragraphs, headings, etc.)
262/// 2. Scoring chunk coherence using sentence embeddings
263/// 3. Finding optimal split points that maximize semantic unity
264///
265/// **Performance Target**: +40% semantic coherence, -60% entity fragmentation
266pub struct BoundaryAwareChunkingStrategy {
267    #[cfg_attr(not(feature = "async"), allow(dead_code))]
268    boundary_detector: crate::text::BoundaryDetector,
269    #[cfg_attr(not(feature = "async"), allow(dead_code))]
270    coherence_scorer: std::sync::Arc<crate::text::SemanticCoherenceScorer>,
271    max_chunk_chars: usize,
272    #[cfg_attr(not(feature = "async"), allow(dead_code))]
273    min_chunk_chars: usize,
274    document_id: DocumentId,
275}
276
277impl BoundaryAwareChunkingStrategy {
278    /// Create a new boundary-aware chunking strategy
279    ///
280    /// # Arguments
281    /// * `boundary_config` - Configuration for boundary detection
282    /// * `coherence_config` - Configuration for coherence scoring
283    /// * `embedding_provider` - Provider for generating sentence embeddings
284    /// * `max_chunk_chars` - Maximum characters per chunk
285    /// * `min_chunk_chars` - Minimum characters per chunk
286    /// * `document_id` - Document identifier for chunk IDs
287    pub fn new(
288        boundary_config: crate::text::BoundaryDetectionConfig,
289        coherence_config: crate::text::CoherenceConfig,
290        embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
291        max_chunk_chars: usize,
292        min_chunk_chars: usize,
293        document_id: DocumentId,
294    ) -> Self {
295        Self {
296            boundary_detector: crate::text::BoundaryDetector::with_config(boundary_config),
297            coherence_scorer: std::sync::Arc::new(crate::text::SemanticCoherenceScorer::new(
298                coherence_config,
299                embedding_provider,
300            )),
301            max_chunk_chars,
302            min_chunk_chars,
303            document_id,
304        }
305    }
306
307    /// Create with default configuration
308    pub fn with_defaults(
309        embedding_provider: std::sync::Arc<dyn crate::embeddings::EmbeddingProvider>,
310        document_id: DocumentId,
311    ) -> Self {
312        Self::new(
313            crate::text::BoundaryDetectionConfig::default(),
314            crate::text::CoherenceConfig::default(),
315            embedding_provider,
316            2000, // max chars
317            200,  // min chars
318            document_id,
319        )
320    }
321
322    /// Chunk text asynchronously (helper for async contexts)
323    #[cfg(feature = "async")]
324    async fn chunk_async(&self, text: &str) -> Vec<TextChunk> {
325        // 1. Detect all semantic boundaries
326        let boundaries = self.boundary_detector.detect_boundaries(text);
327
328        // Extract boundary positions suitable for splitting
329        let boundary_positions: Vec<usize> = boundaries
330            .iter()
331            .filter(|b| {
332                // Filter boundaries that are good split points
333                matches!(
334                    b.boundary_type,
335                    crate::text::BoundaryType::Paragraph
336                        | crate::text::BoundaryType::Heading
337                        | crate::text::BoundaryType::CodeBlock
338                )
339            })
340            .map(|b| b.position)
341            .collect();
342
343        // 2. Find optimal splits using coherence scoring
344        let optimal_result = self
345            .coherence_scorer
346            .find_optimal_split(text, &boundary_positions)
347            .await;
348
349        let chunks = match optimal_result {
350            Ok(result) => {
351                // Use optimally scored chunks
352                self.create_text_chunks_from_scored(&result.chunks)
353            },
354            Err(_) => {
355                // Fallback: use boundary positions directly
356                self.create_text_chunks_from_boundaries(text, &boundary_positions)
357            },
358        };
359
360        // 3. Enforce size constraints
361        self.enforce_size_constraints(chunks)
362    }
363
364    /// Create TextChunk objects from scored chunks
365    #[cfg(feature = "async")]
366    fn create_text_chunks_from_scored(
367        &self,
368        scored_chunks: &[crate::text::ScoredChunk],
369    ) -> Vec<TextChunk> {
370        scored_chunks
371            .iter()
372            .enumerate()
373            .map(|(i, sc)| {
374                let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
375                let mut chunk = TextChunk::new(
376                    chunk_id,
377                    self.document_id.clone(),
378                    sc.text.clone(),
379                    sc.start_pos,
380                    sc.end_pos,
381                );
382
383                // Add coherence score to metadata
384                chunk.metadata.custom.insert(
385                    "coherence_score".to_string(),
386                    sc.coherence_score.to_string(),
387                );
388                chunk
389                    .metadata
390                    .custom
391                    .insert("sentence_count".to_string(), sc.sentence_count.to_string());
392
393                chunk
394            })
395            .collect()
396    }
397
398    /// Create TextChunk objects from boundary positions (fallback)
399    #[cfg(feature = "async")]
400    fn create_text_chunks_from_boundaries(
401        &self,
402        text: &str,
403        boundaries: &[usize],
404    ) -> Vec<TextChunk> {
405        let mut chunks = Vec::new();
406        let mut prev_pos = 0;
407
408        for (i, &pos) in boundaries.iter().enumerate() {
409            if pos > prev_pos {
410                let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, i));
411                let chunk = TextChunk::new(
412                    chunk_id,
413                    self.document_id.clone(),
414                    text[prev_pos..pos].to_string(),
415                    prev_pos,
416                    pos,
417                );
418                chunks.push(chunk);
419                prev_pos = pos;
420            }
421        }
422
423        // Add final chunk
424        if prev_pos < text.len() {
425            let chunk_id = ChunkId::new(format!("{}_{}", self.document_id, chunks.len()));
426            let chunk = TextChunk::new(
427                chunk_id,
428                self.document_id.clone(),
429                text[prev_pos..].to_string(),
430                prev_pos,
431                text.len(),
432            );
433            chunks.push(chunk);
434        }
435
436        chunks
437    }
438
439    /// Enforce size constraints on chunks
440    #[cfg(feature = "async")]
441    fn enforce_size_constraints(&self, mut chunks: Vec<TextChunk>) -> Vec<TextChunk> {
442        let mut result = Vec::new();
443
444        for chunk in chunks.drain(..) {
445            let chunk_len = chunk.content.len();
446
447            if chunk_len > self.max_chunk_chars {
448                // Split large chunks at sentence boundaries
449                result.extend(self.split_large_chunk(chunk));
450            } else if chunk_len < self.min_chunk_chars && !result.is_empty() {
451                // Merge small chunks with previous
452                if let Some(mut prev_chunk) = result.pop() {
453                    prev_chunk.content.push(' ');
454                    prev_chunk.content.push_str(&chunk.content);
455                    prev_chunk.end_offset = chunk.end_offset;
456                    result.push(prev_chunk);
457                } else {
458                    result.push(chunk);
459                }
460            } else {
461                result.push(chunk);
462            }
463        }
464
465        result
466    }
467
468    /// Split a large chunk at sentence boundaries
469    #[cfg(feature = "async")]
470    fn split_large_chunk(&self, chunk: TextChunk) -> Vec<TextChunk> {
471        // Simple split at sentence boundaries
472        let sentences: Vec<&str> = chunk
473            .content
474            .split(&['.', '!', '?'][..])
475            .filter(|s| !s.trim().is_empty())
476            .collect();
477
478        let mut sub_chunks = Vec::new();
479        let mut current_text = String::new();
480        let mut current_start = chunk.start_offset;
481
482        for sentence in sentences {
483            if current_text.len() + sentence.len() > self.max_chunk_chars
484                && !current_text.is_empty()
485            {
486                // Create chunk
487                let chunk_id = ChunkId::new(format!(
488                    "{}_{}",
489                    self.document_id,
490                    CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
491                ));
492                let end = current_start + current_text.len();
493                sub_chunks.push(TextChunk::new(
494                    chunk_id,
495                    self.document_id.clone(),
496                    current_text.clone(),
497                    current_start,
498                    end,
499                ));
500
501                current_start = end;
502                current_text.clear();
503            }
504
505            current_text.push_str(sentence);
506            current_text.push('.');
507        }
508
509        // Add remaining text
510        if !current_text.is_empty() {
511            let chunk_id = ChunkId::new(format!(
512                "{}_{}",
513                self.document_id,
514                CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
515            ));
516            sub_chunks.push(TextChunk::new(
517                chunk_id,
518                self.document_id.clone(),
519                current_text,
520                current_start,
521                chunk.end_offset,
522            ));
523        }
524
525        sub_chunks
526    }
527}
528
529impl ChunkingStrategy for BoundaryAwareChunkingStrategy {
530    #[cfg(feature = "async")]
531    fn chunk(&self, text: &str) -> Vec<TextChunk> {
532        // Create a Tokio runtime to execute async code synchronously
533        // This allows us to use async coherence scoring in a sync context
534        let runtime = tokio::runtime::Runtime::new().expect("Failed to create Tokio runtime");
535
536        runtime.block_on(self.chunk_async(text))
537    }
538
539    #[cfg(not(feature = "async"))]
540    fn chunk(&self, text: &str) -> Vec<TextChunk> {
541        // Sync fallback: simple sentence-based chunking without coherence scoring
542        let sentences: Vec<&str> = text
543            .split(['.', '!', '?'])
544            .map(|s| s.trim())
545            .filter(|s| !s.is_empty())
546            .collect();
547
548        let mut chunks = Vec::new();
549        let mut current = String::new();
550        let mut start_offset = 0;
551
552        for sentence in &sentences {
553            if current.len() + sentence.len() > self.max_chunk_chars && !current.is_empty() {
554                let chunk_id = ChunkId::new(format!(
555                    "{}_{}",
556                    self.document_id,
557                    CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
558                ));
559                let end = start_offset + current.len();
560                chunks.push(TextChunk::new(
561                    chunk_id,
562                    self.document_id.clone(),
563                    current.clone(),
564                    start_offset,
565                    end,
566                ));
567                start_offset = end;
568                current.clear();
569            }
570            if !current.is_empty() {
571                current.push(' ');
572            }
573            current.push_str(sentence);
574        }
575
576        if !current.is_empty() {
577            let chunk_id = ChunkId::new(format!(
578                "{}_{}",
579                self.document_id,
580                CHUNK_COUNTER.fetch_add(1, Ordering::SeqCst)
581            ));
582            let end = start_offset + current.len();
583            chunks.push(TextChunk::new(
584                chunk_id,
585                self.document_id.clone(),
586                current,
587                start_offset,
588                end,
589            ));
590        }
591
592        chunks
593    }
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    #[test]
601    fn test_hierarchical_chunking_strategy() {
602        let document_id = DocumentId::new("test_doc".to_string());
603        let strategy = HierarchicalChunkingStrategy::new(100, 20, document_id);
604
605        let text = "This is paragraph one.\n\nThis is paragraph two with more content to test chunking behavior.";
606        let chunks = strategy.chunk(text);
607
608        assert!(!chunks.is_empty());
609        for chunk in &chunks {
610            assert!(!chunk.content.is_empty());
611            assert!(chunk.start_offset < chunk.end_offset);
612        }
613    }
614
615    #[test]
616    fn test_semantic_chunking_strategy() {
617        let _document_id = DocumentId::new("test_doc".to_string());
618        // Note: In a real test, you would create a proper SemanticChunker
619        // For now, we'll use a mock approach
620        let _config = crate::text::semantic_chunking::SemanticChunkerConfig::default();
621        // We can't easily create a mock embedding generator here, so skip the test
622        // let embedding_gen = crate::vector::EmbeddingGenerator::mock();
623        // let chunker = SemanticChunker::new(config, embedding_gen);
624        // let strategy = SemanticChunkingStrategy::new(chunker, document_id);
625        //
626        // let text = "First sentence. Second sentence. Third sentence. Fourth sentence. Fifth sentence. Sixth sentence.";
627        // let chunks = strategy.chunk(text);
628        //
629        // assert!(!chunks.is_empty());
630        // for chunk in &chunks {
631        //     assert!(!chunk.content.is_empty());
632        // }
633    }
634
635    #[test]
636    #[cfg(feature = "code-chunking")]
637    fn test_rust_code_chunking_strategy() {
638        let document_id = DocumentId::new("rust_code".to_string());
639        let strategy = RustCodeChunkingStrategy::new(10, document_id);
640
641        let rust_code = r#"
642fn main() {
643    println!("Hello, world!");
644}
645
646struct Point {
647    x: f64,
648    y: f64,
649}
650
651impl Point {
652    fn new(x: f64, y: f64) -> Self {
653        Point { x, y }
654    }
655}
656"#;
657
658        let chunks = strategy.chunk(rust_code);
659
660        assert!(!chunks.is_empty());
661        // Should find at least main function and struct/impl blocks
662        assert!(chunks.len() >= 2);
663
664        for chunk in &chunks {
665            assert!(!chunk.content.is_empty());
666            assert!(chunk.start_offset < chunk.end_offset);
667        }
668    }
669}