Skip to main content

lean_ctx/core/embeddings/
tokenizer.rs

1//! Minimal WordPiece tokenizer for BERT-style embedding models.
2//!
3//! Implements the standard BERT tokenization pipeline:
4//! 1. Lowercase + accent stripping
5//! 2. Whitespace + punctuation splitting
6//! 3. WordPiece subword tokenization
7//! 4. Special token insertion ([CLS], [SEP])
8//!
9//! Optimized for code search: handles camelCase, snake_case, and common
10//! programming punctuation correctly.
11
12use std::collections::HashMap;
13use std::path::Path;
14
15pub struct WordPieceTokenizer {
16    vocab: HashMap<String, i32>,
17    cls_id: i32,
18    sep_id: i32,
19    pad_id: i32,
20    unk_id: i32,
21    max_word_chars: usize,
22}
23
24#[derive(Debug, Clone)]
25pub struct TokenizedInput {
26    pub input_ids: Vec<i32>,
27    pub attention_mask: Vec<i32>,
28    pub token_type_ids: Vec<i32>,
29}
30
31impl TokenizedInput {
32    /// Pad the input to a fixed length.
33    pub fn pad_to(&mut self, target_len: usize, pad_id: i32) {
34        while self.input_ids.len() < target_len {
35            self.input_ids.push(pad_id);
36            self.attention_mask.push(0);
37            self.token_type_ids.push(0);
38        }
39    }
40}
41
42impl WordPieceTokenizer {
43    /// Load vocabulary from a standard vocab.txt file (one token per line).
44    pub fn from_file(path: &Path) -> anyhow::Result<Self> {
45        let content = std::fs::read_to_string(path)
46            .map_err(|e| anyhow::anyhow!("Failed to read vocab file {}: {}", path.display(), e))?;
47        Self::from_vocab_str(&content)
48    }
49
50    /// Build tokenizer from vocabulary string (one token per line).
51    pub fn from_vocab_str(vocab_str: &str) -> anyhow::Result<Self> {
52        let vocab: HashMap<String, i32> = vocab_str
53            .lines()
54            .enumerate()
55            .map(|(i, line)| (line.to_string(), i as i32))
56            .collect();
57
58        let cls_id = *vocab
59            .get("[CLS]")
60            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [CLS] token"))?;
61        let sep_id = *vocab
62            .get("[SEP]")
63            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [SEP] token"))?;
64        let pad_id = *vocab
65            .get("[PAD]")
66            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [PAD] token"))?;
67        let unk_id = *vocab
68            .get("[UNK]")
69            .ok_or_else(|| anyhow::anyhow!("Vocabulary missing [UNK] token"))?;
70
71        Ok(Self {
72            vocab,
73            cls_id,
74            sep_id,
75            pad_id,
76            unk_id,
77            max_word_chars: 200,
78        })
79    }
80
81    /// Encode text into token IDs with [CLS] prefix and [SEP] suffix.
82    pub fn encode(&self, text: &str, max_len: usize) -> TokenizedInput {
83        let words = self.pre_tokenize(text);
84        let mut ids = vec![self.cls_id];
85
86        for word in &words {
87            if ids.len() >= max_len - 1 {
88                break;
89            }
90            let subword_ids = self.wordpiece_encode(word);
91            for id in subword_ids {
92                if ids.len() >= max_len - 1 {
93                    break;
94                }
95                ids.push(id);
96            }
97        }
98
99        ids.push(self.sep_id);
100
101        let len = ids.len();
102        TokenizedInput {
103            input_ids: ids,
104            attention_mask: vec![1; len],
105            token_type_ids: vec![0; len],
106        }
107    }
108
109    pub fn pad_id(&self) -> i32 {
110        self.pad_id
111    }
112
113    pub fn vocab_size(&self) -> usize {
114        self.vocab.len()
115    }
116
117    /// Split text into word-level tokens.
118    /// Handles: whitespace splitting, punctuation splitting,
119    /// camelCase splitting, underscore-separated identifiers, then lowercase.
120    fn pre_tokenize(&self, text: &str) -> Vec<String> {
121        let mut words = Vec::new();
122        let mut current = String::new();
123
124        for ch in text.chars() {
125            if ch.is_whitespace() {
126                if !current.is_empty() {
127                    words.extend(self.split_identifier(&current));
128                    current.clear();
129                }
130            } else if is_bert_punctuation(ch) {
131                if !current.is_empty() {
132                    words.extend(self.split_identifier(&current));
133                    current.clear();
134                }
135                words.push(ch.to_string());
136            } else {
137                current.push(ch);
138            }
139        }
140        if !current.is_empty() {
141            words.extend(self.split_identifier(&current));
142        }
143
144        words.iter().map(|w| w.to_lowercase()).collect()
145    }
146
147    /// Split programming identifiers (camelCase, snake_case) into subwords.
148    /// Called BEFORE lowercasing so case boundaries are still visible.
149    fn split_identifier(&self, word: &str) -> Vec<String> {
150        let lower = word.to_lowercase();
151        if self.vocab.contains_key(&lower) {
152            return vec![word.to_string()];
153        }
154
155        let mut parts = Vec::new();
156        let mut current = String::new();
157        let chars: Vec<char> = word.chars().collect();
158
159        for (i, &ch) in chars.iter().enumerate() {
160            if ch == '_' || ch == '-' {
161                if !current.is_empty() {
162                    parts.push(current.clone());
163                    current.clear();
164                }
165            } else if i > 0 && ch.is_ascii_uppercase() && chars[i - 1].is_ascii_lowercase() {
166                if !current.is_empty() {
167                    parts.push(current.clone());
168                    current.clear();
169                }
170                current.push(ch);
171            } else {
172                current.push(ch);
173            }
174        }
175        if !current.is_empty() {
176            parts.push(current);
177        }
178
179        if parts.is_empty() {
180            vec![word.to_string()]
181        } else {
182            parts
183        }
184    }
185
186    /// Apply WordPiece algorithm to a single word.
187    fn wordpiece_encode(&self, word: &str) -> Vec<i32> {
188        if word.chars().count() > self.max_word_chars {
189            return vec![self.unk_id];
190        }
191
192        let chars: Vec<char> = word.chars().collect();
193        let mut tokens = Vec::new();
194        let mut start = 0;
195
196        while start < chars.len() {
197            let mut end = chars.len();
198            let mut matched = false;
199
200            while start < end {
201                let substr: String = chars[start..end].iter().collect();
202                let candidate = if start > 0 {
203                    format!("##{substr}")
204                } else {
205                    substr
206                };
207
208                if let Some(&id) = self.vocab.get(&candidate) {
209                    tokens.push(id);
210                    matched = true;
211                    start = end;
212                    break;
213                }
214                end -= 1;
215            }
216
217            if !matched {
218                tokens.push(self.unk_id);
219                start += 1;
220            }
221        }
222
223        tokens
224    }
225}
226
227/// BERT-style punctuation detection.
228fn is_bert_punctuation(ch: char) -> bool {
229    if ch.is_ascii() {
230        matches!(
231            ch,
232            '!' | '"'
233                | '#'
234                | '$'
235                | '%'
236                | '&'
237                | '\''
238                | '('
239                | ')'
240                | '*'
241                | '+'
242                | ','
243                | '-'
244                | '.'
245                | '/'
246                | ':'
247                | ';'
248                | '<'
249                | '='
250                | '>'
251                | '?'
252                | '@'
253                | '['
254                | '\\'
255                | ']'
256                | '^'
257                | '_'
258                | '`'
259                | '{'
260                | '|'
261                | '}'
262                | '~'
263        )
264    } else {
265        ch.is_ascii_punctuation()
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use super::*;
272
273    fn test_vocab() -> WordPieceTokenizer {
274        let vocab = "[PAD]\n[UNK]\n[CLS]\n[SEP]\nhello\nworld\nfn\nvalidate\ntoken\n##s\n##ing\nauth\n##enticate\nuser\nhandle\nrequest\n##er\nprocess\ndata\n.\n,\n(\n)\n{";
275        WordPieceTokenizer::from_vocab_str(vocab).unwrap()
276    }
277
278    #[test]
279    fn encode_basic() {
280        let tok = test_vocab();
281        let input = tok.encode("hello world", 512);
282        assert_eq!(input.input_ids[0], tok.cls_id);
283        assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
284        assert!(input.input_ids.len() >= 4); // [CLS] hello world [SEP]
285    }
286
287    #[test]
288    fn encode_attention_mask() {
289        let tok = test_vocab();
290        let input = tok.encode("hello", 512);
291        assert!(input.attention_mask.iter().all(|&m| m == 1));
292        assert_eq!(input.attention_mask.len(), input.input_ids.len());
293    }
294
295    #[test]
296    fn encode_token_type_ids_are_zero() {
297        let tok = test_vocab();
298        let input = tok.encode("hello", 512);
299        assert!(input.token_type_ids.iter().all(|&t| t == 0));
300    }
301
302    #[test]
303    fn encode_respects_max_len() {
304        let tok = test_vocab();
305        let input = tok.encode("hello world hello world hello world", 6);
306        assert!(input.input_ids.len() <= 6);
307        assert_eq!(input.input_ids[0], tok.cls_id);
308        assert_eq!(*input.input_ids.last().unwrap(), tok.sep_id);
309    }
310
311    #[test]
312    fn wordpiece_subwords() {
313        let tok = test_vocab();
314        // "tokens" should split into "token" + "##s"
315        let ids = tok.wordpiece_encode("tokens");
316        assert_eq!(ids.len(), 2);
317        assert_eq!(ids[0], *tok.vocab.get("token").unwrap());
318        assert_eq!(ids[1], *tok.vocab.get("##s").unwrap());
319    }
320
321    #[test]
322    fn wordpiece_unknown() {
323        let tok = test_vocab();
324        let ids = tok.wordpiece_encode("xyzzyplugh");
325        assert!(ids.contains(&tok.unk_id));
326    }
327
328    #[test]
329    fn pre_tokenize_camel_case() {
330        let tok = test_vocab();
331        let words = tok.pre_tokenize("handleRequest");
332        assert!(words.contains(&"handle".to_string()));
333        assert!(words.contains(&"request".to_string()));
334    }
335
336    #[test]
337    fn pre_tokenize_snake_case() {
338        let tok = test_vocab();
339        let words = tok.pre_tokenize("validate_token");
340        assert!(words.contains(&"validate".to_string()));
341        assert!(words.contains(&"token".to_string()));
342    }
343
344    #[test]
345    fn pre_tokenize_punctuation() {
346        let tok = test_vocab();
347        let words = tok.pre_tokenize("fn(x)");
348        assert!(words.contains(&"fn".to_string()));
349        assert!(words.contains(&"(".to_string()));
350        assert!(words.contains(&")".to_string()));
351    }
352
353    #[test]
354    fn pad_to_extends() {
355        let tok = test_vocab();
356        let mut input = tok.encode("hello", 512);
357        let original_len = input.input_ids.len();
358        input.pad_to(10, tok.pad_id);
359        assert_eq!(input.input_ids.len(), 10);
360        assert_eq!(input.attention_mask[original_len], 0);
361    }
362
363    #[test]
364    fn vocab_size() {
365        let tok = test_vocab();
366        assert_eq!(tok.vocab_size(), 24);
367    }
368
369    #[test]
370    fn empty_input() {
371        let tok = test_vocab();
372        let input = tok.encode("", 512);
373        assert_eq!(input.input_ids.len(), 2); // [CLS] [SEP]
374    }
375
376    #[test]
377    fn bert_punctuation_detection() {
378        assert!(is_bert_punctuation('.'));
379        assert!(is_bert_punctuation('('));
380        assert!(is_bert_punctuation('{'));
381        assert!(!is_bert_punctuation('a'));
382        assert!(!is_bert_punctuation('0'));
383    }
384}