rust_tokenizers/vocab/
deberta_v2_vocab.rs1use crate::error::TokenizerError;
14use crate::vocab::base_vocab::{
15 read_protobuf_file, read_special_token_mapping_file, swap_key_values, SpecialTokenMap,
16};
17use crate::vocab::Vocab;
18use std::collections::HashMap;
19use std::path::Path;
20
21#[derive(Debug, Clone)]
33pub struct DeBERTaV2Vocab {
34 pub values: HashMap<String, i64>,
36
37 pub indices: HashMap<i64, String>,
39
40 pub special_token_map: SpecialTokenMap,
42
43 pub special_values: HashMap<String, i64>,
47
48 pub special_indices: HashMap<i64, String>,
50}
51
52const DEFAULT_UNK_TOKEN: &str = "[UNK]";
53const DEFAULT_PAD_TOKEN: &str = "[PAD]";
54const DEFAULT_BOS_TOKEN: &str = "[CLS]";
55const DEFAULT_SEP_TOKEN: &str = "[SEP]";
56const DEFAULT_CLS_TOKEN: &str = "[CLS]";
57const DEFAULT_EOS_TOKEN: &str = "[SEP]";
58const DEFAULT_MASK_TOKEN: &str = "[MASK]";
59
60impl DeBERTaV2Vocab {
61 pub fn get_pad_value(&self) -> &str {
62 self.special_token_map
63 .pad_token
64 .as_deref()
65 .unwrap_or(DEFAULT_PAD_TOKEN)
66 }
67
68 pub fn get_bos_value(&self) -> &str {
69 self.special_token_map
70 .bos_token
71 .as_deref()
72 .unwrap_or(DEFAULT_BOS_TOKEN)
73 }
74
75 pub fn get_sep_value(&self) -> &str {
76 self.special_token_map
77 .sep_token
78 .as_deref()
79 .unwrap_or(DEFAULT_SEP_TOKEN)
80 }
81
82 pub fn get_cls_value(&self) -> &str {
83 self.special_token_map
84 .cls_token
85 .as_deref()
86 .unwrap_or(DEFAULT_CLS_TOKEN)
87 }
88
89 pub fn get_eos_value(&self) -> &str {
90 self.special_token_map
91 .eos_token
92 .as_deref()
93 .unwrap_or(DEFAULT_EOS_TOKEN)
94 }
95
96 pub fn get_mask_value(&self) -> &str {
97 self.special_token_map
98 .mask_token
99 .as_deref()
100 .unwrap_or(DEFAULT_MASK_TOKEN)
101 }
102}
103
104impl Vocab for DeBERTaV2Vocab {
105 fn get_unknown_value(&self) -> &str {
106 &self.special_token_map.unk_token
107 }
108
109 fn values(&self) -> &HashMap<String, i64> {
110 &self.values
111 }
112
113 fn indices(&self) -> &HashMap<i64, String> {
114 &self.indices
115 }
116
117 fn special_values(&self) -> &HashMap<String, i64> {
118 &self.special_values
119 }
120
121 fn special_indices(&self) -> &HashMap<i64, String> {
122 &self.special_indices
123 }
124
125 fn values_mut(&mut self) -> &mut HashMap<String, i64> {
126 &mut self.values
127 }
128
129 fn indices_mut(&mut self) -> &mut HashMap<i64, String> {
130 &mut self.indices
131 }
132
133 fn special_values_mut(&mut self) -> &mut HashMap<String, i64> {
134 &mut self.special_values
135 }
136
137 fn special_indices_mut(&mut self) -> &mut HashMap<i64, String> {
138 &mut self.special_indices
139 }
140
141 fn from_file<P: AsRef<Path>>(path: P) -> Result<DeBERTaV2Vocab, TokenizerError> {
142 let mut values = read_protobuf_file(path)?;
143
144 let special_token_map = SpecialTokenMap {
145 unk_token: DEFAULT_UNK_TOKEN.to_string(),
146 pad_token: Some(DEFAULT_PAD_TOKEN.to_string()),
147 bos_token: Some(DEFAULT_BOS_TOKEN.to_string()),
148 sep_token: Some(DEFAULT_SEP_TOKEN.to_string()),
149 cls_token: Some(DEFAULT_CLS_TOKEN.to_string()),
150 eos_token: Some(DEFAULT_EOS_TOKEN.to_string()),
151 mask_token: Some(DEFAULT_MASK_TOKEN.to_string()),
152 additional_special_tokens: None,
153 };
154 if !values.contains_key(special_token_map.mask_token.as_ref().unwrap()) {
155 values.insert(
156 special_token_map.mask_token.as_ref().unwrap().clone(),
157 values.len() as i64,
158 );
159 }
160 Self::from_values_and_special_token_map(values, special_token_map)
161 }
162
163 fn from_file_with_special_token_mapping<P: AsRef<Path>, S: AsRef<Path>>(
164 path: P,
165 special_token_mapping_path: S,
166 ) -> Result<Self, TokenizerError> {
167 let mut values = read_protobuf_file(path)?;
168 let special_token_map = read_special_token_mapping_file(special_token_mapping_path)?;
169
170 if let Some(mask_token) = &special_token_map.mask_token {
171 values.insert(mask_token.clone(), values.len() as i64);
172 }
173 Self::from_values_and_special_token_map(values, special_token_map)
174 }
175
176 fn from_values_and_special_token_map(
177 values: HashMap<String, i64>,
178 special_token_map: SpecialTokenMap,
179 ) -> Result<Self, TokenizerError>
180 where
181 Self: Sized,
182 {
183 let mut special_values = HashMap::new();
184 special_token_map.register_special_values(&values, &mut special_values)?;
185
186 let indices = swap_key_values(&values);
187 let special_indices = swap_key_values(&special_values);
188 Ok(Self {
189 values,
190 indices,
191 special_token_map,
192 special_values,
193 special_indices,
194 })
195 }
196
197 fn token_to_id(&self, token: &str) -> i64 {
198 self._token_to_id(
199 token,
200 &self.values,
201 &self.special_values,
202 self.get_unknown_value(),
203 )
204 }
205
206 fn id_to_token(&self, id: &i64) -> String {
207 self._id_to_token(
208 id,
209 &self.indices,
210 &self.special_indices,
211 self.get_unknown_value(),
212 )
213 }
214}