Skip to main content

riptoken/
lib.rs

1//! # riptoken
2//!
3//! Fast BPE tokenizer for LLMs — a drop-in compatible, faster reimplementation
4//! of OpenAI's [`tiktoken`](https://github.com/openai/tiktoken).
5//!
6//! ## Design
7//!
8//! riptoken is structured as three layers:
9//! 1. A pure-Rust core ([`CoreBPE`]) that can be used directly from Rust.
10//! 2. An optional PyO3 binding (enabled with the `python` feature).
11//! 3. A Python wrapper package shipped on PyPI.
12//!
13//! The core BPE algorithm is a Rust port of tiktoken's with several
14//! optimizations applied — see `README.md` for benchmarks and details.
15//!
16//! ## Example
17//!
18//! ```no_run
19//! use riptoken::{CoreBPE, Rank};
20//! use rustc_hash::FxHashMap;
21//!
22//! // In practice you would load `encoder` from an o200k_base / cl100k_base
23//! // vocabulary file via `riptoken::load_tiktoken_bpe`.
24//! let mut encoder: FxHashMap<Vec<u8>, Rank> = FxHashMap::default();
25//! encoder.insert(b"h".to_vec(), 0);
26//! encoder.insert(b"i".to_vec(), 1);
27//! encoder.insert(b"hi".to_vec(), 2);
28//!
29//! let specials = FxHashMap::default();
30//! let bpe = CoreBPE::new(encoder, specials, r"\w+").unwrap();
31//!
32//! let tokens = bpe.encode_ordinary("hi");
33//! assert_eq!(tokens, vec![2]);
34//! ```
35
36use fancy_regex::Regex as FancyRegex;
37use regex::Regex as FastRegex;
38use regex::RegexBuilder as FastRegexBuilder;
39use regex_automata::{
40    Input,
41    dfa::{dense, regex::Regex as DfaRegex},
42    nfa::thompson,
43    util::syntax,
44};
45use rustc_hash::{FxHashMap as HashMap, FxHasher};
46use std::collections::HashSet;
47use std::hash::{Hash, Hasher};
48
49#[cfg(feature = "python")]
50use pyo3::prelude::*;
51
52// --- Build-time pre-compiled DFAs --------------------------------------------
53
54#[cfg(feature = "precompiled-dfa")]
55mod prebuilt {
56    use super::*;
57
58    // Raw patterns exactly as tiktoken provides them. gpt2, r50k_base,
59    // p50k_base, and p50k_edit all share the same pattern.
60    const GPT2_RAW: &str =
61        r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";
62    const CL100K_RAW: &str = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
63    const O200K_RAW: &str = concat!(
64        r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
65        r"|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
66        r"|\p{N}{1,3}",
67        r"| ?[^\s\p{L}\p{N}]+[\r\n/]*",
68        r"|\s*[\r\n]+",
69        r"|\s+(?!\S)|\s+",
70    );
71
72    /// Force 4-byte alignment for `include_bytes!` data so that
73    /// `dense::DFA::from_bytes` can reinterpret the slice as `&[u32]`.
74    #[repr(C)]
75    struct AlignAs<Align, Bytes: ?Sized> {
76        _align: [Align; 0],
77        bytes: Bytes,
78    }
79
80    macro_rules! include_dfa {
81        ($path:expr) => {{
82            const ALIGNED: &AlignAs<u32, [u8]> = &AlignAs {
83                _align: [],
84                bytes: *include_bytes!($path),
85            };
86            &ALIGNED.bytes
87        }};
88    }
89
90    static GPT2_FWD: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/gpt2_fwd.dfa"));
91    static GPT2_REV: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/gpt2_rev.dfa"));
92    static CL100K_FWD: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/cl100k_fwd.dfa"));
93    static CL100K_REV: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/cl100k_rev.dfa"));
94    static O200K_FWD: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/o200k_fwd.dfa"));
95    static O200K_REV: &[u8] = include_dfa!(concat!(env!("OUT_DIR"), "/o200k_rev.dfa"));
96
97    /// Try to load a pre-built DFA for a known stock tiktoken pattern.
98    ///
99    /// Returns `None` if the pattern doesn't match any stock pattern or if
100    /// deserialization fails.
101    pub(crate) fn try_load(pattern: &str) -> Option<(DfaRegex, ShrinkMode)> {
102        let (fwd_bytes, rev_bytes, shrink_mode) = if pattern == GPT2_RAW {
103            (GPT2_FWD, GPT2_REV, ShrinkMode::Unified)
104        } else if pattern == CL100K_RAW {
105            (CL100K_FWD, CL100K_REV, ShrinkMode::PlainOnly)
106        } else if pattern == O200K_RAW {
107            (O200K_FWD, O200K_REV, ShrinkMode::PlainOnly)
108        } else {
109            return None;
110        };
111        let (fwd, _) = dense::DFA::from_bytes(fwd_bytes).ok()?;
112        let (rev, _) = dense::DFA::from_bytes(rev_bytes).ok()?;
113        Some((
114            DfaRegex::builder().build_from_dfas(fwd.to_owned(), rev.to_owned()),
115            shrink_mode,
116        ))
117    }
118}
119
120/// Integer rank of a token in the BPE vocabulary.
121///
122/// Lower ranks are merged first. [`Rank::MAX`] is reserved as a sentinel meaning
123/// "this byte span is not in the vocabulary".
124pub type Rank = u32;
125
126/// Number of thread-local regex clones. Must be a power of two for cheap
127/// modulo via bitmask, but we use plain `%` since this is off the hot path.
128const MAX_NUM_THREADS: usize = 128;
129
130/// Pieces shorter than this use the `Vec`-based merge path with a linear-scan
131/// min-find. Pieces at or above this length use a heap-based path.
132///
133/// Short pieces benefit from cache locality; long pieces avoid the `O(m·n)`
134/// cliff of the linear scan. The threshold matches tiktoken's.
135const LARGE_PIECE_THRESHOLD: usize = 500;
136
137thread_local! {
138    static THREAD_INDEX: usize = {
139        let mut h = FxHasher::default();
140        std::thread::current().id().hash(&mut h);
141        (h.finish() as usize) % MAX_NUM_THREADS
142    };
143}
144
145#[inline]
146fn thread_index() -> usize {
147    THREAD_INDEX.with(|&i| i)
148}
149
150/// Errors produced when constructing a [`CoreBPE`].
151#[derive(Debug)]
152pub enum BuildError {
153    /// The regex pattern failed to compile in both the fast and the fallback
154    /// engines. Contains the fallback engine's error message.
155    InvalidRegex(String),
156    /// The encoder and decoder had mismatched sizes (usually means duplicate
157    /// ranks or bytes in the input vocabulary).
158    VocabularyMismatch,
159}
160
161impl std::fmt::Display for BuildError {
162    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
163        match self {
164            BuildError::InvalidRegex(e) => write!(f, "invalid regex pattern: {e}"),
165            BuildError::VocabularyMismatch => write!(
166                f,
167                "vocabulary has duplicate entries (encoder/decoder size mismatch)"
168            ),
169        }
170    }
171}
172
173impl std::error::Error for BuildError {}
174
175impl From<fancy_regex::Error> for BuildError {
176    fn from(e: fancy_regex::Error) -> Self {
177        BuildError::InvalidRegex(e.to_string())
178    }
179}
180
181/// Errors produced during decoding.
182#[derive(Debug)]
183pub enum DecodeError {
184    /// A token ID was not in the vocabulary.
185    InvalidToken(Rank),
186    /// The decoded bytes were not valid UTF-8 (only raised by [`CoreBPE::decode`]).
187    InvalidUtf8,
188}
189
190impl std::fmt::Display for DecodeError {
191    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
192        match self {
193            DecodeError::InvalidToken(t) => write!(f, "invalid token id: {t}"),
194            DecodeError::InvalidUtf8 => write!(f, "decoded bytes are not valid UTF-8"),
195        }
196    }
197}
198
199impl std::error::Error for DecodeError {}
200
201// --- Split engine dispatch ---------------------------------------------------
202
203/// How to apply the whitespace-shrink rule in the fast-path engine.
204///
205/// The rule reproduces tiktoken's `\s+(?!\S)` alternative — "match a
206/// whitespace run but don't swallow the last char if a non-whitespace char
207/// follows, so it can attach to the following word via the word
208/// alternative's optional non-word prefix".
209#[derive(Clone, Copy, PartialEq, Debug)]
210enum ShrinkMode {
211    /// No `\s+(?!\S)` alternative in the original pattern (e.g. a toy
212    /// pattern like `\w+|\s+`). No shrinking is applied.
213    None,
214    /// o200k / cl100k style: the pattern has a separate `\s*[\r\n]+`
215    /// alternative which takes priority and matches any whitespace run
216    /// containing a newline as its own piece. The shrink rule should only
217    /// fire on plain whitespace runs — newline runs are emitted whole.
218    PlainOnly,
219    /// gpt2 / r50k / p50k style: no separate newline alternative, so
220    /// `\s+(?!\S)` absorbs every whitespace run, newlines included. The
221    /// shrink rule should fire on any whitespace run.
222    Unified,
223}
224
225/// Strip lookarounds and possessive quantifiers from a tiktoken-style pattern
226/// so it can be compiled by a pure-DFA engine (`regex` crate or
227/// `regex-automata` dense DFA).
228///
229/// Returns `Some((transformed, shrink_mode))` if the pattern is eligible for
230/// the fast (non-backtracking) path, or `None` if it contains features that
231/// require a full backtracking engine.
232fn transform_pattern(pattern: &str) -> Option<(String, ShrinkMode)> {
233    // 1. Collapse the whitespace-run tail and classify the pattern family.
234    //    Every stock tiktoken pattern ends with one of two shapes:
235    //       ...|\s*[\r\n]+|\s+(?!\S)|\s+   (o200k/cl100k-style)
236    //       ...|\s+$|\s+(?!\S)|\s          (gpt2/r50k/p50k-style)
237    //    We collapse the `\s+(?!\S)|...` pair into a single greedy `\s+`
238    //    and reproduce the "leave last char" semantics in `find_pieces`
239    //    below. The `ShrinkMode` distinguishes the two families so the
240    //    shrink rule can be applied correctly: in the o200k family the
241    //    separate `\s*[\r\n]+` alternative already handles runs that
242    //    contain a newline, so shrink must only fire on plain whitespace
243    //    runs; in the gpt2 family there is no such alternative and shrink
244    //    must fire on any whitespace run.
245    //
246    //    The shrink family is determined by the presence of a dedicated
247    //    newline alternative like `\s*[\r\n]` or `\s*[\r\n]+`: patterns
248    //    that have one (o200k/cl100k) use `PlainOnly` so that
249    //    newline-containing matches from that alternative aren't shrunk;
250    //    patterns without one (gpt2/r50k/p50k) use `Unified` so the
251    //    shrink fires on all whitespace runs including newlines.
252    let has_lookahead_ws = pattern.contains(r"\s+(?!\S)");
253    let has_newline_alt = pattern.contains(r"\s*[\r\n]");
254    let shrink_mode = if has_lookahead_ws {
255        if has_newline_alt {
256            ShrinkMode::PlainOnly
257        } else {
258            ShrinkMode::Unified
259        }
260    } else {
261        ShrinkMode::None
262    };
263    let mut stripped = pattern.replace(r"\s+(?!\S)|\s+", r"\s+");
264    stripped = stripped.replace(r"\s+(?!\S)|\s", r"\s+");
265    // 2. Reject anything that still contains a lookaround. We intentionally
266    //    don't try to be clever — any residual `(?=`, `(?!`, `(?<=`, `(?<!`
267    //    means we fall back to fancy-regex.
268    if stripped.contains("(?=")
269        || stripped.contains("(?!")
270        || stripped.contains("(?<=")
271        || stripped.contains("(?<!")
272    {
273        return None;
274    }
275    // 3. Convert possessive quantifiers to greedy. tiktoken's newer
276    //    cl100k_base / p50k_base / o200k_base patterns use `?+`, `++`,
277    //    `*+`, `{n,m}+` as backtracking-engine speed hints ("don't retry
278    //    this match"). The regex crate's parser silently accepts the
279    //    syntax but its DFA interprets it differently, producing wrong
280    //    matches.
281    //
282    //    In a DFA engine possessive markers are semantically unnecessary:
283    //    the DFA is already linear-time and has no backtracking to
284    //    disable. And in every tiktoken pattern the possessive and greedy
285    //    matches are identical by construction — the alternatives are
286    //    disjoint enough that backtracking would never change the result
287    //    — so converting possessive → greedy is safe.
288    //
289    //    Simple string replace handles `?+`, `++`, `*+` (tiktoken patterns
290    //    never put these chars inside `[...]` or escape them). `{n,m}+`
291    //    needs care: `\p{L}+` also contains the literal sequence `}+`,
292    //    but the `+` there is a greedy quantifier on the class, not a
293    //    possessive marker — so we use a precise regex to match only
294    //    ranges `{digits[,digits]}+`.
295    stripped = stripped
296        .replace("?+", "?")
297        .replace("++", "+")
298        .replace("*+", "*");
299    let range_possessive = FastRegex::new(r"(\{\d+(?:,\d*)?\})\+").ok()?;
300    let stripped = range_possessive.replace_all(&stripped, "$1").into_owned();
301    Some((stripped, shrink_mode))
302}
303
304/// Try to rewrite a tiktoken-style pattern into a form that the SIMD-accelerated
305/// `regex` crate can compile (lazy DFA).
306///
307/// Returns `Some((regex, mode))` if the pattern compiles successfully with
308/// the `regex` crate after transformation, or `None` if the pattern contains
309/// other lookarounds (or features) that regex can't handle — in which case
310/// the caller falls back to `fancy-regex`.
311fn try_transform_for_fast_regex(pattern: &str) -> Option<(FastRegex, ShrinkMode)> {
312    let (transformed, shrink_mode) = transform_pattern(pattern)?;
313    let regex = FastRegexBuilder::new(&transformed)
314        .dfa_size_limit(32 * (1 << 20))
315        .build()
316        .ok()?;
317    Some((regex, shrink_mode))
318}
319
320/// Try to build a fully pre-compiled dense DFA for the pattern.
321///
322/// Unlike the lazy DFA in `try_transform_for_fast_regex`, this materializes
323/// all states upfront — zero cold-start at search time, and the resulting
324/// `DfaRegex` has no mutable state so it can be shared across threads without
325/// cloning.
326///
327/// Returns `None` if the pattern can't be transformed or the DFA build fails
328/// (e.g. the state table is too large).
329fn try_build_precompiled_dfa(pattern: &str) -> Option<(DfaRegex, ShrinkMode)> {
330    let (transformed, shrink_mode) = transform_pattern(pattern)?;
331    let dfa = DfaRegex::builder()
332        .syntax(syntax::Config::new().unicode(true).utf8(true))
333        .thompson(thompson::Config::new())
334        .dense(dense::Config::new().start_kind(regex_automata::dfa::StartKind::Unanchored))
335        .build(&transformed)
336        .ok()?;
337    Some((dfa, shrink_mode))
338}
339
340/// True if `s` consists entirely of whitespace with no `\n` or `\r`.
341///
342/// Used under [`ShrinkMode::PlainOnly`] (o200k/cl100k family): the whitespace
343/// shrink rule must NOT fire on matches that contain a newline, because those
344/// come from the `\s*[\r\n]+` alternative in the pattern, which takes priority
345/// in the alternation and is emitted whole. Checking "contains no newline" is
346/// a reliable way to distinguish a `\s+` match from a `\s*[\r\n]+` match
347/// without needing capture groups.
348#[inline]
349fn is_plain_whitespace_run(s: &str) -> bool {
350    !s.is_empty()
351        && s.chars()
352            .all(|c| c.is_whitespace() && c != '\n' && c != '\r')
353}
354
355/// True if `s` consists entirely of whitespace (any kind, including `\n`/`\r`).
356///
357/// Used under [`ShrinkMode::Unified`] (gpt2/r50k/p50k family): there is no
358/// separate newline alternative in those patterns, so `\s+(?!\S)` absorbs
359/// whitespace runs containing newlines and the shrink rule must fire on
360/// them too.
361#[inline]
362fn is_whitespace_run(s: &str) -> bool {
363    !s.is_empty() && s.chars().all(|c| c.is_whitespace())
364}
365
366/// True if the next char in `s` (starting at `pos`) is a non-whitespace char.
367/// Returns `false` if `pos` is at end-of-string.
368#[inline]
369fn next_char_is_non_whitespace(text: &str, pos: usize) -> bool {
370    match text[pos..].chars().next() {
371        Some(c) => !c.is_whitespace(),
372        None => false,
373    }
374}
375
376/// Apply the `\s+(?!\S)` whitespace-shrink rule, returning the adjusted end.
377///
378/// If the matched piece is a whitespace run (per `shrink_mode`) and the
379/// character immediately after it is non-whitespace, shrink by one character
380/// so that trailing whitespace can attach to the following word.
381#[inline]
382fn apply_shrink(text: &str, start: usize, end: usize, shrink_mode: ShrinkMode) -> usize {
383    let piece = &text[start..end];
384    let should_shrink = match shrink_mode {
385        ShrinkMode::None => false,
386        ShrinkMode::PlainOnly => is_plain_whitespace_run(piece),
387        ShrinkMode::Unified => is_whitespace_run(piece),
388    };
389    if should_shrink && end < text.len() && next_char_is_non_whitespace(text, end) {
390        if let Some((last_i, _)) = piece.char_indices().next_back() {
391            if last_i > 0 {
392                return start + last_i;
393            }
394        }
395    }
396    end
397}
398
399/// The pattern-matching engine used to split text into pieces before BPE.
400enum SplitEngine {
401    /// Pre-compiled dense DFA — all states materialized upfront, zero
402    /// cold-start penalty. The `DfaRegex` has no mutable state, so one
403    /// instance is shared across all threads without cloning.
404    PrecompiledDfa {
405        dfa_regex: DfaRegex,
406        shrink_mode: ShrinkMode,
407    },
408    /// SIMD-accelerated lazy DFA. One clone per thread slot, plus a
409    /// [`ShrinkMode`] selecting how to emulate `\s+(?!\S)` in Rust.
410    Fast {
411        clones: Vec<FastRegex>,
412        shrink_mode: ShrinkMode,
413    },
414    /// Backtracking regex with lookaround support. One clone per thread slot.
415    Fancy(Vec<FancyRegex>),
416}
417
418impl SplitEngine {
419    /// Build an engine for `pattern`.
420    ///
421    /// Priority: pre-built DFA (stock patterns) → eager DFA build →
422    /// lazy DFA → fancy-regex.
423    fn new(pattern: &str) -> Result<Self, BuildError> {
424        // 1. Try pre-built DFA from build.rs (zero cost, stock patterns only).
425        #[cfg(feature = "precompiled-dfa")]
426        if let Some((dfa_regex, shrink_mode)) = prebuilt::try_load(pattern) {
427            return Ok(SplitEngine::PrecompiledDfa {
428                dfa_regex,
429                shrink_mode,
430            });
431        }
432        // 2. Try eager dense DFA build (non-stock patterns).
433        if let Some((dfa_regex, shrink_mode)) = try_build_precompiled_dfa(pattern) {
434            return Ok(SplitEngine::PrecompiledDfa {
435                dfa_regex,
436                shrink_mode,
437            });
438        }
439        // 3. Try lazy DFA.
440        if let Some((fast, shrink_mode)) = try_transform_for_fast_regex(pattern) {
441            let clones: Vec<FastRegex> = (0..MAX_NUM_THREADS).map(|_| fast.clone()).collect();
442            return Ok(SplitEngine::Fast {
443                clones,
444                shrink_mode,
445            });
446        }
447        // 4. Fallback: fancy-regex with lookaround support.
448        let fancy = FancyRegex::new(pattern)?;
449        let clones: Vec<FancyRegex> = (0..MAX_NUM_THREADS).map(|_| fancy.clone()).collect();
450        Ok(SplitEngine::Fancy(clones))
451    }
452
453    /// True if this engine uses a non-backtracking path (precompiled or lazy DFA).
454    #[cfg(test)]
455    fn is_fast(&self) -> bool {
456        matches!(
457            self,
458            SplitEngine::PrecompiledDfa { .. } | SplitEngine::Fast { .. }
459        )
460    }
461
462    /// True if this engine uses a pre-compiled dense DFA (zero cold-start).
463    #[cfg(all(test, feature = "precompiled-dfa"))]
464    fn is_precompiled(&self) -> bool {
465        matches!(self, SplitEngine::PrecompiledDfa { .. })
466    }
467
468    /// Iterate the pieces of `text`, invoking `f` with each piece as a `&str`.
469    ///
470    /// On the fast paths, applies the `\s+(?!\S)` whitespace-shrink rule in
471    /// Rust so that output matches tiktoken exactly.
472    #[inline]
473    fn find_pieces<F: FnMut(&str)>(&self, text: &str, mut f: F) {
474        match self {
475            SplitEngine::PrecompiledDfa {
476                dfa_regex,
477                shrink_mode,
478            } => {
479                let haystack = text.as_bytes();
480                let mut pos = 0;
481                while pos < haystack.len() {
482                    let input = Input::new(haystack).range(pos..);
483                    let m = match dfa_regex.find(input) {
484                        Some(m) => m,
485                        None => break,
486                    };
487                    if m.start() > pos {
488                        pos = m.start();
489                    }
490                    let start = m.start();
491                    let end = apply_shrink(text, start, m.end(), *shrink_mode);
492                    f(&text[start..end]);
493                    if end == pos {
494                        pos += 1;
495                    } else {
496                        pos = end;
497                    }
498                }
499            }
500            SplitEngine::Fast {
501                clones,
502                shrink_mode,
503            } => {
504                let regex = &clones[thread_index()];
505                let mut pos = 0;
506                while pos < text.len() {
507                    let m = match regex.find_at(text, pos) {
508                        Some(m) => m,
509                        None => break,
510                    };
511                    if m.start() > pos {
512                        pos = m.start();
513                    }
514                    let start = m.start();
515                    let end = apply_shrink(text, start, m.end(), *shrink_mode);
516                    f(&text[start..end]);
517                    if end == pos {
518                        pos += 1;
519                    } else {
520                        pos = end;
521                    }
522                }
523            }
524            SplitEngine::Fancy(clones) => {
525                let regex = &clones[thread_index()];
526                for mat in regex.find_iter(text) {
527                    match mat {
528                        Ok(m) => f(m.as_str()),
529                        Err(_) => continue,
530                    }
531                }
532            }
533        }
534    }
535}
536
537/// The core BPE encoder/decoder.
538///
539/// A `CoreBPE` owns the vocabulary, the compiled regex used to split text into
540/// pieces, and a pool of per-thread regex clones. It is `Send + Sync` and
541/// designed to be constructed once and shared (e.g. behind an `Arc`).
542#[cfg_attr(feature = "python", pyclass(module = "riptoken._riptoken"))]
543pub struct CoreBPE {
544    /// Byte-sequence → rank. Lookup key for encoding.
545    encoder: HashMap<Vec<u8>, Rank>,
546    /// Rank → byte-sequence. Lookup key for decoding.
547    decoder: HashMap<Rank, Vec<u8>>,
548    /// Special token string → rank.
549    special_tokens_encoder: HashMap<String, Rank>,
550    /// Rank → special token bytes.
551    special_tokens_decoder: HashMap<Rank, Vec<u8>>,
552    /// Fast or fancy engine for splitting text into pieces.
553    split_engine: SplitEngine,
554    /// Thread-local clones of the special-token regex. Empty if there are no
555    /// special tokens. Special-token patterns are always literal alternations,
556    /// so `fancy_regex` is fine here (and needed only for its `find_from_pos`
557    /// API which `regex` crate also provides but we keep symmetric).
558    special_regex_tls: Vec<FancyRegex>,
559    /// Sorted vocabulary bytes, useful for prefix queries.
560    sorted_token_bytes: Vec<Vec<u8>>,
561}
562
563// --- Core BPE merge algorithm -------------------------------------------------
564
565/// Compute the rank of a byte pair directly from a slice — no allocation.
566///
567/// `HashMap<Vec<u8>, Rank>::get` accepts any `Q: ?Sized` where
568/// `Vec<u8>: Borrow<Q>`, and `Vec<u8>` implements `Borrow<[u8]>`, so this
569/// avoids the `.to_vec()` allocation the naive implementation does.
570#[inline(always)]
571fn rank_of(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Rank {
572    ranks.get(piece).copied().unwrap_or(Rank::MAX)
573}
574
575/// `O(m·n)` linear-scan BPE merge for short pieces.
576///
577/// Returns a list of `(start_position, rank)` such that consecutive windows of
578/// two elements give the start/end of each final token:
579///
580/// ```text
581/// parts: [(0, _), (2, _), (5, _), (5, MAX)]
582/// tokens: piece[0..2], piece[2..5]
583/// ```
584#[inline]
585fn byte_pair_merge(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<(usize, Rank)> {
586    // Fast path: trivially short pieces.
587    if piece.len() < 2 {
588        return vec![(0, Rank::MAX), (piece.len(), Rank::MAX)];
589    }
590
591    let mut parts: Vec<(usize, Rank)> = Vec::with_capacity(piece.len() + 1);
592
593    // Populate initial byte pairs AND find the initial minimum in a single pass.
594    let mut min_rank: (Rank, usize) = (Rank::MAX, usize::MAX);
595    for i in 0..piece.len() - 1 {
596        let rank = rank_of(ranks, &piece[i..i + 2]);
597        if rank < min_rank.0 {
598            min_rank = (rank, i);
599        }
600        parts.push((i, rank));
601    }
602    parts.push((piece.len() - 1, Rank::MAX));
603    parts.push((piece.len(), Rank::MAX));
604
605    // Returns the rank of the merge *starting* at `parts[i]`, using the
606    // pre-remove parts vector — so it looks 3 ahead to see past the soon-to-be-
607    // removed `parts[i+1]`.
608    let get_rank = |parts: &[(usize, Rank)], i: usize| -> Rank {
609        if i + 3 < parts.len() {
610            rank_of(ranks, &piece[parts[i].0..parts[i + 3].0])
611        } else {
612            Rank::MAX
613        }
614    };
615
616    while min_rank.0 != Rank::MAX {
617        let i = min_rank.1;
618
619        // Update parts[i-1] and parts[i] BEFORE the remove. `parts.remove`
620        // shifts everything from `i+2` leftward, evicting cache lines. Reading
621        // the hot neighbours first keeps the accesses on hot memory.
622        if i > 0 {
623            parts[i - 1].1 = get_rank(&parts, i - 1);
624        }
625        parts[i].1 = get_rank(&parts, i);
626        parts.remove(i + 1);
627
628        // Rescan for new minimum. Excludes the two trailing sentinels.
629        min_rank = (Rank::MAX, usize::MAX);
630        for (j, &(_, rank)) in parts[..parts.len() - 2].iter().enumerate() {
631            if rank < min_rank.0 {
632                min_rank = (rank, j);
633            }
634        }
635    }
636
637    parts
638}
639
640/// Apply BPE to a single piece that is NOT already a full vocabulary entry.
641#[inline]
642fn byte_pair_encode(piece: &[u8], ranks: &HashMap<Vec<u8>, Rank>) -> Vec<Rank> {
643    // Single byte fast path.
644    if piece.len() == 1 {
645        // Every standard BPE vocab includes all 256 single bytes.
646        return vec![*ranks.get(piece).expect("byte fallback")];
647    }
648
649    if piece.len() < LARGE_PIECE_THRESHOLD {
650        let positions = byte_pair_merge(ranks, piece);
651        // `positions` has n+1 entries and yields n tokens (windows of 2).
652        let mut out: Vec<Rank> = Vec::with_capacity(positions.len() - 1);
653        out.extend(
654            positions
655                .windows(2)
656                .map(|w| rank_of(ranks, &piece[w[0].0..w[1].0])),
657        );
658        out
659    } else {
660        byte_pair_merge_large(ranks, piece)
661    }
662}
663
664// --- Heap-based merge for long pieces ----------------------------------------
665
666/// `O(m log n)` heap-based BPE merge for long pieces, with an intrusive
667/// doubly-linked list embedded in a flat `Vec<State>` to avoid the `O(n)`
668/// shifts of `Vec::remove`.
669///
670/// Uses lazy invalidation: we never remove stale entries from the heap —
671/// instead we bump a generation counter on the state and skip heap entries
672/// whose stored rank no longer matches.
673fn byte_pair_merge_large(ranks: &HashMap<Vec<u8>, Rank>, piece: &[u8]) -> Vec<Rank> {
674    use std::cmp::Reverse;
675    use std::collections::BinaryHeap;
676
677    // state[start].end is the exclusive end position of the current token
678    // starting at `start`. state[start].prev is the start position of the
679    // previous token (or usize::MAX if none). `cur_rank` tracks which rank
680    // this state represents in the heap; stale heap entries are discarded.
681    #[derive(Clone)]
682    struct State {
683        prev: usize,
684        end: usize,
685        cur_rank: Rank,
686    }
687
688    let n = piece.len();
689    let mut state: Vec<State> = (0..n)
690        .map(|i| State {
691            prev: if i == 0 { usize::MAX } else { i - 1 },
692            end: i + 1,
693            cur_rank: 0,
694        })
695        .collect();
696
697    // Heap entry: (Reverse(rank), start). Reverse so BinaryHeap is a min-heap.
698    let mut heap: BinaryHeap<(Reverse<Rank>, usize)> = BinaryHeap::with_capacity(n);
699
700    // Seed with all initial pair ranks.
701    for i in 0..n.saturating_sub(1) {
702        let rank = rank_of(ranks, &piece[i..state[i + 1].end]);
703        state[i].cur_rank = rank;
704        if rank != Rank::MAX {
705            heap.push((Reverse(rank), i));
706        }
707    }
708
709    while let Some((Reverse(rank), start)) = heap.pop() {
710        // Lazy invalidation: skip if this entry is stale.
711        if state[start].cur_rank != rank || rank == Rank::MAX {
712            continue;
713        }
714
715        // Absorb the next token into [start]: extend `end`, unlink [right].
716        let right = state[start].end;
717        if right >= n {
718            continue;
719        }
720        let new_end = state[right].end;
721        state[start].end = new_end;
722
723        // Patch the "prev" of whatever comes after the absorbed one.
724        if new_end < n {
725            state[new_end].prev = start;
726        }
727
728        // Invalidate the old right entry so future stale heap pops skip it.
729        state[right].cur_rank = Rank::MAX;
730
731        // Recompute rank of [start] (now a longer span).
732        let next_end = state[start].end;
733        if next_end < n {
734            let new_rank = rank_of(ranks, &piece[start..state[next_end].end]);
735            state[start].cur_rank = new_rank;
736            if new_rank != Rank::MAX {
737                heap.push((Reverse(new_rank), start));
738            }
739        } else {
740            state[start].cur_rank = Rank::MAX;
741        }
742
743        // Recompute rank of [prev] — it now points to a longer span too.
744        let prev = state[start].prev;
745        if prev != usize::MAX {
746            let prev_next_end = state[prev].end; // this is still `start` unchanged
747            debug_assert_eq!(prev_next_end, start);
748            let span_end = state[start].end;
749            let new_rank = rank_of(ranks, &piece[prev..span_end]);
750            state[prev].cur_rank = new_rank;
751            if new_rank != Rank::MAX {
752                heap.push((Reverse(new_rank), prev));
753            }
754        }
755    }
756
757    // Walk the linked list from start to collect final tokens.
758    let mut tokens = Vec::new();
759    let mut i = 0;
760    while i < n {
761        let end = state[i].end;
762        tokens.push(rank_of(ranks, &piece[i..end]));
763        i = end;
764    }
765    tokens
766}
767
768// --- Special-token regex helpers ----------------------------------------------
769
770/// Build a regex that matches any of the given special token strings.
771///
772/// Returns `None` if `specials` is empty — callers should then skip the
773/// special-token scan entirely.
774fn build_special_regex(specials: &HashMap<String, Rank>) -> Result<Option<FancyRegex>, BuildError> {
775    if specials.is_empty() {
776        return Ok(None);
777    }
778    // Escape each special token literally and join with `|`.
779    let parts: Vec<String> = specials
780        .keys()
781        .map(|s| fancy_regex::escape(s).into_owned())
782        .collect();
783    let pattern = parts.join("|");
784    Ok(Some(FancyRegex::new(&pattern)?))
785}
786
787// --- CoreBPE public API ------------------------------------------------------
788
789impl CoreBPE {
790    /// Construct a new `CoreBPE`.
791    ///
792    /// - `encoder`: byte-sequence → rank map (the vocabulary).
793    /// - `special_tokens_encoder`: special-token string → rank.
794    /// - `pattern`: a regex string used to split text into pieces before BPE.
795    ///
796    /// Returns a [`BuildError`] if the regex is invalid or the vocabulary has
797    /// duplicate entries.
798    pub fn new(
799        encoder: HashMap<Vec<u8>, Rank>,
800        special_tokens_encoder: HashMap<String, Rank>,
801        pattern: &str,
802    ) -> Result<Self, BuildError> {
803        let split_engine = SplitEngine::new(pattern)?;
804        let decoder: HashMap<Rank, Vec<u8>> =
805            encoder.iter().map(|(k, v)| (*v, k.clone())).collect();
806        if decoder.len() != encoder.len() {
807            return Err(BuildError::VocabularyMismatch);
808        }
809        let special_tokens_decoder: HashMap<Rank, Vec<u8>> = special_tokens_encoder
810            .iter()
811            .map(|(k, v)| (*v, k.as_bytes().to_vec()))
812            .collect();
813
814        let special_regex = build_special_regex(&special_tokens_encoder)?;
815        let special_regex_tls: Vec<FancyRegex> = match special_regex {
816            Some(r) => (0..MAX_NUM_THREADS).map(|_| r.clone()).collect(),
817            None => Vec::new(),
818        };
819
820        let mut sorted_token_bytes: Vec<Vec<u8>> = encoder.keys().cloned().collect();
821        sorted_token_bytes.sort();
822
823        Ok(CoreBPE {
824            encoder,
825            decoder,
826            special_tokens_encoder,
827            special_tokens_decoder,
828            split_engine,
829            special_regex_tls,
830            sorted_token_bytes,
831        })
832    }
833
834    /// Size of the vocabulary, defined as `max_rank + 1` across all tokens
835    /// (ordinary + special). This matches tiktoken's `n_vocab` semantics, so
836    /// vocabularies with reserved rank gaps (like `o200k_base`) report the
837    /// "reach" of the id space rather than the count of live tokens.
838    pub fn n_vocab(&self) -> usize {
839        let max_ordinary = self.encoder.values().copied().max().unwrap_or(0);
840        let max_special = self
841            .special_tokens_encoder
842            .values()
843            .copied()
844            .max()
845            .unwrap_or(0);
846        max_ordinary.max(max_special) as usize + 1
847    }
848
849    /// A sorted list of every (non-special) token's bytes.
850    pub fn token_byte_values(&self) -> &[Vec<u8>] {
851        &self.sorted_token_bytes
852    }
853
854    #[inline]
855    fn tl_special_regex(&self) -> Option<&FancyRegex> {
856        self.special_regex_tls.get(thread_index())
857    }
858
859    /// Emit the tokens produced by BPE-encoding one regex-split piece.
860    #[inline]
861    fn emit_piece(&self, piece: &[u8], out: &mut Vec<Rank>) {
862        // Whole-piece fast path: most regex splits are already full tokens.
863        if let Some(&token) = self.encoder.get(piece) {
864            out.push(token);
865            return;
866        }
867        out.extend(byte_pair_encode(piece, &self.encoder));
868    }
869
870    /// Encode ordinary text, ignoring special tokens entirely.
871    ///
872    /// Any special-token substrings in the input will be tokenized as regular
873    /// text (this matches tiktoken's behavior).
874    pub fn encode_ordinary(&self, text: &str) -> Vec<Rank> {
875        // Rough pre-allocation: ~4 bytes per token is the average for English
876        // + code at the o200k_base vocab size. Overshoot a little to avoid the
877        // last realloc and undershoot hurts less than it helps (reallocs are
878        // cheap for the final doubling).
879        let mut ret = Vec::with_capacity(text.len() / 3 + 1);
880        self.split_engine.find_pieces(text, |piece| {
881            self.emit_piece(piece.as_bytes(), &mut ret);
882        });
883        ret
884    }
885
886    /// Encode many ordinary texts in parallel using rayon.
887    ///
888    /// Each text is tokenized on a rayon worker thread; the returned vector
889    /// preserves input order. Uses the global rayon pool, so the level of
890    /// parallelism is controlled by `RAYON_NUM_THREADS` or
891    /// `rayon::ThreadPoolBuilder`.
892    pub fn encode_ordinary_batch(&self, texts: &[&str]) -> Vec<Vec<Rank>> {
893        use rayon::prelude::*;
894        texts.par_iter().map(|t| self.encode_ordinary(t)).collect()
895    }
896
897    /// Encode many texts in parallel using rayon, honoring special tokens.
898    ///
899    /// See [`CoreBPE::encode`] for the special-token semantics.
900    pub fn encode_batch(&self, texts: &[&str], allowed_special: &HashSet<&str>) -> Vec<Vec<Rank>> {
901        use rayon::prelude::*;
902        texts
903            .par_iter()
904            .map(|t| self.encode(t, allowed_special))
905            .collect()
906    }
907
908    /// Encode text, allowing a specific set of special tokens.
909    ///
910    /// Special tokens in `allowed_special` are emitted as their assigned ranks.
911    /// Special tokens NOT in `allowed_special` are tokenized as ordinary text.
912    pub fn encode(&self, text: &str, allowed_special: &HashSet<&str>) -> Vec<Rank> {
913        let special_regex = match self.tl_special_regex() {
914            Some(r) => r,
915            // No special tokens registered at all — just do ordinary encoding.
916            None => return self.encode_ordinary(text),
917        };
918
919        let mut ret = Vec::new();
920        let mut start = 0usize;
921        loop {
922            // Find the next *allowed* special token.
923            let mut next_special: Option<(usize, usize)> = None;
924            let mut search_from = start;
925            while search_from <= text.len() {
926                match special_regex.find_from_pos(text, search_from) {
927                    Ok(Some(m)) => {
928                        if allowed_special.contains(&text[m.start()..m.end()]) {
929                            next_special = Some((m.start(), m.end()));
930                            break;
931                        }
932                        // Skip this match — move one char forward.
933                        search_from = m.start() + 1;
934                    }
935                    _ => break,
936                }
937            }
938
939            let end = next_special.map_or(text.len(), |(s, _)| s);
940
941            // Encode the ordinary text between [start, end).
942            self.split_engine.find_pieces(&text[start..end], |piece| {
943                self.emit_piece(piece.as_bytes(), &mut ret);
944            });
945
946            // Emit the special token (if any) and advance.
947            match next_special {
948                Some((s, e)) => {
949                    let piece = &text[s..e];
950                    if let Some(&tok) = self.special_tokens_encoder.get(piece) {
951                        ret.push(tok);
952                    }
953                    start = e;
954                }
955                None => break,
956            }
957        }
958        ret
959    }
960
961    /// Look up a single token by its byte sequence.
962    pub fn encode_single_token(&self, piece: &[u8]) -> Option<Rank> {
963        if let Some(&r) = self.encoder.get(piece) {
964            return Some(r);
965        }
966        if let Ok(s) = std::str::from_utf8(piece) {
967            if let Some(&r) = self.special_tokens_encoder.get(s) {
968                return Some(r);
969            }
970        }
971        None
972    }
973
974    /// Decode a sequence of tokens into the underlying bytes.
975    ///
976    /// Unknown token IDs are silently skipped — this matches tiktoken's
977    /// `decode_bytes` behavior. Use [`CoreBPE::decode_single_token_bytes`] if
978    /// you need strict validation.
979    pub fn decode_bytes(&self, tokens: &[Rank]) -> Vec<u8> {
980        let mut ret = Vec::with_capacity(tokens.len() * 2);
981        for &token in tokens {
982            if let Some(bytes) = self.decoder.get(&token) {
983                ret.extend_from_slice(bytes);
984            } else if let Some(bytes) = self.special_tokens_decoder.get(&token) {
985                ret.extend_from_slice(bytes);
986            }
987        }
988        ret
989    }
990
991    /// Decode tokens as a UTF-8 string.
992    ///
993    /// Returns [`DecodeError::InvalidUtf8`] if the concatenated bytes are not
994    /// valid UTF-8. This can happen mid-stream when a multi-byte character
995    /// spans a token boundary; prefer [`CoreBPE::decode_bytes`] for streaming.
996    pub fn decode(&self, tokens: &[Rank]) -> Result<String, DecodeError> {
997        String::from_utf8(self.decode_bytes(tokens)).map_err(|_| DecodeError::InvalidUtf8)
998    }
999
1000    /// Look up the bytes of a single token. Returns an error if the token is
1001    /// not in the vocabulary.
1002    pub fn decode_single_token_bytes(&self, token: Rank) -> Result<Vec<u8>, DecodeError> {
1003        if let Some(bytes) = self.decoder.get(&token) {
1004            return Ok(bytes.clone());
1005        }
1006        if let Some(bytes) = self.special_tokens_decoder.get(&token) {
1007            return Ok(bytes.clone());
1008        }
1009        Err(DecodeError::InvalidToken(token))
1010    }
1011}
1012
1013// --- PyO3 bindings ------------------------------------------------------------
1014
1015#[cfg(feature = "python")]
1016#[pymethods]
1017impl CoreBPE {
1018    #[new]
1019    #[pyo3(signature = (encoder, special_tokens_encoder, pattern))]
1020    fn py_new(
1021        encoder: HashMap<Vec<u8>, Rank>,
1022        special_tokens_encoder: HashMap<String, Rank>,
1023        pattern: &str,
1024    ) -> PyResult<Self> {
1025        Self::new(encoder, special_tokens_encoder, pattern)
1026            .map_err(|e| pyo3::exceptions::PyValueError::new_err(e.to_string()))
1027    }
1028
1029    #[pyo3(name = "encode_ordinary")]
1030    fn py_encode_ordinary(&self, py: Python<'_>, text: &str) -> Vec<Rank> {
1031        py.detach(|| self.encode_ordinary(text))
1032    }
1033
1034    #[pyo3(name = "encode", signature = (text, allowed_special = None))]
1035    fn py_encode(
1036        &self,
1037        py: Python<'_>,
1038        text: &str,
1039        allowed_special: Option<HashSet<String>>,
1040    ) -> Vec<Rank> {
1041        py.detach(|| {
1042            let allowed = allowed_special.unwrap_or_default();
1043            let allowed_refs: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect();
1044            self.encode(text, &allowed_refs)
1045        })
1046    }
1047
1048    #[pyo3(name = "encode_ordinary_batch")]
1049    fn py_encode_ordinary_batch(&self, py: Python<'_>, texts: Vec<String>) -> Vec<Vec<Rank>> {
1050        py.detach(|| {
1051            let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1052            self.encode_ordinary_batch(&refs)
1053        })
1054    }
1055
1056    #[pyo3(name = "encode_batch", signature = (texts, allowed_special = None))]
1057    fn py_encode_batch(
1058        &self,
1059        py: Python<'_>,
1060        texts: Vec<String>,
1061        allowed_special: Option<HashSet<String>>,
1062    ) -> Vec<Vec<Rank>> {
1063        py.detach(|| {
1064            let refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
1065            let allowed = allowed_special.unwrap_or_default();
1066            let allowed_refs: HashSet<&str> = allowed.iter().map(|s| s.as_str()).collect();
1067            self.encode_batch(&refs, &allowed_refs)
1068        })
1069    }
1070
1071    #[pyo3(name = "encode_single_token")]
1072    fn py_encode_single_token(&self, piece: &[u8]) -> PyResult<Rank> {
1073        self.encode_single_token(piece)
1074            .ok_or_else(|| pyo3::exceptions::PyKeyError::new_err("token not found"))
1075    }
1076
1077    #[pyo3(name = "decode_bytes")]
1078    fn py_decode_bytes<'py>(
1079        &self,
1080        py: Python<'py>,
1081        tokens: Vec<Rank>,
1082    ) -> pyo3::Bound<'py, pyo3::types::PyBytes> {
1083        let bytes = py.detach(|| self.decode_bytes(&tokens));
1084        pyo3::types::PyBytes::new(py, &bytes)
1085    }
1086
1087    /// Decode tokens into a Python `str`, matching `tiktoken.Encoding.decode`.
1088    ///
1089    /// Invalid UTF-8 sequences (which can occur mid-stream when a multi-byte
1090    /// character spans a token boundary) are replaced with U+FFFD, matching
1091    /// tiktoken's default `errors="replace"` behavior. For strict decoding or
1092    /// raw bytes, use [`decode_bytes`].
1093    #[pyo3(name = "decode")]
1094    fn py_decode(&self, py: Python<'_>, tokens: Vec<Rank>) -> String {
1095        py.detach(|| {
1096            let bytes = self.decode_bytes(&tokens);
1097            String::from_utf8_lossy(&bytes).into_owned()
1098        })
1099    }
1100
1101    #[pyo3(name = "decode_single_token_bytes")]
1102    fn py_decode_single_token_bytes<'py>(
1103        &self,
1104        py: Python<'py>,
1105        token: Rank,
1106    ) -> PyResult<pyo3::Bound<'py, pyo3::types::PyBytes>> {
1107        let bytes = self
1108            .decode_single_token_bytes(token)
1109            .map_err(|e| pyo3::exceptions::PyKeyError::new_err(e.to_string()))?;
1110        Ok(pyo3::types::PyBytes::new(py, &bytes))
1111    }
1112
1113    #[pyo3(name = "n_vocab")]
1114    fn py_n_vocab(&self) -> usize {
1115        self.n_vocab()
1116    }
1117
1118    #[pyo3(name = "token_byte_values")]
1119    fn py_token_byte_values<'py>(
1120        &self,
1121        py: Python<'py>,
1122    ) -> Vec<pyo3::Bound<'py, pyo3::types::PyBytes>> {
1123        self.sorted_token_bytes
1124            .iter()
1125            .map(|b| pyo3::types::PyBytes::new(py, b))
1126            .collect()
1127    }
1128}
1129
1130#[cfg(feature = "python")]
1131#[pymodule]
1132fn _riptoken(_py: Python, m: &Bound<PyModule>) -> PyResult<()> {
1133    m.add_class::<CoreBPE>()?;
1134    Ok(())
1135}
1136
1137// --- Unit tests ---------------------------------------------------------------
1138
1139#[cfg(test)]
1140mod tests {
1141    use super::*;
1142
1143    fn toy_bpe() -> CoreBPE {
1144        let mut encoder = HashMap::default();
1145        for (i, b) in b"helo ".iter().enumerate() {
1146            encoder.insert(vec![*b], i as Rank);
1147        }
1148        encoder.insert(b"he".to_vec(), 100);
1149        encoder.insert(b"ll".to_vec(), 101);
1150        CoreBPE::new(encoder, HashMap::default(), r"\w+| ").unwrap()
1151    }
1152
1153    #[test]
1154    fn merge_empty_piece() {
1155        let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
1156        let result = byte_pair_merge(&ranks, b"");
1157        assert_eq!(result, vec![(0, Rank::MAX), (0, Rank::MAX)]);
1158    }
1159
1160    #[test]
1161    fn merge_single_byte() {
1162        let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
1163        let result = byte_pair_merge(&ranks, b"a");
1164        assert_eq!(result, vec![(0, Rank::MAX), (1, Rank::MAX)]);
1165    }
1166
1167    #[test]
1168    fn merge_two_byte_exact_match() {
1169        let mut ranks = HashMap::default();
1170        ranks.insert(b"ab".to_vec(), 5);
1171        let result = byte_pair_merge(&ranks, b"ab");
1172        let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
1173        assert_eq!(positions, vec![0, 2]);
1174    }
1175
1176    #[test]
1177    fn merge_no_vocab_matches() {
1178        let ranks: HashMap<Vec<u8>, Rank> = HashMap::default();
1179        let result = byte_pair_merge(&ranks, b"abcd");
1180        let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
1181        // No merges possible — each byte is its own token.
1182        assert_eq!(positions, vec![0, 1, 2, 3, 4]);
1183    }
1184
1185    #[test]
1186    fn merge_cascade() {
1187        let mut ranks = HashMap::default();
1188        ranks.insert(b"ab".to_vec(), 0);
1189        ranks.insert(b"cd".to_vec(), 1);
1190        let result = byte_pair_merge(&ranks, b"abcd");
1191        let positions: Vec<usize> = result.iter().map(|&(p, _)| p).collect();
1192        assert_eq!(positions, vec![0, 2, 4]);
1193    }
1194
1195    #[test]
1196    fn encode_toy() {
1197        let bpe = toy_bpe();
1198        let tokens = bpe.encode_ordinary("hello");
1199        // "he"=100, "ll"=101, "o"=3
1200        assert_eq!(tokens, vec![100, 101, 3]);
1201    }
1202
1203    #[test]
1204    fn roundtrip_toy() {
1205        let bpe = toy_bpe();
1206        let text = "hello";
1207        let tokens = bpe.encode_ordinary(text);
1208        let decoded = bpe.decode_bytes(&tokens);
1209        assert_eq!(decoded, text.as_bytes());
1210        assert_eq!(bpe.decode(&tokens).unwrap(), text);
1211    }
1212
1213    #[test]
1214    fn encode_single_token_and_lookup() {
1215        let bpe = toy_bpe();
1216        assert_eq!(bpe.encode_single_token(b"he"), Some(100));
1217        assert_eq!(bpe.encode_single_token(b"zz"), None);
1218        assert_eq!(bpe.decode_single_token_bytes(100).unwrap(), b"he".to_vec());
1219        assert!(bpe.decode_single_token_bytes(9999).is_err());
1220    }
1221
1222    #[test]
1223    fn n_vocab_counts_everything() {
1224        let mut encoder = HashMap::default();
1225        encoder.insert(b"a".to_vec(), 0);
1226        encoder.insert(b"b".to_vec(), 1);
1227        let mut specials = HashMap::default();
1228        specials.insert("<|endoftext|>".to_string(), 2);
1229        let bpe = CoreBPE::new(encoder, specials, r"\w+").unwrap();
1230        assert_eq!(bpe.n_vocab(), 3);
1231    }
1232
1233    #[test]
1234    fn encode_with_allowed_special() {
1235        let mut encoder = HashMap::default();
1236        for b in b"abcdefghijklmnopqrstuvwxyz <>|" {
1237            encoder.insert(vec![*b], *b as Rank);
1238        }
1239        let mut specials = HashMap::default();
1240        specials.insert("<|eot|>".to_string(), 999);
1241        let bpe = CoreBPE::new(encoder, specials, r"\w+|[<|>]").unwrap();
1242
1243        let allowed: HashSet<&str> = std::iter::once("<|eot|>").collect();
1244        let tokens = bpe.encode("ab<|eot|>cd", &allowed);
1245        assert!(tokens.contains(&999));
1246
1247        // When not allowed, the special string is tokenized as ordinary text
1248        // and the special rank does NOT appear.
1249        let empty: HashSet<&str> = HashSet::new();
1250        let tokens = bpe.encode("ab<|eot|>cd", &empty);
1251        assert!(!tokens.contains(&999));
1252    }
1253
1254    #[test]
1255    fn fast_engine_kicks_in_on_tiktoken_patterns() {
1256        // The o200k_base pattern contains `\s+(?!\S)` which fancy-regex can
1257        // compile but regex crate cannot. The transformation should strip it
1258        // and produce a Fast engine (precompiled or lazy DFA).
1259        let o200k = r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+";
1260        let engine = SplitEngine::new(o200k).unwrap();
1261        assert!(engine.is_fast(), "o200k_base should use fast engine");
1262
1263        // Simple pattern with no lookarounds should also use fast.
1264        let simple = SplitEngine::new(r"\w+|\s+").unwrap();
1265        assert!(simple.is_fast());
1266    }
1267
1268    #[test]
1269    #[cfg(feature = "precompiled-dfa")]
1270    fn prebuilt_dfa_used_for_stock_patterns() {
1271        // Stock tiktoken patterns should hit the prebuilt DFA path.
1272        let o200k_raw = concat!(
1273            r"[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
1274            r"|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*(?i:'s|'t|'re|'ve|'m|'ll|'d)?",
1275            r"|\p{N}{1,3}",
1276            r"| ?[^\s\p{L}\p{N}]+[\r\n/]*",
1277            r"|\s*[\r\n]+",
1278            r"|\s+(?!\S)|\s+",
1279        );
1280        let engine = SplitEngine::new(o200k_raw).unwrap();
1281        assert!(
1282            engine.is_precompiled(),
1283            "o200k_base stock pattern should use prebuilt DFA"
1284        );
1285
1286        let gpt2_raw =
1287            r"'(?:[sdmt]|ll|ve|re)| ?\p{L}++| ?\p{N}++| ?[^\s\p{L}\p{N}]++|\s++$|\s+(?!\S)|\s";
1288        let engine = SplitEngine::new(gpt2_raw).unwrap();
1289        assert!(
1290            engine.is_precompiled(),
1291            "gpt2 stock pattern should use prebuilt DFA"
1292        );
1293
1294        let cl100k_raw = r"'(?i:[sdmt]|ll|ve|re)|[^\r\n\p{L}\p{N}]?+\p{L}++|\p{N}{1,3}+| ?[^\s\p{L}\p{N}]++[\r\n]*+|\s++$|\s*[\r\n]|\s+(?!\S)|\s";
1295        let engine = SplitEngine::new(cl100k_raw).unwrap();
1296        assert!(
1297            engine.is_precompiled(),
1298            "cl100k_base stock pattern should use prebuilt DFA"
1299        );
1300
1301        // Non-stock pattern still uses the fast path (eager DFA build).
1302        let custom = SplitEngine::new(r"\w+|\s+").unwrap();
1303        assert!(custom.is_fast(), "custom pattern should use fast engine");
1304    }
1305
1306    #[test]
1307    fn whitespace_shrink_matches_tiktoken_behavior() {
1308        // Build a toy BPE with a pattern that mimics the structure tiktoken
1309        // uses: word with optional leading non-word char, or whitespace.
1310        let mut encoder: HashMap<Vec<u8>, Rank> = HashMap::default();
1311        for b in 0u8..=255 {
1312            encoder.insert(vec![b], b as Rank);
1313        }
1314        // Inject a token for " hello" so we can detect the whitespace-attach
1315        // behavior via the whole-piece fast path.
1316        encoder.insert(b" hello".to_vec(), 1000);
1317        encoder.insert(b"hello".to_vec(), 1001);
1318
1319        // Pattern mirrors tiktoken's shape: optional non-word + letters,
1320        // then the `\s+(?!\S)|\s+` whitespace tail that exercises the
1321        // fast-path shrink rule.
1322        let pattern = r"[^\r\n\p{L}\p{N}]?\p{L}+|\s+(?!\S)|\s+";
1323        let bpe = CoreBPE::new(encoder, HashMap::default(), pattern).unwrap();
1324        assert!(bpe.split_engine.is_fast());
1325
1326        // "  hello" should tokenize as [" ", " hello"] — the first whitespace
1327        // is a standalone piece (one space) and the second attaches to the
1328        // following word as " hello".
1329        let tokens = bpe.encode_ordinary("  hello");
1330        assert_eq!(
1331            tokens,
1332            vec![b' ' as Rank, 1000],
1333            "fast path should replicate `\\s+(?!\\S)` whitespace-shrink behavior"
1334        );
1335
1336        // "hello " (trailing whitespace) should tokenize as ["hello", " "]
1337        // — the trailing whitespace has no following word so it stays full.
1338        let tokens = bpe.encode_ordinary("hello ");
1339        assert_eq!(tokens, vec![1001, b' ' as Rank]);
1340    }
1341
1342    #[test]
1343    fn whitespace_shrink_unified_mode_includes_newlines() {
1344        // gpt2 / r50k / p50k-style pattern: no separate `\s*[\r\n]+`
1345        // alternative, so `\s+(?!\S)` must absorb runs that contain a
1346        // newline and the shrink rule must fire on them. Regression test
1347        // for parity with tiktoken on gpt2-family encodings.
1348        let mut encoder: HashMap<Vec<u8>, Rank> = HashMap::default();
1349        for b in 0u8..=255 {
1350            encoder.insert(vec![b], b as Rank);
1351        }
1352        encoder.insert(b" hello".to_vec(), 1000);
1353        encoder.insert(b"hello".to_vec(), 1001);
1354
1355        // Note the bare `\s` at the tail — this is the gpt2 shape.
1356        let pattern = r" ?\p{L}+|\s+$|\s+(?!\S)|\s";
1357        let bpe = CoreBPE::new(encoder, HashMap::default(), pattern).unwrap();
1358        assert!(bpe.split_engine.is_fast());
1359
1360        // "\n  hello" — the whitespace run `\n  ` contains a newline but
1361        // must still be shrunk to `\n ` so the trailing space can attach
1362        // to `hello` as ` hello`. Expected pieces: ["\n ", " hello"].
1363        let tokens = bpe.encode_ordinary("\n  hello");
1364        assert_eq!(
1365            tokens,
1366            vec![b'\n' as Rank, b' ' as Rank, 1000],
1367            "unified shrink mode must fire on whitespace runs that include newlines"
1368        );
1369
1370        // "\n" at end of input should stay whole — `\s+$` consumes it, and
1371        // with no following non-whitespace there's nothing to shrink for.
1372        let tokens = bpe.encode_ordinary("hi\n");
1373        assert_eq!(tokens, vec![b'h' as Rank, b'i' as Rank, b'\n' as Rank]);
1374    }
1375
1376    #[test]
1377    fn batch_encode_matches_sequential() {
1378        let bpe = toy_bpe();
1379        let texts = vec!["hello", "hello world", "the lazy fox"];
1380        let batch = bpe.encode_ordinary_batch(&texts);
1381        let seq: Vec<Vec<Rank>> = texts.iter().map(|t| bpe.encode_ordinary(t)).collect();
1382        assert_eq!(batch, seq);
1383
1384        // encode_batch with empty allowed set should equal encode_ordinary_batch
1385        let empty: HashSet<&str> = HashSet::new();
1386        let batch_sp = bpe.encode_batch(&texts, &empty);
1387        assert_eq!(batch_sp, seq);
1388    }
1389
1390    #[test]
1391    fn large_piece_matches_small_piece() {
1392        // Cross-validation: the heap path should produce the same tokens
1393        // as the Vec path on pieces that exercise both.
1394        let mut ranks = HashMap::default();
1395        // Byte fallback
1396        for b in 0u8..=255 {
1397            ranks.insert(vec![b], b as Rank);
1398        }
1399        // A few merges
1400        ranks.insert(b"ab".to_vec(), 300);
1401        ranks.insert(b"cd".to_vec(), 301);
1402        ranks.insert(b"abcd".to_vec(), 302);
1403
1404        let piece = b"abcdabcdabcdabcd";
1405        let small = {
1406            let pos = byte_pair_merge(&ranks, piece);
1407            pos.windows(2)
1408                .map(|w| rank_of(&ranks, &piece[w[0].0..w[1].0]))
1409                .collect::<Vec<_>>()
1410        };
1411        let large = byte_pair_merge_large(&ranks, piece);
1412        assert_eq!(small, large, "heap and vec paths disagree");
1413    }
1414}