lean_ctx/core/embeddings/
tokenizer.rs1use std::collections::HashMap;
13use std::path::Path;
14
15pub struct WordPieceTokenizer {
16 vocab: HashMap<String, i32>,
17 cls_id: i32,
18 sep_id: i32,
19 pad_id: i32,
20 unk_id: i32,
21 max_word_chars: usize,
22}
23
24#[derive(Debug, Clone)]
25pub struct TokenizedInput {
26 pub input_ids: Vec<i32>,
27 pub attention_mask: Vec<i32>,
28 pub token_type_ids: Vec<i32>,
29}
30
31impl TokenizedInput {
32 pub fn pad_to(&mut self, target_len: usize, pad_id: i32) {
34 while self.input_ids.len() < target_len {
35 self.input_ids.push(pad_id);
36 self.attention_mask.push(0);
37 self.token_type_ids.push(0);
38 }
39 }
40}
41
42impl WordPieceTokenizer {
43 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
45 let content = std::fs::read_to_string(path)
46 .map_err(|e| anyhow::anyhow!("Failed to read vocab file {}: {}", path.display(), e))?;
47 Self::from_vocab_str(&content)
48 }
49
50 pub fn from_vocab_str(vocab_str: &str) -> anyhow::Result<Self> {
52 let vocab: HashMap<String, i32> = vocab_str
53 .lines()
54 .enumerate()
55 .map(|(i, line)| (line.to_string(), i as i32))
56 .collect();
57
58 let cls_id = *vocab
59 .get("[CLS]")
60 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [CLS] token"))?;
61 let sep_id = *vocab
62 .get("[SEP]")
63 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [SEP] token"))?;
64 let pad_id = *vocab
65 .get("[PAD]")
66 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [PAD] token"))?;
67 let unk_id = *vocab
68 .get("[UNK]")
69 .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [UNK] token"))?;
70
71 Ok(Self {
72 vocab,
73 cls_id,
74 sep_id,
75 pad_id,
76 unk_id,
77 max_word_chars: 200,
78 })
79 }
80
81 pub fn encode(&self, text: &str, max_len: usize) -> TokenizedInput {
83 let words = self.pre_tokenize(text);
84 let mut ids = vec![self.cls_id];
85
86 for word in &words {
87 if ids.len() >= max_len - 1 {
88 break;
89 }
90 let subword_ids = self.wordpiece_encode(word);
91 for id in subword_ids {
92 if ids.len() >= max_len - 1 {
93 break;
94 }
95 ids.push(id);
96 }
97 }
98
99 ids.push(self.sep_id);
100
101 let len = ids.len();
102 TokenizedInput {
103 input_ids: ids,
104 attention_mask: vec![1; len],
105 token_type_ids: vec![0; len],
106 }
107 }
108
109 pub fn pad_id(&self) -> i32 {
110 self.pad_id
111 }
112
113 pub fn vocab_size(&self) -> usize {
114 self.vocab.len()
115 }
116
117 fn pre_tokenize(&self, text: &str) -> Vec<String> {
121 let mut words = Vec::new();
122 let mut current = String::new();
123
124 for ch in text.chars() {
125 if ch.is_whitespace() {
126 if !current.is_empty() {
127 words.extend(self.split_identifier(¤t));
128 current.clear();
129 }
130 } else if is_bert_punctuation(ch) {
131 if !current.is_empty() {
132 words.extend(self.split_identifier(¤t));
133 current.clear();
134 }
135 words.push(ch.to_string());
136 } else {
137 current.push(ch);
138 }
139 }
140 if !current.is_empty() {
141 words.extend(self.split_identifier(¤t));
142 }
143
144 words.iter().map(|w| w.to_lowercase()).collect()
145 }
146
147 fn split_identifier(&self, word: &str) -> Vec<String> {
150 let lower = word.to_lowercase();
151 if self.vocab.contains_key(&lower) {
152 return vec![word.to_string()];
153 }
154
155 let mut parts = Vec::new();
156 let mut current = String::new();
157 let chars: Vec<char> = word.chars().collect();
158
159 for (i, &ch) in chars.iter().enumerate() {
160 if ch == '_' || ch == '-' {
161 if !current.is_empty() {
162 parts.push(current.clone());
163 current.clear();
164 }
165 } else if i > 0 && ch.is_ascii_uppercase() && chars[i - 1].is_ascii_lowercase() {
166 if !current.is_empty() {
167 parts.push(current.clone());
168 current.clear();
169 }
170 current.push(ch);
171 } else {
172 current.push(ch);
173 }
174 }
175 if !current.is_empty() {
176 parts.push(current);
177 }
178
179 if parts.is_empty() {
180 vec![word.to_string()]
181 } else {
182 parts
183 }
184 }
185
186 fn wordpiece_encode(&self, word: &str) -> Vec<i32> {
188 if word.chars().count() > self.max_word_chars {
189 return vec![self.unk_id];
190 }
191
192 let chars: Vec<char> = word.chars().collect();
193 let mut tokens = Vec::new();
194 let mut start = 0;
195
196 while start < chars.len() {
197 let mut end = chars.len();
198 let mut matched = false;
199
200 while start < end {
201 let substr: String = chars[start..end].iter().collect();
202 let candidate = if start > 0 {
203 format!("##{substr}")
204 } else {
205 substr
206 };
207
208 if let Some(&id) = self.vocab.get(&candidate) {
209 tokens.push(id);
210 matched = true;
211 start = end;
212 break;
213 }
214 end -= 1;
215 }
216
217 if !matched {
218 tokens.push(self.unk_id);
219 start += 1;
220 }
221 }
222
223 tokens
224 }
225}
226
227fn is_bert_punctuation(ch: char) -> bool {
229 if ch.is_ascii() {
230 matches!(
231 ch,
232 '!' | '"'
233 | '#'
234 | '$'
235 | '%'
236 | '&'
237 | '\''
238 | '('
239 | ')'
240 | '*'
241 | '+'
242 | ','
243 | '-'
244 | '.'
245 | '/'
246 | ':'
247 | ';'
248 | '<'
249 | '='
250 | '>'
251 | '?'
252 | '@'
253 | '['
254 | '\\'
255 | ']'
256 | '^'
257 | '_'
258 | '`'
259 | '{'
260 | '|'
261 | '}'
262 | '~'
263 )
264 } else {
265 ch.is_ascii_punctuation()
266 }
267}
268
269pub struct HfTokenizerWrapper {
275 inner: WordPieceTokenizer,
276}
277
278impl HfTokenizerWrapper {
279 pub fn from_file(path: &Path) -> anyhow::Result<Self> {
281 let content = std::fs::read_to_string(path).map_err(|e| {
282 anyhow::anyhow!("Failed to read tokenizer.json {}: {}", path.display(), e)
283 })?;
284 Self::from_json(&content)
285 }
286
287 fn from_json(json_str: &str) -> anyhow::Result<Self> {
288 let parsed: serde_json::Value = serde_json::from_str(json_str)
289 .map_err(|e| anyhow::anyhow!("Invalid tokenizer.json: {e}"))?;
290
291 let vocab_obj = parsed
292 .get("model")
293 .and_then(|m| m.get("vocab"))
294 .and_then(|v| v.as_object())
295 .ok_or_else(|| anyhow::anyhow!("tokenizer.json missing model.vocab object"))?;
296
297 let mut vocab_lines: Vec<(String, i32)> = vocab_obj
298 .iter()
299 .filter_map(|(token, id)| id.as_i64().map(|id| (token.clone(), id as i32)))
300 .collect();
301 vocab_lines.sort_by_key(|(_, id)| *id);
302
303 let vocab_str: String = vocab_lines
304 .into_iter()
305 .map(|(token, _)| token)
306 .collect::<Vec<_>>()
307 .join("\n");
308
309 let inner = WordPieceTokenizer::from_vocab_str(&vocab_str)?;
310 Ok(Self { inner })
311 }
312
313 pub fn encode(&self, text: &str, max_len: usize) -> TokenizedInput {
314 self.inner.encode(text, max_len)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 fn test_vocab() -> WordPieceTokenizer {
323 let vocab = "[PAD]\n[UNK]\n[CLS]\n[SEP]\nhello\nworld\nfn\nvalidate\ntoken\n##s\n##ing\nauth\n##enticate\nuser\nhandle\nrequest\n##er\nprocess\ndata\n.\n,\n(\n)\n{";
324 WordPieceTokenizer::from_vocab_str(vocab).unwrap()
325 }
326
327 #[test]
328 fn encode_basic() {
329 let tok = test_vocab();
330 let input = tok.encode("hello world", 512);
331 assert_eq!(input.input_ids[0], tok.cls_id);
332 assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
333 assert!(input.input_ids.len() >= 4); }
335
336 #[test]
337 fn encode_attention_mask() {
338 let tok = test_vocab();
339 let input = tok.encode("hello", 512);
340 assert!(input.attention_mask.iter().all(|&m| m == 1));
341 assert_eq!(input.attention_mask.len(), input.input_ids.len());
342 }
343
344 #[test]
345 fn encode_token_type_ids_are_zero() {
346 let tok = test_vocab();
347 let input = tok.encode("hello", 512);
348 assert!(input.token_type_ids.iter().all(|&t| t == 0));
349 }
350
351 #[test]
352 fn encode_respects_max_len() {
353 let tok = test_vocab();
354 let input = tok.encode("hello world hello world hello world", 6);
355 assert!(input.input_ids.len() <= 6);
356 assert_eq!(input.input_ids[0], tok.cls_id);
357 assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
358 }
359
360 #[test]
361 fn wordpiece_subwords() {
362 let tok = test_vocab();
363 let ids = tok.wordpiece_encode("tokens");
365 assert_eq!(ids.len(), 2);
366 assert_eq!(ids[0], *tok.vocab.get("token").unwrap());
367 assert_eq!(ids[1], *tok.vocab.get("##s").unwrap());
368 }
369
370 #[test]
371 fn wordpiece_unknown() {
372 let tok = test_vocab();
373 let ids = tok.wordpiece_encode("xyzzyplugh");
374 assert!(ids.contains(&tok.unk_id));
375 }
376
377 #[test]
378 fn pre_tokenize_camel_case() {
379 let tok = test_vocab();
380 let words = tok.pre_tokenize("handleRequest");
381 assert!(words.contains(&"handle".to_string()));
382 assert!(words.contains(&"request".to_string()));
383 }
384
385 #[test]
386 fn pre_tokenize_snake_case() {
387 let tok = test_vocab();
388 let words = tok.pre_tokenize("validate_token");
389 assert!(words.contains(&"validate".to_string()));
390 assert!(words.contains(&"token".to_string()));
391 }
392
393 #[test]
394 fn pre_tokenize_punctuation() {
395 let tok = test_vocab();
396 let words = tok.pre_tokenize("fn(x)");
397 assert!(words.contains(&"fn".to_string()));
398 assert!(words.contains(&"(".to_string()));
399 assert!(words.contains(&")".to_string()));
400 }
401
402 #[test]
403 fn pad_to_extends() {
404 let tok = test_vocab();
405 let mut input = tok.encode("hello", 512);
406 let original_len = input.input_ids.len();
407 input.pad_to(10, tok.pad_id);
408 assert_eq!(input.input_ids.len(), 10);
409 assert_eq!(input.attention_mask[original_len], 0);
410 }
411
412 #[test]
413 fn vocab_size() {
414 let tok = test_vocab();
415 assert_eq!(tok.vocab_size(), 24);
416 }
417
418 #[test]
419 fn empty_input() {
420 let tok = test_vocab();
421 let input = tok.encode("", 512);
422 assert_eq!(input.input_ids.len(), 2); }
424
425 #[test]
426 fn bert_punctuation_detection() {
427 assert!(is_bert_punctuation('.'));
428 assert!(is_bert_punctuation('('));
429 assert!(is_bert_punctuation('{'));
430 assert!(!is_bert_punctuation('a'));
431 assert!(!is_bert_punctuation('0'));
432 }
433
434 #[test]
435 fn hf_tokenizer_from_json() {
436 let json = r#"{
437 "version": "1.0",
438 "model": {
439 "type": "WordPiece",
440 "vocab": {
441 "[PAD]": 0, "[UNK]": 1, "[CLS]": 2, "[SEP]": 3,
442 "hello": 4, "world": 5, "fn": 6
443 }
444 }
445 }"#;
446 let tok = HfTokenizerWrapper::from_json(json).unwrap();
447 let input = tok.encode("hello world", 512);
448 assert_eq!(input.input_ids[0], 2); assert_eq!(*input.input_ids.last().unwrap(), 3); assert!(input.input_ids.len() >= 4);
451 }
452
453 #[test]
454 fn hf_tokenizer_invalid_json() {
455 assert!(HfTokenizerWrapper::from_json("not json").is_err());
456 }
457
458 #[test]
459 fn hf_tokenizer_missing_vocab() {
460 let json = r#"{"model": {"type": "WordPiece"}}"#;
461 assert!(HfTokenizerWrapper::from_json(json).is_err());
462 }
463}