entrenar/tokenizer/
config.rs1use serde::{Deserialize, Serialize};
4
5const DEFAULT_VOCAB_SIZE: usize = 32000;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SpecialTokens {
11 pub unk: String,
13 pub bos: String,
15 pub eos: String,
17 pub pad: String,
19 pub mask: String,
21}
22
23impl Default for SpecialTokens {
24 fn default() -> Self {
25 Self {
26 unk: "<unk>".to_string(),
27 bos: "<s>".to_string(),
28 eos: "</s>".to_string(),
29 pad: "<pad>".to_string(),
30 mask: "<mask>".to_string(),
31 }
32 }
33}
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
37pub enum TokenizerType {
38 BPE,
40 WordPiece,
42 Char,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
53pub enum Normalization {
54 #[default]
56 None,
57 NFC,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
63pub struct TokenizerConfig {
64 pub vocab_size: usize,
66 pub min_frequency: usize,
68 pub special_tokens: SpecialTokens,
70 pub lowercase: bool,
72 pub tokenizer_type: TokenizerType,
74 #[serde(default)]
76 pub normalization: Normalization,
77}
78
79impl Default for TokenizerConfig {
80 fn default() -> Self {
81 Self {
82 vocab_size: DEFAULT_VOCAB_SIZE,
83 min_frequency: 2,
84 special_tokens: SpecialTokens::default(),
85 lowercase: false,
86 tokenizer_type: TokenizerType::BPE,
87 normalization: Normalization::default(),
88 }
89 }
90}
91
92impl TokenizerConfig {
93 pub fn bpe() -> Self {
95 Self { tokenizer_type: TokenizerType::BPE, ..Default::default() }
96 }
97
98 pub fn wordpiece() -> Self {
100 Self { tokenizer_type: TokenizerType::WordPiece, ..Default::default() }
101 }
102
103 pub fn char() -> Self {
105 Self { tokenizer_type: TokenizerType::Char, vocab_size: 256, ..Default::default() }
106 }
107
108 pub fn with_vocab_size(mut self, size: usize) -> Self {
110 self.vocab_size = size;
111 self
112 }
113
114 pub fn with_min_frequency(mut self, freq: usize) -> Self {
116 self.min_frequency = freq;
117 self
118 }
119
120 pub fn with_lowercase(mut self, lowercase: bool) -> Self {
122 self.lowercase = lowercase;
123 self
124 }
125
126 pub fn with_normalization(mut self, normalization: Normalization) -> Self {
128 self.normalization = normalization;
129 self
130 }
131}
132
133#[cfg(test)]
134mod tests {
135 use super::*;
136
137 #[test]
138 fn test_tokenizer_config_default() {
139 let config = TokenizerConfig::default();
140 assert_eq!(config.vocab_size, 32000);
141 assert_eq!(config.tokenizer_type, TokenizerType::BPE);
142 }
143
144 #[test]
145 fn test_tokenizer_config_bpe() {
146 let config = TokenizerConfig::bpe().with_vocab_size(1000);
147 assert_eq!(config.vocab_size, 1000);
148 assert_eq!(config.tokenizer_type, TokenizerType::BPE);
149 }
150
151 #[test]
152 fn test_tokenizer_config_wordpiece() {
153 let config = TokenizerConfig::wordpiece();
154 assert_eq!(config.tokenizer_type, TokenizerType::WordPiece);
155 }
156
157 #[test]
158 fn test_tokenizer_config_char() {
159 let config = TokenizerConfig::char();
160 assert_eq!(config.tokenizer_type, TokenizerType::Char);
161 assert_eq!(config.vocab_size, 256);
162 }
163
164 #[test]
165 fn test_special_tokens_default() {
166 let special = SpecialTokens::default();
167 assert_eq!(special.unk, "<unk>");
168 assert_eq!(special.bos, "<s>");
169 assert_eq!(special.eos, "</s>");
170 }
171}