1use crate::exceptions::{LangExtractError, LangExtractResult};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum TokenType {
16 Word = 0,
18 Number = 1,
20 Punctuation = 2,
22 Acronym = 3,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct CharInterval {
29 pub start_pos: usize,
31 pub end_pos: usize,
33}
34
35impl CharInterval {
36 pub fn new(start_pos: usize, end_pos: usize) -> Self {
38 Self { start_pos, end_pos }
39 }
40
41 pub fn length(&self) -> usize {
43 self.end_pos.saturating_sub(self.start_pos)
44 }
45}
46
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
49pub struct TokenInterval {
50 pub start_index: usize,
52 pub end_index: usize,
54}
55
56impl TokenInterval {
57 pub fn new(start_index: usize, end_index: usize) -> LangExtractResult<Self> {
59 if start_index >= end_index {
60 return Err(LangExtractError::invalid_input(format!(
61 "Start index {} must be < end index {}",
62 start_index, end_index
63 )));
64 }
65 Ok(Self {
66 start_index,
67 end_index,
68 })
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub struct Token {
75 pub index: usize,
77 pub token_type: TokenType,
79 pub char_interval: CharInterval,
81 pub first_token_after_newline: bool,
83}
84
85impl Token {
86 pub fn new(
88 index: usize,
89 token_type: TokenType,
90 char_interval: CharInterval,
91 first_token_after_newline: bool,
92 ) -> Self {
93 Self {
94 index,
95 token_type,
96 char_interval,
97 first_token_after_newline,
98 }
99 }
100}
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub struct TokenizedText {
105 pub text: String,
107 pub tokens: Vec<Token>,
109}
110
111impl TokenizedText {
112 pub fn new(text: String) -> Self {
114 Self {
115 text,
116 tokens: Vec::new(),
117 }
118 }
119
120 pub fn len(&self) -> usize {
122 self.tokens.len()
123 }
124
125 pub fn is_empty(&self) -> bool {
127 self.tokens.is_empty()
128 }
129}
130
131pub struct Tokenizer {
133 letters_pattern: Regex,
134 digits_pattern: Regex,
135 symbols_pattern: Regex,
136 slash_abbrev_pattern: Regex,
137 token_pattern: Regex,
138 word_pattern: Regex,
139 end_of_sentence_pattern: Regex,
140 known_abbreviations: HashSet<String>,
141}
142
143impl Tokenizer {
144 pub fn new() -> LangExtractResult<Self> {
146 let letters_pattern = Regex::new(r"[A-Za-z]+").map_err(|e| {
148 LangExtractError::configuration(format!("Failed to compile letters regex: {}", e))
149 })?;
150
151 let digits_pattern = Regex::new(r"[0-9]+").map_err(|e| {
152 LangExtractError::configuration(format!("Failed to compile digits regex: {}", e))
153 })?;
154
155 let symbols_pattern = Regex::new(r"[^A-Za-z0-9\s]+").map_err(|e| {
156 LangExtractError::configuration(format!("Failed to compile symbols regex: {}", e))
157 })?;
158
159 let slash_abbrev_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+").map_err(|e| {
160 LangExtractError::configuration(format!("Failed to compile slash abbreviation regex: {}", e))
161 })?;
162
163 let token_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+|[A-Za-z]+|[0-9]+|[^A-Za-z0-9\s]+").map_err(|e| {
164 LangExtractError::configuration(format!("Failed to compile token regex: {}", e))
165 })?;
166
167 let word_pattern = Regex::new(r"^(?:[A-Za-z]+|[0-9]+)$").map_err(|e| {
168 LangExtractError::configuration(format!("Failed to compile word regex: {}", e))
169 })?;
170
171 let end_of_sentence_pattern = Regex::new(r"[.?!]$").map_err(|e| {
172 LangExtractError::configuration(format!("Failed to compile end of sentence regex: {}", e))
173 })?;
174
175 let known_abbreviations = [
177 "Mr.", "Mrs.", "Ms.", "Dr.", "Prof.", "St.", "Ave.", "Blvd.", "Rd.", "Ltd.", "Inc.", "Corp.",
178 "vs.", "etc.", "et al.", "i.e.", "e.g.", "cf.", "a.m.", "p.m.", "U.S.", "U.K.", "Ph.D.",
179 ]
180 .iter()
181 .map(|s| s.to_string())
182 .collect();
183
184 Ok(Self {
185 letters_pattern,
186 digits_pattern,
187 symbols_pattern,
188 slash_abbrev_pattern,
189 token_pattern,
190 word_pattern,
191 end_of_sentence_pattern,
192 known_abbreviations,
193 })
194 }
195
196 pub fn tokenize(&self, text: &str) -> LangExtractResult<TokenizedText> {
198 let mut tokenized = TokenizedText::new(text.to_string());
199 let mut previous_end = 0;
200
201 for (token_index, token_match) in self.token_pattern.find_iter(text).enumerate() {
202 let start_pos = token_match.start();
203 let end_pos = token_match.end();
204 let matched_text = token_match.as_str();
205
206 let first_token_after_newline = if token_index > 0 {
208 let gap = &text[previous_end..start_pos];
209 gap.contains('\n') || gap.contains('\r')
210 } else {
211 false
212 };
213
214 let token_type = self.classify_token(matched_text);
216
217 let token = Token::new(
218 token_index,
219 token_type,
220 CharInterval::new(start_pos, end_pos),
221 first_token_after_newline,
222 );
223
224 tokenized.tokens.push(token);
225 previous_end = end_pos;
226 }
227
228 Ok(tokenized)
229 }
230
231 fn classify_token(&self, text: &str) -> TokenType {
233 if self.digits_pattern.is_match(text) {
234 TokenType::Number
235 } else if self.slash_abbrev_pattern.is_match(text) {
236 TokenType::Acronym
237 } else if self.word_pattern.is_match(text) {
238 TokenType::Word
239 } else {
240 TokenType::Punctuation
241 }
242 }
243
244 pub fn tokens_text(
246 &self,
247 tokenized_text: &TokenizedText,
248 token_interval: &TokenInterval,
249 ) -> LangExtractResult<String> {
250 if token_interval.start_index >= token_interval.end_index {
251 return Err(LangExtractError::invalid_input(format!(
252 "Invalid token interval: start_index={}, end_index={}",
253 token_interval.start_index, token_interval.end_index
254 )));
255 }
256
257 if token_interval.end_index > tokenized_text.tokens.len() {
258 return Err(LangExtractError::invalid_input(format!(
259 "Token interval end_index {} exceeds token count {}",
260 token_interval.end_index,
261 tokenized_text.tokens.len()
262 )));
263 }
264
265 if tokenized_text.tokens.is_empty() {
266 return Ok(String::new());
267 }
268
269 let start_token = &tokenized_text.tokens[token_interval.start_index];
270 let end_token = &tokenized_text.tokens[token_interval.end_index - 1];
271
272 let start_char = start_token.char_interval.start_pos;
273 let end_char = end_token.char_interval.end_pos;
274
275 Ok(tokenized_text.text[start_char..end_char].to_string())
276 }
277
278 pub fn is_end_of_sentence_token(
280 &self,
281 text: &str,
282 tokens: &[Token],
283 current_idx: usize,
284 ) -> bool {
285 if current_idx >= tokens.len() {
286 return false;
287 }
288
289 let current_token = &tokens[current_idx];
290 let current_token_text = &text[current_token.char_interval.start_pos..current_token.char_interval.end_pos];
291
292 if self.end_of_sentence_pattern.is_match(current_token_text) {
293 if current_idx > 0 {
295 let prev_token = &tokens[current_idx - 1];
296 let prev_token_text = &text[prev_token.char_interval.start_pos..prev_token.char_interval.end_pos];
297 let combined = format!("{}{}", prev_token_text, current_token_text);
298
299 if self.known_abbreviations.contains(&combined) {
300 return false;
301 }
302 }
303 return true;
304 }
305 false
306 }
307
308 pub fn is_sentence_break_after_newline(
310 &self,
311 text: &str,
312 tokens: &[Token],
313 current_idx: usize,
314 ) -> bool {
315 if current_idx + 1 >= tokens.len() {
316 return false;
317 }
318
319 let current_token = &tokens[current_idx];
320 let next_token = &tokens[current_idx + 1];
321
322 let gap_start = current_token.char_interval.end_pos;
324 let gap_end = next_token.char_interval.start_pos;
325
326 if gap_start >= gap_end {
327 return false;
328 }
329
330 let gap_text = &text[gap_start..gap_end];
331 if !gap_text.contains('\n') {
332 return false;
333 }
334
335 let next_token_text = &text[next_token.char_interval.start_pos..next_token.char_interval.end_pos];
337 !next_token_text.is_empty() && next_token_text.chars().next().unwrap().is_uppercase()
338 }
339
340 pub fn find_sentence_range(
342 &self,
343 text: &str,
344 tokens: &[Token],
345 start_token_index: usize,
346 ) -> LangExtractResult<TokenInterval> {
347 if start_token_index >= tokens.len() {
348 return Err(LangExtractError::invalid_input(format!(
349 "start_token_index {} out of range. Total tokens: {}",
350 start_token_index,
351 tokens.len()
352 )));
353 }
354
355 let mut i = start_token_index;
356 while i < tokens.len() {
357 if tokens[i].token_type == TokenType::Punctuation {
358 if self.is_end_of_sentence_token(text, tokens, i) {
359 return TokenInterval::new(start_token_index, i + 1);
360 }
361 }
362 if self.is_sentence_break_after_newline(text, tokens, i) {
363 return TokenInterval::new(start_token_index, i + 1);
364 }
365 i += 1;
366 }
367
368 TokenInterval::new(start_token_index, tokens.len())
369 }
370}
371
372impl Default for Tokenizer {
373 fn default() -> Self {
374 Self::new().expect("Failed to create default tokenizer")
375 }
376}
377
378#[cfg(test)]
379mod tests;
380
381pub struct SentenceIterator<'a> {
383 tokenized_text: &'a TokenizedText,
384 tokenizer: &'a Tokenizer,
385 current_token_pos: usize,
386 token_len: usize,
387}
388
389impl<'a> SentenceIterator<'a> {
390 pub fn new(
392 tokenized_text: &'a TokenizedText,
393 tokenizer: &'a Tokenizer,
394 current_token_pos: usize,
395 ) -> LangExtractResult<Self> {
396 let token_len = tokenized_text.tokens.len();
397
398 if current_token_pos > token_len {
399 return Err(LangExtractError::invalid_input(format!(
400 "Current token position {} is past the length of the document {}",
401 current_token_pos, token_len
402 )));
403 }
404
405 Ok(Self {
406 tokenized_text,
407 tokenizer,
408 current_token_pos,
409 token_len,
410 })
411 }
412}
413
414impl<'a> Iterator for SentenceIterator<'a> {
415 type Item = LangExtractResult<TokenInterval>;
416
417 fn next(&mut self) -> Option<Self::Item> {
418 if self.current_token_pos >= self.token_len {
419 return None;
420 }
421
422 match self.tokenizer.find_sentence_range(
424 &self.tokenized_text.text,
425 &self.tokenized_text.tokens,
426 self.current_token_pos,
427 ) {
428 Ok(sentence_range) => {
429 let adjusted_range = match TokenInterval::new(
432 self.current_token_pos,
433 sentence_range.end_index,
434 ) {
435 Ok(range) => range,
436 Err(e) => return Some(Err(e)),
437 };
438
439 self.current_token_pos = sentence_range.end_index;
440 Some(Ok(adjusted_range))
441 }
442 Err(e) => Some(Err(e)),
443 }
444 }
445}
446
447pub fn tokenize(text: &str) -> LangExtractResult<TokenizedText> {
449 let tokenizer = Tokenizer::new()?;
450 tokenizer.tokenize(text)
451}