rust_tokenizers/vocab/
deberta_v2_vocab.rs

1// Copyright 2020 Microsoft and the HuggingFace Inc. team.
2// Copyright 2019-2020 Guillaume Becquin
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//     http://www.apache.org/licenses/LICENSE-2.0
7// Unless required by applicable law or agreed to in writing, software
8// distributed under the License is distributed on an "AS IS" BASIS,
9// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10// See the License for the specific language governing permissions and
11// limitations under the License.
12
13use crate::error::TokenizerError;
14use crate::vocab::base_vocab::{
15    read_protobuf_file, read_special_token_mapping_file, swap_key_values, SpecialTokenMap,
16};
17use crate::vocab::Vocab;
18use std::collections::HashMap;
19use std::path::Path;
20
21/// # DeBERTaV2Vocab
22/// Vocabulary for DeBERTa (v2) tokenizer. Contains the following special values:
23/// - BOS token
24/// - EOS token
25/// - CLS token
26/// - SEP token
27/// - UNK token
28/// - PAD token
29/// - MASK token
30///
31/// Expects a SentencePiece protobuf file when created from file.
32#[derive(Debug, Clone)]
33pub struct DeBERTaV2Vocab {
34    /// A mapping of tokens as string to indices (i.e. the encoder base)
35    pub values: HashMap<String, i64>,
36
37    /// A mapping of token ids to strings (i.e. the decoder base)
38    pub indices: HashMap<i64, String>,
39
40    /// Special tokens used by the vocabulary
41    pub special_token_map: SpecialTokenMap,
42
43    /// A mapping of special value tokens as strings to IDs (i.e. the encoder base for special
44    /// values), special values typically include things like BOS/EOS markers, class markers, mask
45    /// markers and padding markers
46    pub special_values: HashMap<String, i64>,
47
48    /// A mapping of special value tokens as IDs to strings (i.e. the decoder base for special values)
49    pub special_indices: HashMap<i64, String>,
50}
51
52const DEFAULT_UNK_TOKEN: &str = "[UNK]";
53const DEFAULT_PAD_TOKEN: &str = "[PAD]";
54const DEFAULT_BOS_TOKEN: &str = "[CLS]";
55const DEFAULT_SEP_TOKEN: &str = "[SEP]";
56const DEFAULT_CLS_TOKEN: &str = "[CLS]";
57const DEFAULT_EOS_TOKEN: &str = "[SEP]";
58const DEFAULT_MASK_TOKEN: &str = "[MASK]";
59
60impl DeBERTaV2Vocab {
61    pub fn get_pad_value(&self) -> &str {
62        self.special_token_map
63            .pad_token
64            .as_deref()
65            .unwrap_or(DEFAULT_PAD_TOKEN)
66    }
67
68    pub fn get_bos_value(&self) -> &str {
69        self.special_token_map
70            .bos_token
71            .as_deref()
72            .unwrap_or(DEFAULT_BOS_TOKEN)
73    }
74
75    pub fn get_sep_value(&self) -> &str {
76        self.special_token_map
77            .sep_token
78            .as_deref()
79            .unwrap_or(DEFAULT_SEP_TOKEN)
80    }
81
82    pub fn get_cls_value(&self) -> &str {
83        self.special_token_map
84            .cls_token
85            .as_deref()
86            .unwrap_or(DEFAULT_CLS_TOKEN)
87    }
88
89    pub fn get_eos_value(&self) -> &str {
90        self.special_token_map
91            .eos_token
92            .as_deref()
93            .unwrap_or(DEFAULT_EOS_TOKEN)
94    }
95
96    pub fn get_mask_value(&self) -> &str {
97        self.special_token_map
98            .mask_token
99            .as_deref()
100            .unwrap_or(DEFAULT_MASK_TOKEN)
101    }
102}
103
104impl Vocab for DeBERTaV2Vocab {
105    fn get_unknown_value(&self) -> &str {
106        &self.special_token_map.unk_token
107    }
108
109    fn values(&self) -> &HashMap<String, i64> {
110        &self.values
111    }
112
113    fn indices(&self) -> &HashMap<i64, String> {
114        &self.indices
115    }
116
117    fn special_values(&self) -> &HashMap<String, i64> {
118        &self.special_values
119    }
120
121    fn special_indices(&self) -> &HashMap<i64, String> {
122        &self.special_indices
123    }
124
125    fn values_mut(&mut self) -> &mut HashMap<String, i64> {
126        &mut self.values
127    }
128
129    fn indices_mut(&mut self) -> &mut HashMap<i64, String> {
130        &mut self.indices
131    }
132
133    fn special_values_mut(&mut self) -> &mut HashMap<String, i64> {
134        &mut self.special_values
135    }
136
137    fn special_indices_mut(&mut self) -> &mut HashMap<i64, String> {
138        &mut self.special_indices
139    }
140
141    fn from_file<P: AsRef<Path>>(path: P) -> Result<DeBERTaV2Vocab, TokenizerError> {
142        let mut values = read_protobuf_file(path)?;
143
144        let special_token_map = SpecialTokenMap {
145            unk_token: DEFAULT_UNK_TOKEN.to_string(),
146            pad_token: Some(DEFAULT_PAD_TOKEN.to_string()),
147            bos_token: Some(DEFAULT_BOS_TOKEN.to_string()),
148            sep_token: Some(DEFAULT_SEP_TOKEN.to_string()),
149            cls_token: Some(DEFAULT_CLS_TOKEN.to_string()),
150            eos_token: Some(DEFAULT_EOS_TOKEN.to_string()),
151            mask_token: Some(DEFAULT_MASK_TOKEN.to_string()),
152            additional_special_tokens: None,
153        };
154        if !values.contains_key(special_token_map.mask_token.as_ref().unwrap()) {
155            values.insert(
156                special_token_map.mask_token.as_ref().unwrap().clone(),
157                values.len() as i64,
158            );
159        }
160        Self::from_values_and_special_token_map(values, special_token_map)
161    }
162
163    fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
164        path: P,
165        special_token_mapping_path: S,
166    ) -> Result<Self, TokenizerError> {
167        let mut values = read_protobuf_file(path)?;
168        let special_token_map = read_special_token_mapping_file(special_token_mapping_path)?;
169
170        if let Some(mask_token) = &special_token_map.mask_token {
171            values.insert(mask_token.clone(), values.len() as i64);
172        }
173        Self::from_values_and_special_token_map(values, special_token_map)
174    }
175
176    fn from_values_and_special_token_map(
177        values: HashMap<String, i64>,
178        special_token_map: SpecialTokenMap,
179    ) -> Result<Self, TokenizerError>
180    where
181        Self: Sized,
182    {
183        let mut special_values = HashMap::new();
184        special_token_map.register_special_values(&values, &mut special_values)?;
185
186        let indices = swap_key_values(&values);
187        let special_indices = swap_key_values(&special_values);
188        Ok(Self {
189            values,
190            indices,
191            special_token_map,
192            special_values,
193            special_indices,
194        })
195    }
196
197    fn token_to_id(&self, token: &str) -> i64 {
198        self._token_to_id(
199            token,
200            &self.values,
201            &self.special_values,
202            self.get_unknown_value(),
203        )
204    }
205
206    fn id_to_token(&self, id: &i64) -> String {
207        self._id_to_token(
208            id,
209            &self.indices,
210            &self.special_indices,
211            self.get_unknown_value(),
212        )
213    }
214}