use super::Tokenizer;
use crate::error::{DatasetError, DatasetResult};
pub struct TiktokenTokenizer {
bpe: tiktoken_rs::CoreBPE,
vocab_size: usize,
}
impl TiktokenTokenizer {
pub fn for_model(model: &str) -> DatasetResult<Self> {
let bpe = tiktoken_rs::get_bpe_from_model(model).map_err(|e| DatasetError::Tokenizer {
message: format!("Failed to load tiktoken for model '{}': {}", model, e),
})?;
let vocab_size = match model {
m if m.starts_with("gpt-4") => 100277,
m if m.starts_with("gpt-3.5") => 100277,
_ => 100277, };
Ok(Self { bpe, vocab_size })
}
pub fn cl100k_base() -> DatasetResult<Self> {
let bpe = tiktoken_rs::cl100k_base().map_err(|e| DatasetError::Tokenizer {
message: format!("Failed to load cl100k_base: {}", e),
})?;
Ok(Self {
bpe,
vocab_size: 100277,
})
}
pub fn o200k_base() -> DatasetResult<Self> {
let bpe = tiktoken_rs::o200k_base().map_err(|e| DatasetError::Tokenizer {
message: format!("Failed to load o200k_base: {}", e),
})?;
Ok(Self {
bpe,
vocab_size: 200019,
})
}
}
impl Tokenizer for TiktokenTokenizer {
fn encode(&self, text: &str) -> DatasetResult<Vec<u32>> {
let tokens = self.bpe.encode_with_special_tokens(text);
Ok(tokens.into_iter().collect())
}
fn decode(&self, ids: &[u32]) -> DatasetResult<String> {
self.bpe
.decode(ids.to_vec())
.map_err(|e| DatasetError::Tokenizer {
message: format!("Decoding error: {}", e),
})
}
fn vocab_size(&self) -> usize {
self.vocab_size
}
fn special_tokens(&self) -> Vec<(String, u32)> {
let known = vec![
("<|endoftext|>", 100257u32),
("<|fim_prefix|>", 100258),
("<|fim_middle|>", 100259),
("<|fim_suffix|>", 100260),
];
known
.into_iter()
.filter(|&(_, id)| (id as usize) < self.vocab_size)
.map(|(name, id)| (name.to_string(), id))
.collect()
}
}