1use crate::vocab::Vocab;
9use std::collections::HashMap;
10
11pub trait Tokenizer: Send + Sync {
17 fn tokenize(&self, text: &str) -> Vec<String>;
19
20 fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
22 let tokens = self.tokenize(text);
23 let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
24 vocab.encode(&token_refs)
25 }
26}
27
28#[derive(Debug, Clone, Default)]
34pub struct WhitespaceTokenizer {
35 lowercase: bool,
36}
37
38impl WhitespaceTokenizer {
39 #[must_use] pub fn new() -> Self {
41 Self { lowercase: false }
42 }
43
44 #[must_use] pub fn lowercase() -> Self {
46 Self { lowercase: true }
47 }
48}
49
50impl Tokenizer for WhitespaceTokenizer {
51 fn tokenize(&self, text: &str) -> Vec<String> {
52 text.split_whitespace()
53 .map(|s| {
54 if self.lowercase {
55 s.to_lowercase()
56 } else {
57 s.to_string()
58 }
59 })
60 .collect()
61 }
62}
63
64#[derive(Debug, Clone, Default)]
70pub struct CharTokenizer {
71 include_whitespace: bool,
72}
73
74impl CharTokenizer {
75 #[must_use] pub fn new() -> Self {
77 Self {
78 include_whitespace: true,
79 }
80 }
81
82 #[must_use] pub fn no_whitespace() -> Self {
84 Self {
85 include_whitespace: false,
86 }
87 }
88}
89
90impl Tokenizer for CharTokenizer {
91 fn tokenize(&self, text: &str) -> Vec<String> {
92 if self.include_whitespace {
93 text.chars().map(|c| c.to_string()).collect()
94 } else {
95 text.chars()
96 .filter(|c| !c.is_whitespace())
97 .map(|c| c.to_string())
98 .collect()
99 }
100 }
101}
102
103#[derive(Debug, Clone, Default)]
109pub struct WordPunctTokenizer {
110 lowercase: bool,
111}
112
113impl WordPunctTokenizer {
114 #[must_use] pub fn new() -> Self {
116 Self { lowercase: false }
117 }
118
119 #[must_use] pub fn lowercase() -> Self {
121 Self { lowercase: true }
122 }
123}
124
125impl Tokenizer for WordPunctTokenizer {
126 fn tokenize(&self, text: &str) -> Vec<String> {
127 let mut tokens = Vec::new();
128 let mut current = String::new();
129
130 for c in text.chars() {
131 if c.is_alphanumeric() {
132 current.push(c);
133 } else {
134 if !current.is_empty() {
135 tokens.push(if self.lowercase {
136 current.to_lowercase()
137 } else {
138 current.clone()
139 });
140 current.clear();
141 }
142 if !c.is_whitespace() {
143 tokens.push(c.to_string());
144 }
145 }
146 }
147
148 if !current.is_empty() {
149 tokens.push(if self.lowercase {
150 current.to_lowercase()
151 } else {
152 current
153 });
154 }
155
156 tokens
157 }
158}
159
160#[derive(Debug, Clone)]
166pub struct NGramTokenizer {
167 n: usize,
168 char_level: bool,
169}
170
171impl NGramTokenizer {
172 #[must_use] pub fn word_ngrams(n: usize) -> Self {
174 Self {
175 n: n.max(1),
176 char_level: false,
177 }
178 }
179
180 #[must_use] pub fn char_ngrams(n: usize) -> Self {
182 Self {
183 n: n.max(1),
184 char_level: true,
185 }
186 }
187}
188
189impl Tokenizer for NGramTokenizer {
190 fn tokenize(&self, text: &str) -> Vec<String> {
191 if self.char_level {
192 let chars: Vec<char> = text.chars().collect();
194 if chars.len() < self.n {
195 return vec![text.to_string()];
196 }
197
198 chars
199 .windows(self.n)
200 .map(|w| w.iter().collect::<String>())
201 .collect()
202 } else {
203 let words: Vec<&str> = text.split_whitespace().collect();
205 if words.len() < self.n {
206 return vec![text.to_string()];
207 }
208
209 words.windows(self.n).map(|w| w.join(" ")).collect()
210 }
211 }
212}
213
214#[derive(Debug, Clone)]
220pub struct BasicBPETokenizer {
221 merges: HashMap<(String, String), String>,
222 vocab: Vec<String>,
223}
224
225impl BasicBPETokenizer {
226 #[must_use] pub fn new() -> Self {
228 Self {
229 merges: HashMap::new(),
230 vocab: Vec::new(),
231 }
232 }
233
234 pub fn train(&mut self, text: &str, num_merges: usize) {
236 let mut vocab: HashMap<String, usize> = HashMap::new();
238
239 for word in text.split_whitespace() {
241 let word_with_end = format!("{word}</w>");
242 let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
243 *vocab.entry(chars.join(" ")).or_insert(0) += 1;
244 }
245
246 for _ in 0..num_merges {
247 let mut pairs: HashMap<(String, String), usize> = HashMap::new();
249 for (word, count) in &vocab {
250 let symbols: Vec<&str> = word.split(' ').collect();
251 for i in 0..symbols.len().saturating_sub(1) {
252 let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
253 *pairs.entry(pair).or_insert(0) += count;
254 }
255 }
256
257 if pairs.is_empty() {
258 break;
259 }
260
261 let best_pair = pairs
263 .into_iter()
264 .max_by_key(|(_, count)| *count)
265 .map(|(pair, _)| pair);
266
267 if let Some((a, b)) = best_pair {
268 let merged = format!("{a}{b}");
269 self.merges.insert((a.clone(), b.clone()), merged.clone());
270
271 let pattern = format!("{a} {b}");
273 let mut new_vocab = HashMap::new();
274 for (word, count) in vocab {
275 let new_word = word.replace(&pattern, &merged);
276 *new_vocab.entry(new_word).or_insert(0) += count;
277 }
278 vocab = new_vocab;
279 }
280 }
281
282 let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
284 for word in vocab.keys() {
285 for symbol in word.split(' ') {
286 all_symbols.insert(symbol.to_string());
287 }
288 }
289 self.vocab = all_symbols.into_iter().collect();
290 self.vocab.sort();
291 }
292
293 #[must_use] pub fn get_vocab(&self) -> &[String] {
295 &self.vocab
296 }
297
298 fn apply_bpe(&self, word: &str) -> Vec<String> {
300 let word_with_end = format!("{word}</w>");
301 let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
302
303 loop {
304 let mut best_pair: Option<(usize, &str)> = None;
305
306 for i in 0..symbols.len().saturating_sub(1) {
307 let pair = (symbols[i].clone(), symbols[i + 1].clone());
308 if let Some(merged) = self.merges.get(&pair) {
309 best_pair = Some((i, merged));
310 break;
311 }
312 }
313
314 match best_pair {
315 Some((i, merged)) => {
316 symbols[i] = merged.to_string();
317 symbols.remove(i + 1);
318 }
319 None => break,
320 }
321 }
322
323 symbols
324 }
325}
326
327impl Default for BasicBPETokenizer {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333impl Tokenizer for BasicBPETokenizer {
334 fn tokenize(&self, text: &str) -> Vec<String> {
335 let mut tokens = Vec::new();
336
337 for word in text.split_whitespace() {
338 let word_tokens = self.apply_bpe(word);
339 tokens.extend(word_tokens);
340 }
341
342 tokens
343 }
344}
345
346#[derive(Debug, Clone)]
352pub struct UnigramTokenizer {
353 vocab: HashMap<String, f32>,
354 max_token_length: usize,
355}
356
357impl UnigramTokenizer {
358 #[must_use] pub fn new(vocab: HashMap<String, f32>) -> Self {
360 let max_len = vocab.keys().map(std::string::String::len).max().unwrap_or(1);
361 Self {
362 vocab,
363 max_token_length: max_len,
364 }
365 }
366
367 #[must_use] pub fn from_tokens(tokens: &[&str]) -> Self {
369 let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
370 Self::new(vocab)
371 }
372
373 fn greedy_tokenize(&self, text: &str) -> Vec<String> {
375 let mut tokens = Vec::new();
376 let chars: Vec<char> = text.chars().collect();
377 let mut i = 0;
378
379 while i < chars.len() {
380 let mut best_len = 1;
381 let mut best_token = chars[i].to_string();
382
383 for len in 1..=self.max_token_length.min(chars.len() - i) {
385 let candidate: String = chars[i..i + len].iter().collect();
386 if self.vocab.contains_key(&candidate) {
387 best_len = len;
388 best_token = candidate;
389 }
390 }
391
392 tokens.push(best_token);
393 i += best_len;
394 }
395
396 tokens
397 }
398}
399
400impl Tokenizer for UnigramTokenizer {
401 fn tokenize(&self, text: &str) -> Vec<String> {
402 let mut all_tokens = Vec::new();
404
405 for word in text.split_whitespace() {
406 let word_tokens = self.greedy_tokenize(word);
407 all_tokens.extend(word_tokens);
408 }
409
410 all_tokens
411 }
412}
413
414#[cfg(test)]
419mod tests {
420 use super::*;
421
422 #[test]
423 fn test_whitespace_tokenizer() {
424 let tokenizer = WhitespaceTokenizer::new();
425 let tokens = tokenizer.tokenize("Hello World");
426
427 assert_eq!(tokens, vec!["Hello", "World"]);
428 }
429
430 #[test]
431 fn test_whitespace_tokenizer_lowercase() {
432 let tokenizer = WhitespaceTokenizer::lowercase();
433 let tokens = tokenizer.tokenize("Hello World");
434
435 assert_eq!(tokens, vec!["hello", "world"]);
436 }
437
438 #[test]
439 fn test_char_tokenizer() {
440 let tokenizer = CharTokenizer::new();
441 let tokens = tokenizer.tokenize("Hi!");
442
443 assert_eq!(tokens, vec!["H", "i", "!"]);
444 }
445
446 #[test]
447 fn test_char_tokenizer_no_whitespace() {
448 let tokenizer = CharTokenizer::no_whitespace();
449 let tokens = tokenizer.tokenize("Hi there!");
450
451 assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
452 }
453
454 #[test]
455 fn test_word_punct_tokenizer() {
456 let tokenizer = WordPunctTokenizer::new();
457 let tokens = tokenizer.tokenize("Hello, World!");
458
459 assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
460 }
461
462 #[test]
463 fn test_word_punct_tokenizer_lowercase() {
464 let tokenizer = WordPunctTokenizer::lowercase();
465 let tokens = tokenizer.tokenize("Hello, World!");
466
467 assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
468 }
469
470 #[test]
471 fn test_ngram_word_tokenizer() {
472 let tokenizer = NGramTokenizer::word_ngrams(2);
473 let tokens = tokenizer.tokenize("one two three");
474
475 assert_eq!(tokens, vec!["one two", "two three"]);
476 }
477
478 #[test]
479 fn test_ngram_char_tokenizer() {
480 let tokenizer = NGramTokenizer::char_ngrams(3);
481 let tokens = tokenizer.tokenize("hello");
482
483 assert_eq!(tokens, vec!["hel", "ell", "llo"]);
484 }
485
486 #[test]
487 fn test_bpe_tokenizer_basic() {
488 let mut tokenizer = BasicBPETokenizer::new();
489 tokenizer.train("low lower lowest", 10);
490
491 assert!(!tokenizer.get_vocab().is_empty());
493 }
494
495 #[test]
496 fn test_bpe_tokenizer_apply() {
497 let mut tokenizer = BasicBPETokenizer::new();
498 tokenizer.train("low low low lower lowest", 5);
499
500 let tokens = tokenizer.tokenize("low");
501 assert!(!tokens.is_empty());
502 }
503
504 #[test]
505 fn test_unigram_tokenizer() {
506 let tokenizer = UnigramTokenizer::from_tokens(&[
507 "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
508 ]);
509 let tokens = tokenizer.tokenize("hello world");
510
511 assert!(!tokens.is_empty());
513 }
514
515 #[test]
516 fn test_tokenizer_encode() {
517 let tokenizer = WhitespaceTokenizer::new();
518 let mut vocab = Vocab::new();
519 vocab.add_token("hello");
520 vocab.add_token("world");
521
522 let indices = tokenizer.encode("hello world", &vocab);
523 assert_eq!(indices, vec![0, 1]);
524 }
525
526 #[test]
527 fn test_tokenizer_with_multiple_spaces() {
528 let tokenizer = WhitespaceTokenizer::new();
529 let tokens = tokenizer.tokenize("hello world");
530
531 assert_eq!(tokens, vec!["hello", "world"]);
532 }
533
534 #[test]
535 fn test_empty_text() {
536 let tokenizer = WhitespaceTokenizer::new();
537 let tokens = tokenizer.tokenize("");
538
539 assert!(tokens.is_empty());
540 }
541}