use crate::error::{CuttleError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerConfig {
pub vocab_size: usize,
pub unk_token: String,
pub bos_token: String,
pub eos_token: String,
pub pad_token: String,
}
impl Default for TokenizerConfig {
fn default() -> Self {
Self {
vocab_size: 32000,
unk_token: "<unk>".to_string(),
bos_token: "<s>".to_string(),
eos_token: "</s>".to_string(),
pad_token: "<pad>".to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct Tokenizer {
config: TokenizerConfig,
vocab: HashMap<String, usize>,
id_to_token: HashMap<usize, String>,
special_tokens: HashMap<String, usize>,
}
impl Tokenizer {
pub fn new(config: TokenizerConfig) -> Self {
let mut tokenizer = Self {
config: config.clone(),
vocab: HashMap::new(),
id_to_token: HashMap::new(),
special_tokens: HashMap::new(),
};
tokenizer.add_special_token(&config.pad_token, 0);
tokenizer.add_special_token(&config.unk_token, 1);
tokenizer.add_special_token(&config.bos_token, 2);
tokenizer.add_special_token(&config.eos_token, 3);
tokenizer
}
fn add_special_token(&mut self, token: &str, id: usize) {
self.vocab.insert(token.to_string(), id);
self.id_to_token.insert(id, token.to_string());
self.special_tokens.insert(token.to_string(), id);
}
pub fn build_vocab(&mut self, texts: &[String]) -> Result<()> {
let mut word_freq = HashMap::new();
for text in texts {
let words = self.simple_tokenize(text);
for word in words {
*word_freq.entry(word).or_insert(0) += 1;
}
}
let mut sorted_words: Vec<_> = word_freq.into_iter().collect();
sorted_words.sort_by(|a, b| b.1.cmp(&a.1));
let mut current_id = self.special_tokens.len();
for (word, _freq) in sorted_words {
if current_id >= self.config.vocab_size {
break;
}
if !self.vocab.contains_key(&word) {
self.vocab.insert(word.clone(), current_id);
self.id_to_token.insert(current_id, word);
current_id += 1;
}
}
Ok(())
}
fn simple_tokenize(&self, text: &str) -> Vec<String> {
text.to_lowercase()
.split_whitespace()
.map(|word| {
word.chars()
.filter(|c| c.is_alphanumeric())
.collect::<String>()
})
.filter(|word| !word.is_empty())
.collect()
}
pub fn encode(&self, text: &str) -> Result<Vec<usize>> {
let words = self.simple_tokenize(text);
let mut token_ids = Vec::new();
if let Some(&bos_id) = self.special_tokens.get(&self.config.bos_token) {
token_ids.push(bos_id);
}
for word in words {
let token_id = self.vocab.get(&word).copied().unwrap_or_else(|| {
self.special_tokens
.get(&self.config.unk_token)
.copied()
.unwrap_or(1) });
token_ids.push(token_id);
}
if let Some(&eos_id) = self.special_tokens.get(&self.config.eos_token) {
token_ids.push(eos_id);
}
Ok(token_ids)
}
pub fn decode(&self, token_ids: &[usize]) -> Result<String> {
let mut words = Vec::new();
for &token_id in token_ids {
if let Some(token) = self.id_to_token.get(&token_id) {
if !self.special_tokens.contains_key(token) {
words.push(token.clone());
}
} else {
return Err(CuttleError::TokenizerError(format!(
"Unknown token ID: {}",
token_id
)));
}
}
Ok(words.join(" "))
}
pub fn vocab_size(&self) -> usize {
self.vocab.len()
}
pub fn get_special_token_id(&self, token: &str) -> Option<usize> {
self.special_tokens.get(token).copied()
}
pub fn bos_token_id(&self) -> Option<usize> {
self.get_special_token_id(&self.config.bos_token)
}
pub fn eos_token_id(&self) -> Option<usize> {
self.get_special_token_id(&self.config.eos_token)
}
pub fn pad_token_id(&self) -> Option<usize> {
self.get_special_token_id(&self.config.pad_token)
}
pub fn unk_token_id(&self) -> Option<usize> {
self.get_special_token_id(&self.config.unk_token)
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
let tokenizer_data = TokenizerData {
config: self.config.clone(),
vocab: self.vocab.clone(),
special_tokens: self.special_tokens.clone(),
};
let serialized = serde_json::to_string_pretty(&tokenizer_data).map_err(|e| {
CuttleError::SerializationError(format!("Failed to serialize tokenizer: {}", e))
})?;
std::fs::write(path, serialized).map_err(|e| CuttleError::IoError(e))?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let content = std::fs::read_to_string(path).map_err(|e| CuttleError::IoError(e))?;
if let Ok(hf_tokenizer) = serde_json::from_str::<serde_json::Value>(&content) {
return Self::from_huggingface_json(&hf_tokenizer);
}
let tokenizer_data: TokenizerData = serde_json::from_str(&content).map_err(|e| {
CuttleError::SerializationError(format!("Failed to deserialize tokenizer: {}", e))
})?;
let mut id_to_token = HashMap::new();
for (token, id) in &tokenizer_data.vocab {
id_to_token.insert(*id, token.clone());
}
Ok(Self {
config: tokenizer_data.config,
vocab: tokenizer_data.vocab,
id_to_token,
special_tokens: tokenizer_data.special_tokens,
})
}
fn from_huggingface_json(json: &serde_json::Value) -> Result<Self> {
let mut vocab = HashMap::new();
let mut id_to_token = HashMap::new();
let mut special_tokens = HashMap::new();
if let Some(model) = json.get("model") {
if let Some(vocab_obj) = model.get("vocab") {
if let Some(vocab_map) = vocab_obj.as_object() {
for (token, id) in vocab_map {
if let Some(id_num) = id.as_u64() {
let id_usize = id_num as usize;
vocab.insert(token.clone(), id_usize);
id_to_token.insert(id_usize, token.clone());
}
}
}
}
}
let config = TokenizerConfig {
vocab_size: vocab.len(),
unk_token: "<|endoftext|>".to_string(),
bos_token: "<|endoftext|>".to_string(),
eos_token: "<|endoftext|>".to_string(),
pad_token: "<|endoftext|>".to_string(),
};
if let Some(unk_id) = vocab.get(&config.unk_token) {
special_tokens.insert(config.unk_token.clone(), *unk_id);
}
Ok(Self {
config,
vocab,
id_to_token,
special_tokens,
})
}
pub fn encode_batch(&self, texts: &[String]) -> Result<Vec<Vec<usize>>> {
texts.iter().map(|text| self.encode(text)).collect()
}
pub fn decode_batch(&self, token_ids_batch: &[Vec<usize>]) -> Result<Vec<String>> {
token_ids_batch
.iter()
.map(|token_ids| self.decode(token_ids))
.collect()
}
}
#[derive(Serialize, Deserialize)]
struct TokenizerData {
config: TokenizerConfig,
vocab: HashMap<String, usize>,
special_tokens: HashMap<String, usize>,
}
pub fn create_default_tokenizer() -> Tokenizer {
Tokenizer::new(TokenizerConfig::default())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenizer_encode_decode() {
let mut tokenizer = create_default_tokenizer();
let texts = vec!["hello world".to_string(), "this is a test".to_string()];
tokenizer.build_vocab(&texts).unwrap();
let encoded = tokenizer.encode("hello world").unwrap();
assert!(!encoded.is_empty());
let decoded = tokenizer.decode(&encoded).unwrap();
assert_eq!(decoded, "hello world");
}
}