1use crate::error::TokenizerError;
15use crate::vocab::base_vocab::{
16 read_json_file, read_special_token_mapping_file, swap_key_values, SpecialTokenMap, Vocab,
17};
18use std::collections::HashMap;
19use std::path::Path;
20
21#[derive(Debug, Clone)]
28pub struct Gpt2Vocab {
29 pub values: HashMap<String, i64>,
31
32 pub indices: HashMap<i64, String>,
34
35 pub special_token_map: SpecialTokenMap,
37
38 pub special_values: HashMap<String, i64>,
42
43 pub special_indices: HashMap<i64, String>,
45}
46
47const DEFAULT_UNK_TOKEN: &str = "<|endoftext|>";
48const DEFAULT_BOS_TOKEN: &str = DEFAULT_UNK_TOKEN;
49const DEFAULT_EOS_TOKEN: &str = DEFAULT_UNK_TOKEN;
50
51impl Gpt2Vocab {
52 pub fn get_bos_value(&self) -> &str {
53 self.special_token_map
54 .bos_token
55 .as_deref()
56 .unwrap_or(DEFAULT_BOS_TOKEN)
57 }
58
59 pub fn get_eos_value(&self) -> &str {
60 self.special_token_map
61 .eos_token
62 .as_deref()
63 .unwrap_or(DEFAULT_EOS_TOKEN)
64 }
65}
66
67impl Vocab for Gpt2Vocab {
68 fn get_unknown_value(&self) -> &str {
69 &self.special_token_map.unk_token
70 }
71
72 fn values(&self) -> &HashMap<String, i64> {
73 &self.values
74 }
75
76 fn indices(&self) -> &HashMap<i64, String> {
77 &self.indices
78 }
79
80 fn special_values(&self) -> &HashMap<String, i64> {
81 &self.special_values
82 }
83
84 fn special_indices(&self) -> &HashMap<i64, String> {
85 &self.special_indices
86 }
87
88 fn values_mut(&mut self) -> &mut HashMap<String, i64> {
89 &mut self.values
90 }
91
92 fn indices_mut(&mut self) -> &mut HashMap<i64, String> {
93 &mut self.indices
94 }
95
96 fn special_values_mut(&mut self) -> &mut HashMap<String, i64> {
97 &mut self.special_values
98 }
99
100 fn special_indices_mut(&mut self) -> &mut HashMap<i64, String> {
101 &mut self.special_indices
102 }
103
104 fn from_file<P: AsRef<Path>>(path: P) -> Result<Gpt2Vocab, TokenizerError> {
105 let values = read_json_file(path)?;
106
107 let special_token_map = SpecialTokenMap {
108 unk_token: DEFAULT_UNK_TOKEN.to_string(),
109 pad_token: None,
110 bos_token: Some(DEFAULT_BOS_TOKEN.to_string()),
111 sep_token: None,
112 cls_token: None,
113 eos_token: Some(DEFAULT_EOS_TOKEN.to_string()),
114 mask_token: None,
115 additional_special_tokens: None,
116 };
117 Self::from_values_and_special_token_map(values, special_token_map)
118 }
119
120 fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
121 path: P,
122 special_token_mapping_path: S,
123 ) -> Result<Self, TokenizerError> {
124 let values = read_json_file(path)?;
125 let special_token_map = read_special_token_mapping_file(special_token_mapping_path)?;
126 Self::from_values_and_special_token_map(values, special_token_map)
127 }
128 fn from_values_and_special_token_map(
129 values: HashMap<String, i64>,
130 special_token_map: SpecialTokenMap,
131 ) -> Result<Self, TokenizerError>
132 where
133 Self: Sized,
134 {
135 let mut special_values = HashMap::new();
136 special_token_map.register_special_values(&values, &mut special_values)?;
137
138 let indices = swap_key_values(&values);
139 let special_indices = swap_key_values(&special_values);
140 Ok(Self {
141 values,
142 indices,
143 special_token_map,
144 special_values,
145 special_indices,
146 })
147 }
148 fn token_to_id(&self, token: &str) -> i64 {
149 self._token_to_id(
150 token,
151 &self.values,
152 &self.special_values,
153 self.get_unknown_value(),
154 )
155 }
156
157 fn id_to_token(&self, id: &i64) -> String {
158 self._id_to_token(
159 id,
160 &self.indices,
161 &self.special_indices,
162 self.get_unknown_value(),
163 )
164 }
165}
166
167#[cfg(test)]
171mod tests {
172 extern crate anyhow;
173 use super::*;
174 use std::io::Write;
175
176 #[test]
177 fn test_create_vocab() {
178 let values: HashMap<String, i64> = HashMap::new();
180 let special_values: HashMap<String, i64> = HashMap::new();
181 let indices: HashMap<i64, String> = HashMap::new();
182 let special_indices: HashMap<i64, String> = HashMap::new();
183 let special_token_map = SpecialTokenMap {
184 unk_token: "<|endoftext|>".to_string(),
185 pad_token: None,
186 bos_token: Some("<|endoftext|>".to_string()),
187 sep_token: None,
188 cls_token: None,
189 eos_token: Some("<|endoftext|>".to_string()),
190 mask_token: None,
191 additional_special_tokens: None,
192 };
193
194 let gpt2_vocab = Gpt2Vocab {
196 values,
197 indices,
198 special_token_map,
199 special_values,
200 special_indices,
201 };
202
203 assert_eq!(gpt2_vocab.get_unknown_value(), "<|endoftext|>");
205 assert_eq!(
206 gpt2_vocab.special_token_map.bos_token.as_ref().unwrap(),
207 "<|endoftext|>"
208 );
209 assert_eq!(
210 gpt2_vocab.special_token_map.eos_token.as_ref().unwrap(),
211 "<|endoftext|>"
212 );
213 assert_eq!(gpt2_vocab.values, *gpt2_vocab.values());
214 assert_eq!(gpt2_vocab.special_values, *gpt2_vocab.special_values());
215 }
216
217 #[test]
218 fn test_create_object_from_file() -> anyhow::Result<()> {
219 let mut vocab_file = tempfile::NamedTempFile::new()?;
221 write!(
222 vocab_file,
223 "{{\"hello\": 1,\n \"world\": 0,\n \"<|endoftext|>\": 2,\n \"!\": 3\n}}"
224 )?;
225 let path = vocab_file.into_temp_path();
226 let target_values: HashMap<String, i64> = [
227 ("hello".to_owned(), 1),
228 ("world".to_owned(), 0),
229 ("<|endoftext|>".to_owned(), 2),
230 ("!".to_owned(), 3),
231 ]
232 .iter()
233 .cloned()
234 .collect();
235
236 let special_values: HashMap<String, i64> =
237 [("<|endoftext|>".to_owned(), 2)].iter().cloned().collect();
238
239 let gpt2_vocab = Gpt2Vocab::from_file(&path)?;
241
242 assert_eq!(gpt2_vocab.special_token_map.unk_token, "<|endoftext|>");
244 assert_eq!(gpt2_vocab.values, target_values);
245 assert_eq!(gpt2_vocab.special_values, special_values);
246 drop(path);
247 Ok(())
248 }
249
250 #[test]
251 #[should_panic]
252 fn test_create_object_from_file_without_unknown_token() {
253 let mut vocab_file = tempfile::NamedTempFile::new().unwrap();
255 write!(vocab_file, "{{\"hello\": 1,\n \"world\": 0,\n \"!\": 3\n}}").unwrap();
256 let path = vocab_file.into_temp_path();
257
258 let _ctrl_vocab = Gpt2Vocab::from_file(&path).unwrap();
260 }
261
262 #[test]
263 fn test_encode_tokens() -> anyhow::Result<()> {
264 let mut vocab_file = tempfile::NamedTempFile::new()?;
266 write!(
267 vocab_file,
268 "{{\"hello\": 1,\n \"world\": 0,\n \"<|endoftext|>\": 2,\n \"!\": 3\n}}"
269 )?;
270 let path = vocab_file.into_temp_path();
271 let gpt2_vocab = Gpt2Vocab::from_file(&path)?;
272
273 assert_eq!(gpt2_vocab.token_to_id("hello"), 1);
275 assert_eq!(gpt2_vocab.token_to_id("world"), 0);
276 assert_eq!(gpt2_vocab.token_to_id("!"), 3);
277 assert_eq!(gpt2_vocab.token_to_id("<|endoftext|>"), 2);
278 assert_eq!(gpt2_vocab.token_to_id("oov_value"), 2);
279
280 drop(path);
281 Ok(())
282 }
283
284 #[test]
285 fn test_decode_tokens() -> anyhow::Result<()> {
286 let mut vocab_file = tempfile::NamedTempFile::new()?;
288 write!(
289 vocab_file,
290 "{{\"hello\": 1,\n \"world\": 0,\n \"<|endoftext|>\": 2,\n \"!\": 3\n}}"
291 )?;
292 let path = vocab_file.into_temp_path();
293 let gpt2_vocab = Gpt2Vocab::from_file(&path)?;
294
295 assert_eq!(gpt2_vocab.id_to_token(&(1_i64)), "hello");
297 assert_eq!(gpt2_vocab.id_to_token(&(0_i64)), "world");
298 assert_eq!(gpt2_vocab.id_to_token(&(3_i64)), "!");
299 assert_eq!(gpt2_vocab.id_to_token(&(2_i64)), "<|endoftext|>");
300 drop(path);
301 Ok(())
302 }
303}