1use crate::vocab::Vocab;
19use std::collections::HashMap;
20
21pub trait Tokenizer: Send + Sync {
27 fn tokenize(&self, text: &str) -> Vec<String>;
29
30 fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
32 let tokens = self.tokenize(text);
33 let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
34 vocab.encode(&token_refs)
35 }
36}
37
38#[derive(Debug, Clone, Default)]
44pub struct WhitespaceTokenizer {
45 lowercase: bool,
46}
47
48impl WhitespaceTokenizer {
49 #[must_use]
51 pub fn new() -> Self {
52 Self { lowercase: false }
53 }
54
55 #[must_use]
57 pub fn lowercase() -> Self {
58 Self { lowercase: true }
59 }
60}
61
62impl Tokenizer for WhitespaceTokenizer {
63 fn tokenize(&self, text: &str) -> Vec<String> {
64 text.split_whitespace()
65 .map(|s| {
66 if self.lowercase {
67 s.to_lowercase()
68 } else {
69 s.to_string()
70 }
71 })
72 .collect()
73 }
74}
75
76#[derive(Debug, Clone, Default)]
82pub struct CharTokenizer {
83 include_whitespace: bool,
84}
85
86impl CharTokenizer {
87 #[must_use]
89 pub fn new() -> Self {
90 Self {
91 include_whitespace: true,
92 }
93 }
94
95 #[must_use]
97 pub fn no_whitespace() -> Self {
98 Self {
99 include_whitespace: false,
100 }
101 }
102}
103
104impl Tokenizer for CharTokenizer {
105 fn tokenize(&self, text: &str) -> Vec<String> {
106 if self.include_whitespace {
107 text.chars().map(|c| c.to_string()).collect()
108 } else {
109 text.chars()
110 .filter(|c| !c.is_whitespace())
111 .map(|c| c.to_string())
112 .collect()
113 }
114 }
115}
116
117#[derive(Debug, Clone, Default)]
123pub struct WordPunctTokenizer {
124 lowercase: bool,
125}
126
127impl WordPunctTokenizer {
128 #[must_use]
130 pub fn new() -> Self {
131 Self { lowercase: false }
132 }
133
134 #[must_use]
136 pub fn lowercase() -> Self {
137 Self { lowercase: true }
138 }
139}
140
141impl Tokenizer for WordPunctTokenizer {
142 fn tokenize(&self, text: &str) -> Vec<String> {
143 let mut tokens = Vec::new();
144 let mut current = String::new();
145
146 for c in text.chars() {
147 if c.is_alphanumeric() {
148 current.push(c);
149 } else {
150 if !current.is_empty() {
151 tokens.push(if self.lowercase {
152 current.to_lowercase()
153 } else {
154 current.clone()
155 });
156 current.clear();
157 }
158 if !c.is_whitespace() {
159 tokens.push(c.to_string());
160 }
161 }
162 }
163
164 if !current.is_empty() {
165 tokens.push(if self.lowercase {
166 current.to_lowercase()
167 } else {
168 current
169 });
170 }
171
172 tokens
173 }
174}
175
176#[derive(Debug, Clone)]
182pub struct NGramTokenizer {
183 n: usize,
184 char_level: bool,
185}
186
187impl NGramTokenizer {
188 #[must_use]
190 pub fn word_ngrams(n: usize) -> Self {
191 Self {
192 n: n.max(1),
193 char_level: false,
194 }
195 }
196
197 #[must_use]
199 pub fn char_ngrams(n: usize) -> Self {
200 Self {
201 n: n.max(1),
202 char_level: true,
203 }
204 }
205}
206
207impl Tokenizer for NGramTokenizer {
208 fn tokenize(&self, text: &str) -> Vec<String> {
209 if self.char_level {
210 let chars: Vec<char> = text.chars().collect();
212 if chars.len() < self.n {
213 return vec![text.to_string()];
214 }
215
216 chars
217 .windows(self.n)
218 .map(|w| w.iter().collect::<String>())
219 .collect()
220 } else {
221 let words: Vec<&str> = text.split_whitespace().collect();
223 if words.len() < self.n {
224 return vec![text.to_string()];
225 }
226
227 words.windows(self.n).map(|w| w.join(" ")).collect()
228 }
229 }
230}
231
232#[derive(Debug, Clone)]
238pub struct BasicBPETokenizer {
239 merges: HashMap<(String, String), String>,
240 merge_priority: HashMap<(String, String), usize>,
242 vocab: Vec<String>,
243}
244
245impl BasicBPETokenizer {
246 #[must_use]
248 pub fn new() -> Self {
249 Self {
250 merges: HashMap::new(),
251 merge_priority: HashMap::new(),
252 vocab: Vec::new(),
253 }
254 }
255
256 pub fn train(&mut self, text: &str, num_merges: usize) {
258 let mut vocab: HashMap<String, usize> = HashMap::new();
260
261 for word in text.split_whitespace() {
263 let word_with_end = format!("{word}</w>");
264 let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
265 *vocab.entry(chars.join(" ")).or_insert(0) += 1;
266 }
267
268 for merge_idx in 0..num_merges {
269 let mut pairs: HashMap<(String, String), usize> = HashMap::new();
271 for (word, count) in &vocab {
272 let symbols: Vec<&str> = word.split(' ').collect();
273 for i in 0..symbols.len().saturating_sub(1) {
274 let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
275 *pairs.entry(pair).or_insert(0) += count;
276 }
277 }
278
279 if pairs.is_empty() {
280 break;
281 }
282
283 let best_pair = pairs
285 .into_iter()
286 .max_by_key(|(_, count)| *count)
287 .map(|(pair, _)| pair);
288
289 if let Some((a, b)) = best_pair {
290 let merged = format!("{a}{b}");
291 let pair_key = (a.clone(), b.clone());
292 self.merges.insert(pair_key.clone(), merged.clone());
293 self.merge_priority.insert(pair_key, merge_idx);
295
296 let pattern = format!("{a} {b}");
298 let mut new_vocab = HashMap::new();
299 for (word, count) in vocab {
300 let new_word = word.replace(&pattern, &merged);
301 *new_vocab.entry(new_word).or_insert(0) += count;
302 }
303 vocab = new_vocab;
304 }
305 }
306
307 let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
309 for word in vocab.keys() {
310 for symbol in word.split(' ') {
311 all_symbols.insert(symbol.to_string());
312 }
313 }
314 self.vocab = all_symbols.into_iter().collect();
315 self.vocab.sort();
316 }
317
318 #[must_use]
320 pub fn get_vocab(&self) -> &[String] {
321 &self.vocab
322 }
323
324 fn apply_bpe(&self, word: &str) -> Vec<String> {
330 let word_with_end = format!("{word}</w>");
331 let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
332
333 loop {
334 let mut best: Option<(usize, usize, &str)> = None; for i in 0..symbols.len().saturating_sub(1) {
338 let pair = (symbols[i].clone(), symbols[i + 1].clone());
339 if let Some(merged) = self.merges.get(&pair) {
340 let priority = self
341 .merge_priority
342 .get(&pair)
343 .copied()
344 .unwrap_or(usize::MAX);
345 if best.is_none() || priority < best.unwrap().1 {
346 best = Some((i, priority, merged));
347 }
348 }
349 }
350
351 match best {
352 Some((i, _, merged)) => {
353 symbols[i] = merged.to_string();
354 symbols.remove(i + 1);
355 }
356 None => break,
357 }
358 }
359
360 symbols
361 }
362}
363
364impl Default for BasicBPETokenizer {
365 fn default() -> Self {
366 Self::new()
367 }
368}
369
370impl Tokenizer for BasicBPETokenizer {
371 fn tokenize(&self, text: &str) -> Vec<String> {
372 let mut tokens = Vec::new();
373
374 for word in text.split_whitespace() {
375 let word_tokens = self.apply_bpe(word);
376 tokens.extend(word_tokens);
377 }
378
379 tokens
380 }
381}
382
383#[derive(Debug, Clone)]
389pub struct UnigramTokenizer {
390 vocab: HashMap<String, f32>,
391 max_token_length: usize,
392}
393
394impl UnigramTokenizer {
395 #[must_use]
397 pub fn new(vocab: HashMap<String, f32>) -> Self {
398 let max_len = vocab
399 .keys()
400 .map(std::string::String::len)
401 .max()
402 .unwrap_or(1);
403 Self {
404 vocab,
405 max_token_length: max_len,
406 }
407 }
408
409 #[must_use]
411 pub fn from_tokens(tokens: &[&str]) -> Self {
412 let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
413 Self::new(vocab)
414 }
415
416 fn viterbi_tokenize(&self, text: &str) -> Vec<String> {
421 let chars: Vec<char> = text.chars().collect();
422 let n = chars.len();
423 if n == 0 {
424 return Vec::new();
425 }
426
427 let mut dp = vec![f32::NEG_INFINITY; n + 1];
430 let mut back = vec![1usize; n + 1];
431 dp[0] = 0.0;
432
433 for i in 1..=n {
434 for len in 1..=self.max_token_length.min(i) {
435 let start = i - len;
436 let candidate: String = chars[start..i].iter().collect();
437 if let Some(&score) = self.vocab.get(&candidate) {
438 let log_score = score.ln();
439 let total = dp[start] + log_score;
440 if total > dp[i] {
441 dp[i] = total;
442 back[i] = len;
443 }
444 }
445 }
446 if dp[i] == f32::NEG_INFINITY {
448 dp[i] = dp[i - 1] + (-10.0); back[i] = 1;
450 }
451 }
452
453 let mut tokens = Vec::new();
455 let mut pos = n;
456 while pos > 0 {
457 let len = back[pos];
458 let token: String = chars[pos - len..pos].iter().collect();
459 tokens.push(token);
460 pos -= len;
461 }
462 tokens.reverse();
463 tokens
464 }
465}
466
467impl Tokenizer for UnigramTokenizer {
468 fn tokenize(&self, text: &str) -> Vec<String> {
469 let mut all_tokens = Vec::new();
471
472 for word in text.split_whitespace() {
473 let word_tokens = self.viterbi_tokenize(word);
474 all_tokens.extend(word_tokens);
475 }
476
477 all_tokens
478 }
479}
480
481#[cfg(test)]
486mod tests {
487 use super::*;
488
489 #[test]
490 fn test_whitespace_tokenizer() {
491 let tokenizer = WhitespaceTokenizer::new();
492 let tokens = tokenizer.tokenize("Hello World");
493
494 assert_eq!(tokens, vec!["Hello", "World"]);
495 }
496
497 #[test]
498 fn test_whitespace_tokenizer_lowercase() {
499 let tokenizer = WhitespaceTokenizer::lowercase();
500 let tokens = tokenizer.tokenize("Hello World");
501
502 assert_eq!(tokens, vec!["hello", "world"]);
503 }
504
505 #[test]
506 fn test_char_tokenizer() {
507 let tokenizer = CharTokenizer::new();
508 let tokens = tokenizer.tokenize("Hi!");
509
510 assert_eq!(tokens, vec!["H", "i", "!"]);
511 }
512
513 #[test]
514 fn test_char_tokenizer_no_whitespace() {
515 let tokenizer = CharTokenizer::no_whitespace();
516 let tokens = tokenizer.tokenize("Hi there!");
517
518 assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
519 }
520
521 #[test]
522 fn test_word_punct_tokenizer() {
523 let tokenizer = WordPunctTokenizer::new();
524 let tokens = tokenizer.tokenize("Hello, World!");
525
526 assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
527 }
528
529 #[test]
530 fn test_word_punct_tokenizer_lowercase() {
531 let tokenizer = WordPunctTokenizer::lowercase();
532 let tokens = tokenizer.tokenize("Hello, World!");
533
534 assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
535 }
536
537 #[test]
538 fn test_ngram_word_tokenizer() {
539 let tokenizer = NGramTokenizer::word_ngrams(2);
540 let tokens = tokenizer.tokenize("one two three");
541
542 assert_eq!(tokens, vec!["one two", "two three"]);
543 }
544
545 #[test]
546 fn test_ngram_char_tokenizer() {
547 let tokenizer = NGramTokenizer::char_ngrams(3);
548 let tokens = tokenizer.tokenize("hello");
549
550 assert_eq!(tokens, vec!["hel", "ell", "llo"]);
551 }
552
553 #[test]
554 fn test_bpe_tokenizer_basic() {
555 let mut tokenizer = BasicBPETokenizer::new();
556 tokenizer.train("low lower lowest", 10);
557
558 assert!(!tokenizer.get_vocab().is_empty());
560 }
561
562 #[test]
563 fn test_bpe_tokenizer_apply() {
564 let mut tokenizer = BasicBPETokenizer::new();
565 tokenizer.train("low low low lower lowest", 5);
566
567 let tokens = tokenizer.tokenize("low");
568 assert!(!tokens.is_empty());
569 }
570
571 #[test]
572 fn test_unigram_tokenizer() {
573 let tokenizer = UnigramTokenizer::from_tokens(&[
574 "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
575 ]);
576 let tokens = tokenizer.tokenize("hello world");
577
578 assert!(!tokens.is_empty());
580 }
581
582 #[test]
583 fn test_tokenizer_encode() {
584 let tokenizer = WhitespaceTokenizer::new();
585 let mut vocab = Vocab::new();
586 vocab.add_token("hello");
587 vocab.add_token("world");
588
589 let indices = tokenizer.encode("hello world", &vocab);
590 assert_eq!(indices, vec![0, 1]);
591 }
592
593 #[test]
594 fn test_tokenizer_with_multiple_spaces() {
595 let tokenizer = WhitespaceTokenizer::new();
596 let tokens = tokenizer.tokenize("hello world");
597
598 assert_eq!(tokens, vec!["hello", "world"]);
599 }
600
601 #[test]
602 fn test_empty_text() {
603 let tokenizer = WhitespaceTokenizer::new();
604 let tokens = tokenizer.tokenize("");
605
606 assert!(tokens.is_empty());
607 }
608}