1use crate::vocab::Vocab;
18use std::collections::HashMap;
19
20pub trait Tokenizer: Send + Sync {
26 fn tokenize(&self, text: &str) -> Vec<String>;
28
29 fn encode(&self, text: &str, vocab: &Vocab) -> Vec<usize> {
31 let tokens = self.tokenize(text);
32 let token_refs: Vec<&str> = tokens.iter().map(std::string::String::as_str).collect();
33 vocab.encode(&token_refs)
34 }
35}
36
37#[derive(Debug, Clone, Default)]
43pub struct WhitespaceTokenizer {
44 lowercase: bool,
45}
46
47impl WhitespaceTokenizer {
48 #[must_use]
50 pub fn new() -> Self {
51 Self { lowercase: false }
52 }
53
54 #[must_use]
56 pub fn lowercase() -> Self {
57 Self { lowercase: true }
58 }
59}
60
61impl Tokenizer for WhitespaceTokenizer {
62 fn tokenize(&self, text: &str) -> Vec<String> {
63 text.split_whitespace()
64 .map(|s| {
65 if self.lowercase {
66 s.to_lowercase()
67 } else {
68 s.to_string()
69 }
70 })
71 .collect()
72 }
73}
74
75#[derive(Debug, Clone, Default)]
81pub struct CharTokenizer {
82 include_whitespace: bool,
83}
84
85impl CharTokenizer {
86 #[must_use]
88 pub fn new() -> Self {
89 Self {
90 include_whitespace: true,
91 }
92 }
93
94 #[must_use]
96 pub fn no_whitespace() -> Self {
97 Self {
98 include_whitespace: false,
99 }
100 }
101}
102
103impl Tokenizer for CharTokenizer {
104 fn tokenize(&self, text: &str) -> Vec<String> {
105 if self.include_whitespace {
106 text.chars().map(|c| c.to_string()).collect()
107 } else {
108 text.chars()
109 .filter(|c| !c.is_whitespace())
110 .map(|c| c.to_string())
111 .collect()
112 }
113 }
114}
115
116#[derive(Debug, Clone, Default)]
122pub struct WordPunctTokenizer {
123 lowercase: bool,
124}
125
126impl WordPunctTokenizer {
127 #[must_use]
129 pub fn new() -> Self {
130 Self { lowercase: false }
131 }
132
133 #[must_use]
135 pub fn lowercase() -> Self {
136 Self { lowercase: true }
137 }
138}
139
140impl Tokenizer for WordPunctTokenizer {
141 fn tokenize(&self, text: &str) -> Vec<String> {
142 let mut tokens = Vec::new();
143 let mut current = String::new();
144
145 for c in text.chars() {
146 if c.is_alphanumeric() {
147 current.push(c);
148 } else {
149 if !current.is_empty() {
150 tokens.push(if self.lowercase {
151 current.to_lowercase()
152 } else {
153 current.clone()
154 });
155 current.clear();
156 }
157 if !c.is_whitespace() {
158 tokens.push(c.to_string());
159 }
160 }
161 }
162
163 if !current.is_empty() {
164 tokens.push(if self.lowercase {
165 current.to_lowercase()
166 } else {
167 current
168 });
169 }
170
171 tokens
172 }
173}
174
175#[derive(Debug, Clone)]
181pub struct NGramTokenizer {
182 n: usize,
183 char_level: bool,
184}
185
186impl NGramTokenizer {
187 #[must_use]
189 pub fn word_ngrams(n: usize) -> Self {
190 Self {
191 n: n.max(1),
192 char_level: false,
193 }
194 }
195
196 #[must_use]
198 pub fn char_ngrams(n: usize) -> Self {
199 Self {
200 n: n.max(1),
201 char_level: true,
202 }
203 }
204}
205
206impl Tokenizer for NGramTokenizer {
207 fn tokenize(&self, text: &str) -> Vec<String> {
208 if self.char_level {
209 let chars: Vec<char> = text.chars().collect();
211 if chars.len() < self.n {
212 return vec![text.to_string()];
213 }
214
215 chars
216 .windows(self.n)
217 .map(|w| w.iter().collect::<String>())
218 .collect()
219 } else {
220 let words: Vec<&str> = text.split_whitespace().collect();
222 if words.len() < self.n {
223 return vec![text.to_string()];
224 }
225
226 words.windows(self.n).map(|w| w.join(" ")).collect()
227 }
228 }
229}
230
231#[derive(Debug, Clone)]
237pub struct BasicBPETokenizer {
238 merges: HashMap<(String, String), String>,
239 vocab: Vec<String>,
240}
241
242impl BasicBPETokenizer {
243 #[must_use]
245 pub fn new() -> Self {
246 Self {
247 merges: HashMap::new(),
248 vocab: Vec::new(),
249 }
250 }
251
252 pub fn train(&mut self, text: &str, num_merges: usize) {
254 let mut vocab: HashMap<String, usize> = HashMap::new();
256
257 for word in text.split_whitespace() {
259 let word_with_end = format!("{word}</w>");
260 let chars: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
261 *vocab.entry(chars.join(" ")).or_insert(0) += 1;
262 }
263
264 for _ in 0..num_merges {
265 let mut pairs: HashMap<(String, String), usize> = HashMap::new();
267 for (word, count) in &vocab {
268 let symbols: Vec<&str> = word.split(' ').collect();
269 for i in 0..symbols.len().saturating_sub(1) {
270 let pair = (symbols[i].to_string(), symbols[i + 1].to_string());
271 *pairs.entry(pair).or_insert(0) += count;
272 }
273 }
274
275 if pairs.is_empty() {
276 break;
277 }
278
279 let best_pair = pairs
281 .into_iter()
282 .max_by_key(|(_, count)| *count)
283 .map(|(pair, _)| pair);
284
285 if let Some((a, b)) = best_pair {
286 let merged = format!("{a}{b}");
287 self.merges.insert((a.clone(), b.clone()), merged.clone());
288
289 let pattern = format!("{a} {b}");
291 let mut new_vocab = HashMap::new();
292 for (word, count) in vocab {
293 let new_word = word.replace(&pattern, &merged);
294 *new_vocab.entry(new_word).or_insert(0) += count;
295 }
296 vocab = new_vocab;
297 }
298 }
299
300 let mut all_symbols: std::collections::HashSet<String> = std::collections::HashSet::new();
302 for word in vocab.keys() {
303 for symbol in word.split(' ') {
304 all_symbols.insert(symbol.to_string());
305 }
306 }
307 self.vocab = all_symbols.into_iter().collect();
308 self.vocab.sort();
309 }
310
311 #[must_use]
313 pub fn get_vocab(&self) -> &[String] {
314 &self.vocab
315 }
316
317 fn apply_bpe(&self, word: &str) -> Vec<String> {
319 let word_with_end = format!("{word}</w>");
320 let mut symbols: Vec<String> = word_with_end.chars().map(|c| c.to_string()).collect();
321
322 loop {
323 let mut best_pair: Option<(usize, &str)> = None;
324
325 for i in 0..symbols.len().saturating_sub(1) {
326 let pair = (symbols[i].clone(), symbols[i + 1].clone());
327 if let Some(merged) = self.merges.get(&pair) {
328 best_pair = Some((i, merged));
329 break;
330 }
331 }
332
333 match best_pair {
334 Some((i, merged)) => {
335 symbols[i] = merged.to_string();
336 symbols.remove(i + 1);
337 }
338 None => break,
339 }
340 }
341
342 symbols
343 }
344}
345
346impl Default for BasicBPETokenizer {
347 fn default() -> Self {
348 Self::new()
349 }
350}
351
352impl Tokenizer for BasicBPETokenizer {
353 fn tokenize(&self, text: &str) -> Vec<String> {
354 let mut tokens = Vec::new();
355
356 for word in text.split_whitespace() {
357 let word_tokens = self.apply_bpe(word);
358 tokens.extend(word_tokens);
359 }
360
361 tokens
362 }
363}
364
365#[derive(Debug, Clone)]
371pub struct UnigramTokenizer {
372 vocab: HashMap<String, f32>,
373 max_token_length: usize,
374}
375
376impl UnigramTokenizer {
377 #[must_use]
379 pub fn new(vocab: HashMap<String, f32>) -> Self {
380 let max_len = vocab
381 .keys()
382 .map(std::string::String::len)
383 .max()
384 .unwrap_or(1);
385 Self {
386 vocab,
387 max_token_length: max_len,
388 }
389 }
390
391 #[must_use]
393 pub fn from_tokens(tokens: &[&str]) -> Self {
394 let vocab: HashMap<String, f32> = tokens.iter().map(|&t| (t.to_string(), 1.0)).collect();
395 Self::new(vocab)
396 }
397
398 fn greedy_tokenize(&self, text: &str) -> Vec<String> {
400 let mut tokens = Vec::new();
401 let chars: Vec<char> = text.chars().collect();
402 let mut i = 0;
403
404 while i < chars.len() {
405 let mut best_len = 1;
406 let mut best_token = chars[i].to_string();
407
408 for len in 1..=self.max_token_length.min(chars.len() - i) {
410 let candidate: String = chars[i..i + len].iter().collect();
411 if self.vocab.contains_key(&candidate) {
412 best_len = len;
413 best_token = candidate;
414 }
415 }
416
417 tokens.push(best_token);
418 i += best_len;
419 }
420
421 tokens
422 }
423}
424
425impl Tokenizer for UnigramTokenizer {
426 fn tokenize(&self, text: &str) -> Vec<String> {
427 let mut all_tokens = Vec::new();
429
430 for word in text.split_whitespace() {
431 let word_tokens = self.greedy_tokenize(word);
432 all_tokens.extend(word_tokens);
433 }
434
435 all_tokens
436 }
437}
438
439#[cfg(test)]
444mod tests {
445 use super::*;
446
447 #[test]
448 fn test_whitespace_tokenizer() {
449 let tokenizer = WhitespaceTokenizer::new();
450 let tokens = tokenizer.tokenize("Hello World");
451
452 assert_eq!(tokens, vec!["Hello", "World"]);
453 }
454
455 #[test]
456 fn test_whitespace_tokenizer_lowercase() {
457 let tokenizer = WhitespaceTokenizer::lowercase();
458 let tokens = tokenizer.tokenize("Hello World");
459
460 assert_eq!(tokens, vec!["hello", "world"]);
461 }
462
463 #[test]
464 fn test_char_tokenizer() {
465 let tokenizer = CharTokenizer::new();
466 let tokens = tokenizer.tokenize("Hi!");
467
468 assert_eq!(tokens, vec!["H", "i", "!"]);
469 }
470
471 #[test]
472 fn test_char_tokenizer_no_whitespace() {
473 let tokenizer = CharTokenizer::no_whitespace();
474 let tokens = tokenizer.tokenize("Hi there!");
475
476 assert_eq!(tokens, vec!["H", "i", "t", "h", "e", "r", "e", "!"]);
477 }
478
479 #[test]
480 fn test_word_punct_tokenizer() {
481 let tokenizer = WordPunctTokenizer::new();
482 let tokens = tokenizer.tokenize("Hello, World!");
483
484 assert_eq!(tokens, vec!["Hello", ",", "World", "!"]);
485 }
486
487 #[test]
488 fn test_word_punct_tokenizer_lowercase() {
489 let tokenizer = WordPunctTokenizer::lowercase();
490 let tokens = tokenizer.tokenize("Hello, World!");
491
492 assert_eq!(tokens, vec!["hello", ",", "world", "!"]);
493 }
494
495 #[test]
496 fn test_ngram_word_tokenizer() {
497 let tokenizer = NGramTokenizer::word_ngrams(2);
498 let tokens = tokenizer.tokenize("one two three");
499
500 assert_eq!(tokens, vec!["one two", "two three"]);
501 }
502
503 #[test]
504 fn test_ngram_char_tokenizer() {
505 let tokenizer = NGramTokenizer::char_ngrams(3);
506 let tokens = tokenizer.tokenize("hello");
507
508 assert_eq!(tokens, vec!["hel", "ell", "llo"]);
509 }
510
511 #[test]
512 fn test_bpe_tokenizer_basic() {
513 let mut tokenizer = BasicBPETokenizer::new();
514 tokenizer.train("low lower lowest", 10);
515
516 assert!(!tokenizer.get_vocab().is_empty());
518 }
519
520 #[test]
521 fn test_bpe_tokenizer_apply() {
522 let mut tokenizer = BasicBPETokenizer::new();
523 tokenizer.train("low low low lower lowest", 5);
524
525 let tokens = tokenizer.tokenize("low");
526 assert!(!tokens.is_empty());
527 }
528
529 #[test]
530 fn test_unigram_tokenizer() {
531 let tokenizer = UnigramTokenizer::from_tokens(&[
532 "hel", "lo", "wor", "ld", "h", "e", "l", "o", "w", "r", "d",
533 ]);
534 let tokens = tokenizer.tokenize("hello world");
535
536 assert!(!tokens.is_empty());
538 }
539
540 #[test]
541 fn test_tokenizer_encode() {
542 let tokenizer = WhitespaceTokenizer::new();
543 let mut vocab = Vocab::new();
544 vocab.add_token("hello");
545 vocab.add_token("world");
546
547 let indices = tokenizer.encode("hello world", &vocab);
548 assert_eq!(indices, vec![0, 1]);
549 }
550
551 #[test]
552 fn test_tokenizer_with_multiple_spaces() {
553 let tokenizer = WhitespaceTokenizer::new();
554 let tokens = tokenizer.tokenize("hello world");
555
556 assert_eq!(tokens, vec!["hello", "world"]);
557 }
558
559 #[test]
560 fn test_empty_text() {
561 let tokenizer = WhitespaceTokenizer::new();
562 let tokens = tokenizer.tokenize("");
563
564 assert!(tokens.is_empty());
565 }
566}