1use regex::Regex;
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct BoundaryDetectionConfig {
25 pub detect_sentences: bool,
27
28 pub detect_paragraphs: bool,
30
31 pub detect_headings: bool,
33
34 pub detect_lists: bool,
36
37 pub detect_code_blocks: bool,
39
40 pub min_sentence_length: usize,
42
43 pub heading_markers: Vec<String>,
45}
46
47impl Default for BoundaryDetectionConfig {
48 fn default() -> Self {
49 Self {
50 detect_sentences: true,
51 detect_paragraphs: true,
52 detect_headings: true,
53 detect_lists: true,
54 detect_code_blocks: true,
55 min_sentence_length: 10,
56 heading_markers: vec![
57 "Chapter".to_string(),
58 "Section".to_string(),
59 "Introduction".to_string(),
60 "Conclusion".to_string(),
61 ],
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum BoundaryType {
69 Sentence,
71 Paragraph,
73 Heading,
75 List,
77 CodeBlock,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Boundary {
84 pub position: usize,
86
87 pub boundary_type: BoundaryType,
89
90 pub confidence: f32,
92
93 pub context: Option<String>,
95}
96
97pub struct BoundaryDetector {
99 config: BoundaryDetectionConfig,
100
101 sentence_endings: Regex,
103 markdown_heading: Regex,
104 numbered_list: Regex,
105 bullet_list: Regex,
106 code_block_fence: Regex,
107 rst_heading_underline: Regex,
108}
109
110impl BoundaryDetector {
111 pub fn new() -> Self {
113 Self::with_config(BoundaryDetectionConfig::default())
114 }
115
116 pub fn with_config(config: BoundaryDetectionConfig) -> Self {
118 Self {
119 config,
120 sentence_endings: Regex::new(r"[.!?]+[\s]+").expect("static regex literal"),
122 markdown_heading: Regex::new(r"^#{1,6}\s+.+$").expect("static regex literal"),
123 numbered_list: Regex::new(r"^\d+[.)]\s+").expect("static regex literal"),
124 bullet_list: Regex::new(r"^[\-\*\+]\s+").expect("static regex literal"),
125 code_block_fence: Regex::new(r"^```").expect("static regex literal"),
126 rst_heading_underline: Regex::new("^[=\\-~^\"]+\\s*$").expect("static regex literal"),
127 }
128 }
129
130 pub fn detect_boundaries(&self, text: &str) -> Vec<Boundary> {
132 let mut boundaries = Vec::new();
133
134 if self.config.detect_sentences {
135 boundaries.extend(self.detect_sentence_boundaries(text));
136 }
137
138 if self.config.detect_paragraphs {
139 boundaries.extend(self.detect_paragraph_boundaries(text));
140 }
141
142 if self.config.detect_headings {
143 boundaries.extend(self.detect_heading_boundaries(text));
144 }
145
146 if self.config.detect_lists {
147 boundaries.extend(self.detect_list_boundaries(text));
148 }
149
150 if self.config.detect_code_blocks {
151 boundaries.extend(self.detect_code_block_boundaries(text));
152 }
153
154 boundaries.sort_by_key(|b| b.position);
156 boundaries.dedup_by_key(|b| b.position);
157
158 boundaries
159 }
160
161 fn detect_sentence_boundaries(&self, text: &str) -> Vec<Boundary> {
163 let mut boundaries = Vec::new();
164
165 let abbreviations: HashSet<&str> = [
167 "Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "etc.", "e.g.", "i.e.", "vs.",
168 "cf.", "Jan.", "Feb.", "Mar.", "Apr.", "Jun.", "Jul.", "Aug.", "Sep.", "Oct.", "Nov.",
169 "Dec.",
170 ]
171 .iter()
172 .copied()
173 .collect();
174
175 for mat in self.sentence_endings.find_iter(text) {
177 let position = mat.start();
178
179 let before_text = &text[..position];
181 let is_abbreviation = abbreviations
182 .iter()
183 .any(|abbr| before_text.ends_with(&abbr[..abbr.len() - 1]));
184
185 if !is_abbreviation {
186 let sentence_start = boundaries
188 .last()
189 .map(|b: &Boundary| b.position)
190 .unwrap_or(0);
191 let sentence_length = position - sentence_start;
192
193 if sentence_length >= self.config.min_sentence_length {
194 boundaries.push(Boundary {
195 position: mat.end(),
196 boundary_type: BoundaryType::Sentence,
197 confidence: 0.9,
198 context: None,
199 });
200 }
201 }
202 }
203
204 boundaries
205 }
206
207 fn detect_paragraph_boundaries(&self, text: &str) -> Vec<Boundary> {
209 let mut boundaries = Vec::new();
210
211 let paragraph_regex = Regex::new(r"\n\s*\n").expect("static regex literal");
213
214 for mat in paragraph_regex.find_iter(text) {
215 boundaries.push(Boundary {
216 position: mat.end(),
217 boundary_type: BoundaryType::Paragraph,
218 confidence: 1.0,
219 context: None,
220 });
221 }
222
223 boundaries
224 }
225
226 fn detect_heading_boundaries(&self, text: &str) -> Vec<Boundary> {
228 let mut boundaries = Vec::new();
229
230 let lines: Vec<&str> = text.lines().collect();
231 let mut current_pos = 0;
232
233 for (i, line) in lines.iter().enumerate() {
234 let line_start = current_pos;
235 let line_trimmed = line.trim();
236
237 if self.markdown_heading.is_match(line) {
239 let heading_text = line_trimmed.trim_start_matches('#').trim();
240 boundaries.push(Boundary {
241 position: line_start,
242 boundary_type: BoundaryType::Heading,
243 confidence: 0.95,
244 context: Some(heading_text.to_string()),
245 });
246 }
247
248 if i > 0 && self.rst_heading_underline.is_match(line_trimmed) {
250 let prev_line = lines[i - 1].trim();
251 if !prev_line.is_empty() && line_trimmed.len() >= prev_line.len() {
252 boundaries.push(Boundary {
253 position: line_start,
254 boundary_type: BoundaryType::Heading,
255 confidence: 0.9,
256 context: Some(prev_line.to_string()),
257 });
258 }
259 }
260
261 if line_trimmed.len() > 3
263 && line_trimmed
264 .chars()
265 .all(|c| c.is_uppercase() || c.is_whitespace() || c.is_numeric())
266 && line_trimmed.chars().any(|c| c.is_alphabetic())
267 {
268 boundaries.push(Boundary {
269 position: line_start,
270 boundary_type: BoundaryType::Heading,
271 confidence: 0.7,
272 context: Some(line_trimmed.to_string()),
273 });
274 }
275
276 for marker in &self.config.heading_markers {
278 if line_trimmed.starts_with(marker) {
279 boundaries.push(Boundary {
280 position: line_start,
281 boundary_type: BoundaryType::Heading,
282 confidence: 0.85,
283 context: Some(line_trimmed.to_string()),
284 });
285 break;
286 }
287 }
288
289 current_pos += line.len() + 1; }
291
292 boundaries
293 }
294
295 fn detect_list_boundaries(&self, text: &str) -> Vec<Boundary> {
297 let mut boundaries = Vec::new();
298
299 let lines: Vec<&str> = text.lines().collect();
300 let mut current_pos = 0;
301 let mut in_list = false;
302
303 for line in lines {
304 let line_trimmed = line.trim();
305
306 let is_list_item = self.numbered_list.is_match(line_trimmed)
308 || self.bullet_list.is_match(line_trimmed);
309
310 if is_list_item && !in_list {
312 boundaries.push(Boundary {
313 position: current_pos,
314 boundary_type: BoundaryType::List,
315 confidence: 0.9,
316 context: Some("list_start".to_string()),
317 });
318 in_list = true;
319 }
320
321 if !is_list_item && in_list && !line_trimmed.is_empty() {
323 boundaries.push(Boundary {
324 position: current_pos,
325 boundary_type: BoundaryType::List,
326 confidence: 0.9,
327 context: Some("list_end".to_string()),
328 });
329 in_list = false;
330 }
331
332 current_pos += line.len() + 1;
333 }
334
335 boundaries
336 }
337
338 fn detect_code_block_boundaries(&self, text: &str) -> Vec<Boundary> {
340 let mut boundaries = Vec::new();
341
342 let lines: Vec<&str> = text.lines().collect();
343 let mut current_pos = 0;
344 let mut in_code_block = false;
345
346 for line in lines {
347 let line_trimmed = line.trim();
348
349 if self.code_block_fence.is_match(line_trimmed) {
351 boundaries.push(Boundary {
352 position: current_pos,
353 boundary_type: BoundaryType::CodeBlock,
354 confidence: 1.0,
355 context: if in_code_block {
356 Some("code_end".to_string())
357 } else {
358 Some("code_start".to_string())
359 },
360 });
361 in_code_block = !in_code_block;
362 }
363
364 if !in_code_block && line.starts_with(" ") && !line_trimmed.is_empty() {
366 boundaries.push(Boundary {
367 position: current_pos,
368 boundary_type: BoundaryType::CodeBlock,
369 confidence: 0.7,
370 context: Some("indented_code".to_string()),
371 });
372 }
373
374 current_pos += line.len() + 1;
375 }
376
377 boundaries
378 }
379
380 pub fn get_boundaries_by_type(
382 &self,
383 boundaries: &[Boundary],
384 boundary_type: BoundaryType,
385 ) -> Vec<usize> {
386 boundaries
387 .iter()
388 .filter(|b| b.boundary_type == boundary_type)
389 .map(|b| b.position)
390 .collect()
391 }
392
393 pub fn get_strongest_boundary_at<'a>(
395 &self,
396 boundaries: &'a [Boundary],
397 position: usize,
398 tolerance: usize,
399 ) -> Option<&'a Boundary> {
400 boundaries
401 .iter()
402 .filter(|b| {
403 let dist = b.position.abs_diff(position);
404 dist <= tolerance
405 })
406 .max_by(|a, b| {
407 a.confidence
408 .partial_cmp(&b.confidence)
409 .unwrap_or(std::cmp::Ordering::Equal)
410 })
411 }
412}
413
414impl Default for BoundaryDetector {
415 fn default() -> Self {
416 Self::new()
417 }
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423
424 #[test]
425 fn test_abbreviation_handling() {
426 let detector = BoundaryDetector::new();
427 let text = "Dr. Smith went to the store. He bought milk.";
428
429 let boundaries = detector.detect_sentence_boundaries(text);
430
431 assert_eq!(boundaries.len(), 1);
433 }
434
435 #[test]
436 fn test_paragraph_detection() {
437 let detector = BoundaryDetector::new();
438 let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
439
440 let boundaries = detector.detect_paragraph_boundaries(text);
441
442 assert_eq!(boundaries.len(), 2);
443 assert_eq!(boundaries[0].boundary_type, BoundaryType::Paragraph);
444 }
445
446 #[test]
447 fn test_markdown_heading_detection() {
448 let detector = BoundaryDetector::new();
449 let text = "# Main Heading\n\n## Subheading\n\n### Sub-subheading";
450
451 let boundaries = detector.detect_heading_boundaries(text);
452
453 assert!(boundaries.len() >= 3);
454 assert!(boundaries
455 .iter()
456 .all(|b| b.boundary_type == BoundaryType::Heading));
457 }
458
459 #[test]
460 fn test_list_detection() {
461 let detector = BoundaryDetector::new();
462 let text = "Regular text\n- Item 1\n- Item 2\n* Item 3\nMore text";
463
464 let boundaries = detector.detect_list_boundaries(text);
465
466 assert!(boundaries.len() >= 2); assert_eq!(boundaries[0].boundary_type, BoundaryType::List);
468 }
469
470 #[test]
471 fn test_code_block_detection() {
472 let detector = BoundaryDetector::new();
473 let text = "Some text\n```python\ncode here\n```\nMore text";
474
475 let boundaries = detector.detect_code_block_boundaries(text);
476
477 assert_eq!(boundaries.len(), 2); assert_eq!(boundaries[0].boundary_type, BoundaryType::CodeBlock);
479 }
480
481 #[test]
482 fn test_get_strongest_boundary() {
483 let detector = BoundaryDetector::new();
484 let boundaries = vec![
485 Boundary {
486 position: 100,
487 boundary_type: BoundaryType::Sentence,
488 confidence: 0.7,
489 context: None,
490 },
491 Boundary {
492 position: 105,
493 boundary_type: BoundaryType::Paragraph,
494 confidence: 0.95,
495 context: None,
496 },
497 ];
498
499 let strongest = detector.get_strongest_boundary_at(&boundaries, 102, 10);
500 assert!(strongest.is_some());
501 assert_eq!(strongest.unwrap().boundary_type, BoundaryType::Paragraph);
502 assert_eq!(strongest.unwrap().confidence, 0.95);
503 }
504}