Skip to main content

langextract_rust/
tokenizer.rs

1//! Text tokenization functionality.
2//!
3//! Provides methods to split text into regex-based, word-level (and
4//! punctuation-level) tokens. Tokenization is necessary for alignment
5//! between extracted data and the source text and for forming sentence
6//! boundaries for LLM information extraction.
7
8use crate::exceptions::{LangExtractError, LangExtractResult};
9use regex::Regex;
10use serde::{Deserialize, Serialize};
11use std::collections::HashSet;
12
13/// Enumeration of token types produced during tokenization
14#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
15pub enum TokenType {
16    /// Represents an alphabetical word token
17    Word = 0,
18    /// Represents a numeric token
19    Number = 1,
20    /// Represents punctuation characters
21    Punctuation = 2,
22    /// Represents an acronym or slash-delimited abbreviation
23    Acronym = 3,
24}
25
26/// Represents a character interval in text
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
28pub struct TokenCharSpan {
29    /// Starting character index (inclusive)
30    pub start_pos: usize,
31    /// Ending character index (exclusive)
32    pub end_pos: usize,
33}
34
35impl TokenCharSpan {
36    /// Create a new character interval
37    pub fn new(start_pos: usize, end_pos: usize) -> Self {
38        Self { start_pos, end_pos }
39    }
40
41    /// Get the length of the interval
42    pub fn length(&self) -> usize {
43        self.end_pos.saturating_sub(self.start_pos)
44    }
45}
46
47impl From<TokenCharSpan> for crate::data::CharInterval {
48    fn from(span: TokenCharSpan) -> Self {
49        crate::data::CharInterval::new(Some(span.start_pos), Some(span.end_pos))
50    }
51}
52
53/// Represents a token interval over tokens in tokenized text
54#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
55pub struct TokenInterval {
56    /// The index of the first token in the interval
57    pub start_index: usize,
58    /// The index one past the last token in the interval
59    pub end_index: usize,
60}
61
62impl TokenInterval {
63    /// Create a new token interval
64    pub fn new(start_index: usize, end_index: usize) -> LangExtractResult<Self> {
65        if start_index >= end_index {
66            return Err(LangExtractError::invalid_input(format!(
67                "Start index {} must be < end index {}",
68                start_index, end_index
69            )));
70        }
71        Ok(Self {
72            start_index,
73            end_index,
74        })
75    }
76}
77
78/// Represents a token extracted from text
79#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
80pub struct Token {
81    /// The position of the token in the sequence of tokens
82    pub index: usize,
83    /// The type of the token
84    pub token_type: TokenType,
85    /// The character interval within the original text that this token spans
86    pub char_interval: TokenCharSpan,
87    /// True if the token immediately follows a newline or carriage return
88    pub first_token_after_newline: bool,
89}
90
91impl Token {
92    /// Create a new token
93    pub fn new(
94        index: usize,
95        token_type: TokenType,
96        char_interval: TokenCharSpan,
97        first_token_after_newline: bool,
98    ) -> Self {
99        Self {
100            index,
101            token_type,
102            char_interval,
103            first_token_after_newline,
104        }
105    }
106}
107
108/// Holds the result of tokenizing a text string
109#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
110pub struct TokenizedText {
111    /// The original text that was tokenized
112    pub text: String,
113    /// A list of Token objects extracted from the text
114    pub tokens: Vec<Token>,
115}
116
117impl TokenizedText {
118    /// Create a new tokenized text
119    pub fn new(text: String) -> Self {
120        Self {
121            text,
122            tokens: Vec::new(),
123        }
124    }
125
126    /// Get the number of tokens
127    pub fn len(&self) -> usize {
128        self.tokens.len()
129    }
130
131    /// Check if tokenized text is empty
132    pub fn is_empty(&self) -> bool {
133        self.tokens.is_empty()
134    }
135}
136
137/// Text tokenizer for splitting text into tokens
138pub struct Tokenizer {
139    _letters_pattern: Regex,
140    digits_pattern: Regex,
141    _symbols_pattern: Regex,
142    slash_abbrev_pattern: Regex,
143    token_pattern: Regex,
144    word_pattern: Regex,
145    end_of_sentence_pattern: Regex,
146    known_abbreviations: HashSet<String>,
147}
148
149impl Tokenizer {
150    /// Create a new tokenizer
151    pub fn new() -> LangExtractResult<Self> {
152        // Regex patterns for tokenization (matching Python implementation)
153        let letters_pattern = Regex::new(r"[A-Za-z]+").map_err(|e| {
154            LangExtractError::configuration(format!("Failed to compile letters regex: {}", e))
155        })?;
156
157        let digits_pattern = Regex::new(r"[0-9]+").map_err(|e| {
158            LangExtractError::configuration(format!("Failed to compile digits regex: {}", e))
159        })?;
160
161        let symbols_pattern = Regex::new(r"[^A-Za-z0-9\s]+").map_err(|e| {
162            LangExtractError::configuration(format!("Failed to compile symbols regex: {}", e))
163        })?;
164
165        let slash_abbrev_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+").map_err(|e| {
166            LangExtractError::configuration(format!("Failed to compile slash abbreviation regex: {}", e))
167        })?;
168
169        let token_pattern = Regex::new(r"[A-Za-z0-9]+(?:/[A-Za-z0-9]+)+|[A-Za-z]+|[0-9]+|[^A-Za-z0-9\s]+").map_err(|e| {
170            LangExtractError::configuration(format!("Failed to compile token regex: {}", e))
171        })?;
172
173        let word_pattern = Regex::new(r"^(?:[A-Za-z]+|[0-9]+)$").map_err(|e| {
174            LangExtractError::configuration(format!("Failed to compile word regex: {}", e))
175        })?;
176
177        let end_of_sentence_pattern = Regex::new(r"[.?!]$").map_err(|e| {
178            LangExtractError::configuration(format!("Failed to compile end of sentence regex: {}", e))
179        })?;
180
181        // Known abbreviations that should not count as sentence enders
182        let known_abbreviations = [
183            "Mr.", "Mrs.", "Ms.", "Dr.", "Prof.", "St.", "Ave.", "Blvd.", "Rd.", "Ltd.", "Inc.", "Corp.",
184            "vs.", "etc.", "et al.", "i.e.", "e.g.", "cf.", "a.m.", "p.m.", "U.S.", "U.K.", "Ph.D.",
185        ]
186        .iter()
187        .map(|s| s.to_string())
188        .collect();
189
190        Ok(Self {
191            _letters_pattern: letters_pattern,
192            digits_pattern,
193            _symbols_pattern: symbols_pattern,
194            slash_abbrev_pattern,
195            token_pattern,
196            word_pattern,
197            end_of_sentence_pattern,
198            known_abbreviations,
199        })
200    }
201
202    /// Tokenize text into tokens
203    pub fn tokenize(&self, text: &str) -> LangExtractResult<TokenizedText> {
204        let mut tokenized = TokenizedText::new(text.to_string());
205        let mut previous_end = 0;
206
207        for (token_index, token_match) in self.token_pattern.find_iter(text).enumerate() {
208            let start_pos = token_match.start();
209            let end_pos = token_match.end();
210            let matched_text = token_match.as_str();
211
212            // Check if there's a newline in the gap before this token
213            let first_token_after_newline = if token_index > 0 {
214                let gap = &text[previous_end..start_pos];
215                gap.contains('\n') || gap.contains('\r')
216            } else {
217                false
218            };
219
220            // Classify token type
221            let token_type = self.classify_token(matched_text);
222
223            let token = Token::new(
224                token_index,
225                token_type,
226                TokenCharSpan::new(start_pos, end_pos),
227                first_token_after_newline,
228            );
229
230            tokenized.tokens.push(token);
231            previous_end = end_pos;
232        }
233
234        Ok(tokenized)
235    }
236
237    /// Classify a token's type based on its content
238    fn classify_token(&self, text: &str) -> TokenType {
239        if self.digits_pattern.is_match(text) {
240            TokenType::Number
241        } else if self.slash_abbrev_pattern.is_match(text) {
242            TokenType::Acronym
243        } else if self.word_pattern.is_match(text) {
244            TokenType::Word
245        } else {
246            TokenType::Punctuation
247        }
248    }
249
250    /// Reconstruct text from a token interval
251    pub fn tokens_text(
252        &self,
253        tokenized_text: &TokenizedText,
254        token_interval: &TokenInterval,
255    ) -> LangExtractResult<String> {
256        if token_interval.start_index >= token_interval.end_index {
257            return Err(LangExtractError::invalid_input(format!(
258                "Invalid token interval: start_index={}, end_index={}",
259                token_interval.start_index, token_interval.end_index
260            )));
261        }
262
263        if token_interval.end_index > tokenized_text.tokens.len() {
264            return Err(LangExtractError::invalid_input(format!(
265                "Token interval end_index {} exceeds token count {}",
266                token_interval.end_index,
267                tokenized_text.tokens.len()
268            )));
269        }
270
271        if tokenized_text.tokens.is_empty() {
272            return Ok(String::new());
273        }
274
275        let start_token = &tokenized_text.tokens[token_interval.start_index];
276        let end_token = &tokenized_text.tokens[token_interval.end_index - 1];
277
278        let start_char = start_token.char_interval.start_pos;
279        let end_char = end_token.char_interval.end_pos;
280
281        Ok(tokenized_text.text[start_char..end_char].to_string())
282    }
283
284    /// Check if a punctuation token ends a sentence
285    pub fn is_end_of_sentence_token(
286        &self,
287        text: &str,
288        tokens: &[Token],
289        current_idx: usize,
290    ) -> bool {
291        if current_idx >= tokens.len() {
292            return false;
293        }
294
295        let current_token = &tokens[current_idx];
296        let current_token_text = &text[current_token.char_interval.start_pos..current_token.char_interval.end_pos];
297
298        if self.end_of_sentence_pattern.is_match(current_token_text) {
299            // Check if it's part of a known abbreviation
300            if current_idx > 0 {
301                let prev_token = &tokens[current_idx - 1];
302                let prev_token_text = &text[prev_token.char_interval.start_pos..prev_token.char_interval.end_pos];
303                let combined = format!("{}{}", prev_token_text, current_token_text);
304                
305                if self.known_abbreviations.contains(&combined) {
306                    return false;
307                }
308            }
309            return true;
310        }
311        false
312    }
313
314    /// Check if there's a sentence break after a newline
315    pub fn is_sentence_break_after_newline(
316        &self,
317        text: &str,
318        tokens: &[Token],
319        current_idx: usize,
320    ) -> bool {
321        if current_idx + 1 >= tokens.len() {
322            return false;
323        }
324
325        let current_token = &tokens[current_idx];
326        let next_token = &tokens[current_idx + 1];
327
328        // Check for newline in the gap between tokens
329        let gap_start = current_token.char_interval.end_pos;
330        let gap_end = next_token.char_interval.start_pos;
331        
332        if gap_start >= gap_end {
333            return false;
334        }
335
336        let gap_text = &text[gap_start..gap_end];
337        if !gap_text.contains('\n') {
338            return false;
339        }
340
341        // Check if next token starts with uppercase
342        let next_token_text = &text[next_token.char_interval.start_pos..next_token.char_interval.end_pos];
343        !next_token_text.is_empty() && next_token_text.chars().next().unwrap().is_uppercase()
344    }
345
346    /// Find sentence range starting from a given token index
347    pub fn find_sentence_range(
348        &self,
349        text: &str,
350        tokens: &[Token],
351        start_token_index: usize,
352    ) -> LangExtractResult<TokenInterval> {
353        if start_token_index >= tokens.len() {
354            return Err(LangExtractError::invalid_input(format!(
355                "start_token_index {} out of range. Total tokens: {}",
356                start_token_index,
357                tokens.len()
358            )));
359        }
360
361        let mut i = start_token_index;
362        while i < tokens.len() {
363            if tokens[i].token_type == TokenType::Punctuation {
364                if self.is_end_of_sentence_token(text, tokens, i) {
365                    return TokenInterval::new(start_token_index, i + 1);
366                }
367            }
368            if self.is_sentence_break_after_newline(text, tokens, i) {
369                return TokenInterval::new(start_token_index, i + 1);
370            }
371            i += 1;
372        }
373
374        TokenInterval::new(start_token_index, tokens.len())
375    }
376}
377
378impl Default for Tokenizer {
379    fn default() -> Self {
380        Self::new().expect("Failed to create default tokenizer")
381    }
382}
383
384#[cfg(test)]
385mod tests;
386
387/// Iterator for processing sentences in tokenized text
388pub struct SentenceIterator<'a> {
389    tokenized_text: &'a TokenizedText,
390    tokenizer: &'a Tokenizer,
391    current_token_pos: usize,
392    token_len: usize,
393}
394
395impl<'a> SentenceIterator<'a> {
396    /// Create a new sentence iterator
397    pub fn new(
398        tokenized_text: &'a TokenizedText,
399        tokenizer: &'a Tokenizer,
400        current_token_pos: usize,
401    ) -> LangExtractResult<Self> {
402        let token_len = tokenized_text.tokens.len();
403        
404        if current_token_pos > token_len {
405            return Err(LangExtractError::invalid_input(format!(
406                "Current token position {} is past the length of the document {}",
407                current_token_pos, token_len
408            )));
409        }
410
411        Ok(Self {
412            tokenized_text,
413            tokenizer,
414            current_token_pos,
415            token_len,
416        })
417    }
418}
419
420impl<'a> Iterator for SentenceIterator<'a> {
421    type Item = LangExtractResult<TokenInterval>;
422
423    fn next(&mut self) -> Option<Self::Item> {
424        if self.current_token_pos >= self.token_len {
425            return None;
426        }
427
428        // Find the sentence range starting from current position
429        match self.tokenizer.find_sentence_range(
430            &self.tokenized_text.text,
431            &self.tokenized_text.tokens,
432            self.current_token_pos,
433        ) {
434            Ok(sentence_range) => {
435                // Start the sentence from the current token position.
436                // If we are in the middle of a sentence, we should start from there.
437                let adjusted_range = match TokenInterval::new(
438                    self.current_token_pos,
439                    sentence_range.end_index,
440                ) {
441                    Ok(range) => range,
442                    Err(e) => return Some(Err(e)),
443                };
444
445                self.current_token_pos = sentence_range.end_index;
446                Some(Ok(adjusted_range))
447            }
448            Err(e) => Some(Err(e)),
449        }
450    }
451}
452
453/// Convenience function for creating a tokenizer and tokenizing text
454pub fn tokenize(text: &str) -> LangExtractResult<TokenizedText> {
455    let tokenizer = Tokenizer::new()?;
456    tokenizer.tokenize(text)
457}