1use crate::tokenizer::{Decoder, PreTokenizedString, PreTokenizer, Result, SplitDelimiterBehavior};
2use serde::{de, Deserialize, Deserializer, Serialize};
3
4#[derive(Debug, Clone, PartialEq, Serialize, Eq, Deserialize, Copy)]
6#[serde(rename_all = "snake_case")]
7pub enum PrependScheme {
8 First,
10 Never,
12 Always,
14}
15
16#[derive(Debug, Clone, PartialEq, Serialize, Eq)]
17#[serde(tag = "type")]
20pub struct Metaspace {
21 replacement: char,
22 pub prepend_scheme: PrependScheme,
23 pub split: bool,
24 #[serde(skip)]
25 str_rep: String,
26}
27
28impl<'de> Deserialize<'de> for Metaspace {
29 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
30 where
31 D: Deserializer<'de>,
32 {
33 #[derive(Deserialize)]
34 enum Type {
35 Metaspace,
36 }
37
38 fn default_prepend_scheme_value() -> PrependScheme {
39 PrependScheme::Always
40 }
41
42 #[derive(Deserialize)]
43 pub struct MetaspaceHelper {
44 #[serde(rename = "type")]
45 _type: Type,
46 replacement: char,
47
48 pub add_prefix_space: Option<bool>,
49 #[serde(default = "default_prepend_scheme_value")]
50 pub prepend_scheme: PrependScheme,
51 pub split: Option<bool>,
52 #[serde(rename = "str_rep")]
53 _str_rep: Option<String>,
54 }
55
56 let mut helper = MetaspaceHelper::deserialize(deserializer)?;
57 if let Some(false) = helper.add_prefix_space {
58 if helper.prepend_scheme != PrependScheme::Never {
59 return Err(de::Error::custom(
60 "add_prefix_space does not match declared prepend_scheme",
61 ));
62 }
63 helper.prepend_scheme = PrependScheme::Never;
64 }
65 let instance = Self::new(
66 helper.replacement,
67 helper.prepend_scheme,
68 helper.split.unwrap_or(true),
69 );
70 Ok(instance)
71 }
72}
73
74impl Metaspace {
75 pub fn new(replacement: char, prepend_scheme: PrependScheme, split: bool) -> Self {
76 Self {
77 replacement,
78 str_rep: replacement.to_string(),
79 prepend_scheme,
80 split,
81 }
82 }
83
84 pub fn get_replacement(&self) -> char {
85 self.replacement
86 }
87
88 pub fn set_replacement(&mut self, replacement: char) {
89 self.replacement = replacement;
90 self.str_rep = replacement.to_string();
91 }
92
93 pub fn get_split(&self) -> bool {
94 self.split
95 }
96
97 pub fn set_split(&mut self, split: bool) {
98 self.split = split;
99 }
100
101 pub fn get_prepend_scheme(&self) -> PrependScheme {
102 self.prepend_scheme
103 }
104
105 pub fn set_prepend_scheme(&mut self, scheme: PrependScheme) {
106 self.prepend_scheme = scheme;
107 }
108}
109
110impl Default for Metaspace {
111 fn default() -> Self {
112 Self::new('▁', PrependScheme::Always, true)
113 }
114}
115
116impl PreTokenizer for Metaspace {
117 fn pre_tokenize(&self, pretokenized: &mut PreTokenizedString) -> Result<()> {
118 pretokenized.split(|_, mut normalized| {
119 normalized.replace(' ', &self.str_rep)?;
120 match self.prepend_scheme {
121 PrependScheme::Always => {
122 if !normalized.get().starts_with(self.replacement) {
123 normalized.prepend(&self.str_rep);
124 }
125 }
126 PrependScheme::First => {
127 if !normalized.get().starts_with(self.replacement)
128 && normalized.offsets_original().0 == 0
129 {
130 normalized.prepend(&self.str_rep);
131 }
132 }
133 PrependScheme::Never => {}
134 };
135 if self.split {
136 normalized.split(self.replacement, SplitDelimiterBehavior::MergedWithNext)
137 } else {
138 Ok(vec![normalized])
139 }
140 })
141 }
142}
143
144impl Decoder for Metaspace {
145 fn decode_chain(&self, tokens: Vec<String>) -> Result<Vec<String>> {
146 Ok(tokens
147 .iter()
148 .enumerate()
149 .map(|(i, token)| {
150 token
151 .chars()
152 .flat_map(|c| {
153 if c == self.replacement {
154 if i == 0 && self.prepend_scheme != PrependScheme::Never {
155 None
156 } else {
157 Some(' ')
158 }
159 } else {
160 Some(c)
161 }
162 })
163 .collect::<String>()
164 })
165 .collect())
166 }
167}
168
169#[cfg(test)]
170mod tests {
171 use regex::Regex;
172
173 use super::*;
174 use crate::{OffsetReferential, OffsetType};
175
176 #[test]
177 fn serialization() {
178 let metaspace = Metaspace::new('_', PrependScheme::Always, true);
179 let metaspace_s =
180 r#"{"type":"Metaspace","replacement":"_","prepend_scheme":"always","split":true}"#;
181 assert_eq!(serde_json::to_string(&metaspace).unwrap(), metaspace_s);
182 assert_eq!(
183 serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
184 metaspace
185 );
186
187 let metaspace_s = r#"{"type":"Metaspace","replacement":"_","add_prefix_space":false,"prepend_scheme":"always"}"#;
189 assert!(serde_json::from_str::<Metaspace>(metaspace_s).is_err(),);
190
191 let metaspace = Metaspace::new('_', PrependScheme::Always, true);
192 let metaspace_s = r#"{"type":"Metaspace","str_rep":"_","replacement":"_","add_prefix_space":true,"prepend_scheme":"always"}"#;
193 assert_eq!(
194 serde_json::from_str::<Metaspace>(metaspace_s).unwrap(),
195 metaspace
196 );
197
198 let metaspace_parsed: Metaspace = serde_json::from_str(
199 r#"{"type":"Metaspace","replacement":"_","add_prefix_space":true}"#,
200 )
201 .unwrap();
202 assert_eq!(metaspace_parsed, metaspace);
203 }
204
205 #[test]
206 fn basic() {
207 let pretok = Metaspace::new('▁', PrependScheme::Always, true);
208 let mut pretokenized = PreTokenizedString::from("Hey friend!");
209 pretok.pre_tokenize(&mut pretokenized).unwrap();
210 assert_eq!(
211 pretokenized
212 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
213 .into_iter()
214 .map(|(s, o, _)| (s, o))
215 .collect::<Vec<_>>(),
216 vec![("▁Hey", (0, 6)), ("▁friend!", (6, 16))]
217 );
218 assert_eq!(
219 pretokenized
220 .get_splits(OffsetReferential::Original, OffsetType::Byte)
221 .into_iter()
222 .map(|(s, o, _)| (s, o))
223 .collect::<Vec<_>>(),
224 vec![("▁Hey", (0, 3)), ("▁friend!", (3, 11))]
225 );
226 }
227
228 #[test]
229 fn multiple_spaces() {
230 let pretok = Metaspace::new('▁', PrependScheme::Always, true);
231 let mut pretokenized = PreTokenizedString::from("Hey friend!");
232 pretok.pre_tokenize(&mut pretokenized).unwrap();
233 assert_eq!(
234 pretokenized
235 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
236 .into_iter()
237 .map(|(s, o, _)| (s, o))
238 .collect::<Vec<_>>(),
239 vec![
240 ("▁Hey", (0, 6)),
241 ("▁", (6, 9)),
242 ("▁", (9, 12)),
243 ("▁friend!", (12, 22)),
244 ]
245 );
246 assert_eq!(
247 pretokenized
248 .get_splits(OffsetReferential::Original, OffsetType::Byte)
249 .into_iter()
250 .map(|(s, o, _)| (s, o))
251 .collect::<Vec<_>>(),
252 vec![
253 ("▁Hey", (0, 3)),
254 ("▁", (3, 4)),
255 ("▁", (4, 5)),
256 ("▁friend!", (5, 13)),
257 ]
258 );
259 }
260
261 #[test]
262 fn non_legacy_meta_space() {
263 let mut pretok = Metaspace::new('▁', PrependScheme::Always, true);
264 pretok.set_prepend_scheme(PrependScheme::Always);
265 assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Always, true));
266
267 pretok.set_prepend_scheme(PrependScheme::Never);
268 assert_eq!(pretok, Metaspace::new('▁', PrependScheme::Never, true));
269
270 pretok.set_prepend_scheme(PrependScheme::First);
271 assert_eq!(pretok, Metaspace::new('▁', PrependScheme::First, true));
272
273 let pretok = Metaspace::new('▁', PrependScheme::First, false);
274 let mut pretokenized = PreTokenizedString::from("Hey my friend <s>how▁are you");
275 let re_ref = Regex::new(r"(<s>)").unwrap();
276 pretokenized
277 .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
278 .expect("Bad split");
279
280 pretok.pre_tokenize(&mut pretokenized).unwrap();
281 assert_eq!(
282 pretokenized
283 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
284 .into_iter()
285 .map(|(s, o, _)| (s, o))
286 .collect::<Vec<_>>(),
287 vec![
288 ("▁Hey▁my▁friend▁", (0, 23)),
289 ("<s>", (23, 26)),
290 ("how▁are▁you", (26, 41))
291 ]
292 );
293 let pretok = Metaspace::new('▁', PrependScheme::Always, true);
294 pretok.pre_tokenize(&mut pretokenized).unwrap();
295 assert_eq!(
296 pretokenized
297 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
298 .into_iter()
299 .map(|(s, o, _)| (s, o))
300 .collect::<Vec<_>>(),
301 vec![
302 ("▁Hey", (0, 6)),
303 ("▁my", (6, 11)),
304 ("▁friend", (11, 20)),
305 ("▁", (20, 23)),
306 ("▁<s>", (23, 29)),
307 ("▁how", (29, 35)),
308 ("▁are", (35, 41)),
309 ("▁you", (41, 47))
310 ]
311 );
312
313 let pretok = Metaspace::new('▁', PrependScheme::First, false);
314 let mut pretokenized = PreTokenizedString::from(" Hey <s>how"); pretokenized
316 .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
317 .expect("Bad split");
318 pretok.pre_tokenize(&mut pretokenized).unwrap();
319 assert_eq!(
320 pretokenized
321 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
322 .into_iter()
323 .map(|(s, o, _)| (s, o))
324 .collect::<Vec<_>>(),
325 vec![("▁Hey▁", (0, 9)), ("<s>", (9, 12)), ("how", (12, 15))]
326 );
327
328 let mut pretokenized = PreTokenizedString::from(" Hey <s>how <s>are <s> you"); pretokenized
330 .split(|_, sequence| sequence.split(&re_ref, SplitDelimiterBehavior::Isolated))
331 .expect("Bad split");
332 pretok.pre_tokenize(&mut pretokenized).unwrap();
333 assert_eq!(
334 pretokenized
335 .get_splits(OffsetReferential::Normalized, OffsetType::Byte)
336 .into_iter()
337 .map(|(s, o, _)| (s, o))
338 .collect::<Vec<_>>(),
339 vec![
340 ("▁Hey▁", (0, 9)),
341 ("<s>", (9, 12)),
342 ("how▁", (12, 18)),
343 ("<s>", (18, 21)),
344 ("are▁", (21, 27)),
345 ("<s>", (27, 30)),
346 ("▁you", (30, 36))
347 ]
348 );
349 }
350 #[test]
351 fn decode() {
352 let decoder = Metaspace::new('▁', PrependScheme::Always, true);
353 let res = decoder
354 .decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
355 .unwrap();
356 assert_eq!(res, vec!["Hey", " friend!"]);
357
358 let decoder = Metaspace::new('▁', PrependScheme::Never, true);
359 let res = decoder
360 .decode_chain(vec!["▁Hey".into(), "▁friend!".into()])
361 .unwrap();
362 assert_eq!(res, vec![" Hey", " friend!"]);
363 }
364}