Skip to main content

codec_rs/
pretok_program.rs

1// SPDX-License-Identifier: MIT
2//! Pre-tokenizer program interpreter.
3//!
4//! Executes a [`PreTokProgram`] against an input string, producing the
5//! same sequence of pieces that the legacy `pre_tokenizer_pattern` regex
6//! would have produced. Mirror of `@codecai/web`'s `pretok-program.ts`
7//! and `codecai`'s `pretok_program.py`; see
8//! [`spec/PRETOKENIZER_PROGRAM.md`](https://github.com/wdunn001/Codec/blob/main/spec/PRETOKENIZER_PROGRAM.md)
9//! for the design rationale and op set.
10//!
11//! Why this exists in the Rust client: the `regex` crate doesn't support
12//! lookaround (`\s+(?!\S)`) or ES2025 RegExp Pattern Modifiers
13//! (`(?i:...)`), both of which appear in every GPT-2-family
14//! `pre_tokenizer_pattern`. Without the program interpreter, the Rust
15//! `BPETokenizer` constructor fails before encode() runs on every
16//! shipped Qwen / Llama-3 / Phi-4 / cl100k_base map. With the
17//! interpreter, the program path bypasses regex entirely and the same
18//! maps tokenise byte-for-byte against HuggingFace.
19
20use regex::Regex;
21use serde::{Deserialize, Serialize};
22use std::sync::OnceLock;
23
24use crate::byte_encoder::METASPACE;
25
26// ── Op types ────────────────────────────────────────────────────────────────
27
28/// One op in a [`PreTokProgram`]. See module-level docs for semantics.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(tag = "op", rename_all = "snake_case")]
31pub enum PreTokOp {
32    /// `(?i:p1|p2|...)` — match the longest case-insensitive literal.
33    LiteralsCi { patterns: Vec<String> },
34    /// Case-sensitive literal alternatives — like `LiteralsCi` but matches
35    /// case-exact. Used by older OpenAI tokenizers (p50k_base, r50k_base).
36    Literals { patterns: Vec<String> },
37    /// `\p{L}+`, `[^\r\n\p{L}\p{N}]?\p{L}+` when `lead_other`, or
38    /// ` ?\p{L}+` when `lead_space`. The two lead flags are mutually
39    /// exclusive — `lead_space` is the older-OpenAI shape, `lead_other`
40    /// is the GPT-2 / Qwen / Llama-3 shape.
41    Letters {
42        #[serde(default, skip_serializing_if = "Option::is_none")]
43        lead_other: Option<bool>,
44        #[serde(default, skip_serializing_if = "Option::is_none")]
45        lead_space: Option<bool>,
46    },
47    /// `\p{N}+` (unbounded) or `\p{N}{1,K}` when `max_run > 0`; with optional
48    /// ` ?` literal-space lead for older OpenAI tokenizers.
49    Numbers {
50        #[serde(default, skip_serializing_if = "Option::is_none")]
51        max_run: Option<u32>,
52        #[serde(default, skip_serializing_if = "Option::is_none")]
53        lead_space: Option<bool>,
54    },
55    /// `[ ?][^\s\p{L}\p{N}]+[\r\n]*` with toggleable lead-space and
56    /// trailing-newlines.
57    PunctRun {
58        #[serde(default, skip_serializing_if = "Option::is_none")]
59        lead_space: Option<bool>,
60        #[serde(default, skip_serializing_if = "Option::is_none")]
61        trailing_newlines: Option<bool>,
62        /// Override `trailing_newlines` with an explicit charset string.
63        /// Each character is accepted in the trailing run. Used by
64        /// o200k_base / mistral-nemo whose trailing is `[\r\n/]`.
65        #[serde(default, skip_serializing_if = "Option::is_none")]
66        trailing_chars: Option<String>,
67    },
68    /// Cased-letter run with optional trailing case-insensitive contractions.
69    /// Used by o200k_base / mistral-nemo, which split on case boundaries.
70    /// `kind: "title"` matches `[Lu Lt Lm Lo M]* [Ll Lm Lo M]+`,
71    /// `kind: "upper"` matches `[Lu Lt Lm Lo M]+ [Ll Lm Lo M]*`.
72    LettersCased {
73        kind: CasedKind,
74        #[serde(default, skip_serializing_if = "Option::is_none")]
75        lead_other: Option<bool>,
76        #[serde(default, skip_serializing_if = "Option::is_none")]
77        trailing_ci: Option<Vec<String>>,
78    },
79    /// `\s*[\r\n]+` — paragraph break with leading indentation.
80    NewlineBlock {},
81    /// `\s+(?!\S)` — whitespace at end of input (or with only more ws after).
82    TrailingWs {},
83    /// `\s+` — generic whitespace catchall (always last in GPT-2 programs).
84    WsRun {},
85    /// SentencePiece-style splitter — single-op programs only.
86    MetaspaceSplit {
87        #[serde(default, skip_serializing_if = "Option::is_none")]
88        prefix_first: Option<bool>,
89    },
90}
91
92/// "Title" or "upper" cased-letter shape — see [`PreTokOp::LettersCased`].
93#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
94#[serde(rename_all = "snake_case")]
95pub enum CasedKind {
96    /// `[Lu Lt Lm Lo M]* [Ll Lm Lo M]+` — zero-or-more upper, then 1+ lower.
97    Title,
98    /// `[Lu Lt Lm Lo M]+ [Ll Lm Lo M]*` — one-or-more upper, then 0+ lower.
99    Upper,
100}
101
102/// A compiled pre-tokenizer program. Carried alongside the legacy
103/// `pre_tokenizer_pattern` on v2.1+ maps. Runtimes prefer the program
104/// when present; falls back to the regex otherwise.
105#[derive(Debug, Clone, Serialize, Deserialize)]
106pub struct PreTokProgram {
107    pub version: u32,
108    pub ops: Vec<PreTokOp>,
109}
110
111// ── Class predicates ────────────────────────────────────────────────────────
112
113fn re_letter() -> &'static Regex {
114    static R: OnceLock<Regex> = OnceLock::new();
115    R.get_or_init(|| Regex::new(r"\p{L}").unwrap())
116}
117fn re_number() -> &'static Regex {
118    static R: OnceLock<Regex> = OnceLock::new();
119    R.get_or_init(|| Regex::new(r"\p{N}").unwrap())
120}
121fn re_ws() -> &'static Regex {
122    static R: OnceLock<Regex> = OnceLock::new();
123    R.get_or_init(|| Regex::new(r"\s").unwrap())
124}
125fn re_letter_upper() -> &'static Regex {
126    static R: OnceLock<Regex> = OnceLock::new();
127    R.get_or_init(|| Regex::new(r"[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]").unwrap())
128}
129fn re_letter_lower() -> &'static Regex {
130    static R: OnceLock<Regex> = OnceLock::new();
131    R.get_or_init(|| Regex::new(r"[\p{Ll}\p{Lm}\p{Lo}\p{M}]").unwrap())
132}
133
134fn is_letter(cp: char) -> bool {
135    let mut buf = [0u8; 4];
136    re_letter().is_match(cp.encode_utf8(&mut buf))
137}
138fn is_number(cp: char) -> bool {
139    let mut buf = [0u8; 4];
140    re_number().is_match(cp.encode_utf8(&mut buf))
141}
142fn is_ws(cp: char) -> bool {
143    let mut buf = [0u8; 4];
144    re_ws().is_match(cp.encode_utf8(&mut buf))
145}
146fn is_letter_upper(cp: char) -> bool {
147    let mut buf = [0u8; 4];
148    re_letter_upper().is_match(cp.encode_utf8(&mut buf))
149}
150fn is_letter_lower(cp: char) -> bool {
151    let mut buf = [0u8; 4];
152    re_letter_lower().is_match(cp.encode_utf8(&mut buf))
153}
154
155// ── Per-op matchers ────────────────────────────────────────────────────────
156//
157// Each returns the byte count consumed at position `i`, or 0 if no match.
158
159fn match_literals_ci(patterns: &[String], text: &str, i: usize) -> usize {
160    let rest = &text[i..];
161    let rest_bytes = rest.as_bytes();
162    let mut best = 0;
163    for p in patterns {
164        if p.len() <= best || rest.len() < p.len() {
165            continue;
166        }
167        // Byte-wise ASCII case-fold compare. Avoids slicing `rest` at a
168        // potentially-non-char-boundary when `p.len()` falls inside a
169        // multibyte codepoint (CJK / emoji).
170        let p_bytes = p.as_bytes();
171        let mut ok = true;
172        for k in 0..p.len() {
173            let a = rest_bytes[k];
174            let b = p_bytes[k];
175            if a == b { continue; }
176            if a.is_ascii_uppercase() && a + 32 == b { continue; }
177            if a.is_ascii_lowercase() && a - 32 == b { continue; }
178            ok = false;
179            break;
180        }
181        if ok {
182            best = p.len();
183        }
184    }
185    best
186}
187
188fn match_literals(patterns: &[String], text: &str, i: usize) -> usize {
189    let rest = &text[i..];
190    let bytes = rest.as_bytes();
191    let mut best = 0;
192    for p in patterns {
193        if p.len() <= best || rest.len() < p.len() {
194            continue;
195        }
196        // Byte-wise compare avoids slicing rest at a non-char-boundary —
197        // the patterns are ASCII so it's safe even when `rest` starts with
198        // a multibyte codepoint like a CJK char. Without this, `&rest[..p.len()]`
199        // panics when `p.len()` falls inside a multibyte codepoint.
200        if bytes[..p.len()] == p.as_bytes()[..] {
201            best = p.len();
202        }
203    }
204    best
205}
206
207fn match_letters(lead_other: bool, lead_space: bool, text: &str, i: usize) -> usize {
208    let rest = &text[i..];
209    let mut chars = rest.char_indices().peekable();
210    let mut p = 0usize;
211    if lead_other {
212        // `[^\r\n\p{L}\p{N}]?` — at most one char that is none of those.
213        if let Some(&(_off, c)) = chars.peek() {
214            if c != '\r' && c != '\n' && !is_letter(c) && !is_number(c) {
215                p = c.len_utf8();
216                chars.next();
217            }
218        }
219    } else if lead_space {
220        // ` ?` — at most one literal space.
221        if let Some(&(_off, c)) = chars.peek() {
222            if c == ' ' {
223                p = c.len_utf8();
224                chars.next();
225            }
226        }
227    }
228    // `\p{L}+`
229    let run_start = p;
230    while let Some(&(_off, c)) = chars.peek() {
231        if !is_letter(c) {
232            break;
233        }
234        p += c.len_utf8();
235        chars.next();
236    }
237    if p == run_start {
238        0
239    } else {
240        p
241    }
242}
243
244fn match_numbers(max_run: u32, lead_space: bool, text: &str, i: usize) -> usize {
245    let max = if max_run == 0 { u32::MAX } else { max_run };
246    let mut p = 0usize;
247    let bytes = text.as_bytes();
248    if lead_space && i + p < bytes.len() && bytes[i + p] == b' ' {
249        p += 1;
250    }
251    let run_start = p;
252    let mut count = 0u32;
253    for c in text[i + p..].chars() {
254        if count >= max || !is_number(c) {
255            break;
256        }
257        p += c.len_utf8();
258        count += 1;
259    }
260    if p == run_start { 0 } else { p }
261}
262
263fn match_punct_run(
264    lead_space: bool,
265    trailing_newlines: bool,
266    trailing_chars: Option<&str>,
267    text: &str,
268    i: usize,
269) -> usize {
270    let bytes = text.as_bytes();
271    let mut p = i;
272    if lead_space && p < bytes.len() && bytes[p] == b' ' {
273        p += 1;
274    }
275    // `[^\s\p{L}\p{N}]+`
276    let run_start = p;
277    for c in text[p..].chars() {
278        if is_ws(c) || is_letter(c) || is_number(c) {
279            break;
280        }
281        p += c.len_utf8();
282    }
283    if p == run_start {
284        return 0;
285    }
286    // Trailing chars: prefer explicit charset when set, otherwise legacy
287    // boolean → `\r\n` only.
288    if let Some(chars) = trailing_chars {
289        loop {
290            let Some(c) = text[p..].chars().next() else { break };
291            if !chars.contains(c) {
292                break;
293            }
294            p += c.len_utf8();
295        }
296    } else if trailing_newlines {
297        while p < bytes.len() && (bytes[p] == b'\n' || bytes[p] == b'\r') {
298            p += 1;
299        }
300    }
301    p - i
302}
303
304fn match_letters_cased(
305    kind: CasedKind,
306    lead_other: bool,
307    trailing_ci: Option<&[String]>,
308    text: &str,
309    i: usize,
310) -> usize {
311    let mut p = i;
312    if lead_other {
313        if let Some(c) = text[p..].chars().next() {
314            if c != '\r' && c != '\n' && !is_letter(c) && !is_number(c) {
315                p += c.len_utf8();
316            }
317        }
318    }
319
320    // Greedy prefix run; record each step as a candidate suffix-start.
321    // Lm/Lo/M are in BOTH sets so the longest overall match may need
322    // to back off the prefix run to let the suffix consume them.
323    let mut checkpoints: Vec<usize> = vec![p];
324    while let Some(c) = text[p..].chars().next() {
325        if !is_letter_upper(c) {
326            break;
327        }
328        p += c.len_utf8();
329        checkpoints.push(p);
330    }
331
332    let (min_prefix, min_suffix): (usize, usize) = match kind {
333        CasedKind::Upper => (1, 0),
334        CasedKind::Title => (0, 1),
335    };
336
337    // Try suffix from each checkpoint, longest-prefix first. First success wins.
338    for k in (0..checkpoints.len()).rev() {
339        if k < min_prefix {
340            break;
341        }
342        let mut q = checkpoints[k];
343        let mut suffix_count = 0usize;
344        while let Some(c) = text[q..].chars().next() {
345            if !is_letter_lower(c) {
346                break;
347            }
348            q += c.len_utf8();
349            suffix_count += 1;
350        }
351        if suffix_count < min_suffix {
352            continue;
353        }
354
355        // Optional case-insensitive trailing-contractions match, longest wins.
356        if let Some(patterns) = trailing_ci {
357            let rest = &text[q..];
358            let rest_bytes = rest.as_bytes();
359            let mut best = 0usize;
360            for pat in patterns {
361                if pat.len() <= best || rest.len() < pat.len() {
362                    continue;
363                }
364                let p_bytes = pat.as_bytes();
365                let mut ok = true;
366                for k in 0..pat.len() {
367                    let a = rest_bytes[k];
368                    let b = p_bytes[k];
369                    if a == b {
370                        continue;
371                    }
372                    if a.is_ascii_uppercase() && a + 32 == b {
373                        continue;
374                    }
375                    if a.is_ascii_lowercase() && a - 32 == b {
376                        continue;
377                    }
378                    ok = false;
379                    break;
380                }
381                if ok {
382                    best = pat.len();
383                }
384            }
385            q += best;
386        }
387
388        return q - i;
389    }
390    0
391}
392
393fn match_newline_block(text: &str, i: usize) -> usize {
394    // `\s*[\r\n]+` — greedy `\s*`, then back off until the trailing run is
395    // contiguous newlines.
396    let mut p = 0usize;
397    for c in text[i..].chars() {
398        if !is_ws(c) {
399            break;
400        }
401        p += c.len_utf8();
402    }
403    let bytes = text.as_bytes();
404    // Find the first newline within [i, i+p).
405    let mut first_nl: Option<usize> = None;
406    for q in i..(i + p) {
407        if bytes[q] == b'\n' || bytes[q] == b'\r' {
408            first_nl = Some(q);
409            break;
410        }
411    }
412    let Some(first_nl) = first_nl else { return 0 };
413    // Trim back from end while we see non-newline whitespace.
414    let mut q = i + p;
415    while q > first_nl {
416        let c = bytes[q - 1];
417        if c == b'\n' || c == b'\r' {
418            break;
419        }
420        q -= 1;
421    }
422    q - i
423}
424
425fn match_trailing_ws(text: &str, i: usize) -> usize {
426    // `\s+(?!\S)`: longest whitespace run ending either at EOI or one
427    // code point before a final whitespace.
428    let mut p = i;
429    for c in text[i..].chars() {
430        if !is_ws(c) {
431            break;
432        }
433        p += c.len_utf8();
434    }
435    if p == i {
436        return 0;
437    }
438    if p == text.len() {
439        return p - i;
440    }
441    // Trailing non-ws follows; trim before the LAST whitespace code point.
442    let mut q = i;
443    let mut last_start = i;
444    while q < p {
445        last_start = q;
446        let c = text[q..].chars().next().unwrap();
447        q += c.len_utf8();
448    }
449    last_start - i
450}
451
452fn match_ws_run(text: &str, i: usize) -> usize {
453    let mut p = 0usize;
454    for c in text[i..].chars() {
455        if !is_ws(c) {
456            break;
457        }
458        p += c.len_utf8();
459    }
460    p
461}
462
463// ── Interpreter loop ────────────────────────────────────────────────────────
464
465/// Execute `program` against `text`, returning the same piece sequence
466/// the legacy regex pre-tokenizer would have emitted.
467pub fn run_pretok_program(program: &PreTokProgram, text: &str) -> Vec<String> {
468    // Single-op metaspace shortcut.
469    if program.ops.len() == 1 {
470        if let PreTokOp::MetaspaceSplit { prefix_first } = &program.ops[0] {
471            return run_metaspace(prefix_first.unwrap_or(false), text);
472        }
473    }
474
475    let mut out: Vec<String> = Vec::new();
476    let bytes = text.as_bytes();
477    let n = bytes.len();
478    let mut i = 0usize;
479    'outer: while i < n {
480        for op in &program.ops {
481            let span = match op {
482                PreTokOp::LiteralsCi { patterns } => match_literals_ci(patterns, text, i),
483                PreTokOp::Literals { patterns } => match_literals(patterns, text, i),
484                PreTokOp::Letters {
485                    lead_other,
486                    lead_space,
487                } => match_letters(
488                    lead_other.unwrap_or(false),
489                    lead_space.unwrap_or(false),
490                    text,
491                    i,
492                ),
493                PreTokOp::Numbers {
494                    max_run,
495                    lead_space,
496                } => match_numbers(
497                    max_run.unwrap_or(0),
498                    lead_space.unwrap_or(false),
499                    text,
500                    i,
501                ),
502                PreTokOp::PunctRun {
503                    lead_space,
504                    trailing_newlines,
505                    trailing_chars,
506                } => match_punct_run(
507                    lead_space.unwrap_or(false),
508                    trailing_newlines.unwrap_or(false),
509                    trailing_chars.as_deref(),
510                    text,
511                    i,
512                ),
513                PreTokOp::LettersCased {
514                    kind,
515                    lead_other,
516                    trailing_ci,
517                } => match_letters_cased(
518                    *kind,
519                    lead_other.unwrap_or(false),
520                    trailing_ci.as_deref(),
521                    text,
522                    i,
523                ),
524                PreTokOp::NewlineBlock {} => match_newline_block(text, i),
525                PreTokOp::TrailingWs {} => match_trailing_ws(text, i),
526                PreTokOp::WsRun {} => match_ws_run(text, i),
527                PreTokOp::MetaspaceSplit { .. } => 0, // mixed programs unsupported
528            };
529            if span > 0 {
530                out.push(text[i..i + span].to_string());
531                i += span;
532                continue 'outer;
533            }
534        }
535        // Defensive: no op matched. Consume one scalar value.
536        let c = text[i..].chars().next().unwrap();
537        out.push(c.to_string());
538        i += c.len_utf8();
539    }
540    out
541}
542
543fn run_metaspace(prefix_first: bool, text: &str) -> Vec<String> {
544    let mut out: Vec<String> = Vec::new();
545    let mut buf = String::new();
546    // Collapse `[ \t]+` to a single space, then split on whitespace
547    // retaining each ws char.
548    let mut prev_horiz_ws = false;
549    for c in text.chars() {
550        if c == ' ' || c == '\t' {
551            if !prev_horiz_ws {
552                buf.push(' ');
553                prev_horiz_ws = true;
554            }
555        } else {
556            buf.push(c);
557            prev_horiz_ws = false;
558        }
559    }
560    let mut is_first = true;
561    let mut piece = String::new();
562    for c in buf.chars() {
563        if c.is_whitespace() {
564            if !piece.is_empty() {
565                if prefix_first && is_first {
566                    out.push(std::mem::take(&mut piece));
567                } else {
568                    let mut s = String::with_capacity(piece.len() + 3);
569                    s.push(METASPACE);
570                    s.push_str(&piece);
571                    out.push(s);
572                    piece.clear();
573                }
574                is_first = false;
575            }
576            if c == ' ' {
577                is_first = false;
578            }
579        } else {
580            piece.push(c);
581        }
582    }
583    if !piece.is_empty() {
584        if prefix_first && is_first {
585            out.push(piece);
586        } else {
587            let mut s = String::with_capacity(piece.len() + 3);
588            s.push(METASPACE);
589            s.push_str(&piece);
590            out.push(s);
591        }
592    }
593    out
594}
595
596#[cfg(test)]
597mod tests {
598    use super::*;
599
600    fn qwen_program() -> PreTokProgram {
601        PreTokProgram {
602            version: 1,
603            ops: vec![
604                PreTokOp::LiteralsCi {
605                    patterns: vec![
606                        "'s".into(),
607                        "'t".into(),
608                        "'re".into(),
609                        "'ve".into(),
610                        "'m".into(),
611                        "'ll".into(),
612                        "'d".into(),
613                    ],
614                },
615                PreTokOp::Letters {
616                    lead_other: Some(true),
617                    lead_space: None,
618                },
619                PreTokOp::Numbers {
620                    max_run: None,
621                    lead_space: None,
622                },
623                PreTokOp::PunctRun {
624                    lead_space: Some(true),
625                    trailing_newlines: Some(true),
626                    trailing_chars: None,
627                },
628                PreTokOp::NewlineBlock {},
629                PreTokOp::TrailingWs {},
630                PreTokOp::WsRun {},
631            ],
632        }
633    }
634
635    #[test]
636    fn qwen_program_splits_basic_text() {
637        let p = qwen_program();
638        let out = run_pretok_program(&p, "Hello, world!");
639        assert_eq!(out, vec!["Hello", ",", " world", "!"]);
640    }
641
642    #[test]
643    fn qwen_program_handles_contractions() {
644        let p = qwen_program();
645        let out = run_pretok_program(&p, "it's");
646        assert_eq!(out, vec!["it", "'s"]);
647    }
648
649    #[test]
650    fn qwen_program_unbounded_digits() {
651        let p = qwen_program();
652        // Unbounded `numbers` op consumes the whole digit run as one piece.
653        let out = run_pretok_program(&p, "abc 12345 def");
654        assert_eq!(out, vec!["abc", " ", "12345", " def"]);
655    }
656}