langextract_rust/
tokenizer.rs

1//! Text tokenization functionality.
2//!
3//! Provides methods to split text into regex-based, word-level (and
4//! punctuation-level) tokens. Tokenization is necessary for alignment
5//! between extracted data and the source text and for forming sentence
6//! boundaries for LLM information extraction.
7
8use crate::exceptions::{LangExtractError, LangExtractResult};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12
13/// Enumeration of token types produced during tokenization
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum TokenType {
16    /// Represents an alphabetical word token
17    Word = 0,
18    /// Represents a numeric token
19    Number = 1,
20    /// Represents punctuation characters
21    Punctuation = 2,
22    /// Represents an acronym or slash-delimited abbreviation
23    Acronym = 3,
24}
25
26/// Represents a character interval in text
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct CharInterval {
29    /// Starting character index (inclusive)
30    pub start_pos: usize,
31    /// Ending character index (exclusive)
32    pub end_pos: usize,
33}
34
35impl CharInterval {
36    /// Create a new character interval
37    pub fn new(start_pos: usize, end_pos: usize) -> Self {
38        Self { start_pos, end_pos }
39    }
40
41    /// Get the length of the interval
42    pub fn length(&self) -> usize {
43        self.end_pos.saturating_sub(self.start_pos)
44    }
45}
46
47/// Represents a token interval over tokens in tokenized text
48#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
49pub struct TokenInterval {
50    /// The index of the first token in the interval
51    pub start_index: usize,
52    /// The index one past the last token in the interval
53    pub end_index: usize,
54}
55
56impl TokenInterval {
57    /// Create a new token interval
58    pub fn new(start_index: usize, end_index: usize) -> LangExtractResult<Self> {
59        if start_index >= end_index {
60            return Err(LangExtractError::invalid_input(format!(
61                "Start index {} must be < end index {}",
62                start_index, end_index
63            )));
64        }
65        Ok(Self {
66            start_index,
67            end_index,
68        })
69    }
70}
71
72/// Represents a token extracted from text
73#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
74pub struct Token {
75    /// The position of the token in the sequence of tokens
76    pub index: usize,
77    /// The type of the token
78    pub token_type: TokenType,
79    /// The character interval within the original text that this token spans
80    pub char_interval: CharInterval,
81    /// True if the token immediately follows a newline or carriage return
82    pub first_token_after_newline: bool,
83}
84
85impl Token {
86    /// Create a new token
87    pub fn new(
88        index: usize,
89        token_type: TokenType,
90        char_interval: CharInterval,
91        first_token_after_newline: bool,
92    ) -> Self {
93        Self {
94            index,
95            token_type,
96            char_interval,
97            first_token_after_newline,
98        }
99    }
100}
101
102/// Holds the result of tokenizing a text string
103#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104pub struct TokenizedText {
105    /// The original text that was tokenized
106    pub text: String,
107    /// A list of Token objects extracted from the text
108    pub tokens: Vec<Token>,
109}
110
111impl TokenizedText {
112    /// Create a new tokenized text
113    pub fn new(text: String) -> Self {
114        Self {
115            text,
116            tokens: Vec::new(),
117        }
118    }
119
120    /// Get the number of tokens
121    pub fn len(&self) -> usize {
122        self.tokens.len()
123    }
124
125    /// Check if tokenized text is empty
126    pub fn is_empty(&self) -> bool {
127        self.tokens.is_empty()
128    }
129}
130
131/// Text tokenizer for splitting text into tokens
132pub struct Tokenizer {
133    letters_pattern: Regex,
134    digits_pattern: Regex,
135    symbols_pattern: Regex,
136    slash_abbrev_pattern: Regex,
137    token_pattern: Regex,
138    word_pattern: Regex,
139    end_of_sentence_pattern: Regex,
140    known_abbreviations: HashSet<String>,
141}
142
143impl Tokenizer {
144    /// Create a new tokenizer
145    pub fn new() -> LangExtractResult<Self> {
146        // Regex patterns for tokenization (matching Python implementation)
147        let letters_pattern = Regex::new(r"[A-Za-z]+").map_err(|e| {
148            LangExtractError::configuration(format!("Failed to compile letters regex: {}", e))
149        })?;
150
151        let digits_pattern = Regex::new(r"[0-9]+").map_err(|e| {
152            LangExtractError::configuration(format!("Failed to compile digits regex: {}", e))
153        })?;
154
155        let symbols_pattern = Regex::new(r"[^A-Za-z0-9\s]+").map_err(|e| {
156            LangExtractError::configuration(format!("Failed to compile symbols regex: {}", e))
157        })?;
158
159        let slash_abbrev_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+").map_err(|e| {
160            LangExtractError::configuration(format!("Failed to compile slash abbreviation regex: {}", e))
161        })?;
162
163        let token_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+|[A-Za-z]+|[0-9]+|[^A-Za-z0-9\s]+").map_err(|e| {
164            LangExtractError::configuration(format!("Failed to compile token regex: {}", e))
165        })?;
166
167        let word_pattern = Regex::new(r"^(?:[A-Za-z]+|[0-9]+)$").map_err(|e| {
168            LangExtractError::configuration(format!("Failed to compile word regex: {}", e))
169        })?;
170
171        let end_of_sentence_pattern = Regex::new(r"[.?!]$").map_err(|e| {
172            LangExtractError::configuration(format!("Failed to compile end of sentence regex: {}", e))
173        })?;
174
175        // Known abbreviations that should not count as sentence enders
176        let known_abbreviations = [
177            "Mr.", "Mrs.", "Ms.", "Dr.", "Prof.", "St.", "Ave.", "Blvd.", "Rd.", "Ltd.", "Inc.", "Corp.",
178            "vs.", "etc.", "et al.", "i.e.", "e.g.", "cf.", "a.m.", "p.m.", "U.S.", "U.K.", "Ph.D.",
179        ]
180        .iter()
181        .map(|s| s.to_string())
182        .collect();
183
184        Ok(Self {
185            letters_pattern,
186            digits_pattern,
187            symbols_pattern,
188            slash_abbrev_pattern,
189            token_pattern,
190            word_pattern,
191            end_of_sentence_pattern,
192            known_abbreviations,
193        })
194    }
195
196    /// Tokenize text into tokens
197    pub fn tokenize(&self, text: &str) -> LangExtractResult<TokenizedText> {
198        let mut tokenized = TokenizedText::new(text.to_string());
199        let mut previous_end = 0;
200
201        for (token_index, token_match) in self.token_pattern.find_iter(text).enumerate() {
202            let start_pos = token_match.start();
203            let end_pos = token_match.end();
204            let matched_text = token_match.as_str();
205
206            // Check if there's a newline in the gap before this token
207            let first_token_after_newline = if token_index > 0 {
208                let gap = &text[previous_end..start_pos];
209                gap.contains('\n') || gap.contains('\r')
210            } else {
211                false
212            };
213
214            // Classify token type
215            let token_type = self.classify_token(matched_text);
216
217            let token = Token::new(
218                token_index,
219                token_type,
220                CharInterval::new(start_pos, end_pos),
221                first_token_after_newline,
222            );
223
224            tokenized.tokens.push(token);
225            previous_end = end_pos;
226        }
227
228        Ok(tokenized)
229    }
230
231    /// Classify a token's type based on its content
232    fn classify_token(&self, text: &str) -> TokenType {
233        if self.digits_pattern.is_match(text) {
234            TokenType::Number
235        } else if self.slash_abbrev_pattern.is_match(text) {
236            TokenType::Acronym
237        } else if self.word_pattern.is_match(text) {
238            TokenType::Word
239        } else {
240            TokenType::Punctuation
241        }
242    }
243
244    /// Reconstruct text from a token interval
245    pub fn tokens_text(
246        &self,
247        tokenized_text: &TokenizedText,
248        token_interval: &TokenInterval,
249    ) -> LangExtractResult<String> {
250        if token_interval.start_index >= token_interval.end_index {
251            return Err(LangExtractError::invalid_input(format!(
252                "Invalid token interval: start_index={}, end_index={}",
253                token_interval.start_index, token_interval.end_index
254            )));
255        }
256
257        if token_interval.end_index > tokenized_text.tokens.len() {
258            return Err(LangExtractError::invalid_input(format!(
259                "Token interval end_index {} exceeds token count {}",
260                token_interval.end_index,
261                tokenized_text.tokens.len()
262            )));
263        }
264
265        if tokenized_text.tokens.is_empty() {
266            return Ok(String::new());
267        }
268
269        let start_token = &tokenized_text.tokens[token_interval.start_index];
270        let end_token = &tokenized_text.tokens[token_interval.end_index - 1];
271
272        let start_char = start_token.char_interval.start_pos;
273        let end_char = end_token.char_interval.end_pos;
274
275        Ok(tokenized_text.text[start_char..end_char].to_string())
276    }
277
278    /// Check if a punctuation token ends a sentence
279    pub fn is_end_of_sentence_token(
280        &self,
281        text: &str,
282        tokens: &[Token],
283        current_idx: usize,
284    ) -> bool {
285        if current_idx >= tokens.len() {
286            return false;
287        }
288
289        let current_token = &tokens[current_idx];
290        let current_token_text = &text[current_token.char_interval.start_pos..current_token.char_interval.end_pos];
291
292        if self.end_of_sentence_pattern.is_match(current_token_text) {
293            // Check if it's part of a known abbreviation
294            if current_idx > 0 {
295                let prev_token = &tokens[current_idx - 1];
296                let prev_token_text = &text[prev_token.char_interval.start_pos..prev_token.char_interval.end_pos];
297                let combined = format!("{}{}", prev_token_text, current_token_text);
298                
299                if self.known_abbreviations.contains(&combined) {
300                    return false;
301                }
302            }
303            return true;
304        }
305        false
306    }
307
308    /// Check if there's a sentence break after a newline
309    pub fn is_sentence_break_after_newline(
310        &self,
311        text: &str,
312        tokens: &[Token],
313        current_idx: usize,
314    ) -> bool {
315        if current_idx + 1 >= tokens.len() {
316            return false;
317        }
318
319        let current_token = &tokens[current_idx];
320        let next_token = &tokens[current_idx + 1];
321
322        // Check for newline in the gap between tokens
323        let gap_start = current_token.char_interval.end_pos;
324        let gap_end = next_token.char_interval.start_pos;
325        
326        if gap_start >= gap_end {
327            return false;
328        }
329
330        let gap_text = &text[gap_start..gap_end];
331        if !gap_text.contains('\n') {
332            return false;
333        }
334
335        // Check if next token starts with uppercase
336        let next_token_text = &text[next_token.char_interval.start_pos..next_token.char_interval.end_pos];
337        !next_token_text.is_empty() && next_token_text.chars().next().unwrap().is_uppercase()
338    }
339
340    /// Find sentence range starting from a given token index
341    pub fn find_sentence_range(
342        &self,
343        text: &str,
344        tokens: &[Token],
345        start_token_index: usize,
346    ) -> LangExtractResult<TokenInterval> {
347        if start_token_index >= tokens.len() {
348            return Err(LangExtractError::invalid_input(format!(
349                "start_token_index {} out of range. Total tokens: {}",
350                start_token_index,
351                tokens.len()
352            )));
353        }
354
355        let mut i = start_token_index;
356        while i < tokens.len() {
357            if tokens[i].token_type == TokenType::Punctuation {
358                if self.is_end_of_sentence_token(text, tokens, i) {
359                    return TokenInterval::new(start_token_index, i + 1);
360                }
361            }
362            if self.is_sentence_break_after_newline(text, tokens, i) {
363                return TokenInterval::new(start_token_index, i + 1);
364            }
365            i += 1;
366        }
367
368        TokenInterval::new(start_token_index, tokens.len())
369    }
370}
371
372impl Default for Tokenizer {
373    fn default() -> Self {
374        Self::new().expect("Failed to create default tokenizer")
375    }
376}
377
378#[cfg(test)]
379mod tests;
380
381/// Iterator for processing sentences in tokenized text
382pub struct SentenceIterator<'a> {
383    tokenized_text: &'a TokenizedText,
384    tokenizer: &'a Tokenizer,
385    current_token_pos: usize,
386    token_len: usize,
387}
388
389impl<'a> SentenceIterator<'a> {
390    /// Create a new sentence iterator
391    pub fn new(
392        tokenized_text: &'a TokenizedText,
393        tokenizer: &'a Tokenizer,
394        current_token_pos: usize,
395    ) -> LangExtractResult<Self> {
396        let token_len = tokenized_text.tokens.len();
397        
398        if current_token_pos > token_len {
399            return Err(LangExtractError::invalid_input(format!(
400                "Current token position {} is past the length of the document {}",
401                current_token_pos, token_len
402            )));
403        }
404
405        Ok(Self {
406            tokenized_text,
407            tokenizer,
408            current_token_pos,
409            token_len,
410        })
411    }
412}
413
414impl<'a> Iterator for SentenceIterator<'a> {
415    type Item = LangExtractResult<TokenInterval>;
416
417    fn next(&mut self) -> Option<Self::Item> {
418        if self.current_token_pos >= self.token_len {
419            return None;
420        }
421
422        // Find the sentence range starting from current position
423        match self.tokenizer.find_sentence_range(
424            &self.tokenized_text.text,
425            &self.tokenized_text.tokens,
426            self.current_token_pos,
427        ) {
428            Ok(sentence_range) => {
429                // Start the sentence from the current token position.
430                // If we are in the middle of a sentence, we should start from there.
431                let adjusted_range = match TokenInterval::new(
432                    self.current_token_pos,
433                    sentence_range.end_index,
434                ) {
435                    Ok(range) => range,
436                    Err(e) => return Some(Err(e)),
437                };
438
439                self.current_token_pos = sentence_range.end_index;
440                Some(Ok(adjusted_range))
441            }
442            Err(e) => Some(Err(e)),
443        }
444    }
445}
446
447/// Convenience function for creating a tokenizer and tokenizing text
448pub fn tokenize(text: &str) -> LangExtractResult<TokenizedText> {
449    let tokenizer = Tokenizer::new()?;
450    tokenizer.tokenize(text)
451}