1use crate::vocab::Vocab;
9use std::collections::HashMap;
10
11pub trait Tokenizer: Send + Sync {
17 fn tokenize(&self, text: &str) -> Vec<String>;
19
20 fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
22 let tokens = self.tokenize(text);
23 let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
24 vocab.encode(&token_refs)
25 }
26}
27
28#[derive(Debug, Clone, Default)]
34pub struct WhitespaceTokenizer {
35 lowercase: bool,
36}
37
38impl WhitespaceTokenizer {
39 #[must_use]
41 pub fn new() -> Self {
42 Self { lowercase: false }
43 }
44
45 #[must_use]
47 pub fn lowercase() -> Self {
48 Self { lowercase: true }
49 }
50}
51
52impl Tokenizer for WhitespaceTokenizer {
53 fn tokenize(&self, text: &str) -> Vec<String> {
54 text.split_whitespace()
55 .map(|s| {
56 if self.lowercase {
57 s.to_lowercase()
58 } else {
59 s.to_string()
60 }
61 })
62 .collect()
63 }
64}
65
66#[derive(Debug, Clone, Default)]
72pub struct CharTokenizer {
73 include_whitespace: bool,
74}
75
76impl CharTokenizer {
77 #[must_use]
79 pub fn new() -> Self {
80 Self {
81 include_whitespace: true,
82 }
83 }
84
85 #[must_use]
87 pub fn no_whitespace() -> Self {
88 Self {
89 include_whitespace: false,
90 }
91 }
92}
93
94impl Tokenizer for CharTokenizer {
95 fn tokenize(&self, text: &str) -> Vec<String> {
96 if self.include_whitespace {
97 text.chars().map(|c| c.to_string()).collect()
98 } else {
99 text.chars()
100 .filter(|c| !c.is_whitespace())
101 .map(|c| c.to_string())
102 .collect()
103 }
104 }
105}
106
107#[derive(Debug, Clone, Default)]
113pub struct WordPunctTokenizer {
114 lowercase: bool,
115}
116
117impl WordPunctTokenizer {
118 #[must_use]
120 pub fn new() -> Self {
121 Self { lowercase: false }
122 }
123
124 #[must_use]
126 pub fn lowercase() -> Self {
127 Self { lowercase: true }
128 }
129}
130
131impl Tokenizer for WordPunctTokenizer {
132 fn tokenize(&self, text: &str) -> Vec<String> {
133 let mut tokens = Vec::new();
134 let mut current = String::new();
135
136 for c in text.chars() {
137 if c.is_alphanumeric() {
138 current.push(c);
139 } else {
140 if !current.is_empty() {
141 tokens.push(if self.lowercase {
142 current.to_lowercase()
143 } else {
144 current.clone()
145 });
146 current.clear();
147 }
148 if !c.is_whitespace() {
149 tokens.push(c.to_string());
150 }
151 }
152 }
153
154 if !current.is_empty() {
155 tokens.push(if self.lowercase {
156 current.to_lowercase()
157 } else {
158 current
159 });
160 }
161
162 tokens
163 }
164}
165
166#[derive(Debug, Clone)]
172pub struct NGramTokenizer {
173 n: usize,
174 char_level: bool,
175}
176
177impl NGramTokenizer {
178 #[must_use]
180 pub fn word_ngrams(n: usize) -> Self {
181 Self {
182 n: n.max(1),
183 char_level: false,
184 }
185 }
186
187 #[must_use]
189 pub fn char_ngrams(n: usize) -> Self {
190 Self {
191 n: n.max(1),
192 char_level: true,
193 }
194 }
195}
196
197impl Tokenizer for NGramTokenizer {
198 fn tokenize(&self, text: &str) -> Vec<String> {
199 if self.char_level {
200 let chars: Vec<char> = text.chars().collect();
202 if chars.len() < self.n {
203 return vec![text.to_string()];
204 }
205
206 chars
207 .windows(self.n)
208 .map(|w| w.iter().collect::<String>())
209 .collect()
210 } else {
211 let words: Vec<&str> = text.split_whitespace().collect();
213 if words.len() < self.n {
214 return vec![text.to_string()];
215 }
216
217 words.windows(self.n).map(|w| w.join(" ")).collect()
218 }
219 }
220}
221
222#[derive(Debug, Clone)]
228pub struct BasicBPETokenizer {
229 merges: HashMap<(String, String), String>,
230 vocab: Vec<String>,
231}
232
233impl BasicBPETokenizer {
234 #[must_use]
236 pub fn new() -> Self {
237 Self {
238 merges: HashMap::new(),
239 vocab: Vec::new(),
240 }
241 }
242
243 pub fn train(&mut self, text: &str, num_merges: usize) {
245 let mut vocab: HashMap<String, usize> = HashMap::new();
247
248 for word in text.split_whitespace() {
250 let word_with_end = format!("{word}</w>");
251 let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
252 *vocab.entry(chars.join(" ")).or_insert(0) += 1;
253 }
254
255 for _ in 0..num_merges {
256 let mut pairs: HashMap<(String, String), usize> = HashMap::new();
258 for (word, count) in &vocab {
259 let symbols: Vec<&str> = word.split(' ').collect();
260 for i in 0..symbols.len().saturating_sub(1) {
261 let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
262 *pairs.entry(pair).or_insert(0) += count;
263 }
264 }
265
266 if pairs.is_empty() {
267 break;
268 }
269
270 let best_pair = pairs
272 .into_iter()
273 .max_by_key(|(_, count)| *count)
274 .map(|(pair, _)| pair);
275
276 if let Some((a, b)) = best_pair {
277 let merged = format!("{a}{b}");
278 self.merges.insert((a.clone(), b.clone()), merged.clone());
279
280 let pattern = format!("{a} {b}");
282 let mut new_vocab = HashMap::new();
283 for (word, count) in vocab {
284 let new_word = word.replace(&pattern, &merged);
285 *new_vocab.entry(new_word).or_insert(0) += count;
286 }
287 vocab = new_vocab;
288 }
289 }
290
291 let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
293 for word in vocab.keys() {
294 for symbol in word.split(' ') {
295 all_symbols.insert(symbol.to_string());
296 }
297 }
298 self.vocab = all_symbols.into_iter().collect();
299 self.vocab.sort();
300 }
301
302 #[must_use]
304 pub fn get_vocab(&self) -> &[String] {
305 &self.vocab
306 }
307
308 fn apply_bpe(&self, word: &str) -> Vec<String> {
310 let word_with_end = format!("{word}</w>");
311 let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
312
313 loop {
314 let mut best_pair: Option<(usize, &str)> = None;
315
316 for i in 0..symbols.len().saturating_sub(1) {
317 let pair = (symbols[i].clone(), symbols[i + 1].clone());
318 if let Some(merged) = self.merges.get(&pair) {
319 best_pair = Some((i, merged));
320 break;
321 }
322 }
323
324 match best_pair {
325 Some((i, merged)) => {
326 symbols[i] = merged.to_string();
327 symbols.remove(i + 1);
328 }
329 None => break,
330 }
331 }
332
333 symbols
334 }
335}
336
337impl Default for BasicBPETokenizer {
338 fn default() -> Self {
339 Self::new()
340 }
341}
342
343impl Tokenizer for BasicBPETokenizer {
344 fn tokenize(&self, text: &str) -> Vec<String> {
345 let mut tokens = Vec::new();
346
347 for word in text.split_whitespace() {
348 let word_tokens = self.apply_bpe(word);
349 tokens.extend(word_tokens);
350 }
351
352 tokens
353 }
354}
355
356#[derive(Debug, Clone)]
362pub struct UnigramTokenizer {
363 vocab: HashMap<String, f32>,
364 max_token_length: usize,
365}
366
367impl UnigramTokenizer {
368 #[must_use]
370 pub fn new(vocab: HashMap<String, f32>) -> Self {
371 let max_len = vocab
372 .keys()
373 .map(std::string::String::len)
374 .max()
375 .unwrap_or(1);
376 Self {
377 vocab,
378 max_token_length: max_len,
379 }
380 }
381
382 #[must_use]
384 pub fn from_tokens(tokens: &[&str]) -> Self {
385 let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
386 Self::new(vocab)
387 }
388
389 fn greedy_tokenize(&self, text: &str) -> Vec<String> {
391 let mut tokens = Vec::new();
392 let chars: Vec<char> = text.chars().collect();
393 let mut i = 0;
394
395 while i < chars.len() {
396 let mut best_len = 1;
397 let mut best_token = chars[i].to_string();
398
399 for len in 1..=self.max_token_length.min(chars.len() - i) {
401 let candidate: String = chars[i..i + len].iter().collect();
402 if self.vocab.contains_key(&candidate) {
403 best_len = len;
404 best_token = candidate;
405 }
406 }
407
408 tokens.push(best_token);
409 i += best_len;
410 }
411
412 tokens
413 }
414}
415
416impl Tokenizer for UnigramTokenizer {
417 fn tokenize(&self, text: &str) -> Vec<String> {
418 let mut all_tokens = Vec::new();
420
421 for word in text.split_whitespace() {
422 let word_tokens = self.greedy_tokenize(word);
423 all_tokens.extend(word_tokens);
424 }
425
426 all_tokens
427 }
428}
429
430#[cfg(test)]
435mod tests {
436 use super::*;
437
438 #[test]
439 fn test_whitespace_tokenizer() {
440 let tokenizer = WhitespaceTokenizer::new();
441 let tokens = tokenizer.tokenize("Hello World");
442
443 assert_eq!(tokens, vec!["Hello", "World"]);
444 }
445
446 #[test]
447 fn test_whitespace_tokenizer_lowercase() {
448 let tokenizer = WhitespaceTokenizer::lowercase();
449 let tokens = tokenizer.tokenize("Hello World");
450
451 assert_eq!(tokens, vec!["hello", "world"]);
452 }
453
454 #[test]
455 fn test_char_tokenizer() {
456 let tokenizer = CharTokenizer::new();
457 let tokens = tokenizer.tokenize("Hi!");
458
459 assert_eq!(tokens, vec!["H", "i", "!"]);
460 }
461
462 #[test]
463 fn test_char_tokenizer_no_whitespace() {
464 let tokenizer = CharTokenizer::no_whitespace();
465 let tokens = tokenizer.tokenize("Hi there!");
466
467 assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
468 }
469
470 #[test]
471 fn test_word_punct_tokenizer() {
472 let tokenizer = WordPunctTokenizer::new();
473 let tokens = tokenizer.tokenize("Hello, World!");
474
475 assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
476 }
477
478 #[test]
479 fn test_word_punct_tokenizer_lowercase() {
480 let tokenizer = WordPunctTokenizer::lowercase();
481 let tokens = tokenizer.tokenize("Hello, World!");
482
483 assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
484 }
485
486 #[test]
487 fn test_ngram_word_tokenizer() {
488 let tokenizer = NGramTokenizer::word_ngrams(2);
489 let tokens = tokenizer.tokenize("one two three");
490
491 assert_eq!(tokens, vec!["one two", "two three"]);
492 }
493
494 #[test]
495 fn test_ngram_char_tokenizer() {
496 let tokenizer = NGramTokenizer::char_ngrams(3);
497 let tokens = tokenizer.tokenize("hello");
498
499 assert_eq!(tokens, vec!["hel", "ell", "llo"]);
500 }
501
502 #[test]
503 fn test_bpe_tokenizer_basic() {
504 let mut tokenizer = BasicBPETokenizer::new();
505 tokenizer.train("low lower lowest", 10);
506
507 assert!(!tokenizer.get_vocab().is_empty());
509 }
510
511 #[test]
512 fn test_bpe_tokenizer_apply() {
513 let mut tokenizer = BasicBPETokenizer::new();
514 tokenizer.train("low low low lower lowest", 5);
515
516 let tokens = tokenizer.tokenize("low");
517 assert!(!tokens.is_empty());
518 }
519
520 #[test]
521 fn test_unigram_tokenizer() {
522 let tokenizer = UnigramTokenizer::from_tokens(&[
523 "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
524 ]);
525 let tokens = tokenizer.tokenize("hello world");
526
527 assert!(!tokens.is_empty());
529 }
530
531 #[test]
532 fn test_tokenizer_encode() {
533 let tokenizer = WhitespaceTokenizer::new();
534 let mut vocab = Vocab::new();
535 vocab.add_token("hello");
536 vocab.add_token("world");
537
538 let indices = tokenizer.encode("hello world", &vocab);
539 assert_eq!(indices, vec![0, 1]);
540 }
541
542 #[test]
543 fn test_tokenizer_with_multiple_spaces() {
544 let tokenizer = WhitespaceTokenizer::new();
545 let tokens = tokenizer.tokenize("hello world");
546
547 assert_eq!(tokens, vec!["hello", "world"]);
548 }
549
550 #[test]
551 fn test_empty_text() {
552 let tokenizer = WhitespaceTokenizer::new();
553 let tokens = tokenizer.tokenize("");
554
555 assert!(tokens.is_empty());
556 }
557}