use crate::tokenizers::benchmark::dev::traits::TokenizerTrainingMetrics;
use crate::tokenizers::special_tokens::UNKNOWN;
use crate::tokenizers::traits::{Tokenizer, TokenizerError, TokenIdType};
use crate::utils::corpus_tracker::CorpusTracker;
use std::collections::HashMap;
use std::io::BufRead;
use regex::Regex;
use serde_json;
use std::time::Instant;
#[derive(Debug, serde_derive::Serialize, serde_derive::Deserialize)]
#[serde(bound(serialize = "T: serde::Serialize", deserialize = "T: serde::de::DeserializeOwned"))]
pub struct NaiveTokenizer<T: TokenIdType = u64> {
pub string_to_id: HashMap<String, T>,
pub id_to_string: HashMap<T, String>,
pub corpus: CorpusTracker,
pub unknown_token_id: Option<T>,
pub tokenizer_path: Option<String>,
}
impl<T: TokenIdType> NaiveTokenizer<T> {
pub fn new(tokenizer_path: Option<String>) -> Self {
Self {
string_to_id: HashMap::new(),
id_to_string: HashMap::new(),
corpus: CorpusTracker::new(),
unknown_token_id: None,
tokenizer_path,
}
}
fn init(&mut self) {
let id = T::from_usize(0);
self.string_to_id.insert(" ".to_string(), id);
self.id_to_string.insert(id, " ".to_string());
}
}
impl<T> Tokenizer for NaiveTokenizer<T>
where
T: TokenIdType + serde::Serialize + serde::de::DeserializeOwned + std::fmt::Debug + Eq + std::hash::Hash + Clone,
{
type TokenId = T;
fn name(&self) -> &'static str {
"NaiveTokenizer"
}
fn corpus(&self) -> &CorpusTracker {
&self.corpus
}
fn corpus_mut(&mut self) -> &mut CorpusTracker {
&mut self.corpus
}
fn vocab_size(&self) -> usize {
self.string_to_id.len()
}
fn train(&mut self, data_path: &str, corpus_name: &str) -> Result<TokenizerTrainingMetrics, TokenizerError> {
let start_time = Instant::now();
self.init();
if !std::path::Path::new(data_path).is_file() {
return Err(format!("'{}' is not a valid file", data_path));
}
let file = std::fs::File::open(data_path)
.map_err(|e| format!("Failed to open '{}': {}", data_path, e))?;
let reader = std::io::BufReader::new(file);
let size = std::fs::metadata(data_path)
.map_err(|e| format!("Failed to get metadata for '{}': {}", data_path, e))?
.len();
self.corpus.add(corpus_name.to_string(), size);
for line in reader.lines() {
let line = line.map_err(|e| format!("Failed to read line: {}", e))?;
let tokens: Vec<&str> = line.split_whitespace().collect();
for token in tokens {
if self.string_to_id.contains_key(token) {
continue;
}
let id = T::from_usize(self.string_to_id.len());
self.string_to_id.insert(token.to_string(), id);
self.id_to_string.insert(id, token.to_string());
}
}
let unkown_token_id = T::from_usize(self.string_to_id.len());
self.id_to_string.insert(unkown_token_id, UNKNOWN.to_owned());
self.string_to_id.insert(UNKNOWN.to_owned(), unkown_token_id);
self.unknown_token_id = Some(unkown_token_id);
Ok(TokenizerTrainingMetrics {
training_time_sec: start_time.elapsed().as_secs(),
})
}
fn encode(&self, text: &str) -> Vec<Self::TokenId> {
assert!(self.unknown_token_id.is_some(), "Tokenizer must be trained before encoding");
let mut res: Vec<Self::TokenId> = vec![];
let re = Regex::new(r"\S+|\s+").unwrap();
let text_tokens: Vec<&str> = re.find_iter(text).map(|mat| mat.as_str()).collect();
for token in text_tokens {
if let Some(&id) = self.string_to_id.get(token.to_lowercase().as_str()) {
res.push(id);
} else {
res.push(self.unknown_token_id.unwrap());
}
}
res
}
fn decode(&self, tokens: &[Self::TokenId]) -> String {
let mut res = String::new();
for &token_id in tokens {
if let Some(token) = self.id_to_string.get(&token_id) {
res.push_str(token);
} else {
res.push_str(UNKNOWN);
}
}
res
}
fn save(&self, path: &str) -> Result<(), TokenizerError> {
let json = serde_json::to_string(self)
.map_err(|e| format!("Failed to serialize tokenizer: {}", e))?;
std::fs::write(path, json)
.map_err(|e| format!("Failed to write tokenizer to '{}': {}", path, e))?;
Ok(())
}
fn load(&mut self, path: &str) -> Result<(), TokenizerError> {
let json = std::fs::read_to_string(path)
.map_err(|e| format!("Failed to read tokenizer from '{}': {}", path, e))?;
*self = serde_json::from_str(&json)
.map_err(|e| format!("Failed to deserialize tokenizer: {}", e))?;
Ok(())
}
fn vocab_tokens(&self) -> Vec<String> {
self.string_to_id.keys().cloned().collect()
}
fn size(&self) -> usize {
if let Some(ref path) = self.tokenizer_path {
std::fs::metadata(path)
.map(|meta| meta.len() as usize)
.unwrap_or(0)
} else {
0
}
}
}