use regex::Regex;
use std::sync::LazyLock;
use crate::dictionaries::abbreviations::is_abbreviation;
static DECIMAL_PATTERN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\d+\.\d+").expect("valid regex"));
static URL_PATTERN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"(?:https?://|www\.)\S+").expect("valid regex"));
static EMAIL_PATTERN: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b").expect("valid regex")
});
static INITIALS_PATTERN: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"\b[A-Z]\.(?:[A-Z]\.)*").expect("valid regex"));
#[tracing::instrument(skip_all, fields(text_len = text.len()))]
pub fn split_sentences(text: &str) -> Vec<String> {
if text.trim().is_empty() {
return Vec::new();
}
let min_length = 3;
let mut sentences = Vec::new();
let mut current = String::new();
let chars: Vec<char> = text.chars().collect();
let mut i = 0;
while i < chars.len() {
let ch = chars[i];
current.push(ch);
if is_sentence_terminator(ch) {
let context = extract_context(&chars, i);
if is_sentence_boundary(&context, ¤t) {
let sentence = current.trim().to_string();
if sentence.len() >= min_length {
sentences.push(sentence);
}
current.clear();
}
}
i += 1;
}
let sentence = current.trim().to_string();
if sentence.len() >= min_length {
sentences.push(sentence);
}
sentences
}
pub fn extract_words(text: &str) -> Vec<String> {
text.split_whitespace()
.map(|w| w.trim_matches(|c: char| !c.is_alphanumeric() && c != '\'' && c != '-'))
.filter(|w| !w.is_empty())
.map(|w| w.to_lowercase())
.collect()
}
pub fn split_paragraphs(text: &str) -> Vec<String> {
text.split("\n\n")
.map(|p| p.trim().to_string())
.filter(|p| !p.is_empty())
.collect()
}
const fn is_sentence_terminator(ch: char) -> bool {
matches!(ch, '.' | '!' | '?')
}
struct SentenceContext {
punctuation: char,
word_before: String,
char_after: Option<char>,
text_after: String,
is_end_of_text: bool,
}
fn extract_context(chars: &[char], pos: usize) -> SentenceContext {
let before = get_word_before(chars, pos);
let mut after_start = pos + 1;
while after_start < chars.len() && chars[after_start].is_whitespace() {
after_start += 1;
}
let after_char = chars.get(after_start).copied();
let after_text: String = chars[after_start..].iter().take(20).collect();
SentenceContext {
punctuation: chars[pos],
word_before: before,
char_after: after_char,
text_after: after_text,
is_end_of_text: pos == chars.len() - 1,
}
}
fn get_word_before(chars: &[char], pos: usize) -> String {
let mut i = pos;
while i > 0 {
i -= 1;
if !chars[i].is_whitespace() && chars[i] != '.' {
break;
}
}
let mut word_chars = Vec::new();
loop {
if chars[i].is_alphanumeric() || chars[i] == '.' {
word_chars.push(chars[i]);
} else {
break;
}
if i == 0 {
break;
}
i -= 1;
}
word_chars.reverse();
word_chars.iter().collect()
}
fn is_sentence_boundary(context: &SentenceContext, current_sentence: &str) -> bool {
if context.is_end_of_text {
return true;
}
if context.punctuation == '!' || context.punctuation == '?' {
return check_next_char_capitalization(context);
}
if is_likely_abbreviation(&context.word_before) {
return false;
}
if is_likely_initial(&context.word_before) {
return false;
}
if is_decimal_number(current_sentence) {
return false;
}
if current_sentence.ends_with("...") {
return false;
}
if contains_url_or_email(current_sentence) {
return false;
}
if let Some(next_char) = context.char_after
&& next_char.is_ascii_digit()
&& context
.word_before
.chars()
.last()
.is_some_and(|c| c.is_ascii_digit())
{
return false;
}
if let Some(next_char) = context.char_after {
if next_char.is_uppercase() {
return true;
}
if next_char.is_lowercase() {
return false;
}
}
true
}
fn check_next_char_capitalization(context: &SentenceContext) -> bool {
if let Some(next_char) = context.char_after {
if next_char.is_uppercase() {
return true;
}
if next_char == '"' || next_char == '\'' {
return context
.text_after
.chars()
.nth(1)
.is_some_and(|c| c.is_uppercase());
}
}
true
}
fn is_likely_abbreviation(word: &str) -> bool {
if word.is_empty() {
return false;
}
let word_clean = word.trim_end_matches('.');
if is_abbreviation(word_clean) {
return true;
}
word_clean.len() == 1 && word_clean.chars().next().is_some_and(|c| c.is_uppercase())
}
fn is_likely_initial(word: &str) -> bool {
if word.is_empty() {
return false;
}
if word.len() == 2
&& word.chars().next().is_some_and(|c| c.is_uppercase())
&& word.ends_with('.')
{
return true;
}
INITIALS_PATTERN.is_match(word)
}
fn is_decimal_number(sentence: &str) -> bool {
let last_part: String = sentence
.chars()
.rev()
.take(10)
.collect::<String>()
.chars()
.rev()
.collect();
DECIMAL_PATTERN.is_match(&last_part)
}
fn contains_url_or_email(sentence: &str) -> bool {
let last_part: String = sentence
.chars()
.rev()
.take(50)
.collect::<String>()
.chars()
.rev()
.collect();
URL_PATTERN.is_match(&last_part) || EMAIL_PATTERN.is_match(&last_part)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_sentences() {
let sentences = split_sentences("This is a sentence. This is another sentence.");
assert_eq!(sentences.len(), 2);
assert_eq!(sentences[0], "This is a sentence.");
assert_eq!(sentences[1], "This is another sentence.");
}
#[test]
fn abbreviations_not_split() {
let sentences = split_sentences("Dr. Smith went to the store. He bought milk.");
assert_eq!(sentences.len(), 2);
assert!(sentences[0].contains("Dr. Smith"));
}
#[test]
fn decimal_numbers_not_split() {
let sentences = split_sentences("The price is 3.14 dollars. That's cheap.");
assert_eq!(sentences.len(), 2);
assert!(sentences[0].contains("3.14"));
}
#[test]
fn question_and_exclamation() {
let sentences = split_sentences("Are you serious? I can't believe it! This is amazing.");
assert_eq!(sentences.len(), 3);
}
#[test]
fn empty_input() {
assert!(split_sentences("").is_empty());
assert!(split_sentences(" ").is_empty());
}
#[test]
fn extract_words_basic() {
let words = extract_words("Hello, world! This is a test.");
assert_eq!(words, vec!["hello", "world", "this", "is", "a", "test"]);
}
#[test]
fn split_paragraphs_basic() {
let text = "First paragraph.\n\nSecond paragraph.\n\nThird.";
let paras = split_paragraphs(text);
assert_eq!(paras.len(), 3);
}
}