Skip to main content

graphrag_core/text/
boundary_detection.rs

1//! Semantic Boundary Detection for Boundary-Aware Chunking
2//!
3//! This module implements intelligent detection of semantic boundaries in text,
4//! enabling chunking strategies that respect natural document structure.
5//!
6//! Key capabilities:
7//! - Sentence boundary detection (NLTK-style rules)
8//! - Paragraph detection (newline patterns)
9//! - Heading detection (Markdown, RST, plaintext)
10//! - List boundary detection
11//! - Code block detection
12//!
13//! ## References
14//!
15//! - BAR-RAG Paper: "Boundary-Aware Retrieval-Augmented Generation"
16//! - Target: +40% semantic coherence, -60% entity fragmentation
17
18use regex::Regex;
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21
22/// Configuration for boundary detection
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct BoundaryDetectionConfig {
25    /// Enable sentence boundary detection
26    pub detect_sentences: bool,
27
28    /// Enable paragraph boundary detection
29    pub detect_paragraphs: bool,
30
31    /// Enable heading boundary detection
32    pub detect_headings: bool,
33
34    /// Enable list boundary detection
35    pub detect_lists: bool,
36
37    /// Enable code block boundary detection
38    pub detect_code_blocks: bool,
39
40    /// Minimum sentence length (characters)
41    pub min_sentence_length: usize,
42
43    /// Heading markers (for plaintext detection)
44    pub heading_markers: Vec<String>,
45}
46
47impl Default for BoundaryDetectionConfig {
48    fn default() -> Self {
49        Self {
50            detect_sentences: true,
51            detect_paragraphs: true,
52            detect_headings: true,
53            detect_lists: true,
54            detect_code_blocks: true,
55            min_sentence_length: 10,
56            heading_markers: vec![
57                "Chapter".to_string(),
58                "Section".to_string(),
59                "Introduction".to_string(),
60                "Conclusion".to_string(),
61            ],
62        }
63    }
64}
65
66/// Type of boundary detected
67#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum BoundaryType {
69    /// Sentence boundary (. ! ?)
70    Sentence,
71    /// Paragraph boundary (double newline)
72    Paragraph,
73    /// Heading boundary (markdown #, RST underline)
74    Heading,
75    /// List boundary (bullet points, numbered lists)
76    List,
77    /// Code block boundary (```, indented blocks)
78    CodeBlock,
79}
80
81/// Represents a detected boundary in text
82#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Boundary {
84    /// Position in text (byte offset)
85    pub position: usize,
86
87    /// Type of boundary
88    pub boundary_type: BoundaryType,
89
90    /// Confidence score (0.0-1.0)
91    pub confidence: f32,
92
93    /// Optional context (e.g., heading text)
94    pub context: Option<String>,
95}
96
97/// Boundary detector for semantic text segmentation
98pub struct BoundaryDetector {
99    config: BoundaryDetectionConfig,
100
101    // Cached regex patterns
102    sentence_endings: Regex,
103    markdown_heading: Regex,
104    numbered_list: Regex,
105    bullet_list: Regex,
106    code_block_fence: Regex,
107    rst_heading_underline: Regex,
108}
109
110impl BoundaryDetector {
111    /// Create a new boundary detector with default configuration
112    pub fn new() -> Self {
113        Self::with_config(BoundaryDetectionConfig::default())
114    }
115
116    /// Create a boundary detector with custom configuration
117    pub fn with_config(config: BoundaryDetectionConfig) -> Self {
118        Self {
119            config,
120            // Compile regex patterns once
121            sentence_endings: Regex::new(r"[.!?]+[\s]+").unwrap(),
122            markdown_heading: Regex::new(r"^#{1,6}\s+.+$").unwrap(),
123            numbered_list: Regex::new(r"^\d+[.)]\s+").unwrap(),
124            bullet_list: Regex::new(r"^[\-\*\+]\s+").unwrap(),
125            code_block_fence: Regex::new(r"^```").unwrap(),
126            rst_heading_underline: Regex::new("^[=\\-~^\"]+\\s*$").unwrap(),
127        }
128    }
129
130    /// Detect all semantic boundaries in text
131    pub fn detect_boundaries(&self, text: &str) -> Vec<Boundary> {
132        let mut boundaries = Vec::new();
133
134        if self.config.detect_sentences {
135            boundaries.extend(self.detect_sentence_boundaries(text));
136        }
137
138        if self.config.detect_paragraphs {
139            boundaries.extend(self.detect_paragraph_boundaries(text));
140        }
141
142        if self.config.detect_headings {
143            boundaries.extend(self.detect_heading_boundaries(text));
144        }
145
146        if self.config.detect_lists {
147            boundaries.extend(self.detect_list_boundaries(text));
148        }
149
150        if self.config.detect_code_blocks {
151            boundaries.extend(self.detect_code_block_boundaries(text));
152        }
153
154        // Sort by position and deduplicate
155        boundaries.sort_by_key(|b| b.position);
156        boundaries.dedup_by_key(|b| b.position);
157
158        boundaries
159    }
160
161    /// Detect sentence boundaries using NLTK-style rules
162    fn detect_sentence_boundaries(&self, text: &str) -> Vec<Boundary> {
163        let mut boundaries = Vec::new();
164
165        // Common abbreviations that shouldn't end sentences
166        let abbreviations: HashSet<&str> = [
167            "Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "etc.", "e.g.", "i.e.", "vs.",
168            "cf.", "Jan.", "Feb.", "Mar.", "Apr.", "Jun.", "Jul.", "Aug.", "Sep.", "Oct.", "Nov.",
169            "Dec.",
170        ]
171        .iter()
172        .copied()
173        .collect();
174
175        // Find all potential sentence endings
176        for mat in self.sentence_endings.find_iter(text) {
177            let position = mat.start();
178
179            // Check if this is a false positive (abbreviation)
180            let before_text = &text[..position];
181            let is_abbreviation = abbreviations
182                .iter()
183                .any(|abbr| before_text.ends_with(&abbr[..abbr.len() - 1]));
184
185            if !is_abbreviation {
186                // Check minimum sentence length
187                let sentence_start = boundaries
188                    .last()
189                    .map(|b: &Boundary| b.position)
190                    .unwrap_or(0);
191                let sentence_length = position - sentence_start;
192
193                if sentence_length >= self.config.min_sentence_length {
194                    boundaries.push(Boundary {
195                        position: mat.end(),
196                        boundary_type: BoundaryType::Sentence,
197                        confidence: 0.9,
198                        context: None,
199                    });
200                }
201            }
202        }
203
204        boundaries
205    }
206
207    /// Detect paragraph boundaries (double newlines)
208    fn detect_paragraph_boundaries(&self, text: &str) -> Vec<Boundary> {
209        let mut boundaries = Vec::new();
210
211        // Look for double newlines (paragraph breaks)
212        let paragraph_regex = Regex::new(r"\n\s*\n").unwrap();
213
214        for mat in paragraph_regex.find_iter(text) {
215            boundaries.push(Boundary {
216                position: mat.end(),
217                boundary_type: BoundaryType::Paragraph,
218                confidence: 1.0,
219                context: None,
220            });
221        }
222
223        boundaries
224    }
225
226    /// Detect heading boundaries (Markdown, RST, plaintext)
227    fn detect_heading_boundaries(&self, text: &str) -> Vec<Boundary> {
228        let mut boundaries = Vec::new();
229
230        let lines: Vec<&str> = text.lines().collect();
231        let mut current_pos = 0;
232
233        for (i, line) in lines.iter().enumerate() {
234            let line_start = current_pos;
235            let line_trimmed = line.trim();
236
237            // Markdown headings (# ## ###)
238            if self.markdown_heading.is_match(line) {
239                let heading_text = line_trimmed.trim_start_matches('#').trim();
240                boundaries.push(Boundary {
241                    position: line_start,
242                    boundary_type: BoundaryType::Heading,
243                    confidence: 0.95,
244                    context: Some(heading_text.to_string()),
245                });
246            }
247
248            // RST-style underlined headings
249            if i > 0 && self.rst_heading_underline.is_match(line_trimmed) {
250                let prev_line = lines[i - 1].trim();
251                if !prev_line.is_empty() && line_trimmed.len() >= prev_line.len() {
252                    boundaries.push(Boundary {
253                        position: line_start,
254                        boundary_type: BoundaryType::Heading,
255                        confidence: 0.9,
256                        context: Some(prev_line.to_string()),
257                    });
258                }
259            }
260
261            // Plaintext heading detection (ALL CAPS, or starts with heading marker)
262            if line_trimmed.len() > 3
263                && line_trimmed
264                    .chars()
265                    .all(|c| c.is_uppercase() || c.is_whitespace() || c.is_numeric())
266                && line_trimmed.chars().any(|c| c.is_alphabetic())
267            {
268                boundaries.push(Boundary {
269                    position: line_start,
270                    boundary_type: BoundaryType::Heading,
271                    confidence: 0.7,
272                    context: Some(line_trimmed.to_string()),
273                });
274            }
275
276            // Heading markers (Chapter, Section, etc.)
277            for marker in &self.config.heading_markers {
278                if line_trimmed.starts_with(marker) {
279                    boundaries.push(Boundary {
280                        position: line_start,
281                        boundary_type: BoundaryType::Heading,
282                        confidence: 0.85,
283                        context: Some(line_trimmed.to_string()),
284                    });
285                    break;
286                }
287            }
288
289            current_pos += line.len() + 1; // +1 for newline
290        }
291
292        boundaries
293    }
294
295    /// Detect list boundaries
296    fn detect_list_boundaries(&self, text: &str) -> Vec<Boundary> {
297        let mut boundaries = Vec::new();
298
299        let lines: Vec<&str> = text.lines().collect();
300        let mut current_pos = 0;
301        let mut in_list = false;
302
303        for line in lines {
304            let line_trimmed = line.trim();
305
306            // Check for list item
307            let is_list_item = self.numbered_list.is_match(line_trimmed)
308                || self.bullet_list.is_match(line_trimmed);
309
310            // Transition into list
311            if is_list_item && !in_list {
312                boundaries.push(Boundary {
313                    position: current_pos,
314                    boundary_type: BoundaryType::List,
315                    confidence: 0.9,
316                    context: Some("list_start".to_string()),
317                });
318                in_list = true;
319            }
320
321            // Transition out of list
322            if !is_list_item && in_list && !line_trimmed.is_empty() {
323                boundaries.push(Boundary {
324                    position: current_pos,
325                    boundary_type: BoundaryType::List,
326                    confidence: 0.9,
327                    context: Some("list_end".to_string()),
328                });
329                in_list = false;
330            }
331
332            current_pos += line.len() + 1;
333        }
334
335        boundaries
336    }
337
338    /// Detect code block boundaries
339    fn detect_code_block_boundaries(&self, text: &str) -> Vec<Boundary> {
340        let mut boundaries = Vec::new();
341
342        let lines: Vec<&str> = text.lines().collect();
343        let mut current_pos = 0;
344        let mut in_code_block = false;
345
346        for line in lines {
347            let line_trimmed = line.trim();
348
349            // Fenced code blocks (```)
350            if self.code_block_fence.is_match(line_trimmed) {
351                boundaries.push(Boundary {
352                    position: current_pos,
353                    boundary_type: BoundaryType::CodeBlock,
354                    confidence: 1.0,
355                    context: if in_code_block {
356                        Some("code_end".to_string())
357                    } else {
358                        Some("code_start".to_string())
359                    },
360                });
361                in_code_block = !in_code_block;
362            }
363
364            // Indented code blocks (4+ spaces at start)
365            if !in_code_block && line.starts_with("    ") && !line_trimmed.is_empty() {
366                boundaries.push(Boundary {
367                    position: current_pos,
368                    boundary_type: BoundaryType::CodeBlock,
369                    confidence: 0.7,
370                    context: Some("indented_code".to_string()),
371                });
372            }
373
374            current_pos += line.len() + 1;
375        }
376
377        boundaries
378    }
379
380    /// Get boundary positions of a specific type
381    pub fn get_boundaries_by_type(
382        &self,
383        boundaries: &[Boundary],
384        boundary_type: BoundaryType,
385    ) -> Vec<usize> {
386        boundaries
387            .iter()
388            .filter(|b| b.boundary_type == boundary_type)
389            .map(|b| b.position)
390            .collect()
391    }
392
393    /// Find the strongest boundary type at a given position
394    pub fn get_strongest_boundary_at<'a>(
395        &self,
396        boundaries: &'a [Boundary],
397        position: usize,
398        tolerance: usize,
399    ) -> Option<&'a Boundary> {
400        boundaries
401            .iter()
402            .filter(|b| {
403                let dist = if b.position > position {
404                    b.position - position
405                } else {
406                    position - b.position
407                };
408                dist <= tolerance
409            })
410            .max_by(|a, b| {
411                a.confidence
412                    .partial_cmp(&b.confidence)
413                    .unwrap_or(std::cmp::Ordering::Equal)
414            })
415    }
416}
417
418impl Default for BoundaryDetector {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_sentence_detection() {
430        let detector = BoundaryDetector::new();
431        let text = "This is a sentence. This is another! And a third?";
432
433        let boundaries = detector.detect_sentence_boundaries(text);
434
435        assert_eq!(boundaries.len(), 3);
436        assert_eq!(boundaries[0].boundary_type, BoundaryType::Sentence);
437    }
438
439    #[test]
440    fn test_abbreviation_handling() {
441        let detector = BoundaryDetector::new();
442        let text = "Dr. Smith went to the store. He bought milk.";
443
444        let boundaries = detector.detect_sentence_boundaries(text);
445
446        // Should detect only the second period, not "Dr."
447        assert_eq!(boundaries.len(), 1);
448    }
449
450    #[test]
451    fn test_paragraph_detection() {
452        let detector = BoundaryDetector::new();
453        let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
454
455        let boundaries = detector.detect_paragraph_boundaries(text);
456
457        assert_eq!(boundaries.len(), 2);
458        assert_eq!(boundaries[0].boundary_type, BoundaryType::Paragraph);
459    }
460
461    #[test]
462    fn test_markdown_heading_detection() {
463        let detector = BoundaryDetector::new();
464        let text = "# Main Heading\n\n## Subheading\n\n### Sub-subheading";
465
466        let boundaries = detector.detect_heading_boundaries(text);
467
468        assert!(boundaries.len() >= 3);
469        assert!(boundaries
470            .iter()
471            .all(|b| b.boundary_type == BoundaryType::Heading));
472    }
473
474    #[test]
475    fn test_list_detection() {
476        let detector = BoundaryDetector::new();
477        let text = "Regular text\n- Item 1\n- Item 2\n* Item 3\nMore text";
478
479        let boundaries = detector.detect_list_boundaries(text);
480
481        assert!(boundaries.len() >= 2); // Start and end
482        assert_eq!(boundaries[0].boundary_type, BoundaryType::List);
483    }
484
485    #[test]
486    fn test_code_block_detection() {
487        let detector = BoundaryDetector::new();
488        let text = "Some text\n```python\ncode here\n```\nMore text";
489
490        let boundaries = detector.detect_code_block_boundaries(text);
491
492        assert_eq!(boundaries.len(), 2); // Start and end
493        assert_eq!(boundaries[0].boundary_type, BoundaryType::CodeBlock);
494    }
495
496    #[test]
497    fn test_combined_detection() {
498        let detector = BoundaryDetector::new();
499        let text = "# Heading\n\nFirst paragraph. Second sentence.\n\n- List item 1\n- List item 2\n\nLast paragraph.";
500
501        let boundaries = detector.detect_boundaries(text);
502
503        // Should detect headings, paragraphs, sentences, and lists
504        assert!(boundaries.len() > 5);
505
506        let types: HashSet<_> = boundaries.iter().map(|b| b.boundary_type).collect();
507        assert!(types.contains(&BoundaryType::Heading));
508        assert!(types.contains(&BoundaryType::Paragraph));
509        assert!(types.contains(&BoundaryType::List));
510    }
511
512    #[test]
513    fn test_get_strongest_boundary() {
514        let detector = BoundaryDetector::new();
515        let boundaries = vec![
516            Boundary {
517                position: 100,
518                boundary_type: BoundaryType::Sentence,
519                confidence: 0.7,
520                context: None,
521            },
522            Boundary {
523                position: 105,
524                boundary_type: BoundaryType::Paragraph,
525                confidence: 0.95,
526                context: None,
527            },
528        ];
529
530        let strongest = detector.get_strongest_boundary_at(&boundaries, 102, 10);
531        assert!(strongest.is_some());
532        assert_eq!(strongest.unwrap().boundary_type, BoundaryType::Paragraph);
533        assert_eq!(strongest.unwrap().confidence, 0.95);
534    }
535}