tokenizers/tokenizer/
pattern.rs

1use crate::utils::SysRegex;
2use crate::{Offsets, Result};
3use regex::Regex;
4
5/// Pattern used to split a NormalizedString
6pub trait Pattern {
7    /// Slice the given string in a list of pattern match positions, with
8    /// a boolean indicating whether this is a match or not.
9    ///
10    /// This method *must* cover the whole string in its outputs, with
11    /// contiguous ordered slices.
12    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>>;
13}
14
15impl Pattern for char {
16    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
17        let is_char = |c: char| -> bool { c == *self };
18        is_char.find_matches(inside)
19    }
20}
21
22impl Pattern for &str {
23    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
24        if self.is_empty() {
25            // If we try to find the matches with an empty string, just don't match anything
26            return Ok(vec![((0, inside.chars().count()), false)]);
27        }
28
29        let re = Regex::new(&regex::escape(self))?;
30        (&re).find_matches(inside)
31    }
32}
33
34impl Pattern for &String {
35    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
36        let s: &str = self;
37        s.find_matches(inside)
38    }
39}
40
41impl Pattern for &Regex {
42    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
43        if inside.is_empty() {
44            return Ok(vec![((0, 0), false)]);
45        }
46
47        let mut prev = 0;
48        let mut splits = Vec::with_capacity(inside.len());
49        for m in self.find_iter(inside) {
50            if prev != m.start() {
51                splits.push(((prev, m.start()), false));
52            }
53            splits.push(((m.start(), m.end()), true));
54            prev = m.end();
55        }
56        if prev != inside.len() {
57            splits.push(((prev, inside.len()), false))
58        }
59        Ok(splits)
60    }
61}
62
63impl Pattern for &SysRegex {
64    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
65        if inside.is_empty() {
66            return Ok(vec![((0, 0), false)]);
67        }
68
69        let mut prev = 0;
70        let mut splits = Vec::with_capacity(inside.len());
71        for (start, end) in self.find_iter(inside) {
72            if prev != start {
73                splits.push(((prev, start), false));
74            }
75            splits.push(((start, end), true));
76            prev = end;
77        }
78        if prev != inside.len() {
79            splits.push(((prev, inside.len()), false))
80        }
81        Ok(splits)
82    }
83}
84
85impl<F> Pattern for F
86where
87    F: Fn(char) -> bool,
88{
89    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
90        if inside.is_empty() {
91            return Ok(vec![((0, 0), false)]);
92        }
93
94        let mut last_offset = 0;
95        let mut last_seen = 0;
96
97        let mut matches = inside
98            .char_indices()
99            .flat_map(|(b, c)| {
100                last_seen = b + c.len_utf8();
101                if self(c) {
102                    let mut events = Vec::with_capacity(2);
103                    if last_offset < b {
104                        // We need to emit what was before this match
105                        events.push(((last_offset, b), false));
106                    }
107                    events.push(((b, b + c.len_utf8()), true));
108                    last_offset = b + c.len_utf8();
109                    events
110                } else {
111                    vec![]
112                }
113            })
114            .collect::<Vec<_>>();
115
116        // Do not forget the last potential split
117        if last_seen > last_offset {
118            matches.push(((last_offset, last_seen), false));
119        }
120
121        Ok(matches)
122    }
123}
124
125/// Invert the `is_match` flags for the wrapped Pattern. This is useful
126/// for example when we use a regex that matches words instead of a delimiter,
127/// and we want to match the delimiter.
128pub struct Invert<P: Pattern>(pub P);
129impl<P: Pattern> Pattern for Invert<P> {
130    fn find_matches(&self, inside: &str) -> Result<Vec<(Offsets, bool)>> {
131        Ok(self
132            .0
133            .find_matches(inside)?
134            .into_iter()
135            .map(|(offsets, flag)| (offsets, !flag))
136            .collect())
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143    use regex::Regex;
144
145    macro_rules! do_test {
146        ($inside: expr, $pattern: expr => @ERROR) => {
147            assert!($pattern.find_matches($inside).is_err());
148        };
149        ($inside: expr, $pattern: expr => $result: expr) => {
150            assert_eq!($pattern.find_matches($inside).unwrap(), $result);
151            assert_eq!(
152                Invert($pattern).find_matches($inside).unwrap(),
153                $result
154                    .into_iter()
155                    .map(|v: (Offsets, bool)| (v.0, !v.1))
156                    .collect::<Vec<_>>()
157            );
158        };
159    }
160
161    #[test]
162    fn char() {
163        do_test!("aba", 'a' => vec![((0, 1), true), ((1, 2), false), ((2, 3), true)]);
164        do_test!("bbbba", 'a' => vec![((0, 4), false), ((4, 5), true)]);
165        do_test!("aabbb", 'a' => vec![((0, 1), true), ((1, 2), true), ((2, 5), false)]);
166        do_test!("", 'a' => vec![((0, 0), false)]);
167        do_test!("aaa", 'b' => vec![((0, 3), false)]);
168    }
169
170    #[test]
171    fn str() {
172        do_test!("aba", "a" => vec![((0, 1), true), ((1, 2), false), ((2, 3), true)]);
173        do_test!("bbbba", "a" => vec![((0, 4), false), ((4, 5), true)]);
174        do_test!("aabbb", "a" => vec![((0, 1), true), ((1, 2), true), ((2, 5), false)]);
175        do_test!("aabbb", "ab" => vec![((0, 1), false), ((1, 3), true), ((3, 5), false)]);
176        do_test!("aabbab", "ab" =>
177            vec![((0, 1), false), ((1, 3), true), ((3, 4), false), ((4, 6), true)]
178        );
179        do_test!("", "" => vec![((0, 0), false)]);
180        do_test!("aaa", "" => vec![((0, 3), false)]);
181        do_test!("aaa", "b" => vec![((0, 3), false)]);
182    }
183
184    #[test]
185    fn functions() {
186        let is_b = |c| c == 'b';
187        do_test!("aba", is_b => vec![((0, 1), false), ((1, 2), true), ((2, 3), false)]);
188        do_test!("aaaab", is_b => vec![((0, 4), false), ((4, 5), true)]);
189        do_test!("bbaaa", is_b => vec![((0, 1), true), ((1, 2), true), ((2, 5), false)]);
190        do_test!("", is_b => vec![((0, 0), false)]);
191        do_test!("aaa", is_b => vec![((0, 3), false)]);
192    }
193
194    #[test]
195    fn regex() {
196        let is_whitespace = Regex::new(r"\s+").unwrap();
197        do_test!("a   b", &is_whitespace => vec![((0, 1), false), ((1, 4), true), ((4, 5), false)]);
198        do_test!("   a   b   ", &is_whitespace =>
199            vec![((0, 3), true), ((3, 4), false), ((4, 7), true), ((7, 8), false), ((8, 11), true)]
200        );
201        do_test!("", &is_whitespace => vec![((0, 0), false)]);
202        do_test!("𝔾𝕠𝕠𝕕 π•žπ• π•£π•Ÿπ•šπ•Ÿπ•˜", &is_whitespace =>
203            vec![((0, 16), false), ((16, 17), true), ((17, 45), false)]
204        );
205        do_test!("aaa", &is_whitespace => vec![((0, 3), false)]);
206    }
207
208    #[test]
209    fn sys_regex() {
210        let is_whitespace = SysRegex::new(r"\s+").unwrap();
211        do_test!("a   b", &is_whitespace => vec![((0, 1), false), ((1, 4), true), ((4, 5), false)]);
212        do_test!("   a   b   ", &is_whitespace =>
213            vec![((0, 3), true), ((3, 4), false), ((4, 7), true), ((7, 8), false), ((8, 11), true)]
214        );
215        do_test!("", &is_whitespace => vec![((0, 0), false)]);
216        do_test!("𝔾𝕠𝕠𝕕 π•žπ• π•£π•Ÿπ•šπ•Ÿπ•˜", &is_whitespace =>
217            vec![((0, 16), false), ((16, 17), true), ((17, 45), false)]
218        );
219        do_test!("aaa", &is_whitespace => vec![((0, 3), false)]);
220    }
221}