1use regex::Regex;
19use serde::{Deserialize, Serialize};
20use std::collections::HashSet;
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct BoundaryDetectionConfig {
25 pub detect_sentences: bool,
27
28 pub detect_paragraphs: bool,
30
31 pub detect_headings: bool,
33
34 pub detect_lists: bool,
36
37 pub detect_code_blocks: bool,
39
40 pub min_sentence_length: usize,
42
43 pub heading_markers: Vec<String>,
45}
46
47impl Default for BoundaryDetectionConfig {
48 fn default() -> Self {
49 Self {
50 detect_sentences: true,
51 detect_paragraphs: true,
52 detect_headings: true,
53 detect_lists: true,
54 detect_code_blocks: true,
55 min_sentence_length: 10,
56 heading_markers: vec![
57 "Chapter".to_string(),
58 "Section".to_string(),
59 "Introduction".to_string(),
60 "Conclusion".to_string(),
61 ],
62 }
63 }
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
68pub enum BoundaryType {
69 Sentence,
71 Paragraph,
73 Heading,
75 List,
77 CodeBlock,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct Boundary {
84 pub position: usize,
86
87 pub boundary_type: BoundaryType,
89
90 pub confidence: f32,
92
93 pub context: Option<String>,
95}
96
97pub struct BoundaryDetector {
99 config: BoundaryDetectionConfig,
100
101 sentence_endings: Regex,
103 markdown_heading: Regex,
104 numbered_list: Regex,
105 bullet_list: Regex,
106 code_block_fence: Regex,
107 rst_heading_underline: Regex,
108}
109
110impl BoundaryDetector {
111 pub fn new() -> Self {
113 Self::with_config(BoundaryDetectionConfig::default())
114 }
115
116 pub fn with_config(config: BoundaryDetectionConfig) -> Self {
118 Self {
119 config,
120 sentence_endings: Regex::new(r"[.!?]+[\s]+").unwrap(),
122 markdown_heading: Regex::new(r"^#{1,6}\s+.+$").unwrap(),
123 numbered_list: Regex::new(r"^\d+[.)]\s+").unwrap(),
124 bullet_list: Regex::new(r"^[\-\*\+]\s+").unwrap(),
125 code_block_fence: Regex::new(r"^```").unwrap(),
126 rst_heading_underline: Regex::new("^[=\\-~^\"]+\\s*$").unwrap(),
127 }
128 }
129
130 pub fn detect_boundaries(&self, text: &str) -> Vec<Boundary> {
132 let mut boundaries = Vec::new();
133
134 if self.config.detect_sentences {
135 boundaries.extend(self.detect_sentence_boundaries(text));
136 }
137
138 if self.config.detect_paragraphs {
139 boundaries.extend(self.detect_paragraph_boundaries(text));
140 }
141
142 if self.config.detect_headings {
143 boundaries.extend(self.detect_heading_boundaries(text));
144 }
145
146 if self.config.detect_lists {
147 boundaries.extend(self.detect_list_boundaries(text));
148 }
149
150 if self.config.detect_code_blocks {
151 boundaries.extend(self.detect_code_block_boundaries(text));
152 }
153
154 boundaries.sort_by_key(|b| b.position);
156 boundaries.dedup_by_key(|b| b.position);
157
158 boundaries
159 }
160
161 fn detect_sentence_boundaries(&self, text: &str) -> Vec<Boundary> {
163 let mut boundaries = Vec::new();
164
165 let abbreviations: HashSet<&str> = [
167 "Dr.", "Mr.", "Mrs.", "Ms.", "Prof.", "Sr.", "Jr.", "etc.", "e.g.", "i.e.", "vs.",
168 "cf.", "Jan.", "Feb.", "Mar.", "Apr.", "Jun.", "Jul.", "Aug.", "Sep.", "Oct.", "Nov.",
169 "Dec.",
170 ]
171 .iter()
172 .copied()
173 .collect();
174
175 for mat in self.sentence_endings.find_iter(text) {
177 let position = mat.start();
178
179 let before_text = &text[..position];
181 let is_abbreviation = abbreviations
182 .iter()
183 .any(|abbr| before_text.ends_with(&abbr[..abbr.len() - 1]));
184
185 if !is_abbreviation {
186 let sentence_start = boundaries
188 .last()
189 .map(|b: &Boundary| b.position)
190 .unwrap_or(0);
191 let sentence_length = position - sentence_start;
192
193 if sentence_length >= self.config.min_sentence_length {
194 boundaries.push(Boundary {
195 position: mat.end(),
196 boundary_type: BoundaryType::Sentence,
197 confidence: 0.9,
198 context: None,
199 });
200 }
201 }
202 }
203
204 boundaries
205 }
206
207 fn detect_paragraph_boundaries(&self, text: &str) -> Vec<Boundary> {
209 let mut boundaries = Vec::new();
210
211 let paragraph_regex = Regex::new(r"\n\s*\n").unwrap();
213
214 for mat in paragraph_regex.find_iter(text) {
215 boundaries.push(Boundary {
216 position: mat.end(),
217 boundary_type: BoundaryType::Paragraph,
218 confidence: 1.0,
219 context: None,
220 });
221 }
222
223 boundaries
224 }
225
226 fn detect_heading_boundaries(&self, text: &str) -> Vec<Boundary> {
228 let mut boundaries = Vec::new();
229
230 let lines: Vec<&str> = text.lines().collect();
231 let mut current_pos = 0;
232
233 for (i, line) in lines.iter().enumerate() {
234 let line_start = current_pos;
235 let line_trimmed = line.trim();
236
237 if self.markdown_heading.is_match(line) {
239 let heading_text = line_trimmed.trim_start_matches('#').trim();
240 boundaries.push(Boundary {
241 position: line_start,
242 boundary_type: BoundaryType::Heading,
243 confidence: 0.95,
244 context: Some(heading_text.to_string()),
245 });
246 }
247
248 if i > 0 && self.rst_heading_underline.is_match(line_trimmed) {
250 let prev_line = lines[i - 1].trim();
251 if !prev_line.is_empty() && line_trimmed.len() >= prev_line.len() {
252 boundaries.push(Boundary {
253 position: line_start,
254 boundary_type: BoundaryType::Heading,
255 confidence: 0.9,
256 context: Some(prev_line.to_string()),
257 });
258 }
259 }
260
261 if line_trimmed.len() > 3
263 && line_trimmed
264 .chars()
265 .all(|c| c.is_uppercase() || c.is_whitespace() || c.is_numeric())
266 && line_trimmed.chars().any(|c| c.is_alphabetic())
267 {
268 boundaries.push(Boundary {
269 position: line_start,
270 boundary_type: BoundaryType::Heading,
271 confidence: 0.7,
272 context: Some(line_trimmed.to_string()),
273 });
274 }
275
276 for marker in &self.config.heading_markers {
278 if line_trimmed.starts_with(marker) {
279 boundaries.push(Boundary {
280 position: line_start,
281 boundary_type: BoundaryType::Heading,
282 confidence: 0.85,
283 context: Some(line_trimmed.to_string()),
284 });
285 break;
286 }
287 }
288
289 current_pos += line.len() + 1; }
291
292 boundaries
293 }
294
295 fn detect_list_boundaries(&self, text: &str) -> Vec<Boundary> {
297 let mut boundaries = Vec::new();
298
299 let lines: Vec<&str> = text.lines().collect();
300 let mut current_pos = 0;
301 let mut in_list = false;
302
303 for line in lines {
304 let line_trimmed = line.trim();
305
306 let is_list_item = self.numbered_list.is_match(line_trimmed)
308 || self.bullet_list.is_match(line_trimmed);
309
310 if is_list_item && !in_list {
312 boundaries.push(Boundary {
313 position: current_pos,
314 boundary_type: BoundaryType::List,
315 confidence: 0.9,
316 context: Some("list_start".to_string()),
317 });
318 in_list = true;
319 }
320
321 if !is_list_item && in_list && !line_trimmed.is_empty() {
323 boundaries.push(Boundary {
324 position: current_pos,
325 boundary_type: BoundaryType::List,
326 confidence: 0.9,
327 context: Some("list_end".to_string()),
328 });
329 in_list = false;
330 }
331
332 current_pos += line.len() + 1;
333 }
334
335 boundaries
336 }
337
338 fn detect_code_block_boundaries(&self, text: &str) -> Vec<Boundary> {
340 let mut boundaries = Vec::new();
341
342 let lines: Vec<&str> = text.lines().collect();
343 let mut current_pos = 0;
344 let mut in_code_block = false;
345
346 for line in lines {
347 let line_trimmed = line.trim();
348
349 if self.code_block_fence.is_match(line_trimmed) {
351 boundaries.push(Boundary {
352 position: current_pos,
353 boundary_type: BoundaryType::CodeBlock,
354 confidence: 1.0,
355 context: if in_code_block {
356 Some("code_end".to_string())
357 } else {
358 Some("code_start".to_string())
359 },
360 });
361 in_code_block = !in_code_block;
362 }
363
364 if !in_code_block && line.starts_with(" ") && !line_trimmed.is_empty() {
366 boundaries.push(Boundary {
367 position: current_pos,
368 boundary_type: BoundaryType::CodeBlock,
369 confidence: 0.7,
370 context: Some("indented_code".to_string()),
371 });
372 }
373
374 current_pos += line.len() + 1;
375 }
376
377 boundaries
378 }
379
380 pub fn get_boundaries_by_type(
382 &self,
383 boundaries: &[Boundary],
384 boundary_type: BoundaryType,
385 ) -> Vec<usize> {
386 boundaries
387 .iter()
388 .filter(|b| b.boundary_type == boundary_type)
389 .map(|b| b.position)
390 .collect()
391 }
392
393 pub fn get_strongest_boundary_at<'a>(
395 &self,
396 boundaries: &'a [Boundary],
397 position: usize,
398 tolerance: usize,
399 ) -> Option<&'a Boundary> {
400 boundaries
401 .iter()
402 .filter(|b| {
403 let dist = if b.position > position {
404 b.position - position
405 } else {
406 position - b.position
407 };
408 dist <= tolerance
409 })
410 .max_by(|a, b| {
411 a.confidence
412 .partial_cmp(&b.confidence)
413 .unwrap_or(std::cmp::Ordering::Equal)
414 })
415 }
416}
417
418impl Default for BoundaryDetector {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_sentence_detection() {
430 let detector = BoundaryDetector::new();
431 let text = "This is a sentence. This is another! And a third?";
432
433 let boundaries = detector.detect_sentence_boundaries(text);
434
435 assert_eq!(boundaries.len(), 3);
436 assert_eq!(boundaries[0].boundary_type, BoundaryType::Sentence);
437 }
438
439 #[test]
440 fn test_abbreviation_handling() {
441 let detector = BoundaryDetector::new();
442 let text = "Dr. Smith went to the store. He bought milk.";
443
444 let boundaries = detector.detect_sentence_boundaries(text);
445
446 assert_eq!(boundaries.len(), 1);
448 }
449
450 #[test]
451 fn test_paragraph_detection() {
452 let detector = BoundaryDetector::new();
453 let text = "First paragraph.\n\nSecond paragraph.\n\nThird paragraph.";
454
455 let boundaries = detector.detect_paragraph_boundaries(text);
456
457 assert_eq!(boundaries.len(), 2);
458 assert_eq!(boundaries[0].boundary_type, BoundaryType::Paragraph);
459 }
460
461 #[test]
462 fn test_markdown_heading_detection() {
463 let detector = BoundaryDetector::new();
464 let text = "# Main Heading\n\n## Subheading\n\n### Sub-subheading";
465
466 let boundaries = detector.detect_heading_boundaries(text);
467
468 assert!(boundaries.len() >= 3);
469 assert!(boundaries
470 .iter()
471 .all(|b| b.boundary_type == BoundaryType::Heading));
472 }
473
474 #[test]
475 fn test_list_detection() {
476 let detector = BoundaryDetector::new();
477 let text = "Regular text\n- Item 1\n- Item 2\n* Item 3\nMore text";
478
479 let boundaries = detector.detect_list_boundaries(text);
480
481 assert!(boundaries.len() >= 2); assert_eq!(boundaries[0].boundary_type, BoundaryType::List);
483 }
484
485 #[test]
486 fn test_code_block_detection() {
487 let detector = BoundaryDetector::new();
488 let text = "Some text\n```python\ncode here\n```\nMore text";
489
490 let boundaries = detector.detect_code_block_boundaries(text);
491
492 assert_eq!(boundaries.len(), 2); assert_eq!(boundaries[0].boundary_type, BoundaryType::CodeBlock);
494 }
495
496 #[test]
497 fn test_combined_detection() {
498 let detector = BoundaryDetector::new();
499 let text = "# Heading\n\nFirst paragraph. Second sentence.\n\n- List item 1\n- List item 2\n\nLast paragraph.";
500
501 let boundaries = detector.detect_boundaries(text);
502
503 assert!(boundaries.len() > 5);
505
506 let types: HashSet<_> = boundaries.iter().map(|b| b.boundary_type).collect();
507 assert!(types.contains(&BoundaryType::Heading));
508 assert!(types.contains(&BoundaryType::Paragraph));
509 assert!(types.contains(&BoundaryType::List));
510 }
511
512 #[test]
513 fn test_get_strongest_boundary() {
514 let detector = BoundaryDetector::new();
515 let boundaries = vec![
516 Boundary {
517 position: 100,
518 boundary_type: BoundaryType::Sentence,
519 confidence: 0.7,
520 context: None,
521 },
522 Boundary {
523 position: 105,
524 boundary_type: BoundaryType::Paragraph,
525 confidence: 0.95,
526 context: None,
527 },
528 ];
529
530 let strongest = detector.get_strongest_boundary_at(&boundaries, 102, 10);
531 assert!(strongest.is_some());
532 assert_eq!(strongest.unwrap().boundary_type, BoundaryType::Paragraph);
533 assert_eq!(strongest.unwrap().confidence, 0.95);
534 }
535}