1use flate2::read::GzDecoder;
2use std::collections::HashMap;
3use std::io::Read;
4
5pub struct Tokenizer {
6 vocab: HashMap<String, u32>,
7}
8
9impl Default for Tokenizer {
10 fn default() -> Self {
11 Self::new()
12 }
13}
14
15impl Tokenizer {
16 pub fn new() -> Self {
17 let compressed = include_bytes!("tokenizer/assets/vocab.txt.gz");
18 let mut decoder = GzDecoder::new(&compressed[..]);
19 let mut s = String::new();
20 decoder
21 .read_to_string(&mut s)
22 .expect("Failed to decompress vocabulary asset");
23 let mut vocab = HashMap::new();
24 for (idx, line) in s.lines().enumerate() {
25 vocab.insert(line.to_string(), idx as u32);
26 }
27 Self { vocab }
28 }
29
30 fn preprocess_text(&self, text: &str) -> String {
32 let mut preprocessed = String::new();
33 for c in text.chars() {
34 if c.is_ascii_punctuation() {
35 preprocessed.push(' ');
36 preprocessed.push(c);
37 preprocessed.push(' ');
38 } else {
39 preprocessed.push(c);
40 }
41 }
42 preprocessed.to_lowercase()
43 }
44
45 fn tokenize_word(&self, word: &str) -> Vec<i64> {
47 if word.is_empty() {
48 return vec![];
49 }
50 if let Some(&id) = self.vocab.get(word) {
51 return vec![id as i64];
52 }
53
54 let char_indices: Vec<(usize, char)> = word.char_indices().collect();
55 let mut start = 0;
56 let mut sub_tokens = Vec::new();
57
58 while start < char_indices.len() {
59 let mut end = char_indices.len();
60 let mut cur_sub_token_id = None;
61 let mut cur_end = start;
62
63 while start < end {
64 let substr = &word[char_indices[start].0..if end < char_indices.len() {
65 char_indices[end].0
66 } else {
67 word.len()
68 }];
69 let lookup_str = if start > 0 {
70 format!("##{}", substr)
71 } else {
72 substr.to_string()
73 };
74
75 if let Some(&id) = self.vocab.get(&lookup_str) {
76 cur_sub_token_id = Some(id as i64);
77 cur_end = end;
78 break;
79 }
80 end -= 1;
81 }
82
83 if let Some(id) = cur_sub_token_id {
84 sub_tokens.push(id);
85 start = cur_end;
86 } else {
87 return vec![100];
89 }
90 }
91 sub_tokens
92 }
93
94 pub fn tokenize_query(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
95 let prefix = "Represent this sentence for searching relevant passages: ";
96 let query = format!("{}{}", prefix, text);
97
98 let preprocessed = self.preprocess_text(&query);
99 let mut token_ids = vec![101]; for word in preprocessed.split_whitespace() {
102 token_ids.extend(self.tokenize_word(word));
103 }
104
105 token_ids.push(102); let len = token_ids.len();
108 let mut attention_mask = vec![1; len];
109
110 if token_ids.len() > 512 {
111 token_ids.truncate(512);
112 attention_mask.truncate(512);
113 } else {
114 while token_ids.len() < 512 {
115 token_ids.push(0); attention_mask.push(0);
117 }
118 }
119
120 (token_ids, attention_mask)
121 }
122
123 pub fn tokenize_passage(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
124 let preprocessed = self.preprocess_text(text);
125 let mut token_ids = vec![101]; for word in preprocessed.split_whitespace() {
128 token_ids.extend(self.tokenize_word(word));
129 }
130
131 token_ids.push(102); let len = token_ids.len();
134 let mut attention_mask = vec![1; len];
135
136 if token_ids.len() > 512 {
137 token_ids.truncate(512);
138 attention_mask.truncate(512);
139 } else {
140 while token_ids.len() < 512 {
141 token_ids.push(0); attention_mask.push(0);
143 }
144 }
145
146 (token_ids, attention_mask)
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_tokenizer_prefix_and_padding() {
156 let tokenizer = Tokenizer::new();
157 let (ids, mask) = tokenizer.tokenize_query("test query");
158
159 assert_eq!(ids.len(), 512);
160 assert_eq!(mask.len(), 512);
161
162 assert_eq!(ids[0], 101);
164
165 let mut valid_count = 0;
167 for &m in &mask {
168 if m == 1 {
169 valid_count += 1;
170 }
171 }
172
173 assert!(valid_count > 2); assert_eq!(ids[valid_count - 1], 102); for i in valid_count..512 {
178 assert_eq!(ids[i], 0);
179 assert_eq!(mask[i], 0);
180 }
181
182 assert_eq!(ids[1], *tokenizer.vocab.get("represent").unwrap() as i64);
186 assert_eq!(ids[2], *tokenizer.vocab.get("this").unwrap() as i64);
187 assert_eq!(ids[3], *tokenizer.vocab.get("sentence").unwrap() as i64);
188 }
189}