use std::{
collections::HashMap,
io::{Read, Seek},
};
use byteorder::{LittleEndian, ReadBytesExt};
use crate::{errors::PllmError, gguf::GgufFile};
#[derive(Debug)]
pub struct Tokenizer {
#[allow(dead_code)]
pub(crate) max_token_length: u32,
vocab: Vec<String>,
scores: Vec<f32>,
token_id_map: HashMap<String, usize>,
pub(crate) eos_token: u32,
pub(crate) bos_token: u32,
}
impl Tokenizer {
pub fn get_token(&self, index: usize) -> Option<String> {
self.vocab.get(index).cloned()
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn from_gguf<R: Read + Seek>(gf: &GgufFile<R>) -> Result<Self, PllmError> {
let md = gf.metadata();
let vocab = md.get_string_array_result("tokenizer.ggml.tokens")?;
let token_id_map = vocab
.iter()
.enumerate()
.map(|(i, v)| (v.clone(), i))
.collect::<HashMap<_, _>>();
let scores = md.get_f32_array_result("tokenizer.ggml.scores")?;
let bos_token = md.get_u32_result("tokenizer.ggml.bos_token_id")?;
let eos_token = md.get_u32_result("tokenizer.ggml.eos_token_id")?;
println!("bos: {}, eos: {}", bos_token, eos_token);
Ok(Self {
max_token_length: 0,
vocab,
scores,
token_id_map,
eos_token,
bos_token,
})
}
pub fn from_reader(vocab_size: usize, mut reader: impl Read) -> Result<Self, PllmError> {
let max_token_length = reader.read_u32::<LittleEndian>()?;
let mut vocab = Vec::with_capacity(vocab_size);
let mut scores = Vec::with_capacity(vocab_size);
let mut token_id_map = HashMap::with_capacity(vocab_size);
for i in 0..vocab_size {
let score = reader.read_f32::<LittleEndian>()?;
let len = reader.read_i32::<LittleEndian>()?;
let mut buf = vec![0; len as usize];
reader.read_exact(&mut buf)?;
let value = String::from_utf8_lossy(&buf).to_string();
if len as u32 > max_token_length {
println!(
"Warning: unexpected token length greater than {}, i={}",
max_token_length, i,
)
}
scores.push(score);
vocab.push(value.clone());
token_id_map.insert(value, i);
}
Ok(Self {
max_token_length,
vocab,
scores,
token_id_map,
eos_token: 0,
bos_token: 1,
})
}
pub fn bpe_encode(&self, text: String) -> Result<Vec<u32>, PllmError> {
let mut tokens = Vec::with_capacity(text.len() + 2);
if self.bos_token != 0 {
tokens.push(self.bos_token);
}
let text = if let Some(&dummy_prefix) = self.token_id_map.get("▁") {
tokens.push(dummy_prefix as u32);
text.replace(' ', "▁")
} else {
text
};
for c in text.chars() {
let id = self
.token_id_map
.get(&c.to_string())
.ok_or(PllmError::Other(format!("{} not found in vocab", c)))?
.clone();
tokens.push(id as u32);
}
loop {
let mut best_score = -1e10 as f32;
let mut best_id = 0;
let mut best_idx = None;
if tokens.len() > 0 {
for i in 0..tokens.len() - 1 {
let merge_token = format!(
"{}{}",
self.vocab[tokens[i] as usize],
self.vocab[tokens[i + 1] as usize]
);
if let Some(&id) = self.token_id_map.get(&merge_token) {
if self.scores[id] > best_score {
best_score = self.scores[id];
best_id = id as u32;
best_idx = Some(i);
}
}
}
}
match best_idx {
Some(idx) => {
tokens[idx] = best_id;
tokens.remove(idx + 1);
}
None => break,
}
}
Ok(tokens)
}
}
#[cfg(test)]
mod tests {
use std::{fs::File, io::BufReader};
use crate::{gguf::GgufFile, Tokenizer};
#[test]
fn test_tokenizer_from_gguf() {
let f = File::open("testdata/gemma2b").unwrap();
let mut reader = BufReader::new(f);
let gguf_file = GgufFile::from_reader(&mut reader).unwrap();
let tokenizer = Tokenizer::from_gguf(&gguf_file).unwrap();
println!("{}", tokenizer.vocab_size());
}
}