rust_tokenizers/vocab/
gpt2_vocab.rs

1// Copyright 2018 The Open AI Team Authors
2// Copyright 2018 The HuggingFace Inc. team.
3// Copyright 2019 Guillaume Becquin
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//     http://www.apache.org/licenses/LICENSE-2.0
8// Unless required by applicable law or agreed to in writing, software
9// distributed under the License is distributed on an "AS IS" BASIS,
10// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11// See the License for the specific language governing permissions and
12// limitations under the License.
13
14use crate::error::TokenizerError;
15use crate::vocab::base_vocab::{
16    read_json_file, read_special_token_mapping_file, swap_key_values, SpecialTokenMap, Vocab,
17};
18use std::collections::HashMap;
19use std::path::Path;
20
21/// # GPT2 Vocab
22/// Vocabulary for GPT2 tokenizer. Contains the following special values:
23/// - BOS token
24/// - EOS token
25///
26/// Expects a JSON-format vocabulary when created from file.
27#[derive(Debug, Clone)]
28pub struct Gpt2Vocab {
29    /// A mapping of tokens as string to indices (i.e. the encoder base)
30    pub values: HashMap<String, i64>,
31
32    /// A mapping of token ids to strings (i.e. the decoder base)
33    pub indices: HashMap<i64, String>,
34
35    /// Special tokens used by the vocabulary
36    pub special_token_map: SpecialTokenMap,
37
38    /// A mapping of special value tokens as strings to IDs (i.e. the encoder base for special
39    /// values), special values typically include things like BOS/EOS markers, class markers, mask
40    /// markers and padding markers
41    pub special_values: HashMap<String, i64>,
42
43    /// A mapping of special value tokens as IDs to strings (i.e. the decoder base for special values)
44    pub special_indices: HashMap<i64, String>,
45}
46
47const DEFAULT_UNK_TOKEN: &str = "<|endoftext|>";
48const DEFAULT_BOS_TOKEN: &str = DEFAULT_UNK_TOKEN;
49const DEFAULT_EOS_TOKEN: &str = DEFAULT_UNK_TOKEN;
50
51impl Gpt2Vocab {
52    pub fn get_bos_value(&self) -> &str {
53        self.special_token_map
54            .bos_token
55            .as_deref()
56            .unwrap_or(DEFAULT_BOS_TOKEN)
57    }
58
59    pub fn get_eos_value(&self) -> &str {
60        self.special_token_map
61            .eos_token
62            .as_deref()
63            .unwrap_or(DEFAULT_EOS_TOKEN)
64    }
65}
66
67impl Vocab for Gpt2Vocab {
68    fn get_unknown_value(&self) -> &str {
69        &self.special_token_map.unk_token
70    }
71
72    fn values(&self) -> &HashMap<String, i64> {
73        &self.values
74    }
75
76    fn indices(&self) -> &HashMap<i64, String> {
77        &self.indices
78    }
79
80    fn special_values(&self) -> &HashMap<String, i64> {
81        &self.special_values
82    }
83
84    fn special_indices(&self) -> &HashMap<i64, String> {
85        &self.special_indices
86    }
87
88    fn values_mut(&mut self) -> &mut HashMap<String, i64> {
89        &mut self.values
90    }
91
92    fn indices_mut(&mut self) -> &mut HashMap<i64, String> {
93        &mut self.indices
94    }
95
96    fn special_values_mut(&mut self) -> &mut HashMap<String, i64> {
97        &mut self.special_values
98    }
99
100    fn special_indices_mut(&mut self) -> &mut HashMap<i64, String> {
101        &mut self.special_indices
102    }
103
104    fn from_file<P: AsRef<Path>>(path: P) -> Result<Gpt2Vocab, TokenizerError> {
105        let values = read_json_file(path)?;
106
107        let special_token_map = SpecialTokenMap {
108            unk_token: DEFAULT_UNK_TOKEN.to_string(),
109            pad_token: None,
110            bos_token: Some(DEFAULT_BOS_TOKEN.to_string()),
111            sep_token: None,
112            cls_token: None,
113            eos_token: Some(DEFAULT_EOS_TOKEN.to_string()),
114            mask_token: None,
115            additional_special_tokens: None,
116        };
117        Self::from_values_and_special_token_map(values, special_token_map)
118    }
119
120    fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
121        path: P,
122        special_token_mapping_path: S,
123    ) -> Result<Self, TokenizerError> {
124        let values = read_json_file(path)?;
125        let special_token_map = read_special_token_mapping_file(special_token_mapping_path)?;
126        Self::from_values_and_special_token_map(values, special_token_map)
127    }
128    fn from_values_and_special_token_map(
129        values: HashMap<String, i64>,
130        special_token_map: SpecialTokenMap,
131    ) -> Result<Self, TokenizerError>
132    where
133        Self: Sized,
134    {
135        let mut special_values = HashMap::new();
136        special_token_map.register_special_values(&values, &mut special_values)?;
137
138        let indices = swap_key_values(&values);
139        let special_indices = swap_key_values(&special_values);
140        Ok(Self {
141            values,
142            indices,
143            special_token_map,
144            special_values,
145            special_indices,
146        })
147    }
148    fn token_to_id(&self, token: &str) -> i64 {
149        self._token_to_id(
150            token,
151            &self.values,
152            &self.special_values,
153            self.get_unknown_value(),
154        )
155    }
156
157    fn id_to_token(&self, id: &i64) -> String {
158        self._id_to_token(
159            id,
160            &self.indices,
161            &self.special_indices,
162            self.get_unknown_value(),
163        )
164    }
165}
166
167//==============================
168// Unit tests
169//==============================
170#[cfg(test)]
171mod tests {
172    extern crate anyhow;
173    use super::*;
174    use std::io::Write;
175
176    #[test]
177    fn test_create_vocab() {
178        //        Given
179        let values: HashMap<String, i64> = HashMap::new();
180        let special_values: HashMap<String, i64> = HashMap::new();
181        let indices: HashMap<i64, String> = HashMap::new();
182        let special_indices: HashMap<i64, String> = HashMap::new();
183        let special_token_map = SpecialTokenMap {
184            unk_token: "<|endoftext|>".to_string(),
185            pad_token: None,
186            bos_token: Some("<|endoftext|>".to_string()),
187            sep_token: None,
188            cls_token: None,
189            eos_token: Some("<|endoftext|>".to_string()),
190            mask_token: None,
191            additional_special_tokens: None,
192        };
193
194        //        When
195        let gpt2_vocab = Gpt2Vocab {
196            values,
197            indices,
198            special_token_map,
199            special_values,
200            special_indices,
201        };
202
203        // Then
204        assert_eq!(gpt2_vocab.get_unknown_value(), "<|endoftext|>");
205        assert_eq!(
206            gpt2_vocab.special_token_map.bos_token.as_ref().unwrap(),
207            "<|endoftext|>"
208        );
209        assert_eq!(
210            gpt2_vocab.special_token_map.eos_token.as_ref().unwrap(),
211            "<|endoftext|>"
212        );
213        assert_eq!(gpt2_vocab.values, *gpt2_vocab.values());
214        assert_eq!(gpt2_vocab.special_values, *gpt2_vocab.special_values());
215    }
216
217    #[test]
218    fn test_create_object_from_file() -> anyhow::Result<()> {
219        //        Given
220        let mut vocab_file = tempfile::NamedTempFile::new()?;
221        write!(
222            vocab_file,
223            "{{\"hello\": 1,\n \"world\": 0,\n \"<|endoftext|>\": 2,\n \"!\": 3\n}}"
224        )?;
225        let path = vocab_file.into_temp_path();
226        let target_values: HashMap<String, i64> = [
227            ("hello".to_owned(), 1),
228            ("world".to_owned(), 0),
229            ("<|endoftext|>".to_owned(), 2),
230            ("!".to_owned(), 3),
231        ]
232        .iter()
233        .cloned()
234        .collect();
235
236        let special_values: HashMap<String, i64> =
237            [("<|endoftext|>".to_owned(), 2)].iter().cloned().collect();
238
239        //        When
240        let gpt2_vocab = Gpt2Vocab::from_file(&path)?;
241
242        //        Then
243        assert_eq!(gpt2_vocab.special_token_map.unk_token, "<|endoftext|>");
244        assert_eq!(gpt2_vocab.values, target_values);
245        assert_eq!(gpt2_vocab.special_values, special_values);
246        drop(path);
247        Ok(())
248    }
249
250    #[test]
251    #[should_panic]
252    fn test_create_object_from_file_without_unknown_token() {
253        //        Given
254        let mut vocab_file = tempfile::NamedTempFile::new().unwrap();
255        write!(vocab_file, "{{\"hello\": 1,\n \"world\": 0,\n \"!\": 3\n}}").unwrap();
256        let path = vocab_file.into_temp_path();
257
258        //        When & Then
259        let _ctrl_vocab = Gpt2Vocab::from_file(&path).unwrap();
260    }
261
262    #[test]
263    fn test_encode_tokens() -> anyhow::Result<()> {
264        //        Given
265        let mut vocab_file = tempfile::NamedTempFile::new()?;
266        write!(
267            vocab_file,
268            "{{\"hello\": 1,\n \"world\": 0,\n \"<|endoftext|>\": 2,\n \"!\": 3\n}}"
269        )?;
270        let path = vocab_file.into_temp_path();
271        let gpt2_vocab = Gpt2Vocab::from_file(&path)?;
272
273        //        When & Then
274        assert_eq!(gpt2_vocab.token_to_id("hello"), 1);
275        assert_eq!(gpt2_vocab.token_to_id("world"), 0);
276        assert_eq!(gpt2_vocab.token_to_id("!"), 3);
277        assert_eq!(gpt2_vocab.token_to_id("<|endoftext|>"), 2);
278        assert_eq!(gpt2_vocab.token_to_id("oov_value"), 2);
279
280        drop(path);
281        Ok(())
282    }
283
284    #[test]
285    fn test_decode_tokens() -> anyhow::Result<()> {
286        //        Given
287        let mut vocab_file = tempfile::NamedTempFile::new()?;
288        write!(
289            vocab_file,
290            "{{\"hello\": 1,\n \"world\": 0,\n \"<|endoftext|>\": 2,\n \"!\": 3\n}}"
291        )?;
292        let path = vocab_file.into_temp_path();
293        let gpt2_vocab = Gpt2Vocab::from_file(&path)?;
294
295        //        When & Then
296        assert_eq!(gpt2_vocab.id_to_token(&(1_i64)), "hello");
297        assert_eq!(gpt2_vocab.id_to_token(&(0_i64)), "world");
298        assert_eq!(gpt2_vocab.id_to_token(&(3_i64)), "!");
299        assert_eq!(gpt2_vocab.id_to_token(&(2_i64)), "<|endoftext|>");
300        drop(path);
301        Ok(())
302    }
303}