tokenizers/pre_tokenizers/
metaspace.rs

1use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
2use serde::{de, Deserialize, Deserializer, Serialize};
3
4/// Enum representing options for the metaspace prepending scheme.
5#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
6#[serde(rename_all = "snake_case")]
7pub enum PrependScheme {
8    /// Specifies that the scheme should be prepended only once, on the first split.
9    First,
10    /// Specifies that the space should not be prepended.
11    Never,
12    /// Specifies that the scheme should always be prepended.
13    Always,
14}
15
16#[derive(Debug, Clone, PartialEq, Serialize, Eq)]
17/// Replaces all the whitespaces by the provided meta character and then
18/// splits on this character
19#[serde(tag = "type")]
20pub struct Metaspace {
21    replacement: char,
22    pub prepend_scheme: PrependScheme,
23    pub split: bool,
24    #[serde(skip)]
25    str_rep: String,
26}
27
28impl<'de> Deserialize<'de> for Metaspace {
29    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
30    where
31        D: Deserializer<'de>,
32    {
33        #[derive(Deserialize)]
34        enum Type {
35            Metaspace,
36        }
37
38        fn default_prepend_scheme_value() -> PrependScheme {
39            PrependScheme::Always
40        }
41
42        #[derive(Deserialize)]
43        pub struct MetaspaceHelper {
44            #[serde(rename = "type")]
45            _type: Type,
46            replacement: char,
47
48            pub add_prefix_space: Option<bool>,
49            #[serde(default = "default_prepend_scheme_value")]
50            pub prepend_scheme: PrependScheme,
51            pub split: Option<bool>,
52            #[serde(rename = "str_rep")]
53            _str_rep: Option<String>,
54        }
55
56        let mut helper = MetaspaceHelper::deserialize(deserializer)?;
57        if let Some(false) = helper.add_prefix_space {
58            if helper.prepend_scheme != PrependScheme::Never {
59                return Err(de::Error::custom(
60                    "add_prefix_space does not match declared prepend_scheme",
61                ));
62            }
63            helper.prepend_scheme = PrependScheme::Never;
64        }
65        let instance = Self::new(
66            helper.replacement,
67            helper.prepend_scheme,
68            helper.split.unwrap_or(true),
69        );
70        Ok(instance)
71    }
72}
73
74impl Metaspace {
75    pub fn new(replacement: char, prepend_scheme: PrependScheme, split: bool) -> Self {
76        Self {
77            replacement,
78            str_rep: replacement.to_string(),
79            prepend_scheme,
80            split,
81        }
82    }
83
84    pub fn get_replacement(&self) -> char {
85        self.replacement
86    }
87
88    pub fn set_replacement(&mut self, replacement: char) {
89        self.replacement = replacement;
90        self.str_rep = replacement.to_string();
91    }
92
93    pub fn get_split(&self) -> bool {
94        self.split
95    }
96
97    pub fn set_split(&mut self, split: bool) {
98        self.split = split;
99    }
100
101    pub fn get_prepend_scheme(&self) -> PrependScheme {
102        self.prepend_scheme
103    }
104
105    pub fn set_prepend_scheme(&mut self, scheme: PrependScheme) {
106        self.prepend_scheme = scheme;
107    }
108}
109
110impl Default for Metaspace {
111    fn default() -> Self {
112        Self::new('▁', PrependScheme::Always, true)
113    }
114}
115
116impl PreTokenizer for Metaspace {
117    fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
118        pretokenized.split(|_, mut normalized| {
119            normalized.replace(' ', &self.str_rep)?;
120            match self.prepend_scheme {
121                PrependScheme::Always => {
122                    if !normalized.get().starts_with(self.replacement) {
123                        normalized.prepend(&self.str_rep);
124                    }
125                }
126                PrependScheme::First => {
127                    if !normalized.get().starts_with(self.replacement)
128                        && normalized.offsets_original().0 == 0
129                    {
130                        normalized.prepend(&self.str_rep);
131                    }
132                }
133                PrependScheme::Never => {}
134            };
135            if self.split {
136                normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
137            } else {
138                Ok(vec![normalized])
139            }
140        })
141    }
142}
143
144impl Decoder for Metaspace {
145    fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
146        Ok(tokens
147            .iter()
148            .enumerate()
149            .map(|(i, token)| {
150                token
151                    .chars()
152                    .flat_map(|c| {
153                        if c == self.replacement {
154                            if i == 0 && self.prepend_scheme != PrependScheme::Never {
155                                None
156                            } else {
157                                Some(' ')
158                            }
159                        } else {
160                            Some(c)
161                        }
162                    })
163                    .collect::<String>()
164            })
165            .collect())
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use regex::Regex;
172
173    use super::*;
174    use crate::{OffsetReferential, OffsetType};
175
176    #[test]
177    fn serialization() {
178        let metaspace = Metaspace::new('_', PrependScheme::Always, true);
179        let metaspace_s =
180            r#"{"type":"Metaspace","replacement":"_","prepend_scheme":"always","split":true}"#;
181        assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
182        assert_eq!(
183            serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
184            metaspace
185        );
186
187        // Also check it can deserialize previous versions
188        let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":false,"prepend_scheme":"always"}"#;
189        assert!(serde_json::from_str::<Metaspace>(metaspace_s).is_err(),);
190
191        let metaspace = Metaspace::new('_', PrependScheme::Always, true);
192        let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
193        assert_eq!(
194            serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
195            metaspace
196        );
197
198        let metaspace_parsed: Metaspace = serde_json::from_str(
199            r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#,
200        )
201        .unwrap();
202        assert_eq!(metaspace_parsed, metaspace);
203    }
204
205    #[test]
206    fn basic() {
207        let pretok = Metaspace::new('▁', PrependScheme::Always, true);
208        let mut pretokenized = PreTokenizedString::from("Hey friend!");
209        pretok.pre_tokenize(&mut pretokenized).unwrap();
210        assert_eq!(
211            pretokenized
212                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
213                .into_iter()
214                .map(|(s, o, _)| (s, o))
215                .collect::<Vec<_>>(),
216            vec![("▁Hey", (0, 6)), ("▁friend!", (6, 16))]
217        );
218        assert_eq!(
219            pretokenized
220                .get_splits(OffsetReferential::Original, OffsetType::Byte)
221                .into_iter()
222                .map(|(s, o, _)| (s, o))
223                .collect::<Vec<_>>(),
224            vec![("▁Hey", (0, 3)), ("▁friend!", (3, 11))]
225        );
226    }
227
228    #[test]
229    fn multiple_spaces() {
230        let pretok = Metaspace::new('▁', PrependScheme::Always, true);
231        let mut pretokenized = PreTokenizedString::from("Hey   friend!");
232        pretok.pre_tokenize(&mut pretokenized).unwrap();
233        assert_eq!(
234            pretokenized
235                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
236                .into_iter()
237                .map(|(s, o, _)| (s, o))
238                .collect::<Vec<_>>(),
239            vec![
240                ("▁Hey", (0, 6)),
241                ("▁", (6, 9)),
242                ("▁", (9, 12)),
243                ("▁friend!", (12, 22)),
244            ]
245        );
246        assert_eq!(
247            pretokenized
248                .get_splits(OffsetReferential::Original, OffsetType::Byte)
249                .into_iter()
250                .map(|(s, o, _)| (s, o))
251                .collect::<Vec<_>>(),
252            vec![
253                ("▁Hey", (0, 3)),
254                ("▁", (3, 4)),
255                ("▁", (4, 5)),
256                ("▁friend!", (5, 13)),
257            ]
258        );
259    }
260
261    #[test]
262    fn non_legacy_meta_space() {
263        let mut pretok = Metaspace::new('▁', PrependScheme::Always, true);
264        pretok.set_prepend_scheme(PrependScheme::Always);
265        assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Always, true));
266
267        pretok.set_prepend_scheme(PrependScheme::Never);
268        assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Never, true));
269
270        pretok.set_prepend_scheme(PrependScheme::First);
271        assert_eq!(pretok, Metaspace::new('▁', PrependScheme::First, true));
272
273        let pretok = Metaspace::new('▁', PrependScheme::First, false);
274        let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
275        let re_ref = Regex::new(r"(<s>)").unwrap();
276        pretokenized
277            .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
278            .expect("Bad split");
279
280        pretok.pre_tokenize(&mut pretokenized).unwrap();
281        assert_eq!(
282            pretokenized
283                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
284                .into_iter()
285                .map(|(s, o, _)| (s, o))
286                .collect::<Vec<_>>(),
287            vec![
288                ("▁Hey▁my▁friend▁", (0, 23)),
289                ("<s>", (23, 26)),
290                ("how▁are▁you", (26, 41))
291            ]
292        );
293        let pretok = Metaspace::new('▁', PrependScheme::Always, true);
294        pretok.pre_tokenize(&mut pretokenized).unwrap();
295        assert_eq!(
296            pretokenized
297                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
298                .into_iter()
299                .map(|(s, o, _)| (s, o))
300                .collect::<Vec<_>>(),
301            vec![
302                ("▁Hey", (0, 6)),
303                ("▁my", (6, 11)),
304                ("▁friend", (11, 20)),
305                ("▁", (20, 23)),
306                ("▁<s>", (23, 29)),
307                ("▁how", (29, 35)),
308                ("▁are", (35, 41)),
309                ("▁you", (41, 47))
310            ]
311        );
312
313        let pretok = Metaspace::new('▁', PrependScheme::First, false);
314        let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); // test with prefix
315        pretokenized
316            .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
317            .expect("Bad split");
318        pretok.pre_tokenize(&mut pretokenized).unwrap();
319        assert_eq!(
320            pretokenized
321                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
322                .into_iter()
323                .map(|(s, o, _)| (s, o))
324                .collect::<Vec<_>>(),
325            vec![("▁Hey▁", (0, 9)), ("<s>", (9, 12)), ("how", (12, 15))]
326        );
327
328        let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); // test with many splits
329        pretokenized
330            .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
331            .expect("Bad split");
332        pretok.pre_tokenize(&mut pretokenized).unwrap();
333        assert_eq!(
334            pretokenized
335                .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
336                .into_iter()
337                .map(|(s, o, _)| (s, o))
338                .collect::<Vec<_>>(),
339            vec![
340                ("▁Hey▁", (0, 9)),
341                ("<s>", (9, 12)),
342                ("how▁", (12, 18)),
343                ("<s>", (18, 21)),
344                ("are▁", (21, 27)),
345                ("<s>", (27, 30)),
346                ("▁you", (30, 36))
347            ]
348        );
349    }
350    #[test]
351    fn decode() {
352        let decoder = Metaspace::new('▁', PrependScheme::Always, true);
353        let res = decoder
354            .decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
355            .unwrap();
356        assert_eq!(res, vec!["Hey", " friend!"]);
357
358        let decoder = Metaspace::new('▁', PrependScheme::Never, true);
359        let res = decoder
360            .decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
361            .unwrap();
362        assert_eq!(res, vec![" Hey", " friend!"]);
363    }
364}