use regex::Regex;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BoundaryDetectionConfig {
pub detect_sentences: bool,
pub detect_paragraphs: bool,
pub detect_headings: bool,
pub detect_lists: bool,
pub detect_code_blocks: bool,
pub min_sentence_length: usize,
pub heading_markers: Vec<String>,
}
impl Default for BoundaryDetectionConfig {
fn default() -> Self {
Self {
detect_sentences: true,
detect_paragraphs: true,
detect_headings: true,
detect_lists: true,
detect_code_blocks: true,
min_sentence_length: 10,
heading_markers: vec![
"Chapter".to_string(),
"Section".to_string(),
"Introduction".to_string(),
"Conclusion".to_string(),
],
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum BoundaryType {
Sentence,
Paragraph,
Heading,
List,
CodeBlock,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Boundary {
pub position: usize,
pub boundary_type: BoundaryType,
pub confidence: f32,
pub context: Option<String>,
}
pub struct BoundaryDetector {
config: BoundaryDetectionConfig,
sentence_endings: Regex,
markdown_heading: Regex,
numbered_list: Regex,
bullet_list: Regex,
code_block_fence: Regex,
rst_heading_underline: Regex,
}
impl BoundaryDetector {
pub fn new() -> Self {
Self::with_config(BoundaryDetectionConfig::default())
}
pub fn with_config(config: BoundaryDetectionConfig) -> Self {
Self {
config,
sentence_endings: Regex::new(r"[.!?]+[\s]+").expect("static regex literal"),
markdown_heading: Regex::new(r"^#{1,6}\s+.+$").expect("static regex literal"),
numbered_list: Regex::new(r"^\d+[.)]\s+").expect("static regex literal"),
bullet_list: Regex::new(r"^[\-\*\+]\s+").expect("static regex literal"),
code_block_fence: Regex::new(r"^```").expect("static regex literal"),
rst_heading_underline: Regex::new("^[=\\-~^\"]+\\s*$").expect("static regex literal"),
}
}
pub fn detect_boundaries(&self, text: &str) -> Vec<Boundary> {
let mut boundaries = Vec::new();
if self.config.detect_sentences {
boundaries.extend(self.detect_sentence_boundaries(text));
}
if self.config.detect_paragraphs {
boundaries.extend(self.detect_paragraph_boundaries(text));
}
if self.config.detect_headings {
boundaries.extend(self.detect_heading_boundaries(text));
}
if self.config.detect_lists {
boundaries.extend(self.detect_list_boundaries(text));
}
if self.config.detect_code_blocks {
boundaries.extend(self.detect_code_block_boundaries(text));
}
boundaries.sort_by_key(|b| b.position);
boundaries.dedup_by_key(|b| b.position);
boundaries
}
fn detect_sentence_boundaries(&self, text: &str) -> Vec<Boundary> {
let mut boundaries = Vec::new();
let abbreviations: HashSet<&str> = [
"Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "etc.", "e.g.", "i.e.", "vs.",
"cf.", "Jan.", "Feb.", "Mar.", "Apr.", "Jun.", "Jul.", "Aug.", "Sep.", "Oct.", "Nov.",
"Dec.",
]
.iter()
.copied()
.collect();
for mat in self.sentence_endings.find_iter(text) {
let position = mat.start();
let before_text = &text[..position];
let is_abbreviation = abbreviations
.iter()
.any(|abbr| before_text.ends_with(&abbr[..abbr.len() - 1]));
if !is_abbreviation {
let sentence_start = boundaries
.last()
.map(|b: &Boundary| b.position)
.unwrap_or(0);
let sentence_length = position - sentence_start;
if sentence_length >= self.config.min_sentence_length {
boundaries.push(Boundary {
position: mat.end(),
boundary_type: BoundaryType::Sentence,
confidence: 0.9,
context: None,
});
}
}
}
boundaries
}
fn detect_paragraph_boundaries(&self, text: &str) -> Vec<Boundary> {
let mut boundaries = Vec::new();
let paragraph_regex = Regex::new(r"\n\s*\n").expect("static regex literal");
for mat in paragraph_regex.find_iter(text) {
boundaries.push(Boundary {
position: mat.end(),
boundary_type: BoundaryType::Paragraph,
confidence: 1.0,
context: None,
});
}
boundaries
}
fn detect_heading_boundaries(&self, text: &str) -> Vec<Boundary> {
let mut boundaries = Vec::new();
let lines: Vec<&str> = text.lines().collect();
let mut current_pos = 0;
for (i, line) in lines.iter().enumerate() {
let line_start = current_pos;
let line_trimmed = line.trim();
if self.markdown_heading.is_match(line) {
let heading_text = line_trimmed.trim_start_matches('#').trim();
boundaries.push(Boundary {
position: line_start,
boundary_type: BoundaryType::Heading,
confidence: 0.95,
context: Some(heading_text.to_string()),
});
}
if i > 0 && self.rst_heading_underline.is_match(line_trimmed) {
let prev_line = lines[i - 1].trim();
if !prev_line.is_empty() && line_trimmed.len() >= prev_line.len() {
boundaries.push(Boundary {
position: line_start,
boundary_type: BoundaryType::Heading,
confidence: 0.9,
context: Some(prev_line.to_string()),
});
}
}
if line_trimmed.len() > 3
&& line_trimmed
.chars()
.all(|c| c.is_uppercase() || c.is_whitespace() || c.is_numeric())
&& line_trimmed.chars().any(|c| c.is_alphabetic())
{
boundaries.push(Boundary {
position: line_start,
boundary_type: BoundaryType::Heading,
confidence: 0.7,
context: Some(line_trimmed.to_string()),
});
}
for marker in &self.config.heading_markers {
if line_trimmed.starts_with(marker) {
boundaries.push(Boundary {
position: line_start,
boundary_type: BoundaryType::Heading,
confidence: 0.85,
context: Some(line_trimmed.to_string()),
});
break;
}
}
current_pos += line.len() + 1; }
boundaries
}
fn detect_list_boundaries(&self, text: &str) -> Vec<Boundary> {
let mut boundaries = Vec::new();
let lines: Vec<&str> = text.lines().collect();
let mut current_pos = 0;
let mut in_list = false;
for line in lines {
let line_trimmed = line.trim();
let is_list_item = self.numbered_list.is_match(line_trimmed)
|| self.bullet_list.is_match(line_trimmed);
if is_list_item && !in_list {
boundaries.push(Boundary {
position: current_pos,
boundary_type: BoundaryType::List,
confidence: 0.9,
context: Some("list_start".to_string()),
});
in_list = true;
}
if !is_list_item && in_list && !line_trimmed.is_empty() {
boundaries.push(Boundary {
position: current_pos,
boundary_type: BoundaryType::List,
confidence: 0.9,
context: Some("list_end".to_string()),
});
in_list = false;
}
current_pos += line.len() + 1;
}
boundaries
}
fn detect_code_block_boundaries(&self, text: &str) -> Vec<Boundary> {
let mut boundaries = Vec::new();
let lines: Vec<&str> = text.lines().collect();
let mut current_pos = 0;
let mut in_code_block = false;
for line in lines {
let line_trimmed = line.trim();
if self.code_block_fence.is_match(line_trimmed) {
boundaries.push(Boundary {
position: current_pos,
boundary_type: BoundaryType::CodeBlock,
confidence: 1.0,
context: if in_code_block {
Some("code_end".to_string())
} else {
Some("code_start".to_string())
},
});
in_code_block = !in_code_block;
}
if !in_code_block && line.starts_with(" ") && !line_trimmed.is_empty() {
boundaries.push(Boundary {
position: current_pos,
boundary_type: BoundaryType::CodeBlock,
confidence: 0.7,
context: Some("indented_code".to_string()),
});
}
current_pos += line.len() + 1;
}
boundaries
}
pub fn get_boundaries_by_type(
&self,
boundaries: &[Boundary],
boundary_type: BoundaryType,
) -> Vec<usize> {
boundaries
.iter()
.filter(|b| b.boundary_type == boundary_type)
.map(|b| b.position)
.collect()
}
pub fn get_strongest_boundary_at<'a>(
&self,
boundaries: &'a [Boundary],
position: usize,
tolerance: usize,
) -> Option<&'a Boundary> {
boundaries
.iter()
.filter(|b| {
let dist = b.position.abs_diff(position);
dist <= tolerance
})
.max_by(|a, b| {
a.confidence
.partial_cmp(&b.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
})
}
}
impl Default for BoundaryDetector {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_abbreviation_handling() {
let detector = BoundaryDetector::new();
let text = "Dr. Smith went to the store. He bought milk.";
let boundaries = detector.detect_sentence_boundaries(text);
assert_eq!(boundaries.len(), 1);
}
#[test]
fn test_paragraph_detection() {
let detector = BoundaryDetector::new();
let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
let boundaries = detector.detect_paragraph_boundaries(text);
assert_eq!(boundaries.len(), 2);
assert_eq!(boundaries[0].boundary_type, BoundaryType::Paragraph);
}
#[test]
fn test_markdown_heading_detection() {
let detector = BoundaryDetector::new();
let text = "# Main Heading\n\n## Subheading\n\n### Sub-subheading";
let boundaries = detector.detect_heading_boundaries(text);
assert!(boundaries.len() >= 3);
assert!(boundaries
.iter()
.all(|b| b.boundary_type == BoundaryType::Heading));
}
#[test]
fn test_list_detection() {
let detector = BoundaryDetector::new();
let text = "Regular text\n- Item 1\n- Item 2\n* Item 3\nMore text";
let boundaries = detector.detect_list_boundaries(text);
assert!(boundaries.len() >= 2); assert_eq!(boundaries[0].boundary_type, BoundaryType::List);
}
#[test]
fn test_code_block_detection() {
let detector = BoundaryDetector::new();
let text = "Some text\n```python\ncode here\n```\nMore text";
let boundaries = detector.detect_code_block_boundaries(text);
assert_eq!(boundaries.len(), 2); assert_eq!(boundaries[0].boundary_type, BoundaryType::CodeBlock);
}
#[test]
fn test_get_strongest_boundary() {
let detector = BoundaryDetector::new();
let boundaries = vec![
Boundary {
position: 100,
boundary_type: BoundaryType::Sentence,
confidence: 0.7,
context: None,
},
Boundary {
position: 105,
boundary_type: BoundaryType::Paragraph,
confidence: 0.95,
context: None,
},
];
let strongest = detector.get_strongest_boundary_at(&boundaries, 102, 10);
assert!(strongest.is_some());
assert_eq!(strongest.unwrap().boundary_type, BoundaryType::Paragraph);
assert_eq!(strongest.unwrap().confidence, 0.95);
}
}