tokenizers/pre_tokenizers/
split.rs

1use crate::utils::SysRegex;
2use serde::{Deserialize, Deserializer, Serialize};
3
4use crate::tokenizer::{
5    pattern::Invert, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior,
6};
7
8/// Represents the different patterns that `Split` can use
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Eq)]
10pub enum SplitPattern {
11    String(String),
12    Regex(String),
13}
14
15impl From<String> for SplitPattern {
16    fn from(v: String) -> Self {
17        Self::String(v)
18    }
19}
20
21impl From<&str> for SplitPattern {
22    fn from(v: &str) -> Self {
23        Self::String(v.to_owned())
24    }
25}
26
27#[derive(Debug, Serialize)]
28#[serde(tag = "type")]
29pub struct Split {
30    pattern: SplitPattern,
31    #[serde(skip)]
32    regex: SysRegex,
33    behavior: SplitDelimiterBehavior,
34    invert: bool,
35}
36
37impl<'de> Deserialize<'de> for Split {
38    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
39    where
40        D: Deserializer<'de>,
41    {
42        #[derive(Deserialize)]
43        enum Type {
44            Split,
45        }
46
47        #[derive(Deserialize)]
48        pub struct SplitHelper {
49            #[serde(rename = "type")]
50            _type: Type,
51            pattern: SplitPattern,
52            behavior: SplitDelimiterBehavior,
53            invert: bool,
54        }
55
56        let helper = SplitHelper::deserialize(deserializer)?;
57        Self::new(helper.pattern, helper.behavior, helper.invert).map_err(serde::de::Error::custom)
58    }
59}
60
61impl Clone for Split {
62    fn clone(&self) -> Self {
63        Self::new(self.pattern.clone(), self.behavior, self.invert).unwrap()
64    }
65}
66
67impl PartialEq for Split {
68    fn eq(&self, other: &Self) -> bool {
69        self.pattern == other.pattern
70            && self.behavior == other.behavior
71            && self.invert == other.invert
72    }
73}
74
75impl Split {
76    pub fn new<I: Into<SplitPattern>>(
77        pattern: I,
78        behavior: SplitDelimiterBehavior,
79        invert: bool,
80    ) -> Result<Self> {
81        let pattern: SplitPattern = pattern.into();
82        let regex = match &pattern {
83            SplitPattern::String(s) => SysRegex::new(&regex::escape(s))?,
84            SplitPattern::Regex(r) => SysRegex::new(r)?,
85        };
86
87        Ok(Self {
88            pattern,
89            regex,
90            behavior,
91            invert,
92        })
93    }
94}
95
96impl PreTokenizer for Split {
97    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
98        if self.invert {
99            pretokenized.split(|_, normalized| normalized.split(Invert(&self.regex), self.behavior))
100        } else {
101            pretokenized.split(|_, normalized| normalized.split(&self.regex, self.behavior))
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use crate::{OffsetReferential, OffsetType, PreTokenizer};
110    use SplitDelimiterBehavior::*;
111
112    #[test]
113    fn basic() {
114        let tests = vec![
115            (
116                Removed,
117                "How are you doing?",
118                vec![
119                    ("How", (0, 3)),
120                    ("are", (4, 7)),
121                    ("you", (8, 11)),
122                    ("doing", (12, 17)),
123                    ("?", (17, 18)),
124                ],
125            ),
126            (
127                Isolated,
128                "How are you doing?",
129                vec![
130                    ("How", (0, 3)),
131                    (" ", (3, 4)),
132                    ("are", (4, 7)),
133                    (" ", (7, 8)),
134                    ("you", (8, 11)),
135                    (" ", (11, 12)),
136                    ("doing", (12, 17)),
137                    ("?", (17, 18)),
138                ],
139            ),
140            (
141                MergedWithPrevious,
142                "How are you doing?",
143                vec![
144                    ("How ", (0, 4)),
145                    ("are ", (4, 8)),
146                    ("you ", (8, 12)),
147                    ("doing", (12, 17)),
148                    ("?", (17, 18)),
149                ],
150            ),
151            (
152                MergedWithNext,
153                "How are you doing?",
154                vec![
155                    ("How", (0, 3)),
156                    (" are", (3, 7)),
157                    (" you", (7, 11)),
158                    (" doing", (11, 17)),
159                    ("?", (17, 18)),
160                ],
161            ),
162            (
163                Contiguous,
164                "How are you doing?",
165                vec![
166                    ("How", (0, 3)),
167                    (" ", (3, 4)),
168                    ("are", (4, 7)),
169                    (" ", (7, 8)),
170                    ("you", (8, 11)),
171                    (" ", (11, 12)),
172                    ("doing?", (12, 18)),
173                ],
174            ),
175        ];
176
177        // use whitespace regex
178        let regex = SplitPattern::Regex(r"\w+|[^\w\s]+".into());
179
180        for (behavior, s, res) in tests {
181            let mut pretokenized = PreTokenizedString::from(s);
182            let pretok = Split::new(regex.clone(), behavior, true).unwrap();
183            pretok.pre_tokenize(&mut pretokenized).unwrap();
184            assert_eq!(
185                pretokenized
186                    .get_splits(OffsetReferential::Original, OffsetType::Byte)
187                    .into_iter()
188                    .map(|(s, o, _)| (s, o))
189                    .collect::<Vec<_>>(),
190                res
191            );
192        }
193    }
194
195    #[test]
196    fn regex_string() {
197        let mut pretok_str_for_regex = PreTokenizedString::from("Hey, man!");
198        let mut pretok_str_for_string = pretok_str_for_regex.clone();
199
200        // pre-tokenizer splits on " " - one from Regex, one from string
201        let pretokenizer_regex = Split::new(
202            SplitPattern::Regex(r"\s+".into()),
203            SplitDelimiterBehavior::Removed,
204            false,
205        )
206        .unwrap();
207        let pretokenizer_string = Split::new(" ", SplitDelimiterBehavior::Removed, false).unwrap();
208
209        pretokenizer_regex
210            .pre_tokenize(&mut pretok_str_for_regex)
211            .unwrap();
212        pretokenizer_string
213            .pre_tokenize(&mut pretok_str_for_string)
214            .unwrap();
215
216        assert_eq!(pretok_str_for_regex, pretok_str_for_string);
217    }
218
219    #[test]
220    fn invert() {
221        let mut pretok_str = PreTokenizedString::from("Hello Hello Hello");
222        let mut pretok_str_for_invert = pretok_str.clone();
223
224        // one pre-tokenizer splits on " " - one splits inverted on "Hello"
225        let pretokenizer = Split::new(" ", SplitDelimiterBehavior::Removed, false).unwrap();
226        let pretokenizer_invert =
227            Split::new("Hello", SplitDelimiterBehavior::Removed, true).unwrap();
228
229        pretokenizer.pre_tokenize(&mut pretok_str).unwrap();
230        pretokenizer_invert
231            .pre_tokenize(&mut pretok_str_for_invert)
232            .unwrap();
233
234        assert_eq!(pretok_str, pretok_str_for_invert);
235    }
236
237    #[test]
238    fn serialization() {
239        use SplitDelimiterBehavior::*;
240
241        let split = Split::new("Hello", Removed, true).unwrap();
242        let split_s =
243            r#"{"type":"Split","pattern":{"String":"Hello"},"behavior":"Removed","invert":true}"#;
244        assert_eq!(serde_json::to_string(&split).unwrap(), split_s);
245        assert_eq!(serde_json::from_str::<Split>(split_s).unwrap(), split);
246
247        let split = Split::new(SplitPattern::Regex(r"\s+".into()), Isolated, false).unwrap();
248        let split_s =
249            r#"{"type":"Split","pattern":{"Regex":"\\s+"},"behavior":"Isolated","invert":false}"#;
250        assert_eq!(serde_json::to_string(&split).unwrap(), split_s);
251        assert_eq!(serde_json::from_str::<Split>(split_s).unwrap(), split);
252    }
253}