rten_text/
pre_tokenizers.rs

1//! Pre-tokenizers which split text after normalization and before encoding
2//! into token IDs by models.
3
4use std::error::Error;
5use std::fmt;
6
7use fancy_regex::Regex;
8use unicode_categories::UnicodeCategories;
9
10use crate::split::{SliceExt, SplitExt};
11
12/// Errors occuring while constructing a [`PreTokenizer`] or splitting input
13/// using one.
14#[derive(Clone, Debug)]
15pub enum PreTokenizeError {
16    /// An error occurred while constructing a regex from a pattern or
17    /// splitting a string using a regex.
18    RegexError(Box<fancy_regex::Error>),
19}
20
21impl fmt::Display for PreTokenizeError {
22    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23        match self {
24            Self::RegexError(err) => write!(f, "regex failed {}", err),
25        }
26    }
27}
28
29impl Error for PreTokenizeError {
30    fn source(&self) -> Option<&(dyn Error + 'static)> {
31        match self {
32            Self::RegexError(err) => Some(err),
33        }
34    }
35}
36
37impl From<fancy_regex::Error> for PreTokenizeError {
38    fn from(val: fancy_regex::Error) -> Self {
39        PreTokenizeError::RegexError(Box::new(val))
40    }
41}
42
43/// A pre-tokenizer splits input text into chunks ("words") which are then
44/// tokenized by a [`Model`](crate::models::Model) individually.
45pub trait PreTokenizer {
46    /// Split `text` into chunks and return a vector of sub-slices.
47    fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError>;
48}
49
50/// Split into tokens containing either digits or non-digits.
51pub struct Digits {
52    split: Split,
53}
54
55impl Digits {
56    /// Construct a digit splitter.
57    ///
58    /// `individual_digits` specifies whether each digit in a sequence of digits
59    /// should be its own token or not.
60    pub fn new(individual_digits: bool) -> Digits {
61        let pattern = if individual_digits {
62            r"[0-9]|[^0-9]+"
63        } else {
64            r"[0-9]+|[^0-9]+"
65        };
66
67        Digits {
68            split: Split::new(SplitOptions {
69                pattern,
70                invert: true,
71                delimiter: SplitDelimiterBehavior::Remove,
72            })
73            .expect("pattern should be valid"),
74        }
75    }
76}
77
78impl PreTokenizer for Digits {
79    fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
80        self.split.pre_tokenize(text)
81    }
82}
83
84/// Tokenization regex used by GPT-2.
85///
86/// See <https://github.com/openai/tiktoken/blob/main/tiktoken_ext/openai_public.py>.
87pub const GPT2_REGEX: &str =
88    r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
89
90/// Specifies how [`Split`] should handle delimiters between chunks.
91#[derive(Copy, Clone, Debug, Default, PartialEq)]
92pub enum SplitDelimiterBehavior {
93    /// Exclude the delimiter from the output.
94    #[default]
95    Remove,
96
97    /// Add the delimiter to the output as its own chunk.
98    Isolate,
99}
100
101#[derive(Clone, Debug, Default)]
102pub struct SplitOptions<'a> {
103    pub pattern: &'a str,
104    pub delimiter: SplitDelimiterBehavior,
105    pub invert: bool,
106}
107
108/// Split input strings using a pattern.
109pub struct Split {
110    regex: Regex,
111    delimiter: SplitDelimiterBehavior,
112    invert: bool,
113}
114
115impl Split {
116    /// Construct a pre-tokenizer which splits input using a given regex
117    /// pattern.
118    pub fn new(opts: SplitOptions) -> Result<Self, PreTokenizeError> {
119        let SplitOptions {
120            pattern,
121            delimiter,
122            invert,
123        } = opts;
124        let regex = Regex::new(pattern).map_err(|err| PreTokenizeError::RegexError(err.into()))?;
125
126        Ok(Split {
127            regex,
128            delimiter,
129            invert,
130        })
131    }
132
133    /// Split input strings into chunks using the [`GPT2_REGEX`] pattern
134    /// originating from GPT-2 and subsequently used by many other models.
135    ///
136    /// Use [`new`](Self::new) to specify a custom pattern.
137    pub fn gpt2() -> Self {
138        Self::new(SplitOptions {
139            pattern: GPT2_REGEX,
140            delimiter: SplitDelimiterBehavior::Remove,
141            invert: true,
142        })
143        .expect("should be a valid pattern")
144    }
145}
146
147impl PreTokenizer for Split {
148    fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
149        let mut chunks = Vec::new();
150        let mut last_match_end = 0;
151
152        if self.invert {
153            for match_ in self.regex.find_iter(text) {
154                let match_ = match_?;
155
156                match self.delimiter {
157                    SplitDelimiterBehavior::Isolate => {
158                        let delim_text = &text[last_match_end..match_.range().start];
159                        if !delim_text.is_empty() {
160                            chunks.push(delim_text);
161                        }
162                    }
163                    SplitDelimiterBehavior::Remove => {}
164                }
165
166                if !match_.range().is_empty() {
167                    chunks.push(match_.as_str());
168                }
169
170                last_match_end = match_.range().end;
171            }
172        } else {
173            for match_ in self.regex.split(text) {
174                let match_ = match_?;
175                let match_range = text
176                    .as_bytes()
177                    .subslice_offsets(match_.as_bytes())
178                    .expect("should be sub-slice");
179
180                match self.delimiter {
181                    SplitDelimiterBehavior::Isolate => {
182                        let delim_text = &text[last_match_end..match_range.start];
183                        if !delim_text.is_empty() {
184                            chunks.push(delim_text);
185                        }
186                    }
187                    SplitDelimiterBehavior::Remove => {}
188                }
189
190                if !match_.is_empty() {
191                    chunks.push(match_);
192                }
193
194                last_match_end = match_range.end;
195            }
196        }
197
198        match self.delimiter {
199            SplitDelimiterBehavior::Isolate => {
200                let delim_text = &text[last_match_end..];
201                if !delim_text.is_empty() {
202                    chunks.push(delim_text);
203                }
204            }
205            SplitDelimiterBehavior::Remove => {}
206        }
207
208        Ok(chunks)
209    }
210}
211
212/// Pre-tokenizer that implements the pre-tokenization rules used by BERT.
213///
214/// This splits the input into tokens consisting of either punctuation,
215/// white-space or non-punctuation.
216pub struct Bert {}
217
218impl Bert {
219    pub fn new() -> Self {
220        Bert {}
221    }
222}
223
224impl Default for Bert {
225    fn default() -> Self {
226        Self::new()
227    }
228}
229
230impl PreTokenizer for Bert {
231    fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
232        let is_punc_or_space =
233            |ch: char| ch.is_ascii_punctuation() || ch.is_punctuation() || ch.is_whitespace();
234        let words = text.split_keep_delimeters(is_punc_or_space).collect();
235        Ok(words)
236    }
237}
238
239/// Compose a sequence of pre-tokenizers.
240pub struct Sequence {
241    pre_tokenizers: Vec<Box<dyn PreTokenizer>>,
242}
243
244impl Sequence {
245    pub fn from_vec(pre_tokenizers: Vec<Box<dyn PreTokenizer>>) -> Self {
246        Sequence { pre_tokenizers }
247    }
248}
249
250impl PreTokenizer for Sequence {
251    fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
252        let mut chunks = Vec::from([text]);
253        for pre_tokenizer in &self.pre_tokenizers {
254            let mut next_chunks = Vec::new();
255            for chunk in chunks {
256                let sub_chunks = pre_tokenizer.pre_tokenize(chunk)?;
257                next_chunks.extend(sub_chunks);
258            }
259            chunks = next_chunks;
260        }
261        Ok(chunks)
262    }
263}
264
265#[cfg(test)]
266mod tests {
267    use rten_testing::TestCases;
268
269    use super::{
270        Bert, Digits, PreTokenizer, Sequence, Split, SplitDelimiterBehavior, SplitOptions,
271    };
272
273    #[test]
274    fn test_bert() {
275        #[derive(Debug)]
276        struct Case<'a> {
277            input: &'a str,
278            expected: Vec<&'a str>,
279        }
280
281        let cases = [Case {
282            input: "foo. bar baz, meep",
283            expected: ["foo", ".", " ", "bar", " ", "baz", ",", " ", "meep"].into(),
284        }];
285
286        cases.test_each(|case| {
287            let bert = Bert::new();
288            let chunks = bert.pre_tokenize(case.input).unwrap();
289            assert_eq!(chunks, case.expected);
290        })
291    }
292
293    #[test]
294    fn test_digits() {
295        #[derive(Debug)]
296        struct Case<'a> {
297            individual_digits: bool,
298            input: &'a str,
299            expected: Vec<&'a str>,
300        }
301
302        let cases = [
303            // Examples from
304            // https://huggingface.co/docs/tokenizers/en/api/pre-tokenizers#tokenizers.pre_tokenizers.Digits.
305            Case {
306                individual_digits: false,
307                input: "Call 123 please",
308                expected: ["Call ", "123", " please"].into(),
309            },
310            Case {
311                individual_digits: true,
312                input: "Call 123 please",
313                expected: ["Call ", "1", "2", "3", " please"].into(),
314            },
315        ];
316
317        cases.test_each(|case| {
318            let digits = Digits::new(case.individual_digits);
319            let chunks = digits.pre_tokenize(case.input).unwrap();
320            assert_eq!(chunks, case.expected);
321        })
322    }
323
324    #[test]
325    fn test_split() {
326        #[derive(Debug)]
327        struct Case<'a> {
328            opts: SplitOptions<'a>,
329            input: &'a str,
330            expected: Vec<&'a str>,
331        }
332
333        let cases = [
334            // Non-inverted
335            Case {
336                opts: SplitOptions {
337                    pattern: r"\s+",
338                    ..Default::default()
339                },
340                input: "foo bar   baz meep",
341                expected: ["foo", "bar", "baz", "meep"].into(),
342            },
343            Case {
344                opts: SplitOptions {
345                    pattern: r"\s+",
346                    delimiter: SplitDelimiterBehavior::Isolate,
347                    ..Default::default()
348                },
349                input: " foo bar   baz meep ",
350                expected: [" ", "foo", " ", "bar", "   ", "baz", " ", "meep", " "].into(),
351            },
352            // Inverted
353            Case {
354                opts: SplitOptions {
355                    pattern: r"\s+",
356                    invert: true,
357                    ..Default::default()
358                },
359                input: "foo bar   baz meep",
360                expected: [" ", "   ", " "].into(),
361            },
362            Case {
363                opts: SplitOptions {
364                    pattern: r"\s+",
365                    invert: true,
366                    delimiter: SplitDelimiterBehavior::Isolate,
367                    ..Default::default()
368                },
369                input: "foo bar   baz meep",
370                expected: ["foo", " ", "bar", "   ", "baz", " ", "meep"].into(),
371            },
372        ];
373
374        cases.test_each(|case| {
375            let split = Split::new(case.opts.clone()).unwrap();
376            let chunks = split.pre_tokenize(case.input).unwrap();
377            assert_eq!(chunks, case.expected);
378        })
379    }
380
381    #[test]
382    fn test_sequence() {
383        let split_space: Box<dyn PreTokenizer> = Box::new(
384            Split::new(SplitOptions {
385                pattern: r"\s+",
386                ..Default::default()
387            })
388            .unwrap(),
389        );
390        let split_punct = Box::new(
391            Split::new(SplitOptions {
392                pattern: r"\.",
393                ..Default::default()
394            })
395            .unwrap(),
396        );
397        let seq = Sequence::from_vec([split_space, split_punct].into());
398
399        let chunks = seq.pre_tokenize("foo.bar baz meep").unwrap();
400
401        assert_eq!(chunks, ["foo", "bar", "baz", "meep"]);
402    }
403}