Skip to main content

entrenar/tokenizer/
bpe.rs

1//! BPE (Byte Pair Encoding) tokenizer implementation.
2
3use std::collections::HashMap;
4
5use serde::{Deserialize, Serialize};
6use unicode_normalization::UnicodeNormalization;
7
8use super::config::{Normalization, TokenizerConfig};
9use super::error::{Result, TokenizerError};
10use super::traits::{TokenId, Tokenizer};
11
12/// BPE (Byte Pair Encoding) tokenizer
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct BPETokenizer {
15    config: TokenizerConfig,
16    /// Token to ID mapping
17    vocab: HashMap<String, TokenId>,
18    /// ID to token mapping
19    id_to_token_map: HashMap<TokenId, String>,
20    /// Merge rules (pair -> merged token)
21    merges: Vec<(String, String)>,
22    /// Whether the tokenizer is trained
23    trained: bool,
24}
25
26impl BPETokenizer {
27    /// Create a new BPE tokenizer
28    pub fn new(config: TokenizerConfig) -> Self {
29        Self {
30            config,
31            vocab: HashMap::new(),
32            id_to_token_map: HashMap::new(),
33            merges: Vec::new(),
34            trained: false,
35        }
36    }
37
38    /// Initialize vocabulary with special tokens and bytes
39    fn init_vocab(&mut self) {
40        let mut id: TokenId = 0;
41
42        // Add special tokens
43        let special = [
44            &self.config.special_tokens.unk,
45            &self.config.special_tokens.bos,
46            &self.config.special_tokens.eos,
47            &self.config.special_tokens.pad,
48            &self.config.special_tokens.mask,
49        ];
50
51        for token in special {
52            self.vocab.insert(token.clone(), id);
53            self.id_to_token_map.insert(id, token.clone());
54            id += 1;
55        }
56
57        // Add all single bytes as base vocabulary
58        for byte in 0..=255u8 {
59            let token = format!("{byte:02x}");
60            if !self.vocab.contains_key(&token) {
61                self.vocab.insert(token.clone(), id);
62                self.id_to_token_map.insert(id, token);
63                id += 1;
64            }
65        }
66    }
67
68    /// Get pair frequencies from tokenized corpus
69    #[cfg(test)]
70    fn get_pair_freqs(&self, tokenized: &[Vec<String>]) -> HashMap<(String, String), usize> {
71        let mut freqs = HashMap::new();
72
73        for tokens in tokenized {
74            for pair in tokens.windows(2) {
75                let key = (pair[0].clone(), pair[1].clone());
76                *freqs.entry(key).or_insert(0) += 1;
77            }
78        }
79
80        freqs
81    }
82
83    /// Merge the most frequent pair
84    #[cfg(test)]
85    fn merge_pair(&self, tokenized: &mut [Vec<String>], pair: &(String, String), merged: &str) {
86        for tokens in tokenized.iter_mut() {
87            let mut i = 0;
88            while i < tokens.len().saturating_sub(1) {
89                if tokens[i] == pair.0 && tokens[i + 1] == pair.1 {
90                    tokens[i] = merged.to_string();
91                    tokens.remove(i + 1);
92                }
93                i += 1;
94            }
95        }
96    }
97
98    /// Apply the configured Unicode normalization, then optional lowercasing.
99    ///
100    /// NFC is applied BEFORE lowercasing because `char::to_lowercase()` is not
101    /// closed over non-NFC input for every grapheme — normalizing first makes
102    /// the pipeline deterministic for composed/decomposed variants of the
103    /// same visible text.
104    fn preprocess(&self, text: &str) -> String {
105        let normalized = match self.config.normalization {
106            Normalization::None => text.to_string(),
107            Normalization::NFC => text.nfc().collect(),
108        };
109        if self.config.lowercase {
110            normalized.to_lowercase()
111        } else {
112            normalized
113        }
114    }
115
116    /// Tokenize text to bytes (initial tokenization)
117    fn to_bytes(&self, text: &str) -> Vec<String> {
118        text.as_bytes().iter().map(|b| format!("{b:02x}")).collect()
119    }
120
121    /// Apply all learned merges
122    fn apply_merges(&self, mut tokens: Vec<String>) -> Vec<String> {
123        for (a, b) in &self.merges {
124            let merged = format!("{a}{b}");
125            let mut i = 0;
126            while i < tokens.len().saturating_sub(1) {
127                if &tokens[i] == a && &tokens[i + 1] == b {
128                    tokens[i] = merged.clone();
129                    tokens.remove(i + 1);
130                } else {
131                    i += 1;
132                }
133            }
134        }
135        tokens
136    }
137
138    /// Borrow the learned `token → id` vocabulary map.
139    ///
140    /// Exposed so callers (e.g. `apr tokenize train`) can emit the HuggingFace
141    /// `vocab.json` artifact mandated by `contracts/tokenizer-bpe-v1.yaml` without
142    /// serializing the whole `BPETokenizer` struct. Read-only by design — training
143    /// and encoding continue to own the `HashMap`.
144    pub fn vocab(&self) -> &HashMap<String, TokenId> {
145        &self.vocab
146    }
147
148    /// Borrow the ordered list of learned merge rules (`(left, right)` pairs in
149    /// merge order).
150    ///
151    /// Exposed so callers can write the HuggingFace `merges.txt` artifact. The
152    /// order is load-bearing: `merges.txt` consumers apply pairs top-to-bottom.
153    pub fn merges(&self) -> &[(String, String)] {
154        &self.merges
155    }
156
157    /// Save tokenizer to file
158    pub fn save(&self, path: &str) -> Result<()> {
159        let json = serde_json::to_string_pretty(self)
160            .map_err(|e| TokenizerError::Serialization(e.to_string()))?;
161        std::fs::write(path, json)?;
162        Ok(())
163    }
164
165    /// Load tokenizer from file
166    pub fn load(path: &str) -> Result<Self> {
167        let json = std::fs::read_to_string(path)?;
168        serde_json::from_str(&json).map_err(|e| TokenizerError::Serialization(e.to_string()))
169    }
170
171    /// Reconstruct a trained `BPETokenizer` from the HuggingFace-style pair of
172    /// `vocab.json` + `merges.txt` emitted by `apr tokenize train`.
173    ///
174    /// # Format
175    /// - `vocab.json`: JSON object mapping token string → token id (u32). Order
176    ///   is informational; the loader inverts the map to build `id_to_token`.
177    /// - `merges.txt`: leading `#version: 0.2\n` header line, then one merge per
178    ///   line in apply order. Each line is `"<left> <right>"` with a single
179    ///   ASCII space separator (tokens in the aprender-train hex
180    ///   representation never contain spaces, so space-split is unambiguous).
181    ///
182    /// # Parameters
183    /// - `vocab_path`: path to `vocab.json`
184    /// - `merges_path`: path to `merges.txt`
185    /// - `config`: caller-supplied config (normalization, special tokens, etc.)
186    ///   since those fields are not recorded in the HF-style files. MUST match
187    ///   the config used at training time — wrong normalization here produces
188    ///   silently-wrong encodings.
189    ///
190    /// # Invariants
191    /// - C-PRETOK-BIN INV-PRETOK-001: every loaded vocab id < returned
192    ///   tokenizer's `vocab_size()`.
193    /// - Every merge's `(left, right)` concatenation is present in the loaded
194    ///   vocab (otherwise applying the merge during encode would produce a
195    ///   token the vocab cannot resolve). Enforced; mismatch returns an error.
196    pub fn from_vocab_merges(
197        vocab_path: &str,
198        merges_path: &str,
199        config: TokenizerConfig,
200    ) -> Result<Self> {
201        let vocab_json = std::fs::read_to_string(vocab_path)?;
202        let vocab: HashMap<String, TokenId> = serde_json::from_str(&vocab_json)
203            .map_err(|e| TokenizerError::Serialization(e.to_string()))?;
204
205        let id_to_token_map: HashMap<TokenId, String> =
206            vocab.iter().map(|(tok, &id)| (id, tok.clone())).collect();
207
208        if id_to_token_map.len() != vocab.len() {
209            return Err(TokenizerError::Serialization(
210                "vocab.json contains duplicate token ids (collision detected after inverting map)"
211                    .to_string(),
212            ));
213        }
214
215        let merges_text = std::fs::read_to_string(merges_path)?;
216        let mut merges: Vec<(String, String)> = Vec::new();
217        for (line_no, line) in merges_text.lines().enumerate() {
218            if line.is_empty() || line.starts_with("#") {
219                continue;
220            }
221            let mut parts = line.splitn(2, ' ');
222            let left = parts
223                .next()
224                .ok_or_else(|| {
225                    TokenizerError::Serialization(format!(
226                        "merges.txt line {}: missing left token",
227                        line_no + 1
228                    ))
229                })?
230                .to_string();
231            let right = parts
232                .next()
233                .ok_or_else(|| {
234                    TokenizerError::Serialization(format!(
235                        "merges.txt line {}: missing right token (expected '<left> <right>')",
236                        line_no + 1
237                    ))
238                })?
239                .to_string();
240
241            let merged = format!("{left}{right}");
242            if !vocab.contains_key(&merged) {
243                return Err(TokenizerError::Serialization(format!(
244                    "merges.txt line {}: merged token {:?} not present in vocab.json",
245                    line_no + 1,
246                    merged
247                )));
248            }
249            merges.push((left, right));
250        }
251
252        Ok(Self { config, vocab, id_to_token_map, merges, trained: true })
253    }
254}
255
256impl Tokenizer for BPETokenizer {
257    fn train(&mut self, corpus: &[&str]) -> Result<()> {
258        train_fast(self, corpus)
259    }
260
261    fn encode(&self, text: &str) -> Result<Vec<TokenId>> {
262        if !self.trained {
263            return Err(TokenizerError::NotTrained);
264        }
265
266        let tokens = self.to_bytes(&self.preprocess(text));
267        let tokens = self.apply_merges(tokens);
268
269        let unk_id = *self
270            .vocab
271            .get(&self.config.special_tokens.unk)
272            .expect("UNK token must exist in trained vocabulary");
273
274        let ids: Vec<TokenId> =
275            tokens.iter().map(|t| *self.vocab.get(t).unwrap_or(&unk_id)).collect();
276
277        Ok(ids)
278    }
279
280    fn decode(&self, ids: &[TokenId]) -> Result<String> {
281        if !self.trained {
282            return Err(TokenizerError::NotTrained);
283        }
284
285        let mut hex_string = String::new();
286
287        for &id in ids {
288            if let Some(token) = self.id_to_token_map.get(&id) {
289                // Skip special tokens
290                if token.starts_with('<') && token.ends_with('>') {
291                    continue;
292                }
293                hex_string.push_str(token);
294            }
295        }
296
297        // Convert hex string back to bytes
298        let bytes: Vec<u8> = (0..hex_string.len())
299            .step_by(2)
300            .filter_map(|i| {
301                if i + 2 <= hex_string.len() {
302                    u8::from_str_radix(&hex_string[i..i + 2], 16).ok()
303                } else {
304                    None
305                }
306            })
307            .collect();
308
309        String::from_utf8(bytes).map_err(|e| TokenizerError::Training(e.to_string()))
310    }
311
312    fn vocab_size(&self) -> usize {
313        self.vocab.len()
314    }
315
316    fn is_trained(&self) -> bool {
317        self.trained
318    }
319
320    fn id_to_token(&self, id: TokenId) -> Option<&str> {
321        self.id_to_token_map.get(&id).map(String::as_str)
322    }
323
324    fn token_to_id(&self, token: &str) -> Option<TokenId> {
325        self.vocab.get(token).copied()
326    }
327}
328
329// ─────────────────────────────────────────────────────────────
330// Priority-queue + inverted-index BPE training.
331//
332// Contract: contracts/bpe-training-perf-v1.yaml (v1.1.0).
333//   - Algorithm: Sennrich 2016 / HuggingFace tokenizers style.
334//   - Tie-breaker: lex-min on (left_id, right_id) for cross-run
335//     determinism (INV-BPE-006, FALSIFY-BPE-TRAIN-PERF-002).
336//   - Complexity: O((V + E) log V) amortized where E is total pair-
337//     count updates. Replaces a naive O(V · N · L) loop that did not
338//     complete a 50K-vocab × 127 MB training run in 25 h 40 m.
339//   - Observability: periodic `[bpe]` stderr reports every 1 000
340//     merges (FALSIFY-BPE-TRAIN-PERF-004).
341// ─────────────────────────────────────────────────────────────
342
343#[derive(Clone, Eq, PartialEq)]
344struct HeapEntry {
345    count: i64,
346    pair: (TokenId, TokenId),
347}
348
349impl Ord for HeapEntry {
350    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
351        // Primary: higher count wins (BinaryHeap is a max-heap).
352        // Tie-breaker: smaller (left_id, right_id) tuple wins — invert
353        // the pair comparison so the smaller pair is "greater" and
354        // therefore popped first.
355        self.count.cmp(&other.count).then_with(|| other.pair.cmp(&self.pair))
356    }
357}
358
359impl PartialOrd for HeapEntry {
360    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
361        Some(self.cmp(other))
362    }
363}
364
365/// Fast priority-queue + inverted-index BPE training.
366///
367/// Invoked via `<BPETokenizer as Tokenizer>::train`. Exposed as a free
368/// function so tests can compare against `train_naive_reference` for
369/// FALSIFY-BPE-TRAIN-PERF-001 (parity) and -005 (speedup).
370pub(crate) fn train_fast(tok: &mut BPETokenizer, corpus: &[&str]) -> Result<()> {
371    use std::collections::{BinaryHeap, HashMap, HashSet};
372    use std::time::Instant;
373
374    let start = Instant::now();
375    let target = tok.config.vocab_size;
376    let min_frequency = tok.config.min_frequency.max(1) as i64;
377
378    tok.init_vocab();
379
380    eprintln!("[bpe-setup] ingest start: {} docs", corpus.len());
381    use std::io::Write;
382    let _ = std::io::stderr().flush();
383
384    // Byte-tokenize every document, fold duplicates into (Vec<TokenId>, multiplicity) pairs.
385    let t0 = Instant::now();
386    let mut word_counts: HashMap<Vec<TokenId>, u64> = HashMap::new();
387    for doc in corpus {
388        let text = tok.preprocess(doc);
389        let hex_tokens = tok.to_bytes(&text);
390        if hex_tokens.is_empty() {
391            continue;
392        }
393        let ids: Vec<TokenId> = hex_tokens
394            .iter()
395            .map(|t| *tok.vocab.get(t).expect("byte hex token must be in init_vocab"))
396            .collect();
397        *word_counts.entry(ids).or_insert(0) += 1;
398    }
399    eprintln!(
400        "[bpe-setup] ingest done: {} unique words in {:.1}s",
401        word_counts.len(),
402        t0.elapsed().as_secs_f64()
403    );
404    let _ = std::io::stderr().flush();
405
406    let mut words: Vec<(Vec<TokenId>, u64)> = word_counts.into_iter().collect();
407
408    // Build pair indexes.
409    let t1 = Instant::now();
410    let mut pair_counts: HashMap<(TokenId, TokenId), i64> = HashMap::new();
411    let mut pair_words: HashMap<(TokenId, TokenId), HashSet<usize>> = HashMap::new();
412    for (word_ix, (ids, mult)) in words.iter().enumerate() {
413        let m = *mult as i64;
414        for w in ids.windows(2) {
415            let p = (w[0], w[1]);
416            *pair_counts.entry(p).or_insert(0) += m;
417            pair_words.entry(p).or_default().insert(word_ix);
418        }
419    }
420    eprintln!(
421        "[bpe-setup] pair indexes: {} unique pairs in {:.1}s",
422        pair_counts.len(),
423        t1.elapsed().as_secs_f64()
424    );
425    let _ = std::io::stderr().flush();
426
427    // Seed heap.
428    let t2 = Instant::now();
429    let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(pair_counts.len());
430    for (p, c) in &pair_counts {
431        if *c > 0 {
432            heap.push(HeapEntry { count: *c, pair: *p });
433        }
434    }
435    eprintln!(
436        "[bpe-setup] heap seeded: {} entries in {:.1}s; entering merge loop",
437        heap.len(),
438        t2.elapsed().as_secs_f64()
439    );
440    let _ = std::io::stderr().flush();
441
442    let mut merges_emitted: usize = 0;
443
444    // Scratch buffers hoisted OUT of the per-word loop. Each early common-pair
445    // merge affects ~100K words; allocating transient containers per word cost
446    // ~400K mallocs per merge (observed as 1+ s/merge at merge 400, PID
447    // 1568187, 2026-04-20). HashSet reuse was tried and FALSIFIED (27% slower
448    // on PID 1638021, 2026-04-20) because `HashSet::clear()` walks the backing
449    // array (up to 4096 slots) per call. Vec + sort_unstable + merge-pass for
450    // set ops wins on (u32, u32) keys: cheaper to sort than to hash.
451    let mut old_pairs_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(512);
452    let mut new_pairs_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(512);
453    let mut pairs_touched_buf: Vec<(TokenId, TokenId)> = Vec::with_capacity(1 << 16);
454    let mut affected_buf: Vec<usize> = Vec::with_capacity(1 << 16);
455
456    while tok.vocab.len() < target {
457        let entry = match heap.pop() {
458            Some(e) => e,
459            None => break,
460        };
461        // Drop stale entries — a pair's count was updated after this entry was pushed.
462        let current = *pair_counts.get(&entry.pair).unwrap_or(&0);
463        if current != entry.count {
464            continue;
465        }
466        if current < min_frequency {
467            break;
468        }
469
470        let (a, b) = entry.pair;
471        let a_str = tok.id_to_token_map[&a].clone();
472        let b_str = tok.id_to_token_map[&b].clone();
473        let merged_str = format!("{a_str}{b_str}");
474        let new_id: TokenId = tok.vocab.len() as TokenId;
475        tok.vocab.insert(merged_str.clone(), new_id);
476        tok.id_to_token_map.insert(new_id, merged_str);
477        tok.merges.push((a_str, b_str));
478        merges_emitted += 1;
479
480        // Apply merge in every word containing (a, b). Snapshot the set first so we
481        // can mutate pair_words during the sweep.
482        affected_buf.clear();
483        if let Some(ws) = pair_words.get(&(a, b)) {
484            affected_buf.extend(ws.iter().copied());
485        }
486
487        // Aggregate the set of pairs whose count changed across ALL affected
488        // words. Pushing heap entries once per pair per merge (rather than
489        // once per (pair, word) tuple) is load-bearing: early merges of
490        // common byte-pairs can touch 10⁵+ words, and pushing per-word
491        // produced 10⁸+ stale heap entries / merge, OOM-killing the run
492        // (observed 2026-04-20, PID 1387417 hit 29 GB RSS).
493        pairs_touched_buf.clear();
494
495        for &word_ix in &affected_buf {
496            let (ids, mult) = &mut words[word_ix];
497            let m = *mult as i64;
498
499            // Collect old pairs into reused buffer (zero alloc).
500            old_pairs_buf.clear();
501            old_pairs_buf.extend(ids.windows(2).map(|w| (w[0], w[1])));
502
503            // In-place greedy left-to-right merge of (a, b) → new_id.
504            // Since the merge only shrinks the Vec, read ≥ write holds, so
505            // the single-buffer two-pointer walk is safe.
506            let mut write = 0;
507            let mut read = 0;
508            while read < ids.len() {
509                if read + 1 < ids.len() && ids[read] == a && ids[read + 1] == b {
510                    ids[write] = new_id;
511                    write += 1;
512                    read += 2;
513                } else {
514                    ids[write] = ids[read];
515                    write += 1;
516                    read += 1;
517                }
518            }
519            ids.truncate(write);
520
521            // Collect new pairs into reused buffer.
522            new_pairs_buf.clear();
523            new_pairs_buf.extend(ids.windows(2).map(|w| (w[0], w[1])));
524
525            // Multiset deltas on pair_counts (duplicates matter for counts).
526            for p in &old_pairs_buf {
527                *pair_counts.entry(*p).or_insert(0) -= m;
528            }
529            for p in &new_pairs_buf {
530                *pair_counts.entry(*p).or_insert(0) += m;
531            }
532
533            // Set deltas on pair_words via sort + linear merge-pass.
534            // HashSet alternative was FALSIFIED (27% slower) because
535            // HashSet::clear walks the backing array (4096 slots) per call.
536            // sort_unstable on (u32, u32) is branch-predictable + LLVM
537            // auto-vectorizes; no hashing cost for POD keys.
538            old_pairs_buf.sort_unstable();
539            old_pairs_buf.dedup();
540            new_pairs_buf.sort_unstable();
541            new_pairs_buf.dedup();
542
543            let mut i = 0usize;
544            let mut j = 0usize;
545            while i < old_pairs_buf.len() && j < new_pairs_buf.len() {
546                match old_pairs_buf[i].cmp(&new_pairs_buf[j]) {
547                    std::cmp::Ordering::Less => {
548                        if let Some(ws) = pair_words.get_mut(&old_pairs_buf[i]) {
549                            ws.remove(&word_ix);
550                        }
551                        pairs_touched_buf.push(old_pairs_buf[i]);
552                        i += 1;
553                    }
554                    std::cmp::Ordering::Greater => {
555                        pair_words.entry(new_pairs_buf[j]).or_default().insert(word_ix);
556                        pairs_touched_buf.push(new_pairs_buf[j]);
557                        j += 1;
558                    }
559                    std::cmp::Ordering::Equal => {
560                        // Present in both — no pair_words delta, but still
561                        // touched (multiplicity / top-pair ordering may shift).
562                        pairs_touched_buf.push(old_pairs_buf[i]);
563                        i += 1;
564                        j += 1;
565                    }
566                }
567            }
568            while i < old_pairs_buf.len() {
569                if let Some(ws) = pair_words.get_mut(&old_pairs_buf[i]) {
570                    ws.remove(&word_ix);
571                }
572                pairs_touched_buf.push(old_pairs_buf[i]);
573                i += 1;
574            }
575            while j < new_pairs_buf.len() {
576                pair_words.entry(new_pairs_buf[j]).or_default().insert(word_ix);
577                pairs_touched_buf.push(new_pairs_buf[j]);
578                j += 1;
579            }
580        }
581
582        // Dedup aggregated pairs across all affected words, then push ONE
583        // refreshed heap entry per affected pair (not per word).
584        pairs_touched_buf.sort_unstable();
585        pairs_touched_buf.dedup();
586        for p in &pairs_touched_buf {
587            let c = *pair_counts.get(p).unwrap_or(&0);
588            if c > 0 {
589                heap.push(HeapEntry { count: c, pair: *p });
590            }
591        }
592
593        // The merged pair itself is fully consumed — purge its entries.
594        pair_counts.remove(&(a, b));
595        pair_words.remove(&(a, b));
596
597        if merges_emitted == 1 || merges_emitted.is_multiple_of(100) {
598            let elapsed = start.elapsed().as_secs_f64();
599            let top_count = heap.peek().map(|e| e.count).unwrap_or(0);
600            eprintln!(
601                "[bpe] merges={} vocab={} elapsed={:.1}s top_count={} heap={} pairs={}",
602                merges_emitted,
603                tok.vocab.len(),
604                elapsed,
605                top_count,
606                heap.len(),
607                pair_counts.len()
608            );
609            let _ = std::io::stderr().flush();
610        }
611    }
612
613    let elapsed = start.elapsed().as_secs_f64();
614    eprintln!(
615        "[bpe] DONE merges={} vocab={} elapsed={:.1}s",
616        merges_emitted,
617        tok.vocab.len(),
618        elapsed
619    );
620    let _ = std::io::stderr().flush();
621
622    tok.trained = true;
623    Ok(())
624}
625
626/// Naive reference implementation — the pre-task-#118 algorithm, verbatim
627/// except that the tie-breaker is forced to lex-min on (left_id, right_id)
628/// so its output is a deterministic baseline for FALSIFY-BPE-TRAIN-PERF-001
629/// (parity) and -005 (speedup measurement). Retained ONLY for tests — the
630/// shipped training path is `train_fast`.
631#[cfg(test)]
632#[doc(hidden)]
633pub(crate) fn train_naive_reference(tok: &mut BPETokenizer, corpus: &[&str]) -> Result<()> {
634    let target = tok.config.vocab_size;
635    let min_frequency = tok.config.min_frequency.max(1);
636
637    tok.init_vocab();
638
639    let mut tokenized: Vec<Vec<String>> =
640        corpus.iter().map(|s| tok.to_bytes(&tok.preprocess(s))).collect();
641
642    while tok.vocab.len() < target {
643        let freqs = tok.get_pair_freqs(&tokenized);
644
645        // Pick pair with max count, lex-min on (left_id, right_id) on ties.
646        let mut best: Option<(usize, (TokenId, TokenId), (String, String))> = None;
647        for (pair_str, count) in &freqs {
648            if *count < min_frequency {
649                continue;
650            }
651            let left_id = *tok.vocab.get(&pair_str.0).expect("left must be in vocab");
652            let right_id = *tok.vocab.get(&pair_str.1).expect("right must be in vocab");
653            match &best {
654                None => best = Some((*count, (left_id, right_id), pair_str.clone())),
655                Some((bc, bp, _)) => {
656                    if *count > *bc || (*count == *bc && (left_id, right_id) < *bp) {
657                        best = Some((*count, (left_id, right_id), pair_str.clone()));
658                    }
659                }
660            }
661        }
662
663        let (_count, _ids, pair_str) = match best {
664            Some(b) => b,
665            None => break,
666        };
667
668        let merged = format!("{}{}", pair_str.0, pair_str.1);
669        let new_id: TokenId = tok.vocab.len() as TokenId;
670        tok.vocab.insert(merged.clone(), new_id);
671        tok.id_to_token_map.insert(new_id, merged.clone());
672        tok.merges.push(pair_str.clone());
673        tok.merge_pair(&mut tokenized, &pair_str, &merged);
674    }
675
676    tok.trained = true;
677    Ok(())
678}
679
680#[cfg(test)]
681mod tests {
682    use super::*;
683
684    #[test]
685    fn test_bpe_new() {
686        let config = TokenizerConfig::bpe();
687        let tokenizer = BPETokenizer::new(config);
688        assert!(!tokenizer.is_trained());
689    }
690
691    #[test]
692    fn test_bpe_train() {
693        let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
694        let mut tokenizer = BPETokenizer::new(config);
695
696        let corpus = vec!["hello hello", "hello world", "world hello"];
697        tokenizer.train(&corpus).expect("operation should succeed");
698
699        assert!(tokenizer.is_trained());
700        assert!(tokenizer.vocab_size() > 256); // Base bytes + some merges
701    }
702
703    #[test]
704    fn test_bpe_encode_not_trained() {
705        let config = TokenizerConfig::bpe();
706        let tokenizer = BPETokenizer::new(config);
707
708        let result = tokenizer.encode("hello");
709        assert!(result.is_err());
710    }
711
712    #[test]
713    fn test_bpe_encode_decode() {
714        let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
715        let mut tokenizer = BPETokenizer::new(config);
716
717        let corpus = vec!["hello world", "hello there"];
718        tokenizer.train(&corpus).expect("operation should succeed");
719
720        let text = "hello";
721        let encoded = tokenizer.encode(text).expect("encoding should succeed");
722        let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
723
724        assert_eq!(decoded, text);
725    }
726
727    #[test]
728    fn test_bpe_lowercase() {
729        let config =
730            TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1).with_lowercase(true);
731        let mut tokenizer = BPETokenizer::new(config);
732
733        let corpus = vec!["Hello World"];
734        tokenizer.train(&corpus).expect("operation should succeed");
735
736        let encoded = tokenizer.encode("HELLO").expect("encoding should succeed");
737        let decoded = tokenizer.decode(&encoded).expect("encoding should succeed");
738
739        assert_eq!(decoded, "hello");
740    }
741
742    #[test]
743    fn test_bpe_id_to_token() {
744        let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
745        let mut tokenizer = BPETokenizer::new(config);
746
747        let corpus = vec!["test"];
748        tokenizer.train(&corpus).expect("operation should succeed");
749
750        // ID 0 should be <unk>
751        assert_eq!(tokenizer.id_to_token(0), Some("<unk>"));
752    }
753
754    #[test]
755    fn test_bpe_token_to_id() {
756        let config = TokenizerConfig::bpe().with_vocab_size(300).with_min_frequency(1);
757        let mut tokenizer = BPETokenizer::new(config);
758
759        let corpus = vec!["test"];
760        tokenizer.train(&corpus).expect("operation should succeed");
761
762        assert_eq!(tokenizer.token_to_id("<unk>"), Some(0));
763    }
764
765    // C-TOK-BPE-001 INV-TOK-003: NFC normalization makes composed and decomposed
766    // variants of the same grapheme hash to identical byte sequences, so a
767    // tokenizer trained on one form encodes the other form identically.
768    #[test]
769    fn test_bpe_nfc_composed_decomposed_parity() {
770        let composed = "café"; // U+00E9
771        let decomposed = "cafe\u{0301}"; // e + combining acute
772
773        let config = TokenizerConfig::bpe()
774            .with_vocab_size(300)
775            .with_min_frequency(1)
776            .with_normalization(Normalization::NFC);
777        let mut tokenizer = BPETokenizer::new(config);
778        tokenizer.train(&[composed]).expect("operation should succeed");
779
780        let ids_composed = tokenizer.encode(composed).expect("encoding should succeed");
781        let ids_decomposed = tokenizer.encode(decomposed).expect("encoding should succeed");
782
783        assert_eq!(
784            ids_composed, ids_decomposed,
785            "NFC must map composed and decomposed café to identical token IDs"
786        );
787
788        let decoded = tokenizer.decode(&ids_composed).expect("decoding should succeed");
789        assert_eq!(decoded, composed, "NFC round-trip must recover composed form");
790    }
791
792    // Without NFC, composed and decomposed café MUST diverge — this is the
793    // exact drift INV-TOK-003 is defending against at training/inference boundary.
794    #[test]
795    fn test_bpe_without_nfc_composed_decomposed_diverge() {
796        let composed = "café";
797        let decomposed = "cafe\u{0301}";
798
799        let config = TokenizerConfig::bpe()
800            .with_vocab_size(300)
801            .with_min_frequency(1)
802            .with_normalization(Normalization::None);
803        let mut tokenizer = BPETokenizer::new(config);
804        tokenizer.train(&[composed]).expect("operation should succeed");
805
806        let ids_composed = tokenizer.encode(composed).expect("encoding should succeed");
807        let ids_decomposed = tokenizer.encode(decomposed).expect("encoding should succeed");
808
809        assert_ne!(
810            ids_composed, ids_decomposed,
811            "Without NFC, composed and decomposed café MUST diverge (falsification witness for INV-TOK-003)"
812        );
813    }
814
815    // C-PRETOK-BIN GATE-PRETOK-003 prerequisite: reloading a trained
816    // tokenizer from its emitted vocab.json + merges.txt MUST yield
817    // byte-identical encodings vs the original in-memory tokenizer.
818    // Any drift here means `apr tokenize encode-corpus` (which loads
819    // via from_vocab_merges) would produce shards that differ from
820    // what the tokenizer intended — ShardBatchIter round-trip fails.
821    #[test]
822    fn test_bpe_from_vocab_merges_roundtrip() {
823        use std::fmt::Write;
824        let config = TokenizerConfig::bpe()
825            .with_vocab_size(400)
826            .with_min_frequency(1)
827            .with_normalization(Normalization::NFC);
828        let mut original = BPETokenizer::new(config.clone());
829        let corpus = vec!["def hello():\n    return 1\n", "def world():\n    return 2\n"];
830        original.train(&corpus).expect("training should succeed");
831
832        let tmp = std::env::temp_dir().join(format!(
833            "bpe_roundtrip_{}_{}",
834            std::process::id(),
835            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()
836        ));
837        std::fs::create_dir_all(&tmp).unwrap();
838        let vocab_path = tmp.join("vocab.json");
839        let merges_path = tmp.join("merges.txt");
840
841        let mut entries: Vec<(&String, &TokenId)> = original.vocab().iter().collect();
842        entries.sort_by_key(|(_, id)| *id);
843        let ordered: serde_json::Map<String, serde_json::Value> = entries
844            .into_iter()
845            .map(|(k, v)| (k.clone(), serde_json::Value::Number((*v).into())))
846            .collect();
847        let vocab_json = serde_json::to_string_pretty(&ordered).unwrap();
848        std::fs::write(&vocab_path, vocab_json).unwrap();
849
850        let mut merges_content = String::from("#version: 0.2\n");
851        for (left, right) in original.merges() {
852            writeln!(merges_content, "{left} {right}").unwrap();
853        }
854        std::fs::write(&merges_path, merges_content).unwrap();
855
856        let reloaded = BPETokenizer::from_vocab_merges(
857            vocab_path.to_str().unwrap(),
858            merges_path.to_str().unwrap(),
859            config,
860        )
861        .expect("from_vocab_merges should succeed");
862
863        assert_eq!(reloaded.vocab_size(), original.vocab_size(), "reloaded vocab size must match");
864
865        for text in &corpus {
866            let original_ids = original.encode(text).expect("original encode");
867            let reloaded_ids = reloaded.encode(text).expect("reloaded encode");
868            assert_eq!(
869                original_ids, reloaded_ids,
870                "reloaded encoding must byte-equal original encoding for {text:?}"
871            );
872        }
873
874        let _ = std::fs::remove_dir_all(&tmp);
875    }
876
877    // Negative: from_vocab_merges must reject a merges.txt with a merged
878    // token not present in vocab.json — that's a corrupt pair, and encoding
879    // would silently emit <unk> instead of the intended token.
880    #[test]
881    fn test_bpe_from_vocab_merges_rejects_orphan_merge() {
882        let tmp = std::env::temp_dir().join(format!(
883            "bpe_orphan_{}_{}",
884            std::process::id(),
885            std::time::SystemTime::now().duration_since(std::time::UNIX_EPOCH).unwrap().as_nanos()
886        ));
887        std::fs::create_dir_all(&tmp).unwrap();
888        let vocab_path = tmp.join("vocab.json");
889        let merges_path = tmp.join("merges.txt");
890
891        std::fs::write(&vocab_path, r#"{"<unk>": 0, "aa": 1, "bb": 2}"#).unwrap();
892        std::fs::write(&merges_path, "#version: 0.2\naa bb\n").unwrap();
893
894        let result = BPETokenizer::from_vocab_merges(
895            vocab_path.to_str().unwrap(),
896            merges_path.to_str().unwrap(),
897            TokenizerConfig::bpe(),
898        );
899
900        assert!(
901            result.is_err(),
902            "from_vocab_merges must reject merges.txt with merged token not in vocab.json"
903        );
904        let err_msg = format!("{:?}", result.unwrap_err());
905        assert!(
906            err_msg.contains("aabb"),
907            "error should name the offending merged token, got: {err_msg}"
908        );
909
910        let _ = std::fs::remove_dir_all(&tmp);
911    }
912
913    // Synthetic Python-like corpus builder for perf / parity tests. Deterministic.
914    fn synthetic_python_corpus(n_docs: usize) -> Vec<String> {
915        let templates: &[&str] = &[
916            "def fn_{i}(x):\n    return x * {i}\n",
917            "class C_{i}:\n    def __init__(self):\n        self.x = {i}\n",
918            "for i in range({i}):\n    print(i * {i})\n",
919            "def add_{i}(a, b):\n    return a + b + {i}\n",
920            "import math\nprint(math.sqrt({i}))\n",
921            "if x == {i}:\n    return True\nelse:\n    return False\n",
922            "xs = [{i}, {i}, {i}]\nfor x in xs:\n    print(x)\n",
923            "def process_{i}(data):\n    result = []\n    for item in data:\n        result.append(item + {i})\n    return result\n",
924        ];
925        (0..n_docs).map(|i| templates[i % templates.len()].replace("{i}", &i.to_string())).collect()
926    }
927
928    // FALSIFY-BPE-TRAIN-PERF-001: fast and naive produce identical output under lex-min.
929    #[test]
930    fn bpe_fast_vs_naive_parity() {
931        let config = TokenizerConfig::bpe()
932            .with_vocab_size(512)
933            .with_min_frequency(1)
934            .with_normalization(Normalization::NFC);
935
936        let corpus_owned = synthetic_python_corpus(20);
937        let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
938
939        let mut fast = BPETokenizer::new(config.clone());
940        super::train_fast(&mut fast, &corpus).expect("fast train should succeed");
941
942        let mut naive = BPETokenizer::new(config);
943        super::train_naive_reference(&mut naive, &corpus).expect("naive train should succeed");
944
945        assert_eq!(
946            fast.vocab_size(),
947            naive.vocab_size(),
948            "vocab sizes must match between fast and naive"
949        );
950        assert_eq!(fast.merges(), naive.merges(), "merge sequence must be identical");
951
952        let mut fast_entries: Vec<(&String, &TokenId)> = fast.vocab().iter().collect();
953        let mut naive_entries: Vec<(&String, &TokenId)> = naive.vocab().iter().collect();
954        fast_entries.sort_by_key(|(_, id)| *id);
955        naive_entries.sort_by_key(|(_, id)| *id);
956        assert_eq!(
957            fast_entries, naive_entries,
958            "vocab (id → token) must be identical between fast and naive"
959        );
960    }
961
962    // FALSIFY-BPE-TRAIN-PERF-002: same corpus + same config → byte-identical output.
963    #[test]
964    fn bpe_fast_is_deterministic() {
965        let config = TokenizerConfig::bpe()
966            .with_vocab_size(400)
967            .with_min_frequency(1)
968            .with_normalization(Normalization::NFC);
969
970        let corpus_owned = synthetic_python_corpus(15);
971        let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
972
973        let mut a = BPETokenizer::new(config.clone());
974        super::train_fast(&mut a, &corpus).expect("run A");
975        let mut b = BPETokenizer::new(config);
976        super::train_fast(&mut b, &corpus).expect("run B");
977
978        assert_eq!(a.merges(), b.merges(), "merges must be byte-identical across runs");
979        assert_eq!(a.vocab_size(), b.vocab_size(), "vocab size must match");
980
981        let mut a_entries: Vec<(&String, &TokenId)> = a.vocab().iter().collect();
982        let mut b_entries: Vec<(&String, &TokenId)> = b.vocab().iter().collect();
983        a_entries.sort_by_key(|(_, id)| *id);
984        b_entries.sort_by_key(|(_, id)| *id);
985        assert_eq!(a_entries, b_entries, "vocab map must be byte-identical across runs");
986    }
987
988    // FALSIFY-BPE-TRAIN-PERF-005: fast ≥ 1.5× faster than the naive it replaces.
989    // Org policy: any replacement must clear 1.5× or it is rejected.
990    //
991    // Uses a 500-doc / vocab=2048 / min_frequency=1 representative workload
992    // per contract bpe-training-perf-v1.yaml v1.1.0. min_frequency=1 forces
993    // the full 1787 merges (rather than early-stopping when counts fall
994    // below 2), which is what exposes the quadratic cost of the naïve loop.
995    //
996    // In debug builds the constant-factor noise swamps the signal, so we
997    // only assert in release — but we DO run the test in debug to catch
998    // regressions in the fast path that explode its runtime beyond reason.
999    #[test]
1000    fn bpe_fast_meets_1_5x_parity_replacement_rule() {
1001        use std::time::Instant;
1002
1003        let config = TokenizerConfig::bpe()
1004            .with_vocab_size(2048)
1005            .with_min_frequency(1)
1006            .with_normalization(Normalization::NFC);
1007
1008        let corpus_owned = synthetic_python_corpus(500);
1009        let corpus: Vec<&str> = corpus_owned.iter().map(String::as_str).collect();
1010
1011        let mut naive = BPETokenizer::new(config.clone());
1012        let t0 = Instant::now();
1013        super::train_naive_reference(&mut naive, &corpus).expect("naive train");
1014        let naive_secs = t0.elapsed().as_secs_f64();
1015
1016        let mut fast = BPETokenizer::new(config);
1017        let t0 = Instant::now();
1018        super::train_fast(&mut fast, &corpus).expect("fast train");
1019        let fast_secs = t0.elapsed().as_secs_f64();
1020
1021        let ratio = naive_secs / fast_secs;
1022        eprintln!(
1023            "[bpe-speedup] naive={naive_secs:.3}s fast={fast_secs:.3}s ratio={ratio:.2}× \
1024             vocab_naive={} vocab_fast={}",
1025            naive.vocab_size(),
1026            fast.vocab_size()
1027        );
1028
1029        // Correctness-floor: parity must hold at this scale too.
1030        assert_eq!(
1031            fast.merges(),
1032            naive.merges(),
1033            "at perf-workload scale, fast and naive merges MUST still match"
1034        );
1035
1036        if cfg!(debug_assertions) {
1037            // Debug mode: assert fast is not worse than 1.0× (i.e. not slower).
1038            // The real 1.5× bar is enforced in release mode below.
1039            assert!(
1040                fast_secs < naive_secs * 1.5,
1041                "even in debug, fast must not be dramatically slower than naive \
1042                 (ratio={ratio:.2}×)"
1043            );
1044        } else {
1045            assert!(
1046                ratio >= 1.5,
1047                "org policy: replacement must be ≥1.5× faster than the replaced \
1048                 algorithm — got {ratio:.2}× (naive={naive_secs:.3}s, fast={fast_secs:.3}s)"
1049            );
1050        }
1051    }
1052}
1053
1054#[cfg(test)]
1055mod property_tests {
1056    use super::*;
1057    use proptest::prelude::*;
1058
1059    proptest! {
1060        #![proptest_config(ProptestConfig::with_cases(50))]
1061
1062        #[test]
1063        fn prop_bpe_encode_produces_valid_ids(text in "[a-zA-Z ]{1,20}") {
1064            let config = TokenizerConfig::bpe()
1065                .with_vocab_size(300)
1066                .with_min_frequency(1);
1067            let mut tokenizer = BPETokenizer::new(config);
1068            tokenizer.train(&[&text]).expect("operation should succeed");
1069
1070            let encoded = tokenizer.encode(&text).expect("encoding should succeed");
1071
1072            for id in encoded {
1073                prop_assert!(tokenizer.id_to_token(id).is_some());
1074            }
1075        }
1076
1077        #[test]
1078        fn prop_vocab_size_bounded(target_size in 261usize..500) {
1079            let config = TokenizerConfig::bpe()
1080                .with_vocab_size(target_size)
1081                .with_min_frequency(1);
1082            let mut tokenizer = BPETokenizer::new(config);
1083
1084            let corpus = vec!["hello world hello world test test"];
1085            tokenizer.train(&corpus).expect("operation should succeed");
1086
1087            prop_assert!(tokenizer.vocab_size() <= target_size);
1088        }
1089    }
1090}