1use crate::exceptions::{LangExtractError, LangExtractResult};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum TokenType {
16 Word = 0,
18 Number = 1,
20 Punctuation = 2,
22 Acronym = 3,
24}
25
26#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct TokenCharSpan {
29 pub start_pos: usize,
31 pub end_pos: usize,
33}
34
35impl TokenCharSpan {
36 pub fn new(start_pos: usize, end_pos: usize) -> Self {
38 Self { start_pos, end_pos }
39 }
40
41 pub fn length(&self) -> usize {
43 self.end_pos.saturating_sub(self.start_pos)
44 }
45}
46
47impl From<TokenCharSpan> for crate::data::CharInterval {
48 fn from(span: TokenCharSpan) -> Self {
49 crate::data::CharInterval::new(Some(span.start_pos), Some(span.end_pos))
50 }
51}
52
53#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55pub struct TokenInterval {
56 pub start_index: usize,
58 pub end_index: usize,
60}
61
62impl TokenInterval {
63 pub fn new(start_index: usize, end_index: usize) -> LangExtractResult<Self> {
65 if start_index >= end_index {
66 return Err(LangExtractError::invalid_input(format!(
67 "Start index {} must be < end index {}",
68 start_index, end_index
69 )));
70 }
71 Ok(Self {
72 start_index,
73 end_index,
74 })
75 }
76}
77
78#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80pub struct Token {
81 pub index: usize,
83 pub token_type: TokenType,
85 pub char_interval: TokenCharSpan,
87 pub first_token_after_newline: bool,
89}
90
91impl Token {
92 pub fn new(
94 index: usize,
95 token_type: TokenType,
96 char_interval: TokenCharSpan,
97 first_token_after_newline: bool,
98 ) -> Self {
99 Self {
100 index,
101 token_type,
102 char_interval,
103 first_token_after_newline,
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
110pub struct TokenizedText {
111 pub text: String,
113 pub tokens: Vec<Token>,
115}
116
117impl TokenizedText {
118 pub fn new(text: String) -> Self {
120 Self {
121 text,
122 tokens: Vec::new(),
123 }
124 }
125
126 pub fn len(&self) -> usize {
128 self.tokens.len()
129 }
130
131 pub fn is_empty(&self) -> bool {
133 self.tokens.is_empty()
134 }
135}
136
137pub struct Tokenizer {
139 _letters_pattern: Regex,
140 digits_pattern: Regex,
141 _symbols_pattern: Regex,
142 slash_abbrev_pattern: Regex,
143 token_pattern: Regex,
144 word_pattern: Regex,
145 end_of_sentence_pattern: Regex,
146 known_abbreviations: HashSet<String>,
147}
148
149impl Tokenizer {
150 pub fn new() -> LangExtractResult<Self> {
152 let letters_pattern = Regex::new(r"[A-Za-z]+").map_err(|e| {
154 LangExtractError::configuration(format!("Failed to compile letters regex: {}", e))
155 })?;
156
157 let digits_pattern = Regex::new(r"[0-9]+").map_err(|e| {
158 LangExtractError::configuration(format!("Failed to compile digits regex: {}", e))
159 })?;
160
161 let symbols_pattern = Regex::new(r"[^A-Za-z0-9\s]+").map_err(|e| {
162 LangExtractError::configuration(format!("Failed to compile symbols regex: {}", e))
163 })?;
164
165 let slash_abbrev_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+").map_err(|e| {
166 LangExtractError::configuration(format!("Failed to compile slash abbreviation regex: {}", e))
167 })?;
168
169 let token_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+|[A-Za-z]+|[0-9]+|[^A-Za-z0-9\s]+").map_err(|e| {
170 LangExtractError::configuration(format!("Failed to compile token regex: {}", e))
171 })?;
172
173 let word_pattern = Regex::new(r"^(?:[A-Za-z]+|[0-9]+)$").map_err(|e| {
174 LangExtractError::configuration(format!("Failed to compile word regex: {}", e))
175 })?;
176
177 let end_of_sentence_pattern = Regex::new(r"[.?!]$").map_err(|e| {
178 LangExtractError::configuration(format!("Failed to compile end of sentence regex: {}", e))
179 })?;
180
181 let known_abbreviations = [
183 "Mr.", "Mrs.", "Ms.", "Dr.", "Prof.", "St.", "Ave.", "Blvd.", "Rd.", "Ltd.", "Inc.", "Corp.",
184 "vs.", "etc.", "et al.", "i.e.", "e.g.", "cf.", "a.m.", "p.m.", "U.S.", "U.K.", "Ph.D.",
185 ]
186 .iter()
187 .map(|s| s.to_string())
188 .collect();
189
190 Ok(Self {
191 _letters_pattern: letters_pattern,
192 digits_pattern,
193 _symbols_pattern: symbols_pattern,
194 slash_abbrev_pattern,
195 token_pattern,
196 word_pattern,
197 end_of_sentence_pattern,
198 known_abbreviations,
199 })
200 }
201
202 pub fn tokenize(&self, text: &str) -> LangExtractResult<TokenizedText> {
204 let mut tokenized = TokenizedText::new(text.to_string());
205 let mut previous_end = 0;
206
207 for (token_index, token_match) in self.token_pattern.find_iter(text).enumerate() {
208 let start_pos = token_match.start();
209 let end_pos = token_match.end();
210 let matched_text = token_match.as_str();
211
212 let first_token_after_newline = if token_index > 0 {
214 let gap = &text[previous_end..start_pos];
215 gap.contains('\n') || gap.contains('\r')
216 } else {
217 false
218 };
219
220 let token_type = self.classify_token(matched_text);
222
223 let token = Token::new(
224 token_index,
225 token_type,
226 TokenCharSpan::new(start_pos, end_pos),
227 first_token_after_newline,
228 );
229
230 tokenized.tokens.push(token);
231 previous_end = end_pos;
232 }
233
234 Ok(tokenized)
235 }
236
237 fn classify_token(&self, text: &str) -> TokenType {
239 if self.digits_pattern.is_match(text) {
240 TokenType::Number
241 } else if self.slash_abbrev_pattern.is_match(text) {
242 TokenType::Acronym
243 } else if self.word_pattern.is_match(text) {
244 TokenType::Word
245 } else {
246 TokenType::Punctuation
247 }
248 }
249
250 pub fn tokens_text(
252 &self,
253 tokenized_text: &TokenizedText,
254 token_interval: &TokenInterval,
255 ) -> LangExtractResult<String> {
256 if token_interval.start_index >= token_interval.end_index {
257 return Err(LangExtractError::invalid_input(format!(
258 "Invalid token interval: start_index={}, end_index={}",
259 token_interval.start_index, token_interval.end_index
260 )));
261 }
262
263 if token_interval.end_index > tokenized_text.tokens.len() {
264 return Err(LangExtractError::invalid_input(format!(
265 "Token interval end_index {} exceeds token count {}",
266 token_interval.end_index,
267 tokenized_text.tokens.len()
268 )));
269 }
270
271 if tokenized_text.tokens.is_empty() {
272 return Ok(String::new());
273 }
274
275 let start_token = &tokenized_text.tokens[token_interval.start_index];
276 let end_token = &tokenized_text.tokens[token_interval.end_index - 1];
277
278 let start_char = start_token.char_interval.start_pos;
279 let end_char = end_token.char_interval.end_pos;
280
281 Ok(tokenized_text.text[start_char..end_char].to_string())
282 }
283
284 pub fn is_end_of_sentence_token(
286 &self,
287 text: &str,
288 tokens: &[Token],
289 current_idx: usize,
290 ) -> bool {
291 if current_idx >= tokens.len() {
292 return false;
293 }
294
295 let current_token = &tokens[current_idx];
296 let current_token_text = &text[current_token.char_interval.start_pos..current_token.char_interval.end_pos];
297
298 if self.end_of_sentence_pattern.is_match(current_token_text) {
299 if current_idx > 0 {
301 let prev_token = &tokens[current_idx - 1];
302 let prev_token_text = &text[prev_token.char_interval.start_pos..prev_token.char_interval.end_pos];
303 let combined = format!("{}{}", prev_token_text, current_token_text);
304
305 if self.known_abbreviations.contains(&combined) {
306 return false;
307 }
308 }
309 return true;
310 }
311 false
312 }
313
314 pub fn is_sentence_break_after_newline(
316 &self,
317 text: &str,
318 tokens: &[Token],
319 current_idx: usize,
320 ) -> bool {
321 if current_idx + 1 >= tokens.len() {
322 return false;
323 }
324
325 let current_token = &tokens[current_idx];
326 let next_token = &tokens[current_idx + 1];
327
328 let gap_start = current_token.char_interval.end_pos;
330 let gap_end = next_token.char_interval.start_pos;
331
332 if gap_start >= gap_end {
333 return false;
334 }
335
336 let gap_text = &text[gap_start..gap_end];
337 if !gap_text.contains('\n') {
338 return false;
339 }
340
341 let next_token_text = &text[next_token.char_interval.start_pos..next_token.char_interval.end_pos];
343 !next_token_text.is_empty() && next_token_text.chars().next().unwrap().is_uppercase()
344 }
345
346 pub fn find_sentence_range(
348 &self,
349 text: &str,
350 tokens: &[Token],
351 start_token_index: usize,
352 ) -> LangExtractResult<TokenInterval> {
353 if start_token_index >= tokens.len() {
354 return Err(LangExtractError::invalid_input(format!(
355 "start_token_index {} out of range. Total tokens: {}",
356 start_token_index,
357 tokens.len()
358 )));
359 }
360
361 let mut i = start_token_index;
362 while i < tokens.len() {
363 if tokens[i].token_type == TokenType::Punctuation {
364 if self.is_end_of_sentence_token(text, tokens, i) {
365 return TokenInterval::new(start_token_index, i + 1);
366 }
367 }
368 if self.is_sentence_break_after_newline(text, tokens, i) {
369 return TokenInterval::new(start_token_index, i + 1);
370 }
371 i += 1;
372 }
373
374 TokenInterval::new(start_token_index, tokens.len())
375 }
376}
377
378impl Default for Tokenizer {
379 fn default() -> Self {
380 Self::new().expect("Failed to create default tokenizer")
381 }
382}
383
384#[cfg(test)]
385mod tests;
386
387pub struct SentenceIterator<'a> {
389 tokenized_text: &'a TokenizedText,
390 tokenizer: &'a Tokenizer,
391 current_token_pos: usize,
392 token_len: usize,
393}
394
395impl<'a> SentenceIterator<'a> {
396 pub fn new(
398 tokenized_text: &'a TokenizedText,
399 tokenizer: &'a Tokenizer,
400 current_token_pos: usize,
401 ) -> LangExtractResult<Self> {
402 let token_len = tokenized_text.tokens.len();
403
404 if current_token_pos > token_len {
405 return Err(LangExtractError::invalid_input(format!(
406 "Current token position {} is past the length of the document {}",
407 current_token_pos, token_len
408 )));
409 }
410
411 Ok(Self {
412 tokenized_text,
413 tokenizer,
414 current_token_pos,
415 token_len,
416 })
417 }
418}
419
420impl<'a> Iterator for SentenceIterator<'a> {
421 type Item = LangExtractResult<TokenInterval>;
422
423 fn next(&mut self) -> Option<Self::Item> {
424 if self.current_token_pos >= self.token_len {
425 return None;
426 }
427
428 match self.tokenizer.find_sentence_range(
430 &self.tokenized_text.text,
431 &self.tokenized_text.tokens,
432 self.current_token_pos,
433 ) {
434 Ok(sentence_range) => {
435 let adjusted_range = match TokenInterval::new(
438 self.current_token_pos,
439 sentence_range.end_index,
440 ) {
441 Ok(range) => range,
442 Err(e) => return Some(Err(e)),
443 };
444
445 self.current_token_pos = sentence_range.end_index;
446 Some(Ok(adjusted_range))
447 }
448 Err(e) => Some(Err(e)),
449 }
450 }
451}
452
453pub fn tokenize(text: &str) -> LangExtractResult<TokenizedText> {
455 let tokenizer = Tokenizer::new()?;
456 tokenizer.tokenize(text)
457}