use flate2::read::GzDecoder;
use std::collections::HashMap;
use std::io::Read;
pub struct Tokenizer {
vocab: HashMap<String, u32>,
}
impl Default for Tokenizer {
fn default() -> Self {
Self::new()
}
}
impl Tokenizer {
pub fn new() -> Self {
let compressed = include_bytes!("tokenizer/assets/vocab.txt.gz");
let mut decoder = GzDecoder::new(&compressed[..]);
let mut s = String::new();
decoder
.read_to_string(&mut s)
.expect("Failed to decompress vocabulary asset");
let mut vocab = HashMap::new();
for (idx, line) in s.lines().enumerate() {
vocab.insert(line.to_string(), idx as u32);
}
Self { vocab }
}
fn preprocess_text(&self, text: &str) -> String {
let mut preprocessed = String::new();
for c in text.chars() {
if c.is_ascii_punctuation() {
preprocessed.push(' ');
preprocessed.push(c);
preprocessed.push(' ');
} else {
preprocessed.push(c);
}
}
preprocessed.to_lowercase()
}
fn tokenize_word(&self, word: &str) -> Vec<i64> {
if word.is_empty() {
return vec![];
}
if let Some(&id) = self.vocab.get(word) {
return vec![id as i64];
}
let char_indices: Vec<(usize, char)> = word.char_indices().collect();
let mut start = 0;
let mut sub_tokens = Vec::new();
while start < char_indices.len() {
let mut end = char_indices.len();
let mut cur_sub_token_id = None;
let mut cur_end = start;
while start < end {
let substr = &word[char_indices[start].0..if end < char_indices.len() {
char_indices[end].0
} else {
word.len()
}];
let lookup_str = if start > 0 {
format!("##{}", substr)
} else {
substr.to_string()
};
if let Some(&id) = self.vocab.get(&lookup_str) {
cur_sub_token_id = Some(id as i64);
cur_end = end;
break;
}
end -= 1;
}
if let Some(id) = cur_sub_token_id {
sub_tokens.push(id);
start = cur_end;
} else {
return vec![100];
}
}
sub_tokens
}
pub fn tokenize_query(&self, text: &str) -> (Vec<i64>, Vec<i64>) {
let prefix = "Represent this sentence for searching relevant passages: ";
let query = format!("{}{}", prefix, text);
let preprocessed = self.preprocess_text(&query);
let mut token_ids = vec![101];
for word in preprocessed.split_whitespace() {
token_ids.extend(self.tokenize_word(word));
}
token_ids.push(102);
let len = token_ids.len();
let mut attention_mask = vec![1; len];
if token_ids.len() > 512 {
token_ids.truncate(512);
attention_mask.truncate(512);
} else {
while token_ids.len() < 512 {
token_ids.push(0); attention_mask.push(0);
}
}
(token_ids, attention_mask)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenizer_prefix_and_padding() {
let tokenizer = Tokenizer::new();
let (ids, mask) = tokenizer.tokenize_query("test query");
assert_eq!(ids.len(), 512);
assert_eq!(mask.len(), 512);
assert_eq!(ids[0], 101);
let mut valid_count = 0;
for &m in &mask {
if m == 1 {
valid_count += 1;
}
}
assert!(valid_count > 2); assert_eq!(ids[valid_count - 1], 102);
for i in valid_count..512 {
assert_eq!(ids[i], 0);
assert_eq!(mask[i], 0);
}
assert_eq!(ids[1], *tokenizer.vocab.get("represent").unwrap() as i64);
assert_eq!(ids[2], *tokenizer.vocab.get("this").unwrap() as i64);
assert_eq!(ids[3], *tokenizer.vocab.get("sentence").unwrap() as i64);
}
}