Skip to main content

rust_transformers/preprocessing/vocab/
base_vocab.rs

1// Copyright 2019 Guillaume Becquin
2// Licensed under the Apache License, Version 2.0 (the "License");
3// you may not use this file except in compliance with the License.
4// You may obtain a copy of the License at
5//     http://www.apache.org/licenses/LICENSE-2.0
6// Unless required by applicable law or agreed to in writing, software
7// distributed under the License is distributed on an "AS IS" BASIS,
8// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9// See the License for the specific language governing permissions and
10// limitations under the License.
11
12
13use std::collections::HashMap;
14use std::fs::File;
15use std::io::{BufReader, BufRead};
16use std::error::Error;
17use std::process;
18use std::hash::Hash;
19
20pub fn swap_key_values<T: Clone, U: Hash + Eq + Copy>(input_hashmap: &HashMap<T, U>) -> HashMap<U, T> {
21    input_hashmap
22        .into_iter()
23        .map(|(key, &value)| (value.clone(), key.clone()))
24        .collect()
25}
26
27
28pub trait Vocab {
29    fn unknown_value() -> &'static str;
30
31    fn values(&self) -> &HashMap<String, i64>;
32
33    fn indices(&self) -> &HashMap<i64, String>;
34
35    fn special_values(&self) -> &HashMap<String, i64>;
36
37    fn special_indices(&self) -> &HashMap<i64, String>;
38
39    fn from_file(path: &str) -> Self;
40
41    fn read_vocab_file(path: &str) -> HashMap<String, i64> {
42        let f = File::open(path).expect("Could not open vocabulary file.");
43        let br = BufReader::new(f);
44        let mut data = HashMap::new();
45        let mut index = 0;
46
47        for line in br.lines() {
48            data.insert(line.unwrap().trim().to_owned(), index);
49            index += 1;
50        };
51        data
52    }
53
54    fn _token_to_id(&self,
55                    token: &str,
56                    values: &HashMap<String, i64>,
57                    special_values: &HashMap<String, i64>,
58                    unknown_value: &str) -> Result<i64, Box<dyn Error>> {
59        match special_values.get(token) {
60            Some(index) => Ok(*index),
61            None => match values.get(token) {
62                Some(index) => Ok(*index),
63                None => match values.get(unknown_value) {
64                    Some(index) => Ok(*index),
65                    None => Err("Could not decode token".into())
66                }
67            }
68        }
69    }
70
71    fn _id_to_token(&self,
72                    id: &i64,
73                    indices: &HashMap<i64, String>,
74                    special_indices: &HashMap<i64, String>,
75                    unknown_value: &str) -> Result<String, Box<dyn Error>> {
76        match special_indices.get(id) {
77            Some(token) => Ok(token.clone()),
78            None => match indices.get(id) {
79                Some(token) => Ok(token.clone()),
80                None => Ok(unknown_value.to_owned())
81            }
82        }
83    }
84
85    fn _register_as_special_value(token: &str,
86                                  values: &HashMap<String, i64>,
87                                  special_values: &mut HashMap<String, i64>) {
88        let token_id = match values.get(token) {
89            Some(index) => *index,
90            None => panic!("The special value {} could not be found in the vocabulary", token)
91        };
92        special_values.insert(String::from(token), token_id);
93    }
94
95    fn token_to_id(&self, token: &str) -> i64;
96
97    fn id_to_token(&self, id: &i64) -> String;
98
99    fn convert_tokens_to_ids(&self, tokens: Vec<&str>) -> Vec<i64> {
100        tokens.iter().map(|v| self.token_to_id(v)).collect()
101    }
102}
103
104
105pub struct BaseVocab {
106    pub values: HashMap<String, i64>,
107    pub indices: HashMap<i64, String>,
108    pub unknown_value: &'static str,
109    pub special_values: HashMap<String, i64>,
110    pub special_indices: HashMap<i64, String>,
111}
112
113impl Vocab for BaseVocab {
114    fn unknown_value() -> &'static str { "[UNK]" }
115
116    fn values(&self) -> &HashMap<String, i64> {
117        &self.values
118    }
119
120    fn indices(&self) -> &HashMap<i64, String> {
121        &self.indices
122    }
123
124    fn special_values(&self) -> &HashMap<String, i64> {
125        &self.special_values
126    }
127
128    fn special_indices(&self) -> &HashMap<i64, String> {
129        &self.special_indices
130    }
131
132    fn from_file(path: &str) -> BaseVocab {
133        let values = BaseVocab::read_vocab_file(path);
134        let mut special_values = HashMap::new();
135        let unknown_value = BaseVocab::unknown_value();
136        BaseVocab::_register_as_special_value(unknown_value, &values, &mut special_values);
137
138        let indices = swap_key_values(&values);
139        let special_indices = swap_key_values(&special_values);
140
141        BaseVocab { values, indices, unknown_value, special_values, special_indices }
142    }
143
144    fn token_to_id(&self, token: &str) -> i64 {
145        match self._token_to_id(token, &self.values, &self.special_values, &self.unknown_value) {
146            Ok(index) => index,
147            Err(err) => {
148                println!("{}", err);
149                process::exit(1);
150            }
151        }
152    }
153
154    fn id_to_token(&self, id: &i64) -> String {
155        match self._id_to_token(&id, &self.indices, &self.special_indices, &self.unknown_value) {
156            Ok(token) => token,
157            Err(err) => {
158                println!("{}", err);
159                process::exit(1);
160            }
161        }
162    }
163}
164
165//==============================
166// Unit tests
167//==============================
168#[cfg(test)]
169mod tests {
170    use super::*;
171    use std::io;
172    use std::io::Write;
173
174    #[test]
175    fn test_create_object() {
176//        Given
177        let values: HashMap<String, i64> = HashMap::new();
178        let special_values: HashMap<String, i64> = HashMap::new();
179        let indices: HashMap<i64, String> = HashMap::new();
180        let special_indices: HashMap<i64, String> = HashMap::new();
181        let unknown_value = BaseVocab::unknown_value();
182
183//        When
184        let base_vocab = BaseVocab {
185            values,
186            indices,
187            unknown_value,
188            special_values,
189            special_indices,
190        };
191
192//        Then
193        assert_eq!(base_vocab.unknown_value, "[UNK]");
194        assert_eq!(base_vocab.unknown_value, BaseVocab::unknown_value());
195        assert_eq!(base_vocab.values, *base_vocab.values());
196        assert_eq!(base_vocab.special_values, *base_vocab.special_values());
197    }
198
199    #[test]
200    fn test_create_object_from_file() -> Result<(), io::Error> {
201//        Given
202        let mut vocab_file = tempfile::NamedTempFile::new()?;
203        write!(vocab_file, "hello \n world \n [UNK] \n !")?;
204        let path = vocab_file.into_temp_path();
205        let target_values: HashMap<String, i64> = [
206            ("hello".to_owned(), 0),
207            ("world".to_owned(), 1),
208            ("[UNK]".to_owned(), 2),
209            ("!".to_owned(), 3)
210        ].iter().cloned().collect();
211
212        let special_values: HashMap<String, i64> = [
213            ("[UNK]".to_owned(), 2)
214        ].iter().cloned().collect();
215
216//        When
217        let base_vocab = BaseVocab::from_file(path.to_path_buf().to_str().unwrap());
218
219//        Then
220        assert_eq!(base_vocab.unknown_value, "[UNK]");
221        assert_eq!(base_vocab.values, target_values);
222        assert_eq!(base_vocab.special_values, special_values);
223        drop(path);
224        Ok(())
225    }
226
227    #[test]
228    #[should_panic]
229    fn test_create_object_from_file_without_unknown_token() {
230//        Given
231        let mut vocab_file = tempfile::NamedTempFile::new().unwrap();
232        write!(vocab_file, "hello \n world \n !").unwrap();
233        let path = vocab_file.into_temp_path();
234
235//        When & Then
236        let _base_vocab = BaseVocab::from_file(path.to_path_buf().to_str().unwrap());
237    }
238
239    #[test]
240    fn test_encode_tokens() -> Result<(), io::Error> {
241//        Given
242        let mut vocab_file = tempfile::NamedTempFile::new()?;
243        write!(vocab_file, "hello \n world \n [UNK] \n !")?;
244        let path = vocab_file.into_temp_path();
245        let base_vocab = BaseVocab::from_file(path.to_path_buf().to_str().unwrap());
246
247//        When & Then
248        assert_eq!(base_vocab.token_to_id("hello"), 0);
249        assert_eq!(base_vocab.token_to_id("world"), 1);
250        assert_eq!(base_vocab.token_to_id("!"), 3);
251        assert_eq!(base_vocab.token_to_id("[UNK]"), 2);
252        assert_eq!(base_vocab.token_to_id("oov_value"), 2);
253
254        drop(path);
255        Ok(())
256    }
257
258    #[test]
259    fn test_decode_tokens() -> Result<(), io::Error> {
260//        Given
261        let mut vocab_file = tempfile::NamedTempFile::new()?;
262        write!(vocab_file, "hello \n world \n [UNK] \n !")?;
263        let path = vocab_file.into_temp_path();
264        let base_vocab = BaseVocab::from_file(path.to_path_buf().to_str().unwrap());
265
266//        When & Then
267        assert_eq!(base_vocab.id_to_token(&(0 as i64)), "hello");
268        assert_eq!(base_vocab.id_to_token(&(1 as i64)), "world");
269        assert_eq!(base_vocab.id_to_token(&(3 as i64)), "!");
270        assert_eq!(base_vocab.id_to_token(&(2 as i64)), "[UNK]");
271
272        drop(path);
273        Ok(())
274    }
275}