Skip to main content

lean_ctx/core/embeddings/
tokenizer.rs

1//! Minimal WordPiece tokenizer for BERT-style embedding models.
2//!
3//! Implements the standard BERT tokenization pipeline:
4//! 1. Lowercase + accent stripping
5//! 2. Whitespace + punctuation splitting
6//! 3. WordPiece subword tokenization
7//! 4. Special token insertion (`[CLS]`, `[SEP]`)
8//!
9//! Optimized for code search: handles camelCase, snake_case, and common
10//! programming punctuation correctly.
11
12use std::collections::HashMap;
13use std::path::Path;
14
15pub struct WordPieceTokenizer {
16    vocab: HashMap<String, i32>,
17    cls_id: i32,
18    sep_id: i32,
19    pad_id: i32,
20    unk_id: i32,
21    max_word_chars: usize,
22}
23
24#[derive(Debug, Clone)]
25pub struct TokenizedInput {
26    pub input_ids: Vec<i32>,
27    pub attention_mask: Vec<i32>,
28    pub token_type_ids: Vec<i32>,
29}
30
31impl TokenizedInput {
32    /// Pad the input to a fixed length.
33    pub fn pad_to(&mut self, target_len: usize, pad_id: i32) {
34        while self.input_ids.len() < target_len {
35            self.input_ids.push(pad_id);
36            self.attention_mask.push(0);
37            self.token_type_ids.push(0);
38        }
39    }
40}
41
42impl WordPieceTokenizer {
43    /// Load vocabulary from a standard vocab.txt file (one token per line).
44    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
45        let content = std::fs::read_to_string(path)
46            .map_err(|e| anyhow::anyhow!("Failed to read vocab file {}: {}", path.display(), e))?;
47        Self::from_vocab_str(&content)
48    }
49
50    /// Build tokenizer from vocabulary string (one token per line).
51    pub fn from_vocab_str(vocab_str: &str) -> anyhow::Result<Self> {
52        let vocab: HashMap<String, i32> = vocab_str
53            .lines()
54            .enumerate()
55            .map(|(i, line)| (line.to_string(), i as i32))
56            .collect();
57
58        let cls_id = *vocab
59            .get("[CLS]")
60            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [CLS] token"))?;
61        let sep_id = *vocab
62            .get("[SEP]")
63            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [SEP] token"))?;
64        let pad_id = *vocab
65            .get("[PAD]")
66            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [PAD] token"))?;
67        let unk_id = *vocab
68            .get("[UNK]")
69            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [UNK] token"))?;
70
71        Ok(Self {
72            vocab,
73            cls_id,
74            sep_id,
75            pad_id,
76            unk_id,
77            max_word_chars: 200,
78        })
79    }
80
81    /// Encode text into token IDs with `[CLS]` prefix and `[SEP]` suffix.
82    pub fn encode(&self, text: &str, max_len: usize) -> TokenizedInput {
83        let words = self.pre_tokenize(text);
84        let mut ids = vec![self.cls_id];
85
86        for word in &words {
87            if ids.len() >= max_len - 1 {
88                break;
89            }
90            let subword_ids = self.wordpiece_encode(word);
91            for id in subword_ids {
92                if ids.len() >= max_len - 1 {
93                    break;
94                }
95                ids.push(id);
96            }
97        }
98
99        ids.push(self.sep_id);
100
101        let len = ids.len();
102        TokenizedInput {
103            input_ids: ids,
104            attention_mask: vec![1; len],
105            token_type_ids: vec![0; len],
106        }
107    }
108
109    pub fn pad_id(&self) -> i32 {
110        self.pad_id
111    }
112
113    pub fn vocab_size(&self) -> usize {
114        self.vocab.len()
115    }
116
117    /// Split text into word-level tokens.
118    /// Handles: whitespace splitting, punctuation splitting,
119    /// camelCase splitting, underscore-separated identifiers, then lowercase.
120    fn pre_tokenize(&self, text: &str) -> Vec<String> {
121        let mut words = Vec::new();
122        let mut current = String::new();
123
124        for ch in text.chars() {
125            if ch.is_whitespace() {
126                if !current.is_empty() {
127                    words.extend(self.split_identifier(&current));
128                    current.clear();
129                }
130            } else if is_bert_punctuation(ch) {
131                if !current.is_empty() {
132                    words.extend(self.split_identifier(&current));
133                    current.clear();
134                }
135                words.push(ch.to_string());
136            } else {
137                current.push(ch);
138            }
139        }
140        if !current.is_empty() {
141            words.extend(self.split_identifier(&current));
142        }
143
144        words.iter().map(|w| w.to_lowercase()).collect()
145    }
146
147    /// Split programming identifiers (camelCase, snake_case) into subwords.
148    /// Called BEFORE lowercasing so case boundaries are still visible.
149    fn split_identifier(&self, word: &str) -> Vec<String> {
150        let lower = word.to_lowercase();
151        if self.vocab.contains_key(&lower) {
152            return vec![word.to_string()];
153        }
154
155        let mut parts = Vec::new();
156        let mut current = String::new();
157        let chars: Vec<char> = word.chars().collect();
158
159        for (i, &ch) in chars.iter().enumerate() {
160            if ch == '_' || ch == '-' {
161                if !current.is_empty() {
162                    parts.push(current.clone());
163                    current.clear();
164                }
165            } else if i > 0 && ch.is_ascii_uppercase() && chars[i - 1].is_ascii_lowercase() {
166                if !current.is_empty() {
167                    parts.push(current.clone());
168                    current.clear();
169                }
170                current.push(ch);
171            } else {
172                current.push(ch);
173            }
174        }
175        if !current.is_empty() {
176            parts.push(current);
177        }
178
179        if parts.is_empty() {
180            vec![word.to_string()]
181        } else {
182            parts
183        }
184    }
185
186    /// Apply WordPiece algorithm to a single word.
187    fn wordpiece_encode(&self, word: &str) -> Vec<i32> {
188        if word.chars().count() > self.max_word_chars {
189            return vec![self.unk_id];
190        }
191
192        let chars: Vec<char> = word.chars().collect();
193        let mut tokens = Vec::new();
194        let mut start = 0;
195
196        while start < chars.len() {
197            let mut end = chars.len();
198            let mut matched = false;
199
200            while start < end {
201                let substr: String = chars[start..end].iter().collect();
202                let candidate = if start > 0 {
203                    format!("##{substr}")
204                } else {
205                    substr
206                };
207
208                if let Some(&id) = self.vocab.get(&candidate) {
209                    tokens.push(id);
210                    matched = true;
211                    start = end;
212                    break;
213                }
214                end -= 1;
215            }
216
217            if !matched {
218                tokens.push(self.unk_id);
219                start += 1;
220            }
221        }
222
223        tokens
224    }
225}
226
227/// BERT-style punctuation detection.
228fn is_bert_punctuation(ch: char) -> bool {
229    if ch.is_ascii() {
230        matches!(
231            ch,
232            '!' | '"'
233                | '#'
234                | '$'
235                | '%'
236                | '&'
237                | '\''
238                | '('
239                | ')'
240                | '*'
241                | '+'
242                | ','
243                | '-'
244                | '.'
245                | '/'
246                | ':'
247                | ';'
248                | '<'
249                | '='
250                | '>'
251                | '?'
252                | '@'
253                | '['
254                | '\\'
255                | ']'
256                | '^'
257                | '_'
258                | '`'
259                | '{'
260                | '|'
261                | '}'
262                | '~'
263        )
264    } else {
265        ch.is_ascii_punctuation()
266    }
267}
268
269/// Wrapper for HuggingFace `tokenizer.json` files.
270///
271/// Parses the JSON to extract the vocabulary map, then delegates to
272/// the existing `WordPieceTokenizer` for actual tokenization.
273/// Supports both WordPiece and BPE vocab formats (both map tokens → IDs).
274pub struct HfTokenizerWrapper {
275    inner: WordPieceTokenizer,
276}
277
278impl HfTokenizerWrapper {
279    /// Load from a HuggingFace `tokenizer.json` file.
280    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
281        let content = std::fs::read_to_string(path).map_err(|e| {
282            anyhow::anyhow!("Failed to read tokenizer.json {}: {}", path.display(), e)
283        })?;
284        Self::from_json(&content)
285    }
286
287    fn from_json(json_str: &str) -> anyhow::Result<Self> {
288        let parsed: serde_json::Value = serde_json::from_str(json_str)
289            .map_err(|e| anyhow::anyhow!("Invalid tokenizer.json: {e}"))?;
290
291        let vocab_obj = parsed
292            .get("model")
293            .and_then(|m| m.get("vocab"))
294            .and_then(|v| v.as_object())
295            .ok_or_else(|| anyhow::anyhow!("tokenizer.json missing model.vocab object"))?;
296
297        let mut vocab_lines: Vec<(String, i32)> = vocab_obj
298            .iter()
299            .filter_map(|(token, id)| id.as_i64().map(|id| (token.clone(), id as i32)))
300            .collect();
301        vocab_lines.sort_by_key(|(_, id)| *id);
302
303        let vocab_str: String = vocab_lines
304            .into_iter()
305            .map(|(token, _)| token)
306            .collect::<Vec<_>>()
307            .join("\n");
308
309        let inner = WordPieceTokenizer::from_vocab_str(&vocab_str)?;
310        Ok(Self { inner })
311    }
312
313    pub fn encode(&self, text: &str, max_len: usize) -> TokenizedInput {
314        self.inner.encode(text, max_len)
315    }
316}
317
318#[cfg(test)]
319mod tests {
320    use super::*;
321
322    fn test_vocab() -> WordPieceTokenizer {
323        let vocab = "[PAD]\n[UNK]\n[CLS]\n[SEP]\nhello\nworld\nfn\nvalidate\ntoken\n##s\n##ing\nauth\n##enticate\nuser\nhandle\nrequest\n##er\nprocess\ndata\n.\n,\n(\n)\n{";
324        WordPieceTokenizer::from_vocab_str(vocab).unwrap()
325    }
326
327    #[test]
328    fn encode_basic() {
329        let tok = test_vocab();
330        let input = tok.encode("hello world", 512);
331        assert_eq!(input.input_ids[0], tok.cls_id);
332        assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
333        assert!(input.input_ids.len() >= 4); // [CLS] hello world [SEP]
334    }
335
336    #[test]
337    fn encode_attention_mask() {
338        let tok = test_vocab();
339        let input = tok.encode("hello", 512);
340        assert!(input.attention_mask.iter().all(|&m| m == 1));
341        assert_eq!(input.attention_mask.len(), input.input_ids.len());
342    }
343
344    #[test]
345    fn encode_token_type_ids_are_zero() {
346        let tok = test_vocab();
347        let input = tok.encode("hello", 512);
348        assert!(input.token_type_ids.iter().all(|&t| t == 0));
349    }
350
351    #[test]
352    fn encode_respects_max_len() {
353        let tok = test_vocab();
354        let input = tok.encode("hello world hello world hello world", 6);
355        assert!(input.input_ids.len() <= 6);
356        assert_eq!(input.input_ids[0], tok.cls_id);
357        assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
358    }
359
360    #[test]
361    fn wordpiece_subwords() {
362        let tok = test_vocab();
363        // "tokens" should split into "token" + "##s"
364        let ids = tok.wordpiece_encode("tokens");
365        assert_eq!(ids.len(), 2);
366        assert_eq!(ids[0], *tok.vocab.get("token").unwrap());
367        assert_eq!(ids[1], *tok.vocab.get("##s").unwrap());
368    }
369
370    #[test]
371    fn wordpiece_unknown() {
372        let tok = test_vocab();
373        let ids = tok.wordpiece_encode("xyzzyplugh");
374        assert!(ids.contains(&tok.unk_id));
375    }
376
377    #[test]
378    fn pre_tokenize_camel_case() {
379        let tok = test_vocab();
380        let words = tok.pre_tokenize("handleRequest");
381        assert!(words.contains(&"handle".to_string()));
382        assert!(words.contains(&"request".to_string()));
383    }
384
385    #[test]
386    fn pre_tokenize_snake_case() {
387        let tok = test_vocab();
388        let words = tok.pre_tokenize("validate_token");
389        assert!(words.contains(&"validate".to_string()));
390        assert!(words.contains(&"token".to_string()));
391    }
392
393    #[test]
394    fn pre_tokenize_punctuation() {
395        let tok = test_vocab();
396        let words = tok.pre_tokenize("fn(x)");
397        assert!(words.contains(&"fn".to_string()));
398        assert!(words.contains(&"(".to_string()));
399        assert!(words.contains(&")".to_string()));
400    }
401
402    #[test]
403    fn pad_to_extends() {
404        let tok = test_vocab();
405        let mut input = tok.encode("hello", 512);
406        let original_len = input.input_ids.len();
407        input.pad_to(10, tok.pad_id);
408        assert_eq!(input.input_ids.len(), 10);
409        assert_eq!(input.attention_mask[original_len], 0);
410    }
411
412    #[test]
413    fn vocab_size() {
414        let tok = test_vocab();
415        assert_eq!(tok.vocab_size(), 24);
416    }
417
418    #[test]
419    fn empty_input() {
420        let tok = test_vocab();
421        let input = tok.encode("", 512);
422        assert_eq!(input.input_ids.len(), 2); // [CLS] [SEP]
423    }
424
425    #[test]
426    fn bert_punctuation_detection() {
427        assert!(is_bert_punctuation('.'));
428        assert!(is_bert_punctuation('('));
429        assert!(is_bert_punctuation('{'));
430        assert!(!is_bert_punctuation('a'));
431        assert!(!is_bert_punctuation('0'));
432    }
433
434    #[test]
435    fn hf_tokenizer_from_json() {
436        let json = r#"{
437            "version": "1.0",
438            "model": {
439                "type": "WordPiece",
440                "vocab": {
441                    "[PAD]": 0, "[UNK]": 1, "[CLS]": 2, "[SEP]": 3,
442                    "hello": 4, "world": 5, "fn": 6
443                }
444            }
445        }"#;
446        let tok = HfTokenizerWrapper::from_json(json).unwrap();
447        let input = tok.encode("hello world", 512);
448        assert_eq!(input.input_ids[0], 2); // [CLS]
449        assert_eq!(*input.input_ids.last().unwrap(), 3); // [SEP]
450        assert!(input.input_ids.len() >= 4);
451    }
452
453    #[test]
454    fn hf_tokenizer_invalid_json() {
455        assert!(HfTokenizerWrapper::from_json("not json").is_err());
456    }
457
458    #[test]
459    fn hf_tokenizer_missing_vocab() {
460        let json = r#"{"model": {"type": "WordPiece"}}"#;
461        assert!(HfTokenizerWrapper::from_json(json).is_err());
462    }
463}