rust_tokenizers/vocab/
t5_vocab.rs

1// Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
2// Copyright 2019-2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::error::TokenizerError;
14use crate::vocab::base_vocab::{
15    read_protobuf_file, read_special_token_mapping_file, swap_key_values, SpecialTokenMap,
16};
17use crate::vocab::Vocab;
18use std::collections::HashMap;
19use std::path::Path;
20
21/// # T5 Vocab
22/// Vocabulary for T5 tokenizer. Contains the following special values:
23/// - PAD token
24/// - EOS token
25///
26/// Expects a SentencePiece protobuf file when created from file.
27#[derive(Debug, Clone)]
28pub struct T5Vocab {
29    /// A mapping of tokens as string to indices (i.e. the encoder base)
30    pub values: HashMap<String, i64>,
31
32    /// A mapping of token ids to strings (i.e. the decoder base)
33    pub indices: HashMap<i64, String>,
34
35    /// Special tokens used by the vocabulary
36    pub special_token_map: SpecialTokenMap,
37
38    /// A mapping of special value tokens as strings to IDs (i.e. the encoder base for special
39    /// values), special values typically include things like BOS/EOS markers, class markers, mask
40    /// markers and padding markers
41    pub special_values: HashMap<String, i64>,
42
43    /// A mapping of special value tokens as IDs to strings (i.e. the decoder base for special values)
44    pub special_indices: HashMap<i64, String>,
45}
46
47const DEFAULT_UNK_TOKEN: &str = "<unk>";
48const DEFAULT_PAD_TOKEN: &str = "<pad>";
49const DEFAULT_EOS_TOKEN: &str = "</s>";
50
51impl T5Vocab {
52    pub fn get_pad_value(&self) -> &str {
53        self.special_token_map
54            .pad_token
55            .as_deref()
56            .unwrap_or(DEFAULT_PAD_TOKEN)
57    }
58
59    pub fn get_eos_value(&self) -> &str {
60        self.special_token_map
61            .eos_token
62            .as_deref()
63            .unwrap_or(DEFAULT_EOS_TOKEN)
64    }
65}
66
67impl Vocab for T5Vocab {
68    fn get_unknown_value(&self) -> &str {
69        &self.special_token_map.unk_token
70    }
71
72    fn values(&self) -> &HashMap<String, i64> {
73        &self.values
74    }
75
76    fn indices(&self) -> &HashMap<i64, String> {
77        &self.indices
78    }
79
80    fn special_values(&self) -> &HashMap<String, i64> {
81        &self.special_values
82    }
83
84    fn special_indices(&self) -> &HashMap<i64, String> {
85        &self.special_indices
86    }
87
88    fn values_mut(&mut self) -> &mut HashMap<String, i64> {
89        &mut self.values
90    }
91
92    fn indices_mut(&mut self) -> &mut HashMap<i64, String> {
93        &mut self.indices
94    }
95
96    fn special_values_mut(&mut self) -> &mut HashMap<String, i64> {
97        &mut self.special_values
98    }
99
100    fn special_indices_mut(&mut self) -> &mut HashMap<i64, String> {
101        &mut self.special_indices
102    }
103
104    fn from_file<P: AsRef<Path>>(path: P) -> Result<T5Vocab, TokenizerError> {
105        let values = read_protobuf_file(path)?;
106
107        let special_token_map = SpecialTokenMap {
108            unk_token: DEFAULT_UNK_TOKEN.to_string(),
109            pad_token: Some(DEFAULT_PAD_TOKEN.to_string()),
110            bos_token: None,
111            sep_token: None,
112            cls_token: None,
113            eos_token: Some(DEFAULT_EOS_TOKEN.to_string()),
114            mask_token: None,
115            additional_special_tokens: None,
116        };
117        Self::from_values_and_special_token_map(values, special_token_map)
118    }
119
120    fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
121        path: P,
122        special_token_mapping_path: S,
123    ) -> Result<Self, TokenizerError> {
124        let values = read_protobuf_file(path)?;
125        let special_token_map = read_special_token_mapping_file(special_token_mapping_path)?;
126        Self::from_values_and_special_token_map(values, special_token_map)
127    }
128
129    fn from_values_and_special_token_map(
130        values: HashMap<String, i64>,
131        special_token_map: SpecialTokenMap,
132    ) -> Result<Self, TokenizerError>
133    where
134        Self: Sized,
135    {
136        let mut special_values = HashMap::new();
137        special_token_map.register_special_values(&values, &mut special_values)?;
138
139        let indices = swap_key_values(&values);
140        let special_indices = swap_key_values(&special_values);
141        Ok(Self {
142            values,
143            indices,
144            special_token_map,
145            special_values,
146            special_indices,
147        })
148    }
149
150    fn token_to_id(&self, token: &str) -> i64 {
151        self._token_to_id(
152            token,
153            &self.values,
154            &self.special_values,
155            self.get_unknown_value(),
156        )
157    }
158
159    fn id_to_token(&self, id: &i64) -> String {
160        self._id_to_token(
161            id,
162            &self.indices,
163            &self.special_indices,
164            self.get_unknown_value(),
165        )
166    }
167}