1use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6use unicode_normalization::UnicodeNormalization;
7
8use super::config::{Normalization, TokenizerConfig};
9use super::error::{Result, TokenizerError};
10use super::traits::{TokenId, Tokenizer};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct BPETokenizer {
15 config: TokenizerConfig,
16 vocab: HashMap<String, TokenId>,
18 id_to_token_map: HashMap<TokenId, String>,
20 merges: Vec<(String, String)>,
22 trained: bool,
24}
25
26impl BPETokenizer {
27 pub fn new(config: TokenizerConfig) -> Self {
29 Self {
30 config,
31 vocab: HashMap::new(),
32 id_to_token_map: HashMap::new(),
33 merges: Vec::new(),
34 trained: false,
35 }
36 }
37
38 fn init_vocab(&mut self) {
40 let mut id: TokenId = 0;
41
42 let special = [
44 &self.config.special_tokens.unk,
45 &self.config.special_tokens.bos,
46 &self.config.special_tokens.eos,
47 &self.config.special_tokens.pad,
48 &self.config.special_tokens.mask,
49 ];
50
51 for token in special {
52 self.vocab.insert(token.clone(), id);
53 self.id_to_token_map.insert(id, token.clone());
54 id += 1;
55 }
56
57 for byte in 0..=255u8 {
59 let token = format!("{byte:02x}");
60 if !self.vocab.contains_key(&token) {
61 self.vocab.insert(token.clone(), id);
62 self.id_to_token_map.insert(id, token);
63 id += 1;
64 }
65 }
66 }
67
68 #[cfg(test)]
70 fn get_pair_freqs(&self, tokenized: &[Vec<String>]) -> HashMap<(String, String), usize> {
71 let mut freqs = HashMap::new();
72
73 for tokens in tokenized {
74 for pair in tokens.windows(2) {
75 let key = (pair[0].clone(), pair[1].clone());
76 *freqs.entry(key).or_insert(0) += 1;
77 }
78 }
79
80 freqs
81 }
82
83 #[cfg(test)]
85 fn merge_pair(&self, tokenized: &mut [Vec<String>], pair: &(String, String), merged: &str) {
86 for tokens in tokenized.iter_mut() {
87 let mut i = 0;
88 while i < tokens.len().saturating_sub(1) {
89 if tokens[i] == pair.0 && tokens[i + 1] == pair.1 {
90 tokens[i] = merged.to_string();
91 tokens.remove(i + 1);
92 }
93 i += 1;
94 }
95 }
96 }
97
98 fn preprocess(&self, text: &str) -> String {
105 let normalized = match self.config.normalization {
106 Normalization::None => text.to_string(),
107 Normalization::NFC => text.nfc().collect(),
108 };
109 if self.config.lowercase {
110 normalized.to_lowercase()
111 } else {
112 normalized
113 }
114 }
115
116 fn to_bytes(&self, text: &str) -> Vec<String> {
118 text.as_bytes().iter().map(|b| format!("{b:02x}")).collect()
119 }
120
121 fn apply_merges(&self, mut tokens: Vec<String>) -> Vec<String> {
123 for (a, b) in &self.merges {
124 let merged = format!("{a}{b}");
125 let mut i = 0;
126 while i < tokens.len().saturating_sub(1) {
127 if &tokens[i] == a && &tokens[i + 1] == b {
128 tokens[i] = merged.clone();
129 tokens.remove(i + 1);
130 } else {
131 i += 1;
132 }
133 }
134 }
135 tokens
136 }
137
138 pub fn vocab(&self) -> &HashMap<String, TokenId> {
145 &self.vocab
146 }
147
148 pub fn merges(&self) -> &[(String, String)] {
154 &self.merges
155 }
156
157 pub fn save(&self, path: &str) -> Result<()> {
159 let json = serde_json::to_string_pretty(self)
160 .map_err(|e| TokenizerError::Serialization(e.to_string()))?;
161 std::fs::write(path, json)?;
162 Ok(())
163 }
164
165 pub fn load(path: &str) -> Result<Self> {
167 let json = std::fs::read_to_string(path)?;
168 serde_json::from_str(&json).map_err(|e| TokenizerError::Serialization(e.to_string()))
169 }
170
171 pub fn from_vocab_merges(
197 vocab_path: &str,
198 merges_path: &str,
199 config: TokenizerConfig,
200 ) -> Result<Self> {
201 let vocab_json = std::fs::read_to_string(vocab_path)?;
202 let vocab: HashMap<String, TokenId> = serde_json::from_str(&vocab_json)
203 .map_err(|e| TokenizerError::Serialization(e.to_string()))?;
204
205 let id_to_token_map: HashMap<TokenId, String> =
206 vocab.iter().map(|(tok, &id)| (id, tok.clone())).collect();
207
208 if id_to_token_map.len() != vocab.len() {
209 return Err(TokenizerError::Serialization(
210 "vocab.json contains duplicate token ids (collision detected after inverting map)"
211 .to_string(),
212 ));
213 }
214
215 let hex_byte_count =
233 (0u8..=255).map(|b| format!("{b:02x}")).filter(|hex| vocab.contains_key(hex)).count();
234 const MIN_HEX_BYTES: usize = 200;
235 if hex_byte_count < MIN_HEX_BYTES {
236 return Err(TokenizerError::Serialization(format!(
237 "FALSIFY-BPE-FORMAT-MISMATCH-001: vocab.json at {} contains \
238 only {hex_byte_count}/256 canonical hex-byte tokens (\"00\"..\"ff\"), \
239 below the {MIN_HEX_BYTES} threshold. aprender-train's BPETokenizer \
240 uses HEX-BYTE format internally (to_bytes emits \"64\" for byte 'd', \
241 etc.); loading a HuggingFace GPT-2 byte-level vocab (e.g., from \
242 `apr tokenize import-hf` of Qwen2/Llama2/Mistral, which use \
243 Ġ-prefix + raw chars) would silently produce 99.99%% `<unk>` \
244 tokens during encode (root cause of SHIP-TWO §60 val_loss=0.00081 \
245 anomaly). Fix scope: implement Ġ-prefix encoding path in \
246 BPETokenizer (multi-PR), OR use a different tokenizer for HF \
247 byte-level vocabs. For now, this fail-fast prevents silent corpus \
248 corruption. Tracking: PMAT-CODE-TOKENIZE-BPE-FORMAT-001.",
249 vocab_path
250 )));
251 }
252
253 let merges_text = std::fs::read_to_string(merges_path)?;
254 let mut merges: Vec<(String, String)> = Vec::new();
255 for (line_no, line) in merges_text.lines().enumerate() {
256 if line.is_empty() || line.starts_with("#") {
257 continue;
258 }
259 let mut parts = line.splitn(2, ' ');
260 let left = parts
261 .next()
262 .ok_or_else(|| {
263 TokenizerError::Serialization(format!(
264 "merges.txt line {}: missing left token",
265 line_no + 1
266 ))
267 })?
268 .to_string();
269 let right = parts
270 .next()
271 .ok_or_else(|| {
272 TokenizerError::Serialization(format!(
273 "merges.txt line {}: missing right token (expected '<left> <right>')",
274 line_no + 1
275 ))
276 })?
277 .to_string();
278
279 let merged = format!("{left}{right}");
280 if !vocab.contains_key(&merged) {
281 return Err(TokenizerError::Serialization(format!(
282 "merges.txt line {}: merged token {:?} not present in vocab.json",
283 line_no + 1,
284 merged
285 )));
286 }
287 merges.push((left, right));
288 }
289
290 Ok(Self { config, vocab, id_to_token_map, merges, trained: true })
291 }
292}
293
294impl Tokenizer for BPETokenizer {
295 fn train(&mut self, corpus: &[&str]) -> Result<()> {
296 train_fast(self, corpus)
297 }
298
299 fn encode(&self, text: &str) -> Result<Vec<TokenId>> {
300 if !self.trained {
301 return Err(TokenizerError::NotTrained);
302 }
303
304 let tokens = self.to_bytes(&self.preprocess(text));
305 let tokens = self.apply_merges(tokens);
306
307 let unk_id = *self
308 .vocab
309 .get(&self.config.special_tokens.unk)
310 .expect("UNK token must exist in trained vocabulary");
311
312 let ids: Vec<TokenId> =
313 tokens.iter().map(|t| *self.vocab.get(t).unwrap_or(&unk_id)).collect();
314
315 Ok(ids)
316 }
317
318 fn decode(&self, ids: &[TokenId]) -> Result<String> {
319 if !self.trained {
320 return Err(TokenizerError::NotTrained);
321 }
322
323 let mut hex_string = String::new();
324
325 for &id in ids {
326 if let Some(token) = self.id_to_token_map.get(&id) {
327 if token.starts_with('<') && token.ends_with('>') {
329 continue;
330 }
331 hex_string.push_str(token);
332 }
333 }
334
335 let bytes: Vec<u8> = (0..hex_string.len())
337 .step_by(2)
338 .filter_map(|i| {
339 if i + 2 <= hex_string.len() {
340 u8::from_str_radix(&hex_string[i..i + 2], 16).ok()
341 } else {
342 None
343 }
344 })
345 .collect();
346
347 String::from_utf8(bytes).map_err(|e| TokenizerError::Training(e.to_string()))
348 }
349
350 fn vocab_size(&self) -> usize {
351 self.vocab.len()
352 }
353
354 fn is_trained(&self) -> bool {
355 self.trained
356 }
357
358 fn id_to_token(&self, id: TokenId) -> Option<&str> {
359 self.id_to_token_map.get(&id).map(String::as_str)
360 }
361
362 fn token_to_id(&self, token: &str) -> Option<TokenId> {
363 self.vocab.get(token).copied()
364 }
365}
366
367#[derive(Clone, Eq, PartialEq)]
382struct HeapEntry {
383 count: i64,
384 pair: (TokenId, TokenId),
385}
386
387impl Ord for HeapEntry {
388 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
389 self.count.cmp(&other.count).then_with(|| other.pair.cmp(&self.pair))
394 }
395}
396
397impl PartialOrd for HeapEntry {
398 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
399 Some(self.cmp(other))
400 }
401}
402
403pub(crate) fn train_fast(tok: &mut BPETokenizer, corpus: &[&str]) -> Result<()> {
409 use std::collections::{BinaryHeap, HashMap, HashSet};
410 use std::time::Instant;
411
412 let start = Instant::now();
413 let target = tok.config.vocab_size;
414 let min_frequency = tok.config.min_frequency.max(1) as i64;
415
416 tok.init_vocab();
417
418 eprintln!("[bpe-setup] ingest start: {} docs", corpus.len());
419 use std::io::Write;
420 let _ = std::io::stderr().flush();
421
422 let t0 = Instant::now();
424 let mut word_counts: HashMap<Vec<TokenId>, u64> = HashMap::new();
425 for doc in corpus {
426 let text = tok.preprocess(doc);
427 let hex_tokens = tok.to_bytes(&text);
428 if hex_tokens.is_empty() {
429 continue;
430 }
431 let ids: Vec<TokenId> = hex_tokens
432 .iter()
433 .map(|t| *tok.vocab.get(t).expect("byte hex token must be in init_vocab"))
434 .collect();
435 *word_counts.entry(ids).or_insert(0) += 1;
436 }
437 eprintln!(
438 "[bpe-setup] ingest done: {} unique words in {:.1}s",
439 word_counts.len(),
440 t0.elapsed().as_secs_f64()
441 );
442 let _ = std::io::stderr().flush();
443
444 let mut words: Vec<(Vec<TokenId>, u64)> = word_counts.into_iter().collect();
445
446 let t1 = Instant::now();
448 let mut pair_counts: HashMap<(TokenId, TokenId), i64> = HashMap::new();
449 let mut pair_words: HashMap<(TokenId, TokenId), HashSet<usize>> = HashMap::new();
450 for (word_ix, (ids, mult)) in words.iter().enumerate() {
451 let m = *mult as i64;
452 for w in ids.windows(2) {
453 let p = (w[0], w[1]);
454 *pair_counts.entry(p).or_insert(0) += m;
455 pair_words.entry(p).or_default().insert(word_ix);
456 }
457 }
458 eprintln!(
459 "[bpe-setup] pair indexes: {} unique pairs in {:.1}s",
460 pair_counts.len(),
461 t1.elapsed().as_secs_f64()
462 );
463 let _ = std::io::stderr().flush();
464
465 let t2 = Instant::now();
467 let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(pair_counts.len());
468 for (p, c) in &pair_counts {
469 if *c > 0 {
470 heap.push(HeapEntry { count: *c, pair: *p });
471 }
472 }
473 eprintln!(
474 "[bpe-setup] heap seeded: {} entries in {:.1}s; entering merge loop",
475 heap.len(),
476 t2.elapsed().as_secs_f64()
477 );
478 let _ = std::io::stderr().flush();
479
480 let mut merges_emitted: usize = 0;
481
482 let mut old_pairs_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(512);
490 let mut new_pairs_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(512);
491 let mut pairs_touched_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(1 << 16);
492 let mut affected_buf: Vec<usize> = Vec::with_capacity(1 << 16);
493
494 while tok.vocab.len() < target {
495 let entry = match heap.pop() {
496 Some(e) => e,
497 None => break,
498 };
499 let current = *pair_counts.get(&entry.pair).unwrap_or(&0);
501 if current != entry.count {
502 continue;
503 }
504 if current < min_frequency {
505 break;
506 }
507
508 let (a, b) = entry.pair;
509 let a_str = tok.id_to_token_map[&a].clone();
510 let b_str = tok.id_to_token_map[&b].clone();
511 let merged_str = format!("{a_str}{b_str}");
512 let new_id: TokenId = tok.vocab.len() as TokenId;
513 tok.vocab.insert(merged_str.clone(), new_id);
514 tok.id_to_token_map.insert(new_id, merged_str);
515 tok.merges.push((a_str, b_str));
516 merges_emitted += 1;
517
518 affected_buf.clear();
521 if let Some(ws) = pair_words.get(&(a, b)) {
522 affected_buf.extend(ws.iter().copied());
523 }
524
525 pairs_touched_buf.clear();
532
533 for &word_ix in &affected_buf {
534 let (ids, mult) = &mut words[word_ix];
535 let m = *mult as i64;
536
537 old_pairs_buf.clear();
539 old_pairs_buf.extend(ids.windows(2).map(|w| (w[0], w[1])));
540
541 let mut write = 0;
545 let mut read = 0;
546 while read < ids.len() {
547 if read + 1 < ids.len() && ids[read] == a && ids[read + 1] == b {
548 ids[write] = new_id;
549 write += 1;
550 read += 2;
551 } else {
552 ids[write] = ids[read];
553 write += 1;
554 read += 1;
555 }
556 }
557 ids.truncate(write);
558
559 new_pairs_buf.clear();
561 new_pairs_buf.extend(ids.windows(2).map(|w| (w[0], w[1])));
562
563 for p in &old_pairs_buf {
565 *pair_counts.entry(*p).or_insert(0) -= m;
566 }
567 for p in &new_pairs_buf {
568 *pair_counts.entry(*p).or_insert(0) += m;
569 }
570
571 old_pairs_buf.sort_unstable();
577 old_pairs_buf.dedup();
578 new_pairs_buf.sort_unstable();
579 new_pairs_buf.dedup();
580
581 let mut i = 0usize;
582 let mut j = 0usize;
583 while i < old_pairs_buf.len() && j < new_pairs_buf.len() {
584 match old_pairs_buf[i].cmp(&new_pairs_buf[j]) {
585 std::cmp::Ordering::Less => {
586 if let Some(ws) = pair_words.get_mut(&old_pairs_buf[i]) {
587 ws.remove(&word_ix);
588 }
589 pairs_touched_buf.push(old_pairs_buf[i]);
590 i += 1;
591 }
592 std::cmp::Ordering::Greater => {
593 pair_words.entry(new_pairs_buf[j]).or_default().insert(word_ix);
594 pairs_touched_buf.push(new_pairs_buf[j]);
595 j += 1;
596 }
597 std::cmp::Ordering::Equal => {
598 pairs_touched_buf.push(old_pairs_buf[i]);
601 i += 1;
602 j += 1;
603 }
604 }
605 }
606 while i < old_pairs_buf.len() {
607 if let Some(ws) = pair_words.get_mut(&old_pairs_buf[i]) {
608 ws.remove(&word_ix);
609 }
610 pairs_touched_buf.push(old_pairs_buf[i]);
611 i += 1;
612 }
613 while j < new_pairs_buf.len() {
614 pair_words.entry(new_pairs_buf[j]).or_default().insert(word_ix);
615 pairs_touched_buf.push(new_pairs_buf[j]);
616 j += 1;
617 }
618 }
619
620 pairs_touched_buf.sort_unstable();
623 pairs_touched_buf.dedup();
624 for p in &pairs_touched_buf {
625 let c = *pair_counts.get(p).unwrap_or(&0);
626 if c > 0 {
627 heap.push(HeapEntry { count: c, pair: *p });
628 }
629 }
630
631 pair_counts.remove(&(a, b));
633 pair_words.remove(&(a, b));
634
635 if merges_emitted == 1 || merges_emitted.is_multiple_of(100) {
636 let elapsed = start.elapsed().as_secs_f64();
637 let top_count = heap.peek().map(|e| e.count).unwrap_or(0);
638 eprintln!(
639 "[bpe] merges={} vocab={} elapsed={:.1}s top_count={} heap={} pairs={}",
640 merges_emitted,
641 tok.vocab.len(),
642 elapsed,
643 top_count,
644 heap.len(),
645 pair_counts.len()
646 );
647 let _ = std::io::stderr().flush();
648 }
649 }
650
651 let elapsed = start.elapsed().as_secs_f64();
652 eprintln!(
653 "[bpe] DONE merges={} vocab={} elapsed={:.1}s",
654 merges_emitted,
655 tok.vocab.len(),
656 elapsed
657 );
658 let _ = std::io::stderr().flush();
659
660 tok.trained = true;
661 Ok(())
662}
663
664#[cfg(test)]
670#[doc(hidden)]
671pub(crate) fn train_naive_reference(tok: &mut BPETokenizer, corpus: &[&str]) -> Result<()> {
672 let target = tok.config.vocab_size;
673 let min_frequency = tok.config.min_frequency.max(1);
674
675 tok.init_vocab();
676
677 let mut tokenized: Vec<Vec<String>> =
678 corpus.iter().map(|s| tok.to_bytes(&tok.preprocess(s))).collect();
679
680 while tok.vocab.len() < target {
681 let freqs = tok.get_pair_freqs(&tokenized);
682
683 let mut best: Option<(usize, (TokenId, TokenId), (String, String))> = None;
685 for (pair_str, count) in &freqs {
686 if *count < min_frequency {
687 continue;
688 }
689 let left_id = *tok.vocab.get(&pair_str.0).expect("left must be in vocab");
690 let right_id = *tok.vocab.get(&pair_str.1).expect("right must be in vocab");
691 match &best {
692 None => best = Some((*count, (left_id, right_id), pair_str.clone())),
693 Some((bc, bp, _)) => {
694 if *count > *bc || (*count == *bc && (left_id, right_id) < *bp) {
695 best = Some((*count, (left_id, right_id), pair_str.clone()));
696 }
697 }
698 }
699 }
700
701 let (_count, _ids, pair_str) = match best {
702 Some(b) => b,
703 None => break,
704 };
705
706 let merged = format!("{}{}", pair_str.0, pair_str.1);
707 let new_id: TokenId = tok.vocab.len() as TokenId;
708 tok.vocab.insert(merged.clone(), new_id);
709 tok.id_to_token_map.insert(new_id, merged.clone());
710 tok.merges.push(pair_str.clone());
711 tok.merge_pair(&mut tokenized, &pair_str, &merged);
712 }
713
714 tok.trained = true;
715 Ok(())
716}
717
718#[cfg(test)]
719mod tests {
720 use super::*;
721
722 #[test]
723 fn test_bpe_new() {
724 let config = TokenizerConfig::bpe();
725 let tokenizer = BPETokenizer::new(config);
726 assert!(!tokenizer.is_trained());
727 }
728
729 #[test]
730 fn test_bpe_train() {
731 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
732 let mut tokenizer = BPETokenizer::new(config);
733
734 let corpus = vec!["hello hello", "hello world", "world hello"];
735 tokenizer.train(&corpus).expect("operation should succeed");
736
737 assert!(tokenizer.is_trained());
738 assert!(tokenizer.vocab_size() > 256); }
740
741 #[test]
742 fn test_bpe_encode_not_trained() {
743 let config = TokenizerConfig::bpe();
744 let tokenizer = BPETokenizer::new(config);
745
746 let result = tokenizer.encode("hello");
747 assert!(result.is_err());
748 }
749
750 #[test]
751 fn test_bpe_encode_decode() {
752 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
753 let mut tokenizer = BPETokenizer::new(config);
754
755 let corpus = vec!["hello world", "hello there"];
756 tokenizer.train(&corpus).expect("operation should succeed");
757
758 let text = "hello";
759 let encoded = tokenizer.encode(text).expect("encoding should succeed");
760 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
761
762 assert_eq!(decoded, text);
763 }
764
765 #[test]
766 fn test_bpe_lowercase() {
767 let config =
768 TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1).with_lowercase(true);
769 let mut tokenizer = BPETokenizer::new(config);
770
771 let corpus = vec!["Hello World"];
772 tokenizer.train(&corpus).expect("operation should succeed");
773
774 let encoded = tokenizer.encode("HELLO").expect("encoding should succeed");
775 let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
776
777 assert_eq!(decoded, "hello");
778 }
779
780 #[test]
781 fn test_bpe_id_to_token() {
782 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
783 let mut tokenizer = BPETokenizer::new(config);
784
785 let corpus = vec!["test"];
786 tokenizer.train(&corpus).expect("operation should succeed");
787
788 assert_eq!(tokenizer.id_to_token(0), Some("<unk>"));
790 }
791
792 #[test]
793 fn test_bpe_token_to_id() {
794 let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
795 let mut tokenizer = BPETokenizer::new(config);
796
797 let corpus = vec!["test"];
798 tokenizer.train(&corpus).expect("operation should succeed");
799
800 assert_eq!(tokenizer.token_to_id("<unk>"), Some(0));
801 }
802
803 #[test]
807 fn test_bpe_nfc_composed_decomposed_parity() {
808 let composed = "café"; let decomposed = "cafe\u{0301}"; let config = TokenizerConfig::bpe()
812 .with_vocab_size(300)
813 .with_min_frequency(1)
814 .with_normalization(Normalization::NFC);
815 let mut tokenizer = BPETokenizer::new(config);
816 tokenizer.train(&[composed]).expect("operation should succeed");
817
818 let ids_composed = tokenizer.encode(composed).expect("encoding should succeed");
819 let ids_decomposed = tokenizer.encode(decomposed).expect("encoding should succeed");
820
821 assert_eq!(
822 ids_composed, ids_decomposed,
823 "NFC must map composed and decomposed café to identical token IDs"
824 );
825
826 let decoded = tokenizer.decode(&ids_composed).expect("decoding should succeed");
827 assert_eq!(decoded, composed, "NFC round-trip must recover composed form");
828 }
829
830 #[test]
833 fn test_bpe_without_nfc_composed_decomposed_diverge() {
834 let composed = "café";
835 let decomposed = "cafe\u{0301}";
836
837 let config = TokenizerConfig::bpe()
838 .with_vocab_size(300)
839 .with_min_frequency(1)
840 .with_normalization(Normalization::None);
841 let mut tokenizer = BPETokenizer::new(config);
842 tokenizer.train(&[composed]).expect("operation should succeed");
843
844 let ids_composed = tokenizer.encode(composed).expect("encoding should succeed");
845 let ids_decomposed = tokenizer.encode(decomposed).expect("encoding should succeed");
846
847 assert_ne!(
848 ids_composed, ids_decomposed,
849 "Without NFC, composed and decomposed café MUST diverge (falsification witness for INV-TOK-003)"
850 );
851 }
852
853 #[test]
860 fn test_bpe_from_vocab_merges_roundtrip() {
861 use std::fmt::Write;
862 let config = TokenizerConfig::bpe()
863 .with_vocab_size(400)
864 .with_min_frequency(1)
865 .with_normalization(Normalization::NFC);
866 let mut original = BPETokenizer::new(config.clone());
867 let corpus = vec!["def hello():\n return 1\n", "def world():\n return 2\n"];
868 original.train(&corpus).expect("training should succeed");
869
870 let tmp = std::env::temp_dir().join(format!(
871 "bpe_roundtrip_{}_{}",
872 std::process::id(),
873 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()
874 ));
875 std::fs::create_dir_all(&tmp).unwrap();
876 let vocab_path = tmp.join("vocab.json");
877 let merges_path = tmp.join("merges.txt");
878
879 let mut entries: Vec<(&String, &TokenId)> = original.vocab().iter().collect();
880 entries.sort_by_key(|(_, id)| *id);
881 let ordered: serde_json::Map<String, serde_json::Value> = entries
882 .into_iter()
883 .map(|(k, v)| (k.clone(), serde_json::Value::Number((*v).into())))
884 .collect();
885 let vocab_json = serde_json::to_string_pretty(&ordered).unwrap();
886 std::fs::write(&vocab_path, vocab_json).unwrap();
887
888 let mut merges_content = String::from("#version: 0.2\n");
889 for (left, right) in original.merges() {
890 writeln!(merges_content, "{left} {right}").unwrap();
891 }
892 std::fs::write(&merges_path, merges_content).unwrap();
893
894 let reloaded = BPETokenizer::from_vocab_merges(
895 vocab_path.to_str().unwrap(),
896 merges_path.to_str().unwrap(),
897 config,
898 )
899 .expect("from_vocab_merges should succeed");
900
901 assert_eq!(reloaded.vocab_size(), original.vocab_size(), "reloaded vocab size must match");
902
903 for text in &corpus {
904 let original_ids = original.encode(text).expect("original encode");
905 let reloaded_ids = reloaded.encode(text).expect("reloaded encode");
906 assert_eq!(
907 original_ids, reloaded_ids,
908 "reloaded encoding must byte-equal original encoding for {text:?}"
909 );
910 }
911
912 let _ = std::fs::remove_dir_all(&tmp);
913 }
914
915 #[test]
919 fn test_bpe_from_vocab_merges_rejects_orphan_merge() {
920 let tmp = std::env::temp_dir().join(format!(
921 "bpe_orphan_{}_{}",
922 std::process::id(),
923 std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()
924 ));
925 std::fs::create_dir_all(&tmp).unwrap();
926 let vocab_path = tmp.join("vocab.json");
927 let merges_path = tmp.join("merges.txt");
928
929 let mut vocab_obj = serde_json::Map::new();
934 vocab_obj.insert("<unk>".to_string(), serde_json::json!(0));
935 vocab_obj.insert("aa".to_string(), serde_json::json!(1));
936 vocab_obj.insert("bb".to_string(), serde_json::json!(2));
937 for b in 0u32..256 {
938 vocab_obj.insert(format!("{b:02x}"), serde_json::json!(3 + b));
939 }
940 std::fs::write(&vocab_path, serde_json::to_string(&vocab_obj).unwrap()).unwrap();
941 std::fs::write(&merges_path, "#version: 0.2\naa bb\n").unwrap();
942
943 let result = BPETokenizer::from_vocab_merges(
944 vocab_path.to_str().unwrap(),
945 merges_path.to_str().unwrap(),
946 TokenizerConfig::bpe(),
947 );
948
949 assert!(
950 result.is_err(),
951 "from_vocab_merges must reject merges.txt with merged token not in vocab.json"
952 );
953 let err_msg = format!("{:?}", result.unwrap_err());
954 assert!(
955 err_msg.contains("aabb"),
956 "error should name the offending merged token, got: {err_msg}"
957 );
958
959 let _ = std::fs::remove_dir_all(&tmp);
960 }
961
962 fn synthetic_python_corpus(n_docs: usize) -> Vec<String> {
964 let templates: &[&str] = &[
965 "def fn_{i}(x):\n return x * {i}\n",
966 "class C_{i}:\n def __init__(self):\n self.x = {i}\n",
967 "for i in range({i}):\n print(i * {i})\n",
968 "def add_{i}(a, b):\n return a + b + {i}\n",
969 "import math\nprint(math.sqrt({i}))\n",
970 "if x == {i}:\n return True\nelse:\n return False\n",
971 "xs = [{i}, {i}, {i}]\nfor x in xs:\n print(x)\n",
972 "def process_{i}(data):\n result = []\n for item in data:\n result.append(item + {i})\n return result\n",
973 ];
974 (0..n_docs).map(|i| templates[i % templates.len()].replace("{i}", &i.to_string())).collect()
975 }
976
977 #[test]
979 fn bpe_fast_vs_naive_parity() {
980 let config = TokenizerConfig::bpe()
981 .with_vocab_size(512)
982 .with_min_frequency(1)
983 .with_normalization(Normalization::NFC);
984
985 let corpus_owned = synthetic_python_corpus(20);
986 let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
987
988 let mut fast = BPETokenizer::new(config.clone());
989 super::train_fast(&mut fast, &corpus).expect("fast train should succeed");
990
991 let mut naive = BPETokenizer::new(config);
992 super::train_naive_reference(&mut naive, &corpus).expect("naive train should succeed");
993
994 assert_eq!(
995 fast.vocab_size(),
996 naive.vocab_size(),
997 "vocab sizes must match between fast and naive"
998 );
999 assert_eq!(fast.merges(), naive.merges(), "merge sequence must be identical");
1000
1001 let mut fast_entries: Vec<(&String, &TokenId)> = fast.vocab().iter().collect();
1002 let mut naive_entries: Vec<(&String, &TokenId)> = naive.vocab().iter().collect();
1003 fast_entries.sort_by_key(|(_, id)| *id);
1004 naive_entries.sort_by_key(|(_, id)| *id);
1005 assert_eq!(
1006 fast_entries, naive_entries,
1007 "vocab (id → token) must be identical between fast and naive"
1008 );
1009 }
1010
1011 #[test]
1013 fn bpe_fast_is_deterministic() {
1014 let config = TokenizerConfig::bpe()
1015 .with_vocab_size(400)
1016 .with_min_frequency(1)
1017 .with_normalization(Normalization::NFC);
1018
1019 let corpus_owned = synthetic_python_corpus(15);
1020 let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
1021
1022 let mut a = BPETokenizer::new(config.clone());
1023 super::train_fast(&mut a, &corpus).expect("run A");
1024 let mut b = BPETokenizer::new(config);
1025 super::train_fast(&mut b, &corpus).expect("run B");
1026
1027 assert_eq!(a.merges(), b.merges(), "merges must be byte-identical across runs");
1028 assert_eq!(a.vocab_size(), b.vocab_size(), "vocab size must match");
1029
1030 let mut a_entries: Vec<(&String, &TokenId)> = a.vocab().iter().collect();
1031 let mut b_entries: Vec<(&String, &TokenId)> = b.vocab().iter().collect();
1032 a_entries.sort_by_key(|(_, id)| *id);
1033 b_entries.sort_by_key(|(_, id)| *id);
1034 assert_eq!(a_entries, b_entries, "vocab map must be byte-identical across runs");
1035 }
1036
1037 #[test]
1049 fn bpe_fast_meets_1_5x_parity_replacement_rule() {
1050 use std::time::Instant;
1051
1052 let config = TokenizerConfig::bpe()
1053 .with_vocab_size(2048)
1054 .with_min_frequency(1)
1055 .with_normalization(Normalization::NFC);
1056
1057 let corpus_owned = synthetic_python_corpus(500);
1058 let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
1059
1060 let mut naive = BPETokenizer::new(config.clone());
1061 let t0 = Instant::now();
1062 super::train_naive_reference(&mut naive, &corpus).expect("naive train");
1063 let naive_secs = t0.elapsed().as_secs_f64();
1064
1065 let mut fast = BPETokenizer::new(config);
1066 let t0 = Instant::now();
1067 super::train_fast(&mut fast, &corpus).expect("fast train");
1068 let fast_secs = t0.elapsed().as_secs_f64();
1069
1070 let ratio = naive_secs / fast_secs;
1071 eprintln!(
1072 "[bpe-speedup] naive={naive_secs:.3}s fast={fast_secs:.3}s ratio={ratio:.2}× \
1073 vocab_naive={} vocab_fast={}",
1074 naive.vocab_size(),
1075 fast.vocab_size()
1076 );
1077
1078 assert_eq!(
1080 fast.merges(),
1081 naive.merges(),
1082 "at perf-workload scale, fast and naive merges MUST still match"
1083 );
1084
1085 if cfg!(debug_assertions) {
1086 assert!(
1089 fast_secs < naive_secs * 1.5,
1090 "even in debug, fast must not be dramatically slower than naive \
1091 (ratio={ratio:.2}×)"
1092 );
1093 } else {
1094 assert!(
1095 ratio >= 1.5,
1096 "org policy: replacement must be ≥1.5× faster than the replaced \
1097 algorithm — got {ratio:.2}× (naive={naive_secs:.3}s, fast={fast_secs:.3}s)"
1098 );
1099 }
1100 }
1101}
1102
1103#[cfg(test)]
1104mod property_tests {
1105 use super::*;
1106 use proptest::prelude::*;
1107
1108 proptest! {
1109 #![proptest_config(ProptestConfig::with_cases(50))]
1110
1111 #[test]
1112 fn prop_bpe_encode_produces_valid_ids(text in "[a-zA-Z ]{1,20}") {
1113 let config = TokenizerConfig::bpe()
1114 .with_vocab_size(300)
1115 .with_min_frequency(1);
1116 let mut tokenizer = BPETokenizer::new(config);
1117 tokenizer.train(&[&text]).expect("operation should succeed");
1118
1119 let encoded = tokenizer.encode(&text).expect("encoding should succeed");
1120
1121 for id in encoded {
1122 prop_assert!(tokenizer.id_to_token(id).is_some());
1123 }
1124 }
1125
1126 #[test]
1127 fn prop_vocab_size_bounded(target_size in 261usize..500) {
1128 let config = TokenizerConfig::bpe()
1129 .with_vocab_size(target_size)
1130 .with_min_frequency(1);
1131 let mut tokenizer = BPETokenizer::new(config);
1132
1133 let corpus = vec!["hello world hello world test test"];
1134 tokenizer.train(&corpus).expect("operation should succeed");
1135
1136 prop_assert!(tokenizer.vocab_size() <= target_size);
1137 }
1138 }
1139
1140 #[test]
1182 fn falsify_bpe_format_mismatch_gpt2_vocab_load_fails_fast() {
1183 let tmp = tempfile::TempDir::new().expect("tempdir");
1191 let vocab_path = tmp.path().join("vocab.json");
1192 let merges_path = tmp.path().join("merges.txt");
1193
1194 let mut vocab_obj = serde_json::Map::new();
1196 vocab_obj.insert("<unk>".to_string(), serde_json::json!(0));
1197 for (i, ch) in "abcdefghijklmnopqrstuvwxyz0123456789()[]{}".chars().enumerate() {
1198 vocab_obj.insert(ch.to_string(), serde_json::json!(i + 1));
1199 }
1200 for (i, word) in ["Ġdef", "Ġreturn", "Ġfor", "Ġif"].iter().enumerate() {
1202 vocab_obj.insert((*word).to_string(), serde_json::json!(100 + i));
1203 }
1204 std::fs::write(&vocab_path, serde_json::to_string(&vocab_obj).unwrap())
1205 .expect("write vocab");
1206 std::fs::write(&merges_path, "#version: 0.2\n").expect("write merges");
1208
1209 let result = BPETokenizer::from_vocab_merges(
1210 vocab_path.to_str().unwrap(),
1211 merges_path.to_str().unwrap(),
1212 TokenizerConfig::bpe(),
1213 );
1214
1215 assert!(
1216 result.is_err(),
1217 "FALSIFY-BPE-FORMAT-MISMATCH-001 (load-time fail-fast): \
1218 from_vocab_merges accepted a GPT-2 byte-level vocab.json \
1219 that does NOT contain hex-byte tokens. Pre-this-fix, this \
1220 load succeeded silently and subsequent encode() calls \
1221 produced 100% `<unk>` tokens — the root cause of SHIP-TWO \
1222 §60's val_loss=0.00081 anomaly (shards became 99.99% \
1223 `<unk>` from Qwen vocab). The load MUST refuse so encode-\
1224 corpus cannot silently corrupt the corpus."
1225 );
1226 let err_msg = format!("{:?}", result.unwrap_err());
1227 assert!(
1228 err_msg.contains("FALSIFY-BPE-FORMAT-MISMATCH-001"),
1229 "Err message MUST cite the falsifier id (auditability): {err_msg}"
1230 );
1231 assert!(
1232 err_msg.contains("hex-byte"),
1233 "Err message MUST mention the canonical 'hex-byte' format \
1234 so operators recognize the cause: {err_msg}"
1235 );
1236 assert!(
1237 err_msg.contains("apr tokenize import-hf"),
1238 "Err message MUST name `apr tokenize import-hf` so operators \
1239 know which command produces the incompatible vocab format: \
1240 {err_msg}"
1241 );
1242 }
1243}