1pub mod bert;
2pub mod byte_level;
3pub mod delimiter;
4pub mod digits;
5pub mod metaspace;
6pub mod punctuation;
7pub mod sequence;
8pub mod split;
9pub mod unicode_scripts;
10pub mod whitespace;
11
12use serde::{Deserialize, Deserializer, Serialize};
13
14use crate::pre_tokenizers::bert::BertPreTokenizer;
15use crate::pre_tokenizers::byte_level::ByteLevel;
16use crate::pre_tokenizers::delimiter::CharDelimiterSplit;
17use crate::pre_tokenizers::digits::Digits;
18use crate::pre_tokenizers::metaspace::Metaspace;
19use crate::pre_tokenizers::punctuation::Punctuation;
20use crate::pre_tokenizers::sequence::Sequence;
21use crate::pre_tokenizers::split::Split;
22use crate::pre_tokenizers::unicode_scripts::UnicodeScripts;
23use crate::pre_tokenizers::whitespace::{Whitespace, WhitespaceSplit};
24use crate::{PreTokenizedString, PreTokenizer};
25
26#[derive(Serialize, Clone, Debug, PartialEq)]
27#[serde(untagged)]
28pub enum PreTokenizerWrapper {
29 BertPreTokenizer(BertPreTokenizer),
30 ByteLevel(ByteLevel),
31 Delimiter(CharDelimiterSplit),
32 Metaspace(Metaspace),
33 Whitespace(Whitespace),
34 Sequence(Sequence),
35 Split(Split),
36 Punctuation(Punctuation),
37 WhitespaceSplit(WhitespaceSplit),
38 Digits(Digits),
39 UnicodeScripts(UnicodeScripts),
40}
41
42impl PreTokenizer for PreTokenizerWrapper {
43 fn pre_tokenize(&self, normalized: &mut PreTokenizedString) -> crate::Result<()> {
44 match self {
45 Self::BertPreTokenizer(bpt) => bpt.pre_tokenize(normalized),
46 Self::ByteLevel(bpt) => bpt.pre_tokenize(normalized),
47 Self::Delimiter(dpt) => dpt.pre_tokenize(normalized),
48 Self::Metaspace(mspt) => mspt.pre_tokenize(normalized),
49 Self::Whitespace(wspt) => wspt.pre_tokenize(normalized),
50 Self::Punctuation(tok) => tok.pre_tokenize(normalized),
51 Self::Sequence(tok) => tok.pre_tokenize(normalized),
52 Self::Split(tok) => tok.pre_tokenize(normalized),
53 Self::WhitespaceSplit(wspt) => wspt.pre_tokenize(normalized),
54 Self::Digits(wspt) => wspt.pre_tokenize(normalized),
55 Self::UnicodeScripts(us) => us.pre_tokenize(normalized),
56 }
57 }
58}
59
60impl<'de> Deserialize<'de> for PreTokenizerWrapper {
61 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
62 where
63 D: Deserializer<'de>,
64 {
65 #[derive(Deserialize)]
66 pub struct Tagged {
67 #[serde(rename = "type")]
68 variant: EnumType,
69 #[serde(flatten)]
70 rest: serde_json::Value,
71 }
72 #[derive(Deserialize, Serialize)]
73 pub enum EnumType {
74 BertPreTokenizer,
75 ByteLevel,
76 Delimiter,
77 Metaspace,
78 Whitespace,
79 Sequence,
80 Split,
81 Punctuation,
82 WhitespaceSplit,
83 Digits,
84 UnicodeScripts,
85 }
86
87 #[derive(Deserialize)]
88 #[serde(untagged)]
89 pub enum PreTokenizerHelper {
90 Tagged(Tagged),
91 Legacy(serde_json::Value),
92 }
93
94 #[derive(Deserialize)]
95 #[serde(untagged)]
96 pub enum PreTokenizerUntagged {
97 BertPreTokenizer(BertPreTokenizer),
98 ByteLevel(ByteLevel),
99 Delimiter(CharDelimiterSplit),
100 Metaspace(Metaspace),
101 Whitespace(Whitespace),
102 Sequence(Sequence),
103 Split(Split),
104 Punctuation(Punctuation),
105 WhitespaceSplit(WhitespaceSplit),
106 Digits(Digits),
107 UnicodeScripts(UnicodeScripts),
108 }
109
110 let helper = PreTokenizerHelper::deserialize(deserializer)?;
111
112 Ok(match helper {
113 PreTokenizerHelper::Tagged(pretok) => {
114 let mut values: serde_json::Map<String, serde_json::Value> =
115 serde_json::from_value(pretok.rest).map_err(serde::de::Error::custom)?;
116 values.insert(
117 "type".to_string(),
118 serde_json::to_value(&pretok.variant).map_err(serde::de::Error::custom)?,
119 );
120 let values = serde_json::Value::Object(values);
121 match pretok.variant {
122 EnumType::BertPreTokenizer => PreTokenizerWrapper::BertPreTokenizer(
123 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
124 ),
125 EnumType::ByteLevel => PreTokenizerWrapper::ByteLevel(
126 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
127 ),
128 EnumType::Delimiter => PreTokenizerWrapper::Delimiter(
129 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
130 ),
131 EnumType::Metaspace => PreTokenizerWrapper::Metaspace(
132 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
133 ),
134 EnumType::Whitespace => PreTokenizerWrapper::Whitespace(
135 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
136 ),
137 EnumType::Sequence => PreTokenizerWrapper::Sequence(
138 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
139 ),
140 EnumType::Split => PreTokenizerWrapper::Split(
141 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
142 ),
143 EnumType::Punctuation => PreTokenizerWrapper::Punctuation(
144 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
145 ),
146 EnumType::WhitespaceSplit => PreTokenizerWrapper::WhitespaceSplit(
147 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
148 ),
149 EnumType::Digits => PreTokenizerWrapper::Digits(
150 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
151 ),
152 EnumType::UnicodeScripts => PreTokenizerWrapper::UnicodeScripts(
153 serde_json::from_value(values).map_err(serde::de::Error::custom)?,
154 ),
155 }
156 }
157
158 PreTokenizerHelper::Legacy(value) => {
159 let untagged = serde_json::from_value(value).map_err(serde::de::Error::custom)?;
160 match untagged {
161 PreTokenizerUntagged::BertPreTokenizer(bert) => {
162 PreTokenizerWrapper::BertPreTokenizer(bert)
163 }
164 PreTokenizerUntagged::ByteLevel(byte_level) => {
165 PreTokenizerWrapper::ByteLevel(byte_level)
166 }
167 PreTokenizerUntagged::Delimiter(delimiter) => {
168 PreTokenizerWrapper::Delimiter(delimiter)
169 }
170 PreTokenizerUntagged::Metaspace(metaspace) => {
171 PreTokenizerWrapper::Metaspace(metaspace)
172 }
173 PreTokenizerUntagged::Whitespace(whitespace) => {
174 PreTokenizerWrapper::Whitespace(whitespace)
175 }
176 PreTokenizerUntagged::Sequence(sequence) => {
177 PreTokenizerWrapper::Sequence(sequence)
178 }
179 PreTokenizerUntagged::Split(split) => PreTokenizerWrapper::Split(split),
180 PreTokenizerUntagged::Punctuation(punctuation) => {
181 PreTokenizerWrapper::Punctuation(punctuation)
182 }
183 PreTokenizerUntagged::WhitespaceSplit(whitespace_split) => {
184 PreTokenizerWrapper::WhitespaceSplit(whitespace_split)
185 }
186 PreTokenizerUntagged::Digits(digits) => PreTokenizerWrapper::Digits(digits),
187 PreTokenizerUntagged::UnicodeScripts(unicode_scripts) => {
188 PreTokenizerWrapper::UnicodeScripts(unicode_scripts)
189 }
190 }
191 }
192 })
193 }
194}
195
196impl_enum_from!(BertPreTokenizer, PreTokenizerWrapper, BertPreTokenizer);
197impl_enum_from!(ByteLevel, PreTokenizerWrapper, ByteLevel);
198impl_enum_from!(CharDelimiterSplit, PreTokenizerWrapper, Delimiter);
199impl_enum_from!(Whitespace, PreTokenizerWrapper, Whitespace);
200impl_enum_from!(Punctuation, PreTokenizerWrapper, Punctuation);
201impl_enum_from!(Sequence, PreTokenizerWrapper, Sequence);
202impl_enum_from!(Split, PreTokenizerWrapper, Split);
203impl_enum_from!(Metaspace, PreTokenizerWrapper, Metaspace);
204impl_enum_from!(WhitespaceSplit, PreTokenizerWrapper, WhitespaceSplit);
205impl_enum_from!(Digits, PreTokenizerWrapper, Digits);
206impl_enum_from!(UnicodeScripts, PreTokenizerWrapper, UnicodeScripts);
207
208#[cfg(test)]
209mod tests {
210 use super::metaspace::PrependScheme;
211 use super::*;
212
213 #[test]
214 fn test_deserialize() {
215 let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","str_rep":"▁","add_prefix_space":true}]}"#).unwrap();
216
217 assert_eq!(
218 pre_tokenizer,
219 PreTokenizerWrapper::Sequence(Sequence::new(vec![
220 PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
221 PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
222 ]))
223 );
224
225 let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
226 r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true}"#,
227 )
228 .unwrap();
229
230 assert_eq!(
231 pre_tokenizer,
232 PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
233 );
234
235 let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(r#"{"type":"Sequence","pretokenizers":[{"type":"WhitespaceSplit"},{"type":"Metaspace","replacement":"▁","add_prefix_space":true}]}"#).unwrap();
236
237 assert_eq!(
238 pre_tokenizer,
239 PreTokenizerWrapper::Sequence(Sequence::new(vec![
240 PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {}),
241 PreTokenizerWrapper::Metaspace(Metaspace::new('▁', PrependScheme::Always, true))
242 ]))
243 );
244
245 let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
246 r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"first"}"#,
247 )
248 .unwrap();
249
250 assert_eq!(
251 pre_tokenizer,
252 PreTokenizerWrapper::Metaspace(Metaspace::new(
253 '▁',
254 metaspace::PrependScheme::First,
255 true
256 ))
257 );
258
259 let pre_tokenizer: PreTokenizerWrapper = serde_json::from_str(
260 r#"{"type":"Metaspace","replacement":"▁","add_prefix_space":true, "prepend_scheme":"always"}"#,
261 )
262 .unwrap();
263
264 assert_eq!(
265 pre_tokenizer,
266 PreTokenizerWrapper::Metaspace(Metaspace::new(
267 '▁',
268 metaspace::PrependScheme::Always,
269 true
270 ))
271 );
272 }
273
274 #[test]
275 fn test_deserialize_whitespace_split() {
276 let pre_tokenizer: PreTokenizerWrapper =
277 serde_json::from_str(r#"{"type":"WhitespaceSplit"}"#).unwrap();
278 assert_eq!(
279 pre_tokenizer,
280 PreTokenizerWrapper::WhitespaceSplit(WhitespaceSplit {})
281 );
282 }
283
284 #[test]
285 fn pre_tokenizer_deserialization_no_type() {
286 let json = r#"{"replacement":"▁","add_prefix_space":true, "prepend_scheme":"always"}}"#;
287 let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
288 match reconstructed {
289 Err(err) => assert_eq!(
290 err.to_string(),
291 "data did not match any variant of untagged enum PreTokenizerUntagged"
292 ),
293 _ => panic!("Expected an error here"),
294 }
295
296 let json = r#"{"type":"Metaspace", "replacement":"▁" }"#;
297 let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json).unwrap();
298 assert_eq!(
299 reconstructed,
300 PreTokenizerWrapper::Metaspace(Metaspace::default())
301 );
302
303 let json = r#"{"type":"Metaspace", "add_prefix_space":true }"#;
304 let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
305 match reconstructed {
306 Err(err) => assert_eq!(err.to_string(), "missing field `replacement`"),
307 _ => panic!("Expected an error here"),
308 }
309 let json = r#"{"behavior":"default_split"}"#;
310 let reconstructed = serde_json::from_str::<PreTokenizerWrapper>(json);
311 match reconstructed {
312 Err(err) => assert_eq!(
313 err.to_string(),
314 "data did not match any variant of untagged enum PreTokenizerUntagged"
315 ),
316 _ => panic!("Expected an error here"),
317 }
318 }
319}