lean_ctx/core/embeddings/
tokenizer.rs1use std::collections::HashMap;
13use std::path::Path;
14
15pub struct WordPieceTokenizer {
16 vocab: HashMap<String, i32>,
17 cls_id: i32,
18 sep_id: i32,
19 pad_id: i32,
20 unk_id: i32,
21 max_word_chars: usize,
22}
23
24#[derive(Debug, Clone)]
25pub struct TokenizedInput {
26 pub input_ids: Vec<i32>,
27 pub attention_mask: Vec<i32>,
28 pub token_type_ids: Vec<i32>,
29}
30
31impl TokenizedInput {
32 pub fn pad_to(&mut self, target_len: usize, pad_id: i32) {
34 while self.input_ids.len() < target_len {
35 self.input_ids.push(pad_id);
36 self.attention_mask.push(0);
37 self.token_type_ids.push(0);
38 }
39 }
40}
41
42impl WordPieceTokenizer {
43 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
45 let content = std::fs::read_to_string(path)
46 .map_err(|e| anyhow::anyhow!("Failed to read vocab file {}: {}", path.display(), e))?;
47 Self::from_vocab_str(&content)
48 }
49
50 pub fn from_vocab_str(vocab_str: &str) -> anyhow::Result<Self> {
52 let vocab: HashMap<String, i32> = vocab_str
53 .lines()
54 .enumerate()
55 .map(|(i, line)| (line.to_string(), i as i32))
56 .collect();
57
58 let cls_id = *vocab
59 .get("[CLS]")
60 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [CLS] token"))?;
61 let sep_id = *vocab
62 .get("[SEP]")
63 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [SEP] token"))?;
64 let pad_id = *vocab
65 .get("[PAD]")
66 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [PAD] token"))?;
67 let unk_id = *vocab
68 .get("[UNK]")
69 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [UNK] token"))?;
70
71 Ok(Self {
72 vocab,
73 cls_id,
74 sep_id,
75 pad_id,
76 unk_id,
77 max_word_chars: 200,
78 })
79 }
80
81 pub fn encode(&self, text: &str, max_len: usize) -> TokenizedInput {
83 let words = self.pre_tokenize(text);
84 let mut ids = vec![self.cls_id];
85
86 for word in &words {
87 if ids.len() >= max_len - 1 {
88 break;
89 }
90 let subword_ids = self.wordpiece_encode(word);
91 for id in subword_ids {
92 if ids.len() >= max_len - 1 {
93 break;
94 }
95 ids.push(id);
96 }
97 }
98
99 ids.push(self.sep_id);
100
101 let len = ids.len();
102 TokenizedInput {
103 input_ids: ids,
104 attention_mask: vec![1; len],
105 token_type_ids: vec![0; len],
106 }
107 }
108
109 pub fn pad_id(&self) -> i32 {
110 self.pad_id
111 }
112
113 pub fn vocab_size(&self) -> usize {
114 self.vocab.len()
115 }
116
117 fn pre_tokenize(&self, text: &str) -> Vec<String> {
121 let mut words = Vec::new();
122 let mut current = String::new();
123
124 for ch in text.chars() {
125 if ch.is_whitespace() {
126 if !current.is_empty() {
127 words.extend(self.split_identifier(¤t));
128 current.clear();
129 }
130 } else if is_bert_punctuation(ch) {
131 if !current.is_empty() {
132 words.extend(self.split_identifier(¤t));
133 current.clear();
134 }
135 words.push(ch.to_string());
136 } else {
137 current.push(ch);
138 }
139 }
140 if !current.is_empty() {
141 words.extend(self.split_identifier(¤t));
142 }
143
144 words.iter().map(|w| w.to_lowercase()).collect()
145 }
146
147 fn split_identifier(&self, word: &str) -> Vec<String> {
150 let lower = word.to_lowercase();
151 if self.vocab.contains_key(&lower) {
152 return vec![word.to_string()];
153 }
154
155 let mut parts = Vec::new();
156 let mut current = String::new();
157 let chars: Vec<char> = word.chars().collect();
158
159 for (i, &ch) in chars.iter().enumerate() {
160 if ch == '_' || ch == '-' {
161 if !current.is_empty() {
162 parts.push(current.clone());
163 current.clear();
164 }
165 } else if i > 0 && ch.is_ascii_uppercase() && chars[i - 1].is_ascii_lowercase() {
166 if !current.is_empty() {
167 parts.push(current.clone());
168 current.clear();
169 }
170 current.push(ch);
171 } else {
172 current.push(ch);
173 }
174 }
175 if !current.is_empty() {
176 parts.push(current);
177 }
178
179 if parts.is_empty() {
180 vec![word.to_string()]
181 } else {
182 parts
183 }
184 }
185
186 fn wordpiece_encode(&self, word: &str) -> Vec<i32> {
188 if word.chars().count() > self.max_word_chars {
189 return vec![self.unk_id];
190 }
191
192 let chars: Vec<char> = word.chars().collect();
193 let mut tokens = Vec::new();
194 let mut start = 0;
195
196 while start < chars.len() {
197 let mut end = chars.len();
198 let mut matched = false;
199
200 while start < end {
201 let substr: String = chars[start..end].iter().collect();
202 let candidate = if start > 0 {
203 format!("##{substr}")
204 } else {
205 substr
206 };
207
208 if let Some(&id) = self.vocab.get(&candidate) {
209 tokens.push(id);
210 matched = true;
211 start = end;
212 break;
213 }
214 end -= 1;
215 }
216
217 if !matched {
218 tokens.push(self.unk_id);
219 start += 1;
220 }
221 }
222
223 tokens
224 }
225}
226
227fn is_bert_punctuation(ch: char) -> bool {
229 if ch.is_ascii() {
230 matches!(
231 ch,
232 '!' | '"'
233 | '#'
234 | '$'
235 | '%'
236 | '&'
237 | '\''
238 | '('
239 | ')'
240 | '*'
241 | '+'
242 | ','
243 | '-'
244 | '.'
245 | '/'
246 | ':'
247 | ';'
248 | '<'
249 | '='
250 | '>'
251 | '?'
252 | '@'
253 | '['
254 | '\\'
255 | ']'
256 | '^'
257 | '_'
258 | '`'
259 | '{'
260 | '|'
261 | '}'
262 | '~'
263 )
264 } else {
265 ch.is_ascii_punctuation()
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use super::*;
272
273 fn test_vocab() -> WordPieceTokenizer {
274 let vocab = "[PAD]\n[UNK]\n[CLS]\n[SEP]\nhello\nworld\nfn\nvalidate\ntoken\n##s\n##ing\nauth\n##enticate\nuser\nhandle\nrequest\n##er\nprocess\ndata\n.\n,\n(\n)\n{";
275 WordPieceTokenizer::from_vocab_str(vocab).unwrap()
276 }
277
278 #[test]
279 fn encode_basic() {
280 let tok = test_vocab();
281 let input = tok.encode("hello world", 512);
282 assert_eq!(input.input_ids[0], tok.cls_id);
283 assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
284 assert!(input.input_ids.len() >= 4); }
286
287 #[test]
288 fn encode_attention_mask() {
289 let tok = test_vocab();
290 let input = tok.encode("hello", 512);
291 assert!(input.attention_mask.iter().all(|&m| m == 1));
292 assert_eq!(input.attention_mask.len(), input.input_ids.len());
293 }
294
295 #[test]
296 fn encode_token_type_ids_are_zero() {
297 let tok = test_vocab();
298 let input = tok.encode("hello", 512);
299 assert!(input.token_type_ids.iter().all(|&t| t == 0));
300 }
301
302 #[test]
303 fn encode_respects_max_len() {
304 let tok = test_vocab();
305 let input = tok.encode("hello world hello world hello world", 6);
306 assert!(input.input_ids.len() <= 6);
307 assert_eq!(input.input_ids[0], tok.cls_id);
308 assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
309 }
310
311 #[test]
312 fn wordpiece_subwords() {
313 let tok = test_vocab();
314 let ids = tok.wordpiece_encode("tokens");
316 assert_eq!(ids.len(), 2);
317 assert_eq!(ids[0], *tok.vocab.get("token").unwrap());
318 assert_eq!(ids[1], *tok.vocab.get("##s").unwrap());
319 }
320
321 #[test]
322 fn wordpiece_unknown() {
323 let tok = test_vocab();
324 let ids = tok.wordpiece_encode("xyzzyplugh");
325 assert!(ids.contains(&tok.unk_id));
326 }
327
328 #[test]
329 fn pre_tokenize_camel_case() {
330 let tok = test_vocab();
331 let words = tok.pre_tokenize("handleRequest");
332 assert!(words.contains(&"handle".to_string()));
333 assert!(words.contains(&"request".to_string()));
334 }
335
336 #[test]
337 fn pre_tokenize_snake_case() {
338 let tok = test_vocab();
339 let words = tok.pre_tokenize("validate_token");
340 assert!(words.contains(&"validate".to_string()));
341 assert!(words.contains(&"token".to_string()));
342 }
343
344 #[test]
345 fn pre_tokenize_punctuation() {
346 let tok = test_vocab();
347 let words = tok.pre_tokenize("fn(x)");
348 assert!(words.contains(&"fn".to_string()));
349 assert!(words.contains(&"(".to_string()));
350 assert!(words.contains(&")".to_string()));
351 }
352
353 #[test]
354 fn pad_to_extends() {
355 let tok = test_vocab();
356 let mut input = tok.encode("hello", 512);
357 let original_len = input.input_ids.len();
358 input.pad_to(10, tok.pad_id);
359 assert_eq!(input.input_ids.len(), 10);
360 assert_eq!(input.attention_mask[original_len], 0);
361 }
362
363 #[test]
364 fn vocab_size() {
365 let tok = test_vocab();
366 assert_eq!(tok.vocab_size(), 24);
367 }
368
369 #[test]
370 fn empty_input() {
371 let tok = test_vocab();
372 let input = tok.encode("", 512);
373 assert_eq!(input.input_ids.len(), 2); }
375
376 #[test]
377 fn bert_punctuation_detection() {
378 assert!(is_bert_punctuation('.'));
379 assert!(is_bert_punctuation('('));
380 assert!(is_bert_punctuation('{'));
381 assert!(!is_bert_punctuation('a'));
382 assert!(!is_bert_punctuation('0'));
383 }
384}