use serde::{Deserialize, Serialize};
const DEFAULT_VOCAB_SIZE: usize = 32000;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SpecialTokens {
pub unk: String,
pub bos: String,
pub eos: String,
pub pad: String,
pub mask: String,
}
impl Default for SpecialTokens {
fn default() -> Self {
Self {
unk: "<unk>".to_string(),
bos: "<s>".to_string(),
eos: "</s>".to_string(),
pad: "<pad>".to_string(),
mask: "<mask>".to_string(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TokenizerType {
BPE,
WordPiece,
Char,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenizerConfig {
pub vocab_size: usize,
pub min_frequency: usize,
pub special_tokens: SpecialTokens,
pub lowercase: bool,
pub tokenizer_type: TokenizerType,
}
impl Default for TokenizerConfig {
fn default() -> Self {
Self {
vocab_size: DEFAULT_VOCAB_SIZE,
min_frequency: 2,
special_tokens: SpecialTokens::default(),
lowercase: false,
tokenizer_type: TokenizerType::BPE,
}
}
}
impl TokenizerConfig {
pub fn bpe() -> Self {
Self { tokenizer_type: TokenizerType::BPE, ..Default::default() }
}
pub fn wordpiece() -> Self {
Self { tokenizer_type: TokenizerType::WordPiece, ..Default::default() }
}
pub fn char() -> Self {
Self { tokenizer_type: TokenizerType::Char, vocab_size: 256, ..Default::default() }
}
pub fn with_vocab_size(mut self, size: usize) -> Self {
self.vocab_size = size;
self
}
pub fn with_min_frequency(mut self, freq: usize) -> Self {
self.min_frequency = freq;
self
}
pub fn with_lowercase(mut self, lowercase: bool) -> Self {
self.lowercase = lowercase;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tokenizer_config_default() {
let config = TokenizerConfig::default();
assert_eq!(config.vocab_size, 32000);
assert_eq!(config.tokenizer_type, TokenizerType::BPE);
}
#[test]
fn test_tokenizer_config_bpe() {
let config = TokenizerConfig::bpe().with_vocab_size(1000);
assert_eq!(config.vocab_size, 1000);
assert_eq!(config.tokenizer_type, TokenizerType::BPE);
}
#[test]
fn test_tokenizer_config_wordpiece() {
let config = TokenizerConfig::wordpiece();
assert_eq!(config.tokenizer_type, TokenizerType::WordPiece);
}
#[test]
fn test_tokenizer_config_char() {
let config = TokenizerConfig::char();
assert_eq!(config.tokenizer_type, TokenizerType::Char);
assert_eq!(config.vocab_size, 256);
}
#[test]
fn test_special_tokens_default() {
let special = SpecialTokens::default();
assert_eq!(special.unk, "<unk>");
assert_eq!(special.bos, "<s>");
assert_eq!(special.eos, "</s>");
}
}