1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6use unicode_normalization::UnicodeNormalization;
7
8use super::config::{Normalization, TokenizerConfig};
9use super::error::{Result, TokenizerError};
10use super::traits::{TokenId, Tokenizer};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct BPETokenizer {
15 config: TokenizerConfig,
16 vocab: HashMap<String, TokenId>,
18 id_to_token_map: HashMap<TokenId, String>,
20 merges: Vec<(String, String)>,
22 trained: bool,
24}
25
26impl BPETokenizer {
27 pub fn new(config: TokenizerConfig) -> Self {
29 Self {
30 config,
31 vocab: HashMap::new(),
32 id_to_token_map: HashMap::new(),
33 merges: Vec::new(),
34 trained: false,
35 }
36 }
37
38 fn init_vocab(&mut self) {
40 let mut id: TokenId = 0;
41
42 let special = [
44 &self.config.special_tokens.unk,
45 &self.config.special_tokens.bos,
46 &self.config.special_tokens.eos,
47 &self.config.special_tokens.pad,
48 &self.config.special_tokens.mask,
49 ];
50
51 for token in special {
52 self.vocab.insert(token.clone(), id);
53 self.id_to_token_map.insert(id, token.clone());
54 id += 1;
55 }
56
57 for byte in 0..=255u8 {
59 let token = format!("{byte:02x}");
60 if !self.vocab.contains_key(&token) {
61 self.vocab.insert(token.clone(), id);
62 self.id_to_token_map.insert(id, token);
63 id += 1;
64 }
65 }
66 }
67
68 #[cfg(test)]
70 fn get_pair_freqs(&self, tokenized: &[Vec<String>]) -> HashMap<(String, String), usize> {
71 let mut freqs = HashMap::new();
72
73 for tokens in tokenized {
74 for pair in tokens.windows(2) {
75 let key = (pair[0].clone(), pair[1].clone());
76 *freqs.entry(key).or_insert(0) += 1;
77 }
78 }
79
80 freqs
81 }
82
83 #[cfg(test)]
85 fn merge_pair(&self, tokenized: &mut [Vec<String>], pair: &(String, String), merged: &str) {
86 for tokens in tokenized.iter_mut() {
87 let mut i = 0;
88 while i < tokens.len().saturating_sub(1) {
89 if tokens[i] == pair.0 && tokens[i + 1] == pair.1 {
90 tokens[i] = merged.to_string();
91 tokens.remove(i + 1);
92 }
93 i += 1;
94 }
95 }
96 }
97
98 fn preprocess(&self, text: &str) -> String {
105 let normalized = match self.config.normalization {
106 Normalization::None => text.to_string(),
107 Normalization::NFC => text.nfc().collect(),
108 };
109 if self.config.lowercase {
110 normalized.to_lowercase()
111 } else {
112 normalized
113 }
114 }
115
116 fn to_bytes(&self, text: &str) -> Vec<String> {
118 text.as_bytes().iter().map(|b| format!("{b:02x}")).collect()
119 }
120
121 fn apply_merges(&self, mut tokens: Vec<String>) -> Vec<String> {
123 for (a, b) in &self.merges {
124 let merged = format!("{a}{b}");
125 let mut i = 0;
126 while i < tokens.len().saturating_sub(1) {
127 if &tokens[i] == a && &tokens[i + 1] == b {
128 tokens[i] = merged.clone();
129 tokens.remove(i + 1);
130 } else {
131 i += 1;
132 }
133 }
134 }
135 tokens
136 }
137
138 pub fn vocab(&self) -> &HashMap<String, TokenId> {
145 &self.vocab
146 }
147
148 pub fn merges(&self) -> &[(String, String)] {
154 &self.merges
155 }
156
157 pub fn save(&self, path: &str) -> Result<()> {
159 let json = serde_json::to_string_pretty(self)
160 .map_err(|e| TokenizerError::Serialization(e.to_string()))?;
161 std::fs::write(path, json)?;
162 Ok(())
163 }
164
165 pub fn load(path: &str) -> Result<Self> {
167 let json = std::fs::read_to_string(path)?;
168 serde_json::from_str(&json).map_err(|e| TokenizerError::Serialization(e.to_string()))
169 }
170
171 pub fn from_vocab_merges(
197 vocab_path: &str,
198 merges_path: &str,
199 config: TokenizerConfig,
200 ) -> Result<Self> {
201 let vocab_json = std::fs::read_to_string(vocab_path)?;
202 let vocab: HashMap<String, TokenId> = serde_json::from_str(&vocab_json)
203 .map_err(|e| TokenizerError::Serialization(e.to_string()))?;
204
205 let id_to_token_map: HashMap<TokenId, String> =
206 vocab.iter().map(|(tok, &id)| (id, tok.clone())).collect();
207
208 if id_to_token_map.len() != vocab.len() {
209 return Err(TokenizerError::Serialization(
210 "vocab.json contains duplicate token ids (collision detected after inverting map)"
211 .to_string(),
212 ));
213 }
214
215 let merges_text = std::fs::read_to_string(merges_path)?;
216 let mut merges: Vec<(String, String)> = Vec::new();
217 for (line_no, line) in merges_text.lines().enumerate() {
218 if line.is_empty() || line.starts_with("#") {
219 continue;
220 }
221 let mut parts = line.splitn(2, ' ');
222 let left = parts
223 .next()
224 .ok_or_else(|| {
225 TokenizerError::Serialization(format!(
226 "merges.txt line {}: missing left token",
227 line_no + 1
228 ))
229 })?
230 .to_string();
231 let right = parts
232 .next()
233 .ok_or_else(|| {
234 TokenizerError::Serialization(format!(
235 "merges.txt line {}: missing right token (expected '<left> <right>')",
236 line_no + 1
237 ))
238 })?
239 .to_string();
240
241 let merged = format!("{left}{right}");
242 if !vocab.contains_key(&merged) {
243 return Err(TokenizerError::Serialization(format!(
244 "merges.txt line {}: merged token {:?} not present in vocab.json",
245 line_no + 1,
246 merged
247 )));
248 }
249 merges.push((left, right));
250 }
251
252 Ok(Self { config, vocab, id_to_token_map, merges, trained: true })
253 }
254}
255
256impl Tokenizer for BPETokenizer {
257 fn train(&mut self, corpus: &[&str]) -> Result<()> {
258 train_fast(self, corpus)
259 }
260
261 fn encode(&self, text: &str) -> Result<Vec<TokenId>> {
262 if !self.trained {
263 return Err(TokenizerError::NotTrained);
264 }
265
266 let tokens = self.to_bytes(&self.preprocess(text));
267 let tokens = self.apply_merges(tokens);
268
269 let unk_id = *self
270 .vocab
271 .get(&self.config.special_tokens.unk)
272 .expect("UNK token must exist in trained vocabulary");
273
274 let ids: Vec<TokenId> =
275 tokens.iter().map(|t| *self.vocab.get(t).unwrap_or(&unk_id)).collect();
276
277 Ok(ids)
278 }
279
280 fn decode(&self, ids: &[TokenId]) -> Result<String> {
281 if !self.trained {
282 return Err(TokenizerError::NotTrained);
283 }
284
285 let mut hex_string = String::new();
286
287 for &id in ids {
288 if let Some(token) = self.id_to_token_map.get(&id) {
289 if token.starts_with('<') && token.ends_with('>') {
291 continue;
292 }
293 hex_string.push_str(token);
294 }
295 }
296
297 let bytes: Vec<u8> = (0..hex_string.len())
299 .step_by(2)
300 .filter_map(|i| {
301 if i + 2 <= hex_string.len() {
302 u8::from_str_radix(&hex_string[i..i + 2], 16).ok()
303 } else {
304 None
305 }
306 })
307 .collect();
308
309 String::from_utf8(bytes).map_err(|e| TokenizerError::Training(e.to_string()))
310 }
311
312 fn vocab_size(&self) -> usize {
313 self.vocab.len()
314 }
315
316 fn is_trained(&self) -> bool {
317 self.trained
318 }
319
320 fn id_to_token(&self, id: TokenId) -> Option<&str> {
321 self.id_to_token_map.get(&id).map(String::as_str)
322 }
323
324 fn token_to_id(&self, token: &str) -> Option<TokenId> {
325 self.vocab.get(token).copied()
326 }
327}
328
329#[derive(Clone, Eq, PartialEq)]
344struct HeapEntry {
345 count: i64,
346 pair: (TokenId, TokenId),
347}
348
349impl Ord for HeapEntry {
350 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
351 self.count.cmp(&other.count).then_with(|| other.pair.cmp(&self.pair))
356 }
357}
358
359impl PartialOrd for HeapEntry {
360 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
361 Some(self.cmp(other))
362 }
363}
364
365pub(crate) fn train_fast(tok: &mut BPETokenizer, corpus: &[&str]) -> Result<()> {
371 use std::collections::{BinaryHeap, HashMap, HashSet};
372 use std::time::Instant;
373
374 let start = Instant::now();
375 let target = tok.config.vocab_size;
376 let min_frequency = tok.config.min_frequency.max(1) as i64;
377
378 tok.init_vocab();
379
380 eprintln!("[bpe-setup] ingest start: {} docs", corpus.len());
381 use std::io::Write;
382 let _ = std::io::stderr().flush();
383
384 let t0 = Instant::now();
386 let mut word_counts: HashMap<Vec<TokenId>, u64> = HashMap::new();
387 for doc in corpus {
388 let text = tok.preprocess(doc);
389 let hex_tokens = tok.to_bytes(&text);
390 if hex_tokens.is_empty() {
391 continue;
392 }
393 let ids: Vec<TokenId> = hex_tokens
394 .iter()
395 .map(|t| *tok.vocab.get(t).expect("byte hex token must be in init_vocab"))
396 .collect();
397 *word_counts.entry(ids).or_insert(0) += 1;
398 }
399 eprintln!(
400 "[bpe-setup] ingest done: {} unique words in {:.1}s",
401 word_counts.len(),
402 t0.elapsed().as_secs_f64()
403 );
404 let _ = std::io::stderr().flush();
405
406 let mut words: Vec<(Vec<TokenId>, u64)> = word_counts.into_iter().collect();
407
408 let t1 = Instant::now();
410 let mut pair_counts: HashMap<(TokenId, TokenId), i64> = HashMap::new();
411 let mut pair_words: HashMap<(TokenId, TokenId), HashSet<usize>> = HashMap::new();
412 for (word_ix, (ids, mult)) in words.iter().enumerate() {
413 let m = *mult as i64;
414 for w in ids.windows(2) {
415 let p = (w[0], w[1]);
416 *pair_counts.entry(p).or_insert(0) += m;
417 pair_words.entry(p).or_default().insert(word_ix);
418 }
419 }
420 eprintln!(
421 "[bpe-setup] pair indexes: {} unique pairs in {:.1}s",
422 pair_counts.len(),
423 t1.elapsed().as_secs_f64()
424 );
425 let _ = std::io::stderr().flush();
426
427 let t2 = Instant::now();
429 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(pair_counts.len());
430 for (p, c) in &pair_counts {
431 if *c > 0 {
432 heap.push(HeapEntry { count: *c, pair: *p });
433 }
434 }
435 eprintln!(
436 "[bpe-setup] heap seeded: {} entries in {:.1}s; entering merge loop",
437 heap.len(),
438 t2.elapsed().as_secs_f64()
439 );
440 let _ = std::io::stderr().flush();
441
442 let mut merges_emitted: usize = 0;
443
444 let mut old_pairs_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(512);
452 let mut new_pairs_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(512);
453 let mut pairs_touched_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(1 << 16);
454 let mut affected_buf: Vec<usize> = Vec::with_capacity(1 << 16);
455
456 while tok.vocab.len() < target {
457 let entry = match heap.pop() {
458 Some(e) => e,
459 None => break,
460 };
461 let current = *pair_counts.get(&entry.pair).unwrap_or(&0);
463 if current != entry.count {
464 continue;
465 }
466 if current < min_frequency {
467 break;
468 }
469
470 let (a, b) = entry.pair;
471 let a_str = tok.id_to_token_map[&a].clone();
472 let b_str = tok.id_to_token_map[&b].clone();
473 let merged_str = format!("{a_str}{b_str}");
474 let new_id: TokenId = tok.vocab.len() as TokenId;
475 tok.vocab.insert(merged_str.clone(), new_id);
476 tok.id_to_token_map.insert(new_id, merged_str);
477 tok.merges.push((a_str, b_str));
478 merges_emitted += 1;
479
480 affected_buf.clear();
483 if let Some(ws) = pair_words.get(&(a, b)) {
484 affected_buf.extend(ws.iter().copied());
485 }
486
487 pairs_touched_buf.clear();
494
495 for &word_ix in &affected_buf {
496 let (ids, mult) = &mut words[word_ix];
497 let m = *mult as i64;
498
499 old_pairs_buf.clear();
501 old_pairs_buf.extend(ids.windows(2).map(|w| (w[0], w[1])));
502
503 let mut write = 0;
507 let mut read = 0;
508 while read < ids.len() {
509 if read + 1 < ids.len() && ids[read] == a && ids[read + 1] == b {
510 ids[write] = new_id;
511 write += 1;
512 read += 2;
513 } else {
514 ids[write] = ids[read];
515 write += 1;
516 read += 1;
517 }
518 }
519 ids.truncate(write);
520
521 new_pairs_buf.clear();
523 new_pairs_buf.extend(ids.windows(2).map(|w| (w[0], w[1])));
524
525 for p in &old_pairs_buf {
527 *pair_counts.entry(*p).or_insert(0) -= m;
528 }
529 for p in &new_pairs_buf {
530 *pair_counts.entry(*p).or_insert(0) += m;
531 }
532
533 old_pairs_buf.sort_unstable();
539 old_pairs_buf.dedup();
540 new_pairs_buf.sort_unstable();
541 new_pairs_buf.dedup();
542
543 let mut i = 0usize;
544 let mut j = 0usize;
545 while i < old_pairs_buf.len() && j < new_pairs_buf.len() {
546 match old_pairs_buf[i].cmp(&new_pairs_buf[j]) {
547 std::cmp::Ordering::Less => {
548 if let Some(ws) = pair_words.get_mut(&old_pairs_buf[i]) {
549 ws.remove(&word_ix);
550 }
551 pairs_touched_buf.push(old_pairs_buf[i]);
552 i += 1;
553 }
554 std::cmp::Ordering::Greater => {
555 pair_words.entry(new_pairs_buf[j]).or_default().insert(word_ix);
556 pairs_touched_buf.push(new_pairs_buf[j]);
557 j += 1;
558 }
559 std::cmp::Ordering::Equal => {
560 pairs_touched_buf.push(old_pairs_buf[i]);
563 i += 1;
564 j += 1;
565 }
566 }
567 }
568 while i < old_pairs_buf.len() {
569 if let Some(ws) = pair_words.get_mut(&old_pairs_buf[i]) {
570 ws.remove(&word_ix);
571 }
572 pairs_touched_buf.push(old_pairs_buf[i]);
573 i += 1;
574 }
575 while j < new_pairs_buf.len() {
576 pair_words.entry(new_pairs_buf[j]).or_default().insert(word_ix);
577 pairs_touched_buf.push(new_pairs_buf[j]);
578 j += 1;
579 }
580 }
581
582 pairs_touched_buf.sort_unstable();
585 pairs_touched_buf.dedup();
586 for p in &pairs_touched_buf {
587 let c = *pair_counts.get(p).unwrap_or(&0);
588 if c > 0 {
589 heap.push(HeapEntry { count: c, pair: *p });
590 }
591 }
592
593 pair_counts.remove(&(a, b));
595 pair_words.remove(&(a, b));
596
597 if merges_emitted == 1 || merges_emitted.is_multiple_of(100) {
598 let elapsed = start.elapsed().as_secs_f64();
599 let top_count = heap.peek().map(|e| e.count).unwrap_or(0);
600 eprintln!(
601 "[bpe] merges={} vocab={} elapsed={:.1}s top_count={} heap={} pairs={}",
602 merges_emitted,
603 tok.vocab.len(),
604 elapsed,
605 top_count,
606 heap.len(),
607 pair_counts.len()
608 );
609 let _ = std::io::stderr().flush();
610 }
611 }
612
613 let elapsed = start.elapsed().as_secs_f64();
614 eprintln!(
615 "[bpe] DONE merges={} vocab={} elapsed={:.1}s",
616 merges_emitted,
617 tok.vocab.len(),
618 elapsed
619 );
620 let _ = std::io::stderr().flush();
621
622 tok.trained = true;
623 Ok(())
624}
625
626#[cfg(test)]
632#[doc(hidden)]
633pub(crate) fn train_naive_reference(tok: &mut BPETokenizer, corpus: &[&str]) -> Result<()> {
634 let target = tok.config.vocab_size;
635 let min_frequency = tok.config.min_frequency.max(1);
636
637 tok.init_vocab();
638
639 let mut tokenized: Vec<Vec<String>> =
640 corpus.iter().map(|s| tok.to_bytes(&tok.preprocess(s))).collect();
641
642 while tok.vocab.len() < target {
643 let freqs = tok.get_pair_freqs(&tokenized);
644
645 let mut best: Option<(usize, (TokenId, TokenId), (String, String))> = None;
647 for (pair_str, count) in &freqs {
648 if *count < min_frequency {
649 continue;
650 }
651 let left_id = *tok.vocab.get(&pair_str.0).expect("left must be in vocab");
652 let right_id = *tok.vocab.get(&pair_str.1).expect("right must be in vocab");
653 match &best {
654 None => best = Some((*count, (left_id, right_id), pair_str.clone())),
655 Some((bc, bp, _)) => {
656 if *count > *bc || (*count == *bc && (left_id, right_id) < *bp) {
657 best = Some((*count, (left_id, right_id), pair_str.clone()));
658 }
659 }
660 }
661 }
662
663 let (_count, _ids, pair_str) = match best {
664 Some(b) => b,
665 None => break,
666 };
667
668 let merged = format!("{}{}", pair_str.0, pair_str.1);
669 let new_id: TokenId = tok.vocab.len() as TokenId;
670 tok.vocab.insert(merged.clone(), new_id);
671 tok.id_to_token_map.insert(new_id, merged.clone());
672 tok.merges.push(pair_str.clone());
673 tok.merge_pair(&mut tokenized, &pair_str, &merged);
674 }
675
676 tok.trained = true;
677 Ok(())
678}
679
680#[cfg(test)]
681mod tests {
682 use super::*;
683
684 #[test]
685 fn test_bpe_new() {
686 let config = TokenizerConfig::bpe();
687 let tokenizer = BPETokenizer::new(config);
688 assert!(!tokenizer.is_trained());
689 }
690
691 #[test]
692 fn test_bpe_train() {
693 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
694 let mut tokenizer = BPETokenizer::new(config);
695
696 let corpus = vec!["hello hello", "hello world", "world hello"];
697 tokenizer.train(&corpus).expect("operation should succeed");
698
699 assert!(tokenizer.is_trained());
700 assert!(tokenizer.vocab_size() > 256); }
702
703 #[test]
704 fn test_bpe_encode_not_trained() {
705 let config = TokenizerConfig::bpe();
706 let tokenizer = BPETokenizer::new(config);
707
708 let result = tokenizer.encode("hello");
709 assert!(result.is_err());
710 }
711
712 #[test]
713 fn test_bpe_encode_decode() {
714 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
715 let mut tokenizer = BPETokenizer::new(config);
716
717 let corpus = vec!["hello world", "hello there"];
718 tokenizer.train(&corpus).expect("operation should succeed");
719
720 let text = "hello";
721 let encoded = tokenizer.encode(text).expect("encoding should succeed");
722 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
723
724 assert_eq!(decoded, text);
725 }
726
727 #[test]
728 fn test_bpe_lowercase() {
729 let config =
730 TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1).with_lowercase(true);
731 let mut tokenizer = BPETokenizer::new(config);
732
733 let corpus = vec!["Hello World"];
734 tokenizer.train(&corpus).expect("operation should succeed");
735
736 let encoded = tokenizer.encode("HELLO").expect("encoding should succeed");
737 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
738
739 assert_eq!(decoded, "hello");
740 }
741
742 #[test]
743 fn test_bpe_id_to_token() {
744 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
745 let mut tokenizer = BPETokenizer::new(config);
746
747 let corpus = vec!["test"];
748 tokenizer.train(&corpus).expect("operation should succeed");
749
750 assert_eq!(tokenizer.id_to_token(0), Some("<unk>"));
752 }
753
754 #[test]
755 fn test_bpe_token_to_id() {
756 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
757 let mut tokenizer = BPETokenizer::new(config);
758
759 let corpus = vec!["test"];
760 tokenizer.train(&corpus).expect("operation should succeed");
761
762 assert_eq!(tokenizer.token_to_id("<unk>"), Some(0));
763 }
764
765 #[test]
769 fn test_bpe_nfc_composed_decomposed_parity() {
770 let composed = "café"; let decomposed = "cafe\u{0301}"; let config = TokenizerConfig::bpe()
774 .with_vocab_size(300)
775 .with_min_frequency(1)
776 .with_normalization(Normalization::NFC);
777 let mut tokenizer = BPETokenizer::new(config);
778 tokenizer.train(&[composed]).expect("operation should succeed");
779
780 let ids_composed = tokenizer.encode(composed).expect("encoding should succeed");
781 let ids_decomposed = tokenizer.encode(decomposed).expect("encoding should succeed");
782
783 assert_eq!(
784 ids_composed, ids_decomposed,
785 "NFC must map composed and decomposed café to identical token IDs"
786 );
787
788 let decoded = tokenizer.decode(&ids_composed).expect("decoding should succeed");
789 assert_eq!(decoded, composed, "NFC round-trip must recover composed form");
790 }
791
792 #[test]
795 fn test_bpe_without_nfc_composed_decomposed_diverge() {
796 let composed = "café";
797 let decomposed = "cafe\u{0301}";
798
799 let config = TokenizerConfig::bpe()
800 .with_vocab_size(300)
801 .with_min_frequency(1)
802 .with_normalization(Normalization::None);
803 let mut tokenizer = BPETokenizer::new(config);
804 tokenizer.train(&[composed]).expect("operation should succeed");
805
806 let ids_composed = tokenizer.encode(composed).expect("encoding should succeed");
807 let ids_decomposed = tokenizer.encode(decomposed).expect("encoding should succeed");
808
809 assert_ne!(
810 ids_composed, ids_decomposed,
811 "Without NFC, composed and decomposed café MUST diverge (falsification witness for INV-TOK-003)"
812 );
813 }
814
815 #[test]
822 fn test_bpe_from_vocab_merges_roundtrip() {
823 use std::fmt::Write;
824 let config = TokenizerConfig::bpe()
825 .with_vocab_size(400)
826 .with_min_frequency(1)
827 .with_normalization(Normalization::NFC);
828 let mut original = BPETokenizer::new(config.clone());
829 let corpus = vec!["def hello():\n return 1\n", "def world():\n return 2\n"];
830 original.train(&corpus).expect("training should succeed");
831
832 let tmp = std::env::temp_dir().join(format!(
833 "bpe_roundtrip_{}_{}",
834 std::process::id(),
835 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()
836 ));
837 std::fs::create_dir_all(&tmp).unwrap();
838 let vocab_path = tmp.join("vocab.json");
839 let merges_path = tmp.join("merges.txt");
840
841 let mut entries: Vec<(&String, &TokenId)> = original.vocab().iter().collect();
842 entries.sort_by_key(|(_, id)| *id);
843 let ordered: serde_json::Map<String, serde_json::Value> = entries
844 .into_iter()
845 .map(|(k, v)| (k.clone(), serde_json::Value::Number((*v).into())))
846 .collect();
847 let vocab_json = serde_json::to_string_pretty(&ordered).unwrap();
848 std::fs::write(&vocab_path, vocab_json).unwrap();
849
850 let mut merges_content = String::from("#version: 0.2\n");
851 for (left, right) in original.merges() {
852 writeln!(merges_content, "{left} {right}").unwrap();
853 }
854 std::fs::write(&merges_path, merges_content).unwrap();
855
856 let reloaded = BPETokenizer::from_vocab_merges(
857 vocab_path.to_str().unwrap(),
858 merges_path.to_str().unwrap(),
859 config,
860 )
861 .expect("from_vocab_merges should succeed");
862
863 assert_eq!(reloaded.vocab_size(), original.vocab_size(), "reloaded vocab size must match");
864
865 for text in &corpus {
866 let original_ids = original.encode(text).expect("original encode");
867 let reloaded_ids = reloaded.encode(text).expect("reloaded encode");
868 assert_eq!(
869 original_ids, reloaded_ids,
870 "reloaded encoding must byte-equal original encoding for {text:?}"
871 );
872 }
873
874 let _ = std::fs::remove_dir_all(&tmp);
875 }
876
877 #[test]
881 fn test_bpe_from_vocab_merges_rejects_orphan_merge() {
882 let tmp = std::env::temp_dir().join(format!(
883 "bpe_orphan_{}_{}",
884 std::process::id(),
885 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()
886 ));
887 std::fs::create_dir_all(&tmp).unwrap();
888 let vocab_path = tmp.join("vocab.json");
889 let merges_path = tmp.join("merges.txt");
890
891 std::fs::write(&vocab_path, r#"{"<unk>": 0, "aa": 1, "bb": 2}"#).unwrap();
892 std::fs::write(&merges_path, "#version: 0.2\naa bb\n").unwrap();
893
894 let result = BPETokenizer::from_vocab_merges(
895 vocab_path.to_str().unwrap(),
896 merges_path.to_str().unwrap(),
897 TokenizerConfig::bpe(),
898 );
899
900 assert!(
901 result.is_err(),
902 "from_vocab_merges must reject merges.txt with merged token not in vocab.json"
903 );
904 let err_msg = format!("{:?}", result.unwrap_err());
905 assert!(
906 err_msg.contains("aabb"),
907 "error should name the offending merged token, got: {err_msg}"
908 );
909
910 let _ = std::fs::remove_dir_all(&tmp);
911 }
912
913 fn synthetic_python_corpus(n_docs: usize) -> Vec<String> {
915 let templates: &[&str] = &[
916 "def fn_{i}(x):\n return x * {i}\n",
917 "class C_{i}:\n def __init__(self):\n self.x = {i}\n",
918 "for i in range({i}):\n print(i * {i})\n",
919 "def add_{i}(a, b):\n return a + b + {i}\n",
920 "import math\nprint(math.sqrt({i}))\n",
921 "if x == {i}:\n return True\nelse:\n return False\n",
922 "xs = [{i}, {i}, {i}]\nfor x in xs:\n print(x)\n",
923 "def process_{i}(data):\n result = []\n for item in data:\n result.append(item + {i})\n return result\n",
924 ];
925 (0..n_docs).map(|i| templates[i % templates.len()].replace("{i}", &i.to_string())).collect()
926 }
927
928 #[test]
930 fn bpe_fast_vs_naive_parity() {
931 let config = TokenizerConfig::bpe()
932 .with_vocab_size(512)
933 .with_min_frequency(1)
934 .with_normalization(Normalization::NFC);
935
936 let corpus_owned = synthetic_python_corpus(20);
937 let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
938
939 let mut fast = BPETokenizer::new(config.clone());
940 super::train_fast(&mut fast, &corpus).expect("fast train should succeed");
941
942 let mut naive = BPETokenizer::new(config);
943 super::train_naive_reference(&mut naive, &corpus).expect("naive train should succeed");
944
945 assert_eq!(
946 fast.vocab_size(),
947 naive.vocab_size(),
948 "vocab sizes must match between fast and naive"
949 );
950 assert_eq!(fast.merges(), naive.merges(), "merge sequence must be identical");
951
952 let mut fast_entries: Vec<(&String, &TokenId)> = fast.vocab().iter().collect();
953 let mut naive_entries: Vec<(&String, &TokenId)> = naive.vocab().iter().collect();
954 fast_entries.sort_by_key(|(_, id)| *id);
955 naive_entries.sort_by_key(|(_, id)| *id);
956 assert_eq!(
957 fast_entries, naive_entries,
958 "vocab (id → token) must be identical between fast and naive"
959 );
960 }
961
962 #[test]
964 fn bpe_fast_is_deterministic() {
965 let config = TokenizerConfig::bpe()
966 .with_vocab_size(400)
967 .with_min_frequency(1)
968 .with_normalization(Normalization::NFC);
969
970 let corpus_owned = synthetic_python_corpus(15);
971 let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
972
973 let mut a = BPETokenizer::new(config.clone());
974 super::train_fast(&mut a, &corpus).expect("run A");
975 let mut b = BPETokenizer::new(config);
976 super::train_fast(&mut b, &corpus).expect("run B");
977
978 assert_eq!(a.merges(), b.merges(), "merges must be byte-identical across runs");
979 assert_eq!(a.vocab_size(), b.vocab_size(), "vocab size must match");
980
981 let mut a_entries: Vec<(&String, &TokenId)> = a.vocab().iter().collect();
982 let mut b_entries: Vec<(&String, &TokenId)> = b.vocab().iter().collect();
983 a_entries.sort_by_key(|(_, id)| *id);
984 b_entries.sort_by_key(|(_, id)| *id);
985 assert_eq!(a_entries, b_entries, "vocab map must be byte-identical across runs");
986 }
987
988 #[test]
1000 fn bpe_fast_meets_1_5x_parity_replacement_rule() {
1001 use std::time::Instant;
1002
1003 let config = TokenizerConfig::bpe()
1004 .with_vocab_size(2048)
1005 .with_min_frequency(1)
1006 .with_normalization(Normalization::NFC);
1007
1008 let corpus_owned = synthetic_python_corpus(500);
1009 let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
1010
1011 let mut naive = BPETokenizer::new(config.clone());
1012 let t0 = Instant::now();
1013 super::train_naive_reference(&mut naive, &corpus).expect("naive train");
1014 let naive_secs = t0.elapsed().as_secs_f64();
1015
1016 let mut fast = BPETokenizer::new(config);
1017 let t0 = Instant::now();
1018 super::train_fast(&mut fast, &corpus).expect("fast train");
1019 let fast_secs = t0.elapsed().as_secs_f64();
1020
1021 let ratio = naive_secs / fast_secs;
1022 eprintln!(
1023 "[bpe-speedup] naive={naive_secs:.3}s fast={fast_secs:.3}s ratio={ratio:.2}× \
1024 vocab_naive={} vocab_fast={}",
1025 naive.vocab_size(),
1026 fast.vocab_size()
1027 );
1028
1029 assert_eq!(
1031 fast.merges(),
1032 naive.merges(),
1033 "at perf-workload scale, fast and naive merges MUST still match"
1034 );
1035
1036 if cfg!(debug_assertions) {
1037 assert!(
1040 fast_secs < naive_secs * 1.5,
1041 "even in debug, fast must not be dramatically slower than naive \
1042 (ratio={ratio:.2}×)"
1043 );
1044 } else {
1045 assert!(
1046 ratio >= 1.5,
1047 "org policy: replacement must be ≥1.5× faster than the replaced \
1048 algorithm — got {ratio:.2}× (naive={naive_secs:.3}s, fast={fast_secs:.3}s)"
1049 );
1050 }
1051 }
1052}
1053
1054#[cfg(test)]
1055mod property_tests {
1056 use super::*;
1057 use proptest::prelude::*;
1058
1059 proptest! {
1060 #![proptest_config(ProptestConfig::with_cases(50))]
1061
1062 #[test]
1063 fn prop_bpe_encode_produces_valid_ids(text in "[a-zA-Z ]{1,20}") {
1064 let config = TokenizerConfig::bpe()
1065 .with_vocab_size(300)
1066 .with_min_frequency(1);
1067 let mut tokenizer = BPETokenizer::new(config);
1068 tokenizer.train(&[&text]).expect("operation should succeed");
1069
1070 let encoded = tokenizer.encode(&text).expect("encoding should succeed");
1071
1072 for id in encoded {
1073 prop_assert!(tokenizer.id_to_token(id).is_some());
1074 }
1075 }
1076
1077 #[test]
1078 fn prop_vocab_size_bounded(target_size in 261usize..500) {
1079 let config = TokenizerConfig::bpe()
1080 .with_vocab_size(target_size)
1081 .with_min_frequency(1);
1082 let mut tokenizer = BPETokenizer::new(config);
1083
1084 let corpus = vec!["hello world hello world test test"];
1085 tokenizer.train(&corpus).expect("operation should succeed");
1086
1087 prop_assert!(tokenizer.vocab_size() <= target_size);
1088 }
1089 }
1090}