Skip to main content

cmdhub_cli/
tokenizer.rs

1use flate2::read::GzDecoder;
2use std::collections::HashMap;
3use std::io::Read;
4
5pub struct Tokenizer {
6    vocab: HashMap<String, u32>,
7}
8
9impl Default for Tokenizer {
10    fn default() -> Self {
11        Self::new()
12    }
13}
14
15impl Tokenizer {
16    pub fn new() -> Self {
17        let compressed = include_bytes!("tokenizer/assets/vocab.txt.gz");
18        let mut decoder = GzDecoder::new(&compressed[..]);
19        let mut s = String::new();
20        decoder
21            .read_to_string(&mut s)
22            .expect("Failed to decompress vocabulary asset");
23        let mut vocab = HashMap::new();
24        for (idx, line) in s.lines().enumerate() {
25            vocab.insert(line.to_string(), idx as u32);
26        }
27        Self { vocab }
28    }
29
30    /// Preprocesses text by lowercasing and splitting punctuation to match BERT tokenization rules.
31    fn preprocess_text(&self, text: &str) -> String {
32        let mut preprocessed = String::new();
33        for c in text.chars() {
34            if c.is_ascii_punctuation() {
35                preprocessed.push(' ');
36                preprocessed.push(c);
37                preprocessed.push(' ');
38            } else {
39                preprocessed.push(c);
40            }
41        }
42        preprocessed.to_lowercase()
43    }
44
45    /// Performs WordPiece tokenization on a single word.
46    fn tokenize_word(&self, word: &str) -> Vec<i64> {
47        if word.is_empty() {
48            return vec![];
49        }
50        if let Some(&id) = self.vocab.get(word) {
51            return vec![id as i64];
52        }
53
54        let char_indices: Vec<(usize, char)> = word.char_indices().collect();
55        let mut start = 0;
56        let mut sub_tokens = Vec::new();
57
58        while start < char_indices.len() {
59            let mut end = char_indices.len();
60            let mut cur_sub_token_id = None;
61            let mut cur_end = start;
62
63            while start < end {
64                let substr = &word[char_indices[start].0..if end < char_indices.len() {
65                    char_indices[end].0
66                } else {
67                    word.len()
68                }];
69                let lookup_str = if start > 0 {
70                    format!("##{}", substr)
71                } else {
72                    substr.to_string()
73                };
74
75                if let Some(&id) = self.vocab.get(&lookup_str) {
76                    cur_sub_token_id = Some(id as i64);
77                    cur_end = end;
78                    break;
79                }
80                end -= 1;
81            }
82
83            if let Some(id) = cur_sub_token_id {
84                sub_tokens.push(id);
85                start = cur_end;
86            } else {
87                // If any sub-word cannot be resolved, return [UNK] (ID 100) for the entire word
88                return vec![100];
89            }
90        }
91        sub_tokens
92    }
93
94    pub fn tokenize_query(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
95        let prefix = "Represent this sentence for searching relevant passages: ";
96        let query = format!("{}{}", prefix, text);
97
98        let preprocessed = self.preprocess_text(&query);
99        let mut token_ids = vec![101]; // [CLS]
100
101        for word in preprocessed.split_whitespace() {
102            token_ids.extend(self.tokenize_word(word));
103        }
104
105        token_ids.push(102); // [SEP]
106
107        let len = token_ids.len();
108        let mut attention_mask = vec![1; len];
109
110        if token_ids.len() > 512 {
111            token_ids.truncate(512);
112            attention_mask.truncate(512);
113        } else {
114            while token_ids.len() < 512 {
115                token_ids.push(0); // [PAD]
116                attention_mask.push(0);
117            }
118        }
119
120        (token_ids, attention_mask)
121    }
122
123    pub fn tokenize_passage(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
124        let preprocessed = self.preprocess_text(text);
125        let mut token_ids = vec![101]; // [CLS]
126
127        for word in preprocessed.split_whitespace() {
128            token_ids.extend(self.tokenize_word(word));
129        }
130
131        token_ids.push(102); // [SEP]
132
133        let len = token_ids.len();
134        let mut attention_mask = vec![1; len];
135
136        if token_ids.len() > 512 {
137            token_ids.truncate(512);
138            attention_mask.truncate(512);
139        } else {
140            while token_ids.len() < 512 {
141                token_ids.push(0); // [PAD]
142                attention_mask.push(0);
143            }
144        }
145
146        (token_ids, attention_mask)
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_tokenizer_prefix_and_padding() {
156        let tokenizer = Tokenizer::new();
157        let (ids, mask) = tokenizer.tokenize_query("test query");
158
159        assert_eq!(ids.len(), 512);
160        assert_eq!(mask.len(), 512);
161
162        // CLS position
163        assert_eq!(ids[0], 101);
164
165        // Attention mask matches valid tokens
166        let mut valid_count = 0;
167        for &m in &mask {
168            if m == 1 {
169                valid_count += 1;
170            }
171        }
172
173        assert!(valid_count > 2); // CLS, SEP, plus query and prefix tokens
174        assert_eq!(ids[valid_count - 1], 102); // SEP position
175
176        // Rest of the array is padded with 0
177        for i in valid_count..512 {
178            assert_eq!(ids[i], 0);
179            assert_eq!(mask[i], 0);
180        }
181
182        // Verify the prefix is tokenized as part of the query.
183        // The first few tokens after CLS should correspond to "represent", "this", "sentence"
184        // Let's verify that IDs match.
185        assert_eq!(ids[1], *tokenizer.vocab.get("represent").unwrap() as i64);
186        assert_eq!(ids[2], *tokenizer.vocab.get("this").unwrap() as i64);
187        assert_eq!(ids[3], *tokenizer.vocab.get("sentence").unwrap() as i64);
188    }
189}