Skip to main content

scirs2_text/
segmentation.rs

1//! Text Chunking and Sentence Segmentation (`segmentation.rs`)
2//!
3//! Provides:
4//!
5//! - [`SentenceSegmenter`] -- boundary-aware sentence splitter with
6//!   abbreviation handling
7//! - [`TextChunker`] -- sliding-window chunker with configurable overlap,
8//!   optionally respecting sentence boundaries
9//! - [`TextChunk`] -- metadata-rich chunk descriptor
10
11use std::collections::HashSet;
12
13// ---------------------------------------------------------------------------
14// Built-in English abbreviations
15// ---------------------------------------------------------------------------
16
17fn builtin_abbreviations() -> HashSet<String> {
18    [
19        // Titles
20        "Mr", "Mrs", "Ms", "Miss", "Dr", "Prof", "Rev", "Gen", "Col", "Capt", "Lt", "Sgt", "Cpl",
21        "Pte", "Sr", "Jr", // Geographic
22        "St", "Ave", "Blvd", "Rd", "Ln", "Ct", "Pl", "Mt", "Ft", // Time / month
23        "Jan", "Feb", "Mar", "Apr", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec", "Mon", "Tue",
24        "Wed", "Thu", "Fri", "Sat", "Sun", // Miscellaneous
25        "etc", "vs", "approx", "est", "dept", "corp", "co", "inc", "Fig", "fig", "Vol", "vol",
26        "No", "Nos", "pp", "Ch", "Sec", "e.g", "i.e", "et", "al", "n.b", "N.B", "Esq",
27    ]
28    .iter()
29    .map(|s| s.to_string())
30    .collect()
31}
32
33// ---------------------------------------------------------------------------
34// SentenceSegmenter
35// ---------------------------------------------------------------------------
36
37/// Sentence boundary detector.
38///
39/// Uses heuristic rules:
40/// 1. `.`, `!`, `?` followed by whitespace and an upper-case letter (or end of
41///    string) are candidate boundaries.
42/// 2. Tokens ending with a known abbreviation are NOT treated as sentence
43///    boundaries.
44/// 3. Ellipsis (`...`) is NOT treated as a boundary.
45///
46/// # Example
47///
48/// ```rust
49/// use scirs2_text::segmentation::SentenceSegmenter;
50///
51/// let seg = SentenceSegmenter::new();
52/// let sentences = seg.segment("Hello, Dr. Smith. How are you today?");
53/// assert_eq!(sentences.len(), 2);
54/// ```
55pub struct SentenceSegmenter {
56    /// Known abbreviations (without trailing period).
57    abbreviations: HashSet<String>,
58    /// Minimum byte length of a sentence (shorter candidates are merged).
59    pub min_sentence_len: usize,
60}
61
62impl Default for SentenceSegmenter {
63    fn default() -> Self {
64        Self::new()
65    }
66}
67
68impl SentenceSegmenter {
69    /// Create a new segmenter with built-in English abbreviations.
70    pub fn new() -> Self {
71        Self {
72            abbreviations: builtin_abbreviations(),
73            min_sentence_len: 10,
74        }
75    }
76
77    /// Create a segmenter with a custom abbreviation list.
78    pub fn with_abbreviations(abbrevs: Vec<String>) -> Self {
79        let mut set = builtin_abbreviations();
80        for a in abbrevs {
81            set.insert(a);
82        }
83        Self {
84            abbreviations: set,
85            min_sentence_len: 10,
86        }
87    }
88
89    /// Segment `text` into sentence string slices.
90    pub fn segment<'a>(&self, text: &'a str) -> Vec<&'a str> {
91        if text.trim().is_empty() {
92            return Vec::new();
93        }
94        let boundaries = self.find_boundaries(text);
95        let mut result: Vec<&'a str> = Vec::new();
96        let mut start = 0;
97
98        for end in boundaries {
99            let slice = text[start..end].trim();
100            if !slice.is_empty() {
101                result.push(slice);
102            }
103            start = end;
104        }
105
106        let tail = text[start..].trim();
107        if !tail.is_empty() {
108            result.push(tail);
109        }
110
111        result
112    }
113
114    /// Segment `text` and return owned `String`s.
115    pub fn segment_owned(&self, text: &str) -> Vec<String> {
116        if text.trim().is_empty() {
117            return Vec::new();
118        }
119        let raw: Vec<String> = self.segment(text).iter().map(|s| s.to_string()).collect();
120
121        // Merge very short fragments into the previous sentence.
122        let mut result: Vec<String> = Vec::new();
123        for s in raw {
124            if s.len() < self.min_sentence_len && !result.is_empty() {
125                if let Some(last) = result.last_mut() {
126                    last.push(' ');
127                    last.push_str(&s);
128                }
129            } else {
130                result.push(s);
131            }
132        }
133        result
134    }
135
136    // -----------------------------------------------------------------------
137    // Internal helpers
138    // -----------------------------------------------------------------------
139
140    fn find_boundaries(&self, text: &str) -> Vec<usize> {
141        let chars: Vec<(usize, char)> = text.char_indices().collect();
142        let n = chars.len();
143        let mut boundaries: Vec<usize> = Vec::new();
144
145        let mut i = 0usize;
146        while i < n {
147            let (byte_pos, ch) = chars[i];
148
149            if ch == '.' || ch == '!' || ch == '?' {
150                // Check for ellipsis (...)
151                if ch == '.' && i + 2 < n && chars[i + 1].1 == '.' && chars[i + 2].1 == '.' {
152                    i += 3;
153                    continue;
154                }
155
156                // Check if this period follows an abbreviation.
157                if ch == '.' && self.is_abbreviation_period(text, byte_pos) {
158                    i += 1;
159                    continue;
160                }
161
162                // Check if period is inside a number like 3.14
163                if ch == '.' && self.is_decimal_period(text, byte_pos) {
164                    i += 1;
165                    continue;
166                }
167
168                // Find end of punctuation cluster
169                let mut end_i = i + 1;
170                while end_i < n
171                    && (chars[end_i].1 == '!' || chars[end_i].1 == '?' || chars[end_i].1 == '.')
172                {
173                    end_i += 1;
174                }
175
176                let boundary_byte = if end_i < n {
177                    chars[end_i].0
178                } else {
179                    text.len()
180                };
181
182                if self.is_sentence_boundary(text, boundary_byte) {
183                    boundaries.push(boundary_byte);
184                }
185
186                i = end_i;
187                continue;
188            }
189
190            i += 1;
191        }
192
193        boundaries
194    }
195
196    fn is_abbreviation_period(&self, text: &str, period_byte: usize) -> bool {
197        let prefix = &text[..period_byte];
198        let word = prefix
199            .rsplit(|c: char| !c.is_alphabetic() && c != '.')
200            .next()
201            .unwrap_or("");
202        self.abbreviations.contains(word)
203            || self.abbreviations.contains(&word.to_lowercase())
204            || (word.len() == 1 && word.chars().next().is_some_and(|c| c.is_uppercase()))
205    }
206
207    fn is_decimal_period(&self, text: &str, period_byte: usize) -> bool {
208        // Check if preceded by a digit and followed by a digit (e.g., 3.14)
209        let before = text[..period_byte]
210            .chars()
211            .next_back()
212            .is_some_and(|c| c.is_ascii_digit());
213        let after = text[period_byte + 1..]
214            .chars()
215            .next()
216            .is_some_and(|c| c.is_ascii_digit());
217        before && after
218    }
219
220    fn is_sentence_boundary(&self, text: &str, pos: usize) -> bool {
221        if pos >= text.len() {
222            return true;
223        }
224        let after = &text[pos..];
225        let trimmed = after.trim_start();
226        if trimmed.is_empty() {
227            return true;
228        }
229        trimmed.chars().next().is_some_and(|c| {
230            c.is_uppercase()
231                || c.is_ascii_digit()
232                || matches!(c, '"' | '\'' | '(' | '[' | '\u{201C}' | '\u{2018}')
233        })
234    }
235}
236
237// ---------------------------------------------------------------------------
238// TextChunker
239// ---------------------------------------------------------------------------
240
241/// A chunk of text with positional metadata.
242#[derive(Debug, Clone)]
243pub struct TextChunk {
244    /// The chunk text.
245    pub text: String,
246    /// Byte offset of the chunk start in the source text.
247    pub start: usize,
248    /// Byte offset of the chunk end (exclusive) in the source text.
249    pub end: usize,
250    /// Zero-based index of this chunk.
251    pub chunk_index: usize,
252    /// Total number of chunks produced.
253    pub total_chunks: usize,
254}
255
256/// Sliding-window text chunker.
257///
258/// # Example
259///
260/// ```rust
261/// use scirs2_text::segmentation::TextChunker;
262///
263/// let chunker = TextChunker::new(10, 2);
264/// let chunks = chunker.chunk("Rust is fast. Rust is safe. Rust is fun.");
265/// assert!(!chunks.is_empty());
266/// ```
267pub struct TextChunker {
268    /// Number of tokens (words) per chunk.
269    pub chunk_size: usize,
270    /// Number of tokens of overlap between consecutive chunks.
271    pub overlap: usize,
272    /// If `true`, try to respect sentence boundaries.
273    pub by_sentence: bool,
274}
275
276impl Default for TextChunker {
277    fn default() -> Self {
278        Self::new(512, 50)
279    }
280}
281
282impl TextChunker {
283    /// Create a new chunker.
284    pub fn new(chunk_size: usize, overlap: usize) -> Self {
285        let safe_overlap = if overlap >= chunk_size {
286            chunk_size.saturating_sub(1)
287        } else {
288            overlap
289        };
290        Self {
291            chunk_size,
292            overlap: safe_overlap,
293            by_sentence: false,
294        }
295    }
296
297    /// Enable sentence-boundary-respecting mode.
298    pub fn with_sentence_boundaries(mut self) -> Self {
299        self.by_sentence = true;
300        self
301    }
302
303    /// Chunk `text` and return plain `String` chunks.
304    pub fn chunk(&self, text: &str) -> Vec<String> {
305        self.chunk_with_metadata(text)
306            .into_iter()
307            .map(|c| c.text)
308            .collect()
309    }
310
311    /// Chunk `text` and return `TextChunk` structs with metadata.
312    pub fn chunk_with_metadata(&self, text: &str) -> Vec<TextChunk> {
313        if text.is_empty() {
314            return Vec::new();
315        }
316
317        if self.by_sentence {
318            self.chunk_by_sentence(text)
319        } else {
320            self.chunk_by_tokens(text)
321        }
322    }
323
324    // -----------------------------------------------------------------------
325    // Token-based chunking
326    // -----------------------------------------------------------------------
327
328    fn chunk_by_tokens(&self, text: &str) -> Vec<TextChunk> {
329        let tokens: Vec<(usize, usize)> = token_byte_ranges(text);
330
331        if tokens.is_empty() {
332            return Vec::new();
333        }
334
335        let step = self.chunk_size.saturating_sub(self.overlap).max(1);
336        let n = tokens.len();
337
338        let mut raw: Vec<(usize, usize)> = Vec::new();
339        let mut start_idx = 0usize;
340        while start_idx < n {
341            let end_idx = (start_idx + self.chunk_size).min(n);
342            let chunk_start_byte = tokens[start_idx].0;
343            let chunk_end_byte = tokens[end_idx - 1].1;
344            raw.push((chunk_start_byte, chunk_end_byte));
345            if end_idx >= n {
346                break;
347            }
348            start_idx += step;
349        }
350
351        let total = raw.len();
352        raw.into_iter()
353            .enumerate()
354            .map(|(idx, (start, end))| TextChunk {
355                text: text[start..end].to_string(),
356                start,
357                end,
358                chunk_index: idx,
359                total_chunks: total,
360            })
361            .collect()
362    }
363
364    // -----------------------------------------------------------------------
365    // Sentence-boundary-aware chunking
366    // -----------------------------------------------------------------------
367
368    fn chunk_by_sentence(&self, text: &str) -> Vec<TextChunk> {
369        let segmenter = SentenceSegmenter::new();
370        let sentences = segmenter.segment(text);
371
372        if sentences.is_empty() {
373            return Vec::new();
374        }
375
376        let mut chunks_data: Vec<(String, usize, usize)> = Vec::new();
377        let overlap_sentences = (self.overlap / 10).max(1);
378        let mut i = 0;
379
380        while i < sentences.len() {
381            let mut word_count = 0;
382            let mut j = i;
383            let mut chunk_parts: Vec<&str> = Vec::new();
384
385            while j < sentences.len() {
386                let sentence = sentences[j];
387                let wc = sentence.split_whitespace().count();
388                if word_count + wc > self.chunk_size && !chunk_parts.is_empty() {
389                    break;
390                }
391                chunk_parts.push(sentence);
392                word_count += wc;
393                j += 1;
394            }
395
396            if !chunk_parts.is_empty() {
397                let combined = chunk_parts.join(" ");
398                let start_byte = text.find(chunk_parts[0]).unwrap_or(0);
399                let last = chunk_parts[chunk_parts.len() - 1];
400                let last_start = text.rfind(last).unwrap_or(start_byte);
401                let end_byte = (last_start + last.len()).min(text.len());
402                chunks_data.push((combined, start_byte, end_byte));
403            }
404
405            let advance = (j - i).saturating_sub(overlap_sentences).max(1);
406            i += advance;
407        }
408
409        let total = chunks_data.len();
410        chunks_data
411            .into_iter()
412            .enumerate()
413            .map(|(idx, (text_s, start, end))| TextChunk {
414                text: text_s,
415                start,
416                end,
417                chunk_index: idx,
418                total_chunks: total,
419            })
420            .collect()
421    }
422}
423
424// ---------------------------------------------------------------------------
425// Helpers
426// ---------------------------------------------------------------------------
427
428/// Return `(start_byte, end_byte)` pairs for each whitespace-delimited token.
429pub fn token_byte_ranges(text: &str) -> Vec<(usize, usize)> {
430    let mut result = Vec::new();
431    let mut in_token = false;
432    let mut token_start = 0usize;
433
434    for (byte_pos, ch) in text.char_indices() {
435        if ch.is_whitespace() {
436            if in_token {
437                result.push((token_start, byte_pos));
438                in_token = false;
439            }
440        } else if !in_token {
441            token_start = byte_pos;
442            in_token = true;
443        }
444    }
445    if in_token {
446        result.push((token_start, text.len()));
447    }
448    result
449}
450
451// ---------------------------------------------------------------------------
452// Tests
453// ---------------------------------------------------------------------------
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_basic_segmentation() {
461        let seg = SentenceSegmenter::new();
462        let sentences = seg.segment("Hello world. How are you? I am fine.");
463        assert_eq!(
464            sentences.len(),
465            3,
466            "Expected 3 sentences, got {:?}",
467            sentences
468        );
469    }
470
471    #[test]
472    fn test_abbreviation_not_split() {
473        let seg = SentenceSegmenter::new();
474        let sentences = seg.segment("We met Dr. Smith today. He is well.");
475        assert_eq!(
476            sentences.len(),
477            2,
478            "Abbreviation should not create extra splits: {:?}",
479            sentences
480        );
481    }
482
483    #[test]
484    fn test_exclamation_and_question() {
485        let seg = SentenceSegmenter::new();
486        let sentences = seg.segment("Amazing! Really? Yes absolutely.");
487        assert!(sentences.len() >= 2);
488    }
489
490    #[test]
491    fn test_segment_owned() {
492        let seg = SentenceSegmenter::new();
493        let sentences = seg.segment_owned("First sentence. Second sentence. Third sentence.");
494        assert!(!sentences.is_empty());
495        for s in &sentences {
496            assert!(!s.is_empty());
497        }
498    }
499
500    #[test]
501    fn test_empty_text_returns_empty() {
502        let seg = SentenceSegmenter::new();
503        assert!(seg.segment("").is_empty());
504        assert!(seg.segment_owned("").is_empty());
505    }
506
507    #[test]
508    fn test_single_sentence() {
509        let seg = SentenceSegmenter::new();
510        let result = seg.segment("This is just one sentence");
511        assert_eq!(result.len(), 1);
512    }
513
514    #[test]
515    fn test_with_abbreviations() {
516        let seg = SentenceSegmenter::with_abbreviations(vec!["Esq".to_string()]);
517        let result = seg.segment("John Smith, Esq. is present. He said hello.");
518        assert_eq!(result.len(), 2, "Got {:?}", result);
519    }
520
521    #[test]
522    fn test_no_false_split_on_decimal() {
523        let seg = SentenceSegmenter::new();
524        let result = seg.segment("Pi is about 3.14159 in value. That is a fact.");
525        assert_eq!(result.len(), 2, "Got {:?}", result);
526    }
527
528    #[test]
529    fn test_chunker_basic() {
530        let chunker = TextChunker::new(5, 1);
531        let text = "one two three four five six seven eight nine ten";
532        let chunks = chunker.chunk(text);
533        assert!(!chunks.is_empty());
534        for chunk in &chunks {
535            let wc = chunk.split_whitespace().count();
536            assert!(wc <= 5, "Chunk '{}' has {} words", chunk, wc);
537        }
538    }
539
540    #[test]
541    fn test_chunker_overlap() {
542        let chunker = TextChunker::new(4, 2);
543        let text = "a b c d e f g h";
544        let chunks = chunker.chunk(text);
545        assert!(chunks.len() >= 2);
546        if chunks.len() >= 2 {
547            let words_0: Vec<&str> = chunks[0].split_whitespace().collect();
548            let words_1: Vec<&str> = chunks[1].split_whitespace().collect();
549            let last_two: Vec<&str> = words_0.iter().rev().take(2).rev().copied().collect();
550            let first_two: Vec<&str> = words_1.iter().take(2).copied().collect();
551            assert_eq!(last_two, first_two, "Overlap should share tokens");
552        }
553    }
554
555    #[test]
556    fn test_chunker_with_metadata() {
557        let chunker = TextChunker::new(3, 0);
558        let text = "alpha beta gamma delta epsilon";
559        let chunks = chunker.chunk_with_metadata(text);
560        for (i, chunk) in chunks.iter().enumerate() {
561            assert_eq!(chunk.chunk_index, i);
562            assert_eq!(chunk.total_chunks, chunks.len());
563            assert_eq!(&text[chunk.start..chunk.end], chunk.text.as_str());
564        }
565    }
566
567    #[test]
568    fn test_chunker_empty_text() {
569        let chunker = TextChunker::new(10, 2);
570        assert!(chunker.chunk("").is_empty());
571        assert!(chunker.chunk_with_metadata("").is_empty());
572    }
573
574    #[test]
575    fn test_chunker_short_text() {
576        let chunker = TextChunker::new(100, 10);
577        let text = "just three words";
578        let chunks = chunker.chunk(text);
579        assert_eq!(chunks.len(), 1);
580        assert_eq!(chunks[0], text);
581    }
582
583    #[test]
584    fn test_chunker_by_sentence() {
585        let chunker = TextChunker::new(20, 5).with_sentence_boundaries();
586        let text = "The quick brown fox jumps. A lazy dog sleeps. The sun is shining.";
587        let chunks = chunker.chunk(text);
588        assert!(!chunks.is_empty());
589    }
590
591    #[test]
592    fn test_chunker_overlap_clamped() {
593        let chunker = TextChunker::new(3, 10);
594        assert!(chunker.overlap < chunker.chunk_size);
595    }
596
597    #[test]
598    fn test_token_byte_ranges() {
599        let text = "hello world foo";
600        let ranges = token_byte_ranges(text);
601        assert_eq!(ranges.len(), 3);
602        assert_eq!(&text[ranges[0].0..ranges[0].1], "hello");
603        assert_eq!(&text[ranges[1].0..ranges[1].1], "world");
604        assert_eq!(&text[ranges[2].0..ranges[2].1], "foo");
605    }
606}