1use std::error::Error;
5use std::fmt;
6
7use fancy_regex::Regex;
8use unicode_categories::UnicodeCategories;
9
10use crate::split::{SliceExt, SplitExt};
11
12#[derive(Clone, Debug)]
15pub enum PreTokenizeError {
16 RegexError(Box<fancy_regex::Error>),
19}
20
21impl fmt::Display for PreTokenizeError {
22 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
23 match self {
24 Self::RegexError(err) => write!(f, "regex failed {}", err),
25 }
26 }
27}
28
29impl Error for PreTokenizeError {
30 fn source(&self) -> Option<&(dyn Error + 'static)> {
31 match self {
32 Self::RegexError(err) => Some(err),
33 }
34 }
35}
36
37impl From<fancy_regex::Error> for PreTokenizeError {
38 fn from(val: fancy_regex::Error) -> Self {
39 PreTokenizeError::RegexError(Box::new(val))
40 }
41}
42
43pub trait PreTokenizer {
46 fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError>;
48}
49
50pub struct Digits {
52 split: Split,
53}
54
55impl Digits {
56 pub fn new(individual_digits: bool) -> Digits {
61 let pattern = if individual_digits {
62 r"[0-9]|[^0-9]+"
63 } else {
64 r"[0-9]+|[^0-9]+"
65 };
66
67 Digits {
68 split: Split::new(SplitOptions {
69 pattern,
70 invert: true,
71 delimiter: SplitDelimiterBehavior::Remove,
72 })
73 .expect("pattern should be valid"),
74 }
75 }
76}
77
78impl PreTokenizer for Digits {
79 fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
80 self.split.pre_tokenize(text)
81 }
82}
83
84pub const GPT2_REGEX: &str =
88 r"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+";
89
90#[derive(Copy, Clone, Debug, Default, PartialEq)]
92pub enum SplitDelimiterBehavior {
93 #[default]
95 Remove,
96
97 Isolate,
99}
100
101#[derive(Clone, Debug, Default)]
102pub struct SplitOptions<'a> {
103 pub pattern: &'a str,
104 pub delimiter: SplitDelimiterBehavior,
105 pub invert: bool,
106}
107
108pub struct Split {
110 regex: Regex,
111 delimiter: SplitDelimiterBehavior,
112 invert: bool,
113}
114
115impl Split {
116 pub fn new(opts: SplitOptions) -> Result<Self, PreTokenizeError> {
119 let SplitOptions {
120 pattern,
121 delimiter,
122 invert,
123 } = opts;
124 let regex = Regex::new(pattern).map_err(|err| PreTokenizeError::RegexError(err.into()))?;
125
126 Ok(Split {
127 regex,
128 delimiter,
129 invert,
130 })
131 }
132
133 pub fn gpt2() -> Self {
138 Self::new(SplitOptions {
139 pattern: GPT2_REGEX,
140 delimiter: SplitDelimiterBehavior::Remove,
141 invert: true,
142 })
143 .expect("should be a valid pattern")
144 }
145}
146
147impl PreTokenizer for Split {
148 fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
149 let mut chunks = Vec::new();
150 let mut last_match_end = 0;
151
152 if self.invert {
153 for match_ in self.regex.find_iter(text) {
154 let match_ = match_?;
155
156 match self.delimiter {
157 SplitDelimiterBehavior::Isolate => {
158 let delim_text = &text[last_match_end..match_.range().start];
159 if !delim_text.is_empty() {
160 chunks.push(delim_text);
161 }
162 }
163 SplitDelimiterBehavior::Remove => {}
164 }
165
166 if !match_.range().is_empty() {
167 chunks.push(match_.as_str());
168 }
169
170 last_match_end = match_.range().end;
171 }
172 } else {
173 for match_ in self.regex.split(text) {
174 let match_ = match_?;
175 let match_range = text
176 .as_bytes()
177 .subslice_offsets(match_.as_bytes())
178 .expect("should be sub-slice");
179
180 match self.delimiter {
181 SplitDelimiterBehavior::Isolate => {
182 let delim_text = &text[last_match_end..match_range.start];
183 if !delim_text.is_empty() {
184 chunks.push(delim_text);
185 }
186 }
187 SplitDelimiterBehavior::Remove => {}
188 }
189
190 if !match_.is_empty() {
191 chunks.push(match_);
192 }
193
194 last_match_end = match_range.end;
195 }
196 }
197
198 match self.delimiter {
199 SplitDelimiterBehavior::Isolate => {
200 let delim_text = &text[last_match_end..];
201 if !delim_text.is_empty() {
202 chunks.push(delim_text);
203 }
204 }
205 SplitDelimiterBehavior::Remove => {}
206 }
207
208 Ok(chunks)
209 }
210}
211
212pub struct Bert {}
217
218impl Bert {
219 pub fn new() -> Self {
220 Bert {}
221 }
222}
223
224impl Default for Bert {
225 fn default() -> Self {
226 Self::new()
227 }
228}
229
230impl PreTokenizer for Bert {
231 fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
232 let is_punc_or_space =
233 |ch: char| ch.is_ascii_punctuation() || ch.is_punctuation() || ch.is_whitespace();
234 let words = text.split_keep_delimeters(is_punc_or_space).collect();
235 Ok(words)
236 }
237}
238
239pub struct Sequence {
241 pre_tokenizers: Vec<Box<dyn PreTokenizer>>,
242}
243
244impl Sequence {
245 pub fn from_vec(pre_tokenizers: Vec<Box<dyn PreTokenizer>>) -> Self {
246 Sequence { pre_tokenizers }
247 }
248}
249
250impl PreTokenizer for Sequence {
251 fn pre_tokenize<'a>(&self, text: &'a str) -> Result<Vec<&'a str>, PreTokenizeError> {
252 let mut chunks = Vec::from([text]);
253 for pre_tokenizer in &self.pre_tokenizers {
254 let mut next_chunks = Vec::new();
255 for chunk in chunks {
256 let sub_chunks = pre_tokenizer.pre_tokenize(chunk)?;
257 next_chunks.extend(sub_chunks);
258 }
259 chunks = next_chunks;
260 }
261 Ok(chunks)
262 }
263}
264
265#[cfg(test)]
266mod tests {
267 use rten_testing::TestCases;
268
269 use super::{
270 Bert, Digits, PreTokenizer, Sequence, Split, SplitDelimiterBehavior, SplitOptions,
271 };
272
273 #[test]
274 fn test_bert() {
275 #[derive(Debug)]
276 struct Case<'a> {
277 input: &'a str,
278 expected: Vec<&'a str>,
279 }
280
281 let cases = [Case {
282 input: "foo. bar baz, meep",
283 expected: ["foo", ".", " ", "bar", " ", "baz", ",", " ", "meep"].into(),
284 }];
285
286 cases.test_each(|case| {
287 let bert = Bert::new();
288 let chunks = bert.pre_tokenize(case.input).unwrap();
289 assert_eq!(chunks, case.expected);
290 })
291 }
292
293 #[test]
294 fn test_digits() {
295 #[derive(Debug)]
296 struct Case<'a> {
297 individual_digits: bool,
298 input: &'a str,
299 expected: Vec<&'a str>,
300 }
301
302 let cases = [
303 Case {
306 individual_digits: false,
307 input: "Call 123 please",
308 expected: ["Call ", "123", " please"].into(),
309 },
310 Case {
311 individual_digits: true,
312 input: "Call 123 please",
313 expected: ["Call ", "1", "2", "3", " please"].into(),
314 },
315 ];
316
317 cases.test_each(|case| {
318 let digits = Digits::new(case.individual_digits);
319 let chunks = digits.pre_tokenize(case.input).unwrap();
320 assert_eq!(chunks, case.expected);
321 })
322 }
323
324 #[test]
325 fn test_split() {
326 #[derive(Debug)]
327 struct Case<'a> {
328 opts: SplitOptions<'a>,
329 input: &'a str,
330 expected: Vec<&'a str>,
331 }
332
333 let cases = [
334 Case {
336 opts: SplitOptions {
337 pattern: r"\s+",
338 ..Default::default()
339 },
340 input: "foo bar baz meep",
341 expected: ["foo", "bar", "baz", "meep"].into(),
342 },
343 Case {
344 opts: SplitOptions {
345 pattern: r"\s+",
346 delimiter: SplitDelimiterBehavior::Isolate,
347 ..Default::default()
348 },
349 input: " foo bar baz meep ",
350 expected: [" ", "foo", " ", "bar", " ", "baz", " ", "meep", " "].into(),
351 },
352 Case {
354 opts: SplitOptions {
355 pattern: r"\s+",
356 invert: true,
357 ..Default::default()
358 },
359 input: "foo bar baz meep",
360 expected: [" ", " ", " "].into(),
361 },
362 Case {
363 opts: SplitOptions {
364 pattern: r"\s+",
365 invert: true,
366 delimiter: SplitDelimiterBehavior::Isolate,
367 ..Default::default()
368 },
369 input: "foo bar baz meep",
370 expected: ["foo", " ", "bar", " ", "baz", " ", "meep"].into(),
371 },
372 ];
373
374 cases.test_each(|case| {
375 let split = Split::new(case.opts.clone()).unwrap();
376 let chunks = split.pre_tokenize(case.input).unwrap();
377 assert_eq!(chunks, case.expected);
378 })
379 }
380
381 #[test]
382 fn test_sequence() {
383 let split_space: Box<dyn PreTokenizer> = Box::new(
384 Split::new(SplitOptions {
385 pattern: r"\s+",
386 ..Default::default()
387 })
388 .unwrap(),
389 );
390 let split_punct = Box::new(
391 Split::new(SplitOptions {
392 pattern: r"\.",
393 ..Default::default()
394 })
395 .unwrap(),
396 );
397 let seq = Sequence::from_vec([split_space, split_punct].into());
398
399 let chunks = seq.pre_tokenize("foo.bar baz meep").unwrap();
400
401 assert_eq!(chunks, ["foo", "bar", "baz", "meep"]);
402 }
403}