use std::path::Path;
use tokenizers::models::bpe::{BpeTrainer, BPE};
use tokenizers::normalizers::Sequence as NormalizerSequence;
use tokenizers::normalizers::{Lowercase, StripAccents, NFD};
use tokenizers::pre_tokenizers::whitespace::Whitespace;
use tokenizers::{AddedToken, Tokenizer};
#[derive(Debug)]
pub enum SubwordError {
Tokenizer(String),
Io(std::io::Error),
Training(String),
}
impl std::fmt::Display for SubwordError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Tokenizer(msg) => write!(f, "Tokenizer error: {}", msg),
Self::Io(e) => write!(f, "I/O error: {}", e),
Self::Training(msg) => write!(f, "Training error: {}", msg),
}
}
}
impl std::error::Error for SubwordError {}
impl From<std::io::Error> for SubwordError {
fn from(e: std::io::Error) -> Self {
Self::Io(e)
}
}
pub type Result<T> = std::result::Result<T, SubwordError>;
pub mod special_tokens {
pub const UNK: &str = "<unk>";
pub const PAD: &str = "<pad>";
pub const BOS: &str = "<s>";
pub const EOS: &str = "</s>";
}
#[derive(Debug, Clone)]
pub struct BpeConfig {
pub vocab_size: usize,
pub min_frequency: u64,
pub lowercase: bool,
pub strip_accents: bool,
pub special_tokens: Vec<String>,
pub show_progress: bool,
}
impl Default for BpeConfig {
fn default() -> Self {
Self {
vocab_size: 30000,
min_frequency: 2,
lowercase: true,
strip_accents: false,
special_tokens: vec![
special_tokens::UNK.to_string(),
special_tokens::PAD.to_string(),
special_tokens::BOS.to_string(),
special_tokens::EOS.to_string(),
],
show_progress: true,
}
}
}
pub struct SubwordTokenizer {
tokenizer: Tokenizer,
}
impl SubwordTokenizer {
pub fn new(tokenizer: Tokenizer) -> Self {
Self { tokenizer }
}
pub fn train_bpe<I, S>(texts: I, config: BpeConfig) -> Result<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
use tokenizers::models::TrainerWrapper;
let trainer = BpeTrainer::builder()
.vocab_size(config.vocab_size)
.min_frequency(config.min_frequency)
.show_progress(config.show_progress)
.special_tokens(
config
.special_tokens
.iter()
.map(|s| AddedToken::from(s.clone(), true))
.collect(),
)
.build();
let mut tokenizer = Tokenizer::new(BPE::default());
let mut normalizers = Vec::new();
normalizers.push(tokenizers::NormalizerWrapper::NFD(NFD));
if config.strip_accents {
normalizers.push(tokenizers::NormalizerWrapper::StripAccents(StripAccents));
}
if config.lowercase {
normalizers.push(tokenizers::NormalizerWrapper::Lowercase(Lowercase));
}
if !normalizers.is_empty() {
tokenizer.with_normalizer(Some(NormalizerSequence::new(normalizers)));
}
tokenizer.with_pre_tokenizer(Some(Whitespace::default()));
let text_refs: Vec<String> = texts.into_iter().map(|s| s.as_ref().to_string()).collect();
if text_refs.is_empty() {
return Err(SubwordError::Training(
"No training texts provided".to_string(),
));
}
let mut wrapper_trainer = TrainerWrapper::BpeTrainer(trainer);
tokenizer
.train(&mut wrapper_trainer, text_refs.iter().map(|s| s.as_str()))
.map_err(|e| SubwordError::Tokenizer(e.to_string()))?;
Ok(Self { tokenizer })
}
pub fn train_bpe_default<I, S>(texts: I, vocab_size: usize) -> Result<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
let config = BpeConfig {
vocab_size,
..Default::default()
};
Self::train_bpe(texts, config)
}
pub fn load<P: AsRef<Path>>(path: P) -> Result<Self> {
let tokenizer = Tokenizer::from_file(path.as_ref())
.map_err(|e| SubwordError::Tokenizer(e.to_string()))?;
Ok(Self { tokenizer })
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> Result<()> {
self.tokenizer
.save(path.as_ref(), false)
.map_err(|e| SubwordError::Tokenizer(e.to_string()))?;
Ok(())
}
pub fn encode(&self, text: &str) -> Vec<String> {
match self.tokenizer.encode(text, false) {
Ok(encoding) => encoding.get_tokens().to_vec(),
Err(_) => vec![],
}
}
pub fn encode_ids(&self, text: &str) -> Vec<u32> {
match self.tokenizer.encode(text, false) {
Ok(encoding) => encoding.get_ids().to_vec(),
Err(_) => vec![],
}
}
pub fn encode_with_special(&self, text: &str) -> Vec<String> {
match self.tokenizer.encode(text, true) {
Ok(encoding) => encoding.get_tokens().to_vec(),
Err(_) => vec![],
}
}
pub fn decode(&self, ids: &[u32]) -> String {
match self.tokenizer.decode(ids, true) {
Ok(text) => text,
Err(_) => String::new(),
}
}
pub fn decode_tokens(&self, tokens: &[String]) -> String {
let ids: Vec<u32> = tokens
.iter()
.filter_map(|t| self.tokenizer.token_to_id(t))
.collect();
self.decode(&ids)
}
pub fn vocab_size(&self) -> usize {
self.tokenizer.get_vocab_size(true)
}
pub fn token_to_id(&self, token: &str) -> Option<u32> {
self.tokenizer.token_to_id(token)
}
pub fn id_to_token(&self, id: u32) -> Option<String> {
self.tokenizer.id_to_token(id)
}
pub fn contains(&self, token: &str) -> bool {
self.tokenizer.token_to_id(token).is_some()
}
pub fn inner(&self) -> &Tokenizer {
&self.tokenizer
}
pub fn inner_mut(&mut self) -> &mut Tokenizer {
&mut self.tokenizer
}
}
pub struct TokenizedSentences<'a, I> {
sentences: I,
tokenizer: &'a SubwordTokenizer,
}
impl<'a, I> TokenizedSentences<'a, I> {
pub fn new(sentences: I, tokenizer: &'a SubwordTokenizer) -> Self {
Self {
sentences,
tokenizer,
}
}
}
impl<'a, I, S> Iterator for TokenizedSentences<'a, I>
where
I: Iterator<Item = S>,
S: AsRef<str>,
{
type Item = Vec<String>;
fn next(&mut self) -> Option<Self::Item> {
self.sentences
.next()
.map(|s| self.tokenizer.encode(s.as_ref()))
}
}
pub trait TokenizeExt<'a>: Sized {
fn tokenize(self, tokenizer: &'a SubwordTokenizer) -> TokenizedSentences<'a, Self>;
}
impl<'a, I, S> TokenizeExt<'a> for I
where
I: Iterator<Item = S>,
S: AsRef<str>,
{
fn tokenize(self, tokenizer: &'a SubwordTokenizer) -> TokenizedSentences<'a, Self> {
TokenizedSentences::new(self, tokenizer)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_texts() -> Vec<&'static str> {
vec![
"The quick brown fox jumps over the lazy dog.",
"Hello world! How are you today?",
"Machine learning is transforming technology.",
"Natural language processing enables computers to understand text.",
"The fox jumped over the fence.",
"Hello there, how is everything going?",
"Deep learning models learn representations.",
"Text processing is fundamental to NLP.",
]
}
#[test]
fn test_train_bpe() {
let config = BpeConfig {
vocab_size: 100,
min_frequency: 1,
show_progress: false,
..Default::default()
};
let tokenizer = SubwordTokenizer::train_bpe(sample_texts(), config);
assert!(tokenizer.is_ok());
let tokenizer = tokenizer.unwrap();
assert!(tokenizer.vocab_size() > 0);
}
#[test]
fn test_encode_decode() {
let config = BpeConfig {
vocab_size: 100,
min_frequency: 1,
show_progress: false,
..Default::default()
};
let tokenizer = SubwordTokenizer::train_bpe(sample_texts(), config).unwrap();
let text = "hello world";
let tokens = tokenizer.encode(text);
assert!(!tokens.is_empty());
let ids = tokenizer.encode_ids(text);
assert!(!ids.is_empty());
assert_eq!(tokens.len(), ids.len());
let decoded = tokenizer.decode(&ids);
assert!(!decoded.is_empty());
}
#[test]
fn test_vocab_operations() {
let config = BpeConfig {
vocab_size: 100,
min_frequency: 1,
show_progress: false,
..Default::default()
};
let tokenizer = SubwordTokenizer::train_bpe(sample_texts(), config).unwrap();
assert!(tokenizer.contains(special_tokens::UNK));
assert!(tokenizer.contains(special_tokens::PAD));
if let Some(unk_id) = tokenizer.token_to_id(special_tokens::UNK) {
assert_eq!(
tokenizer.id_to_token(unk_id),
Some(special_tokens::UNK.to_string())
);
}
}
#[test]
fn test_tokenize_iterator() {
let config = BpeConfig {
vocab_size: 100,
min_frequency: 1,
show_progress: false,
..Default::default()
};
let tokenizer = SubwordTokenizer::train_bpe(sample_texts(), config).unwrap();
let sentences = vec!["hello world", "test sentence"];
let tokenized: Vec<Vec<String>> = sentences.iter().tokenize(&tokenizer).collect();
assert_eq!(tokenized.len(), 2);
assert!(!tokenized[0].is_empty());
assert!(!tokenized[1].is_empty());
}
#[test]
fn test_empty_training() {
let config = BpeConfig {
vocab_size: 100,
min_frequency: 1,
show_progress: false,
..Default::default()
};
let result = SubwordTokenizer::train_bpe(Vec::<&str>::new(), config);
assert!(result.is_err());
}
#[test]
fn test_save_load() {
let config = BpeConfig {
vocab_size: 100,
min_frequency: 1,
show_progress: false,
..Default::default()
};
let tokenizer = SubwordTokenizer::train_bpe(sample_texts(), config).unwrap();
let temp_dir = tempfile::tempdir().unwrap();
let path = temp_dir.path().join("tokenizer.json");
tokenizer.save(&path).unwrap();
assert!(path.exists());
let loaded = SubwordTokenizer::load(&path).unwrap();
assert_eq!(loaded.vocab_size(), tokenizer.vocab_size());
let text = "hello world";
assert_eq!(loaded.encode(text), tokenizer.encode(text));
}
}