Skip to main content

oxibonsai_runtime/grammar/
regex_compiler.rs

1//! Regex → BNF Grammar compiler.
2//!
3//! Compiles a regular expression pattern string into a [`Grammar`] that is
4//! usable with [`super::constraint::GrammarConstraint`] for constrained decoding.
5//!
6//! # Algorithm
7//!
8//! 1. **Regex parser → Thompson NFA**: Builds a Nondeterministic Finite
9//!    Automaton via Thompson construction from the parsed regex AST.
10//! 2. **Subset DFA construction**: Converts the NFA to a DFA via powerset
11//!    construction (ε-closure + transition computation).
12//! 3. **DFA → Grammar**: Each DFA state becomes a non-terminal; transitions
13//!    become single-byte terminal rules; accept states emit ε-productions.
14//!
15//! # Supported regex features
16//!
17//! - Literals: any byte literal
18//! - `.` — any byte except `\n` (0x0A)
19//! - `[abc]`, `[a-z]`, `[^abc]` — character classes
20//! - `*`, `+`, `?` — greedy quantifiers
21//! - `{n}`, `{n,}`, `{n,m}` — counted quantifiers
22//! - `|` — alternation
23//! - `(...)` — grouping (non-capturing)
24//! - Anchors: `^` (start) and `$` (end) are silently ignored
25//! - Escape sequences: `\d`, `\w`, `\s`, `\D`, `\W`, `\S`, `\n`, `\r`, `\t`,
26//!   `\.`, `\\`, `\[`, `\]`, `\(`, `\)`, `\*`, `\+`, `\?`, `\{`, `\}`, `\|`
27//!
28//! # Unsupported (returns `RegexCompileError::UnsupportedFeature`)
29//!
30//! - Backreferences `\1`, `\2`, …
31//! - Lookahead/lookbehind: `(?=...)`, `(?!...)`, `(?<=...)`, `(?<!...)`
32//! - Named groups: `(?P<name>...)`, `(?<name>...)`
33//! - Atomic groups, possessive quantifiers
34//! - Unicode properties `\p{Letter}`
35
36use std::collections::{BTreeSet, HashMap, VecDeque};
37
38use super::ast::{Grammar, NonTerminalId, Rule, Symbol};
39
40// ─────────────────────────────────────────────────────────────────────────────
41// Public error type
42// ─────────────────────────────────────────────────────────────────────────────
43
44/// Errors arising from compiling a regex pattern into a Grammar.
45#[derive(Debug, Clone, PartialEq)]
46pub enum RegexCompileError {
47    /// The regex pattern has a syntax error.
48    InvalidSyntax(String),
49    /// The pattern uses a feature not supported by this compiler.
50    UnsupportedFeature(String),
51    /// The DFA state count or NFA expansion exceeded an internal limit.
52    DepthExceeded {
53        /// The exceeded limit.
54        limit: usize,
55    },
56    /// The pattern is the empty string.
57    EmptyPattern,
58    /// The pattern contains invalid UTF-8 byte sequences where UTF-8 is required.
59    InvalidUtf8(String),
60}
61
62impl std::fmt::Display for RegexCompileError {
63    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64        match self {
65            Self::InvalidSyntax(msg) => write!(f, "regex syntax error: {msg}"),
66            Self::UnsupportedFeature(feat) => {
67                write!(f, "unsupported regex feature: {feat}")
68            }
69            Self::DepthExceeded { limit } => {
70                write!(f, "regex complexity limit exceeded (limit: {limit})")
71            }
72            Self::EmptyPattern => write!(f, "regex pattern is empty"),
73            Self::InvalidUtf8(msg) => write!(f, "invalid UTF-8 in regex pattern: {msg}"),
74        }
75    }
76}
77
78impl std::error::Error for RegexCompileError {}
79
80// ─────────────────────────────────────────────────────────────────────────────
81// ByteSet — 256-bit bitset for byte ranges
82// ─────────────────────────────────────────────────────────────────────────────
83
84/// A dense bitset covering all 256 possible byte values.
85///
86/// Stored as 4 × u64 words (256 bits total).  Bit `b` in word `b >> 6` at
87/// position `b & 63`.
88#[derive(Debug, Clone, PartialEq, Eq, Hash, Default)]
89struct ByteSet([u64; 4]);
90
91impl ByteSet {
92    /// Empty set (no bytes).
93    fn empty() -> Self {
94        Self([0u64; 4])
95    }
96
97    /// Full set (all 256 bytes).
98    fn full() -> Self {
99        Self([u64::MAX; 4])
100    }
101
102    /// Set containing all bytes except `\n` (0x0A). Used for `.`.
103    fn any_except_newline() -> Self {
104        let mut s = Self::full();
105        s.remove(b'\n');
106        s
107    }
108
109    /// Set a single byte.
110    fn insert(&mut self, b: u8) {
111        let word = (b >> 6) as usize;
112        let bit = b & 63;
113        self.0[word] |= 1u64 << bit;
114    }
115
116    /// Remove a single byte.
117    fn remove(&mut self, b: u8) {
118        let word = (b >> 6) as usize;
119        let bit = b & 63;
120        self.0[word] &= !(1u64 << bit);
121    }
122
123    /// Test whether byte `b` is in this set.
124    fn contains(&self, b: u8) -> bool {
125        let word = (b >> 6) as usize;
126        let bit = b & 63;
127        self.0[word] & (1u64 << bit) != 0
128    }
129
130    /// Boolean complement of this set (all bytes NOT in self).
131    fn complement(&self) -> Self {
132        Self([!self.0[0], !self.0[1], !self.0[2], !self.0[3]])
133    }
134
135    /// Union of two byte sets.
136    fn union(&self, other: &Self) -> Self {
137        Self([
138            self.0[0] | other.0[0],
139            self.0[1] | other.0[1],
140            self.0[2] | other.0[2],
141            self.0[3] | other.0[3],
142        ])
143    }
144
145    /// Iterate over all bytes in the set.
146    fn iter(&self) -> impl Iterator<Item = u8> + '_ {
147        (0u8..=255u8).filter(|&b| self.contains(b))
148    }
149
150    /// Return true if the set is empty.
151    #[allow(dead_code)]
152    fn is_empty(&self) -> bool {
153        self.0 == [0u64; 4]
154    }
155}
156
157// ─────────────────────────────────────────────────────────────────────────────
158// NFA representation (Thompson construction)
159// ─────────────────────────────────────────────────────────────────────────────
160
161/// Maximum number of NFA states allowed (guards against pathological patterns).
162const MAX_NFA_STATES: usize = 16_384;
163
164/// Maximum number of DFA states allowed.
165const MAX_DFA_STATES: usize = 2048;
166
167/// Maximum repetition count expansion limit.
168const MAX_REPETITION: usize = 64;
169
170/// One state in the NFA.
171#[derive(Debug, Clone)]
172struct NfaState {
173    /// Labeled (byte-set) transitions: (label, target_state_id).
174    transitions: Vec<(ByteSet, usize)>,
175    /// Epsilon transitions to other state ids.
176    epsilon: Vec<usize>,
177    /// Whether this is an accept state.
178    is_accept: bool,
179}
180
181impl NfaState {
182    fn new() -> Self {
183        Self {
184            transitions: Vec::new(),
185            epsilon: Vec::new(),
186            is_accept: false,
187        }
188    }
189}
190
191/// An NFA fragment returned by Thompson construction sub-routines.
192/// `start` is the entry state id; `end` is the single accepting state id.
193struct NfaFrag {
194    start: usize,
195    end: usize,
196}
197
198/// The full NFA builder — holds all states and a counter for fresh ids.
199struct Nfa {
200    states: Vec<NfaState>,
201}
202
203impl Nfa {
204    fn new() -> Self {
205        Self { states: Vec::new() }
206    }
207
208    /// Allocate a fresh NFA state and return its id.
209    fn alloc(&mut self) -> Result<usize, RegexCompileError> {
210        if self.states.len() >= MAX_NFA_STATES {
211            return Err(RegexCompileError::DepthExceeded {
212                limit: MAX_NFA_STATES,
213            });
214        }
215        let id = self.states.len();
216        self.states.push(NfaState::new());
217        Ok(id)
218    }
219
220    /// Add an epsilon transition from `from` → `to`.
221    fn add_epsilon(&mut self, from: usize, to: usize) {
222        self.states[from].epsilon.push(to);
223    }
224
225    /// Add a labeled transition from `from` →[label]→ `to`.
226    fn add_transition(&mut self, from: usize, label: ByteSet, to: usize) {
227        self.states[from].transitions.push((label, to));
228    }
229
230    /// Compute the ε-closure of a set of NFA states.
231    fn epsilon_closure(&self, seeds: impl IntoIterator<Item = usize>) -> BTreeSet<usize> {
232        let mut closure: BTreeSet<usize> = BTreeSet::new();
233        let mut worklist: VecDeque<usize> = VecDeque::new();
234
235        for s in seeds {
236            if closure.insert(s) {
237                worklist.push_back(s);
238            }
239        }
240
241        while let Some(state) = worklist.pop_front() {
242            for &target in &self.states[state].epsilon {
243                if closure.insert(target) {
244                    worklist.push_back(target);
245                }
246            }
247        }
248
249        closure
250    }
251
252    /// Build Thompson fragment for a single ByteSet (character class or literal).
253    fn build_byte_set(&mut self, label: ByteSet) -> Result<NfaFrag, RegexCompileError> {
254        let start = self.alloc()?;
255        let end = self.alloc()?;
256        self.add_transition(start, label, end);
257        Ok(NfaFrag { start, end })
258    }
259
260    /// Concatenate two fragments: `a` · `b`.
261    fn build_concat(&mut self, a: NfaFrag, b: NfaFrag) -> NfaFrag {
262        // Connect end(a) to start(b) via epsilon.
263        self.add_epsilon(a.end, b.start);
264        NfaFrag {
265            start: a.start,
266            end: b.end,
267        }
268    }
269
270    /// Alternation: `a | b`.
271    fn build_alternation(&mut self, a: NfaFrag, b: NfaFrag) -> Result<NfaFrag, RegexCompileError> {
272        let start = self.alloc()?;
273        let end = self.alloc()?;
274        self.add_epsilon(start, a.start);
275        self.add_epsilon(start, b.start);
276        self.add_epsilon(a.end, end);
277        self.add_epsilon(b.end, end);
278        Ok(NfaFrag { start, end })
279    }
280
281    /// Kleene star: `a*`.
282    fn build_star(&mut self, a: NfaFrag) -> Result<NfaFrag, RegexCompileError> {
283        let start = self.alloc()?;
284        let end = self.alloc()?;
285        // start → a.start (enter loop)
286        self.add_epsilon(start, a.start);
287        // start → end (skip entirely)
288        self.add_epsilon(start, end);
289        // a.end → a.start (repeat)
290        self.add_epsilon(a.end, a.start);
291        // a.end → end (exit)
292        self.add_epsilon(a.end, end);
293        Ok(NfaFrag { start, end })
294    }
295
296    /// Plus: `a+` = `a · a*`.
297    fn build_plus(&mut self, a: NfaFrag) -> Result<NfaFrag, RegexCompileError> {
298        // We need two copies; just wire up manually: start→a.start, a.end loops.
299        let loop_start = self.alloc()?;
300        let loop_end = self.alloc()?;
301        // After completing `a` once, we can repeat or exit.
302        self.add_epsilon(a.end, loop_start);
303        // loop_start → a.start (repeat)
304        self.add_epsilon(loop_start, a.start);
305        // loop_start → loop_end (exit)
306        self.add_epsilon(loop_start, loop_end);
307        Ok(NfaFrag {
308            start: a.start,
309            end: loop_end,
310        })
311    }
312
313    /// Optional: `a?`.
314    fn build_optional(&mut self, a: NfaFrag) -> Result<NfaFrag, RegexCompileError> {
315        let start = self.alloc()?;
316        let end = self.alloc()?;
317        self.add_epsilon(start, a.start);
318        self.add_epsilon(start, end);
319        self.add_epsilon(a.end, end);
320        Ok(NfaFrag { start, end })
321    }
322}
323
324// ─────────────────────────────────────────────────────────────────────────────
325// Regex AST
326// ─────────────────────────────────────────────────────────────────────────────
327
328/// Internal regex AST node.
329#[derive(Debug, Clone)]
330enum RegexNode {
331    /// Matches a set of bytes (single step).
332    ByteClass(ByteSet),
333    /// Concatenation of sub-expressions.
334    Concat(Vec<RegexNode>),
335    /// Alternation of sub-expressions.
336    Alternation(Vec<RegexNode>),
337    /// Kleene star: zero or more repetitions.
338    Star(Box<RegexNode>),
339    /// Plus: one or more repetitions.
340    Plus(Box<RegexNode>),
341    /// Optional: zero or one repetition.
342    Optional(Box<RegexNode>),
343    /// Counted exact: exactly `n` repetitions.
344    CountedExact(Box<RegexNode>, usize),
345    /// Counted range: `n` to `m` repetitions (or unbounded if m = None → `n,`).
346    CountedRange(Box<RegexNode>, usize, Option<usize>),
347    /// Empty: matches the empty string (useful as a base case).
348    Empty,
349}
350
351// ─────────────────────────────────────────────────────────────────────────────
352// Regex parser
353// ─────────────────────────────────────────────────────────────────────────────
354
355/// Parser state for the regex string.
356struct RegexParser<'a> {
357    input: &'a [u8],
358    pos: usize,
359}
360
361impl<'a> RegexParser<'a> {
362    fn new(input: &'a str) -> Self {
363        Self {
364            input: input.as_bytes(),
365            pos: 0,
366        }
367    }
368
369    fn peek(&self) -> Option<u8> {
370        self.input.get(self.pos).copied()
371    }
372
373    fn advance(&mut self) -> Option<u8> {
374        let b = self.input.get(self.pos).copied()?;
375        self.pos += 1;
376        Some(b)
377    }
378
379    fn expect(&mut self, expected: u8) -> Result<(), RegexCompileError> {
380        match self.peek() {
381            Some(b) if b == expected => {
382                self.pos += 1;
383                Ok(())
384            }
385            Some(b) => Err(RegexCompileError::InvalidSyntax(format!(
386                "expected '{}' at position {}, got '{}'",
387                expected as char, self.pos, b as char
388            ))),
389            None => Err(RegexCompileError::InvalidSyntax(format!(
390                "expected '{}' at position {} but got end of pattern",
391                expected as char, self.pos
392            ))),
393        }
394    }
395
396    fn is_at_end(&self) -> bool {
397        self.pos >= self.input.len()
398    }
399
400    /// Top-level: parse an alternation expression.
401    fn parse_alternation(&mut self) -> Result<RegexNode, RegexCompileError> {
402        let mut branches: Vec<RegexNode> = Vec::new();
403        branches.push(self.parse_concat()?);
404        while self.peek() == Some(b'|') {
405            self.pos += 1; // consume '|'
406            branches.push(self.parse_concat()?);
407        }
408        if branches.len() == 1 {
409            Ok(branches.remove(0))
410        } else {
411            Ok(RegexNode::Alternation(branches))
412        }
413    }
414
415    /// Parse a concatenation of atoms.
416    fn parse_concat(&mut self) -> Result<RegexNode, RegexCompileError> {
417        let mut atoms: Vec<RegexNode> = Vec::new();
418        loop {
419            match self.peek() {
420                None | Some(b')') | Some(b'|') => break,
421                _ => {
422                    let atom = self.parse_quantified_atom()?;
423                    match atom {
424                        RegexNode::Empty => {}
425                        other => atoms.push(other),
426                    }
427                }
428            }
429        }
430        if atoms.is_empty() {
431            Ok(RegexNode::Empty)
432        } else if atoms.len() == 1 {
433            Ok(atoms.remove(0))
434        } else {
435            Ok(RegexNode::Concat(atoms))
436        }
437    }
438
439    /// Parse an atom followed by an optional quantifier.
440    fn parse_quantified_atom(&mut self) -> Result<RegexNode, RegexCompileError> {
441        let atom = self.parse_atom()?;
442        match self.peek() {
443            Some(b'*') => {
444                self.pos += 1;
445                Ok(RegexNode::Star(Box::new(atom)))
446            }
447            Some(b'+') => {
448                self.pos += 1;
449                Ok(RegexNode::Plus(Box::new(atom)))
450            }
451            Some(b'?') => {
452                self.pos += 1;
453                Ok(RegexNode::Optional(Box::new(atom)))
454            }
455            Some(b'{') => self.parse_counted_quantifier(atom),
456            _ => Ok(atom),
457        }
458    }
459
460    /// Parse `{n}`, `{n,}`, or `{n,m}` quantifier.
461    fn parse_counted_quantifier(
462        &mut self,
463        atom: RegexNode,
464    ) -> Result<RegexNode, RegexCompileError> {
465        self.pos += 1; // consume '{'
466        let n = self.parse_decimal_number()?;
467        match self.peek() {
468            Some(b'}') => {
469                self.pos += 1;
470                if n > MAX_REPETITION {
471                    return Err(RegexCompileError::DepthExceeded {
472                        limit: MAX_REPETITION,
473                    });
474                }
475                Ok(RegexNode::CountedExact(Box::new(atom), n))
476            }
477            Some(b',') => {
478                self.pos += 1; // consume ','
479                match self.peek() {
480                    Some(b'}') => {
481                        self.pos += 1;
482                        if n > MAX_REPETITION {
483                            return Err(RegexCompileError::DepthExceeded {
484                                limit: MAX_REPETITION,
485                            });
486                        }
487                        Ok(RegexNode::CountedRange(Box::new(atom), n, None))
488                    }
489                    _ => {
490                        let m = self.parse_decimal_number()?;
491                        self.expect(b'}')?;
492                        if n > m {
493                            return Err(RegexCompileError::InvalidSyntax(format!(
494                                "{{n,m}} quantifier has n={n} > m={m}"
495                            )));
496                        }
497                        if m > MAX_REPETITION {
498                            return Err(RegexCompileError::DepthExceeded {
499                                limit: MAX_REPETITION,
500                            });
501                        }
502                        Ok(RegexNode::CountedRange(Box::new(atom), n, Some(m)))
503                    }
504                }
505            }
506            Some(other) => Err(RegexCompileError::InvalidSyntax(format!(
507                "unexpected '{other}' inside {{}} quantifier at position {}",
508                self.pos
509            ))),
510            None => Err(RegexCompileError::InvalidSyntax(
511                "unterminated '{' quantifier".to_string(),
512            )),
513        }
514    }
515
516    /// Parse a decimal integer.
517    fn parse_decimal_number(&mut self) -> Result<usize, RegexCompileError> {
518        let start = self.pos;
519        while matches!(self.peek(), Some(b'0'..=b'9')) {
520            self.pos += 1;
521        }
522        if self.pos == start {
523            return Err(RegexCompileError::InvalidSyntax(format!(
524                "expected decimal number at position {start}"
525            )));
526        }
527        let digits = &self.input[start..self.pos];
528        // SAFETY: we verified all bytes are ASCII digits.
529        let s = std::str::from_utf8(digits)
530            .map_err(|e| RegexCompileError::InvalidUtf8(format!("non-UTF8 in decimal: {e}")))?;
531        s.parse::<usize>().map_err(|e| {
532            RegexCompileError::InvalidSyntax(format!("overflow in decimal number: {e}"))
533        })
534    }
535
536    /// Parse a single atom: literal, escape, class, group, or anchor.
537    fn parse_atom(&mut self) -> Result<RegexNode, RegexCompileError> {
538        match self.peek() {
539            Some(b'^') => {
540                // Anchor: silently ignore.
541                self.pos += 1;
542                Ok(RegexNode::Empty)
543            }
544            Some(b'$') => {
545                // Anchor: silently ignore.
546                self.pos += 1;
547                Ok(RegexNode::Empty)
548            }
549            Some(b'.') => {
550                self.pos += 1;
551                Ok(RegexNode::ByteClass(ByteSet::any_except_newline()))
552            }
553            Some(b'[') => {
554                self.pos += 1;
555                self.parse_char_class()
556            }
557            Some(b'(') => {
558                self.pos += 1;
559                self.parse_group()
560            }
561            Some(b'\\') => {
562                self.pos += 1;
563                self.parse_escape()
564            }
565            Some(b) => {
566                self.pos += 1;
567                let mut set = ByteSet::empty();
568                set.insert(b);
569                Ok(RegexNode::ByteClass(set))
570            }
571            None => Err(RegexCompileError::InvalidSyntax(
572                "unexpected end of pattern in atom".to_string(),
573            )),
574        }
575    }
576
577    /// Parse a character class `[...]`.
578    fn parse_char_class(&mut self) -> Result<RegexNode, RegexCompileError> {
579        let mut set = ByteSet::empty();
580        let negated = if self.peek() == Some(b'^') {
581            self.pos += 1;
582            true
583        } else {
584            false
585        };
586
587        // First char can be `]` without closing (treated as literal `]`).
588        let mut first = true;
589        loop {
590            match self.peek() {
591                None => {
592                    return Err(RegexCompileError::InvalidSyntax(
593                        "unterminated character class '['".to_string(),
594                    ));
595                }
596                Some(b']') if !first => {
597                    self.pos += 1;
598                    break;
599                }
600                Some(b'\\') => {
601                    self.pos += 1;
602                    let escaped_set = self.parse_escape_to_set()?;
603                    set = set.union(&escaped_set);
604                }
605                Some(b) => {
606                    self.pos += 1;
607                    // Check for range `x-y`.
608                    if self.peek() == Some(b'-') && self.input.get(self.pos + 1) != Some(&b']') {
609                        self.pos += 1; // consume '-'
610                        match self.peek() {
611                            Some(end_b) => {
612                                self.pos += 1;
613                                if end_b < b {
614                                    return Err(RegexCompileError::InvalidSyntax(format!(
615                                        "character class range end '{end_b}' < start '{b}'"
616                                    )));
617                                }
618                                for c in b..=end_b {
619                                    set.insert(c);
620                                }
621                            }
622                            None => {
623                                return Err(RegexCompileError::InvalidSyntax(
624                                    "unterminated character class range".to_string(),
625                                ));
626                            }
627                        }
628                    } else {
629                        set.insert(b);
630                    }
631                }
632            }
633            first = false;
634        }
635
636        if negated {
637            set = set.complement();
638        }
639
640        Ok(RegexNode::ByteClass(set))
641    }
642
643    /// Parse a group `(...)`.
644    fn parse_group(&mut self) -> Result<RegexNode, RegexCompileError> {
645        // Check for special group prefixes.
646        if self.peek() == Some(b'?') {
647            // Look ahead to determine the kind.
648            match self.input.get(self.pos + 1) {
649                Some(b'=') | Some(b'!') => {
650                    return Err(RegexCompileError::UnsupportedFeature(
651                        "lookahead assertions (?=...) and (?!...) are not supported".to_string(),
652                    ));
653                }
654                Some(b'<') => {
655                    // Could be lookbehind (?<=...) / (?<!...) or named group (?<name>...).
656                    match self.input.get(self.pos + 2) {
657                        Some(b'=') | Some(b'!') => {
658                            return Err(RegexCompileError::UnsupportedFeature(
659                                "lookbehind assertions (?<=...) and (?<!...) are not supported"
660                                    .to_string(),
661                            ));
662                        }
663                        _ => {
664                            return Err(RegexCompileError::UnsupportedFeature(
665                                "named groups (?<name>...) are not supported".to_string(),
666                            ));
667                        }
668                    }
669                }
670                Some(b'P') => {
671                    return Err(RegexCompileError::UnsupportedFeature(
672                        "named groups (?P<name>...) are not supported".to_string(),
673                    ));
674                }
675                Some(b':') => {
676                    // Non-capturing group `(?:...)` — consume the `?:` prefix and proceed.
677                    self.pos += 2;
678                }
679                _ => {
680                    return Err(RegexCompileError::UnsupportedFeature(format!(
681                        "unsupported group type starting with '(?{}' at position {}",
682                        self.input
683                            .get(self.pos + 1)
684                            .map(|&b| b as char)
685                            .unwrap_or('?'),
686                        self.pos
687                    )));
688                }
689            }
690        }
691
692        let inner = self.parse_alternation()?;
693        self.expect(b')')?;
694        Ok(inner)
695    }
696
697    /// Parse an escape sequence at the current position (after consuming `\`).
698    fn parse_escape(&mut self) -> Result<RegexNode, RegexCompileError> {
699        let set = self.parse_escape_to_set()?;
700        Ok(RegexNode::ByteClass(set))
701    }
702
703    /// Parse an escape sequence and return its ByteSet.
704    fn parse_escape_to_set(&mut self) -> Result<ByteSet, RegexCompileError> {
705        match self.advance() {
706            None => Err(RegexCompileError::InvalidSyntax(
707                "trailing backslash in pattern".to_string(),
708            )),
709            Some(b'd') => Ok(digit_set()),
710            Some(b'D') => Ok(digit_set().complement()),
711            Some(b'w') => Ok(word_set()),
712            Some(b'W') => Ok(word_set().complement()),
713            Some(b's') => Ok(space_set()),
714            Some(b'S') => Ok(space_set().complement()),
715            Some(b'n') => {
716                let mut s = ByteSet::empty();
717                s.insert(b'\n');
718                Ok(s)
719            }
720            Some(b'r') => {
721                let mut s = ByteSet::empty();
722                s.insert(b'\r');
723                Ok(s)
724            }
725            Some(b't') => {
726                let mut s = ByteSet::empty();
727                s.insert(b'\t');
728                Ok(s)
729            }
730            Some(b) if is_meta_escapable(b) => {
731                let mut s = ByteSet::empty();
732                s.insert(b);
733                Ok(s)
734            }
735            Some(b'1'..=b'9') => Err(RegexCompileError::UnsupportedFeature(
736                "backreferences (\\1, \\2, ...) are not supported".to_string(),
737            )),
738            Some(b'p') => Err(RegexCompileError::UnsupportedFeature(
739                "Unicode properties (\\p{...}) are not supported".to_string(),
740            )),
741            Some(other) => Err(RegexCompileError::InvalidSyntax(format!(
742                "unknown escape sequence '\\{}'",
743                other as char
744            ))),
745        }
746    }
747}
748
749// ─────────────────────────────────────────────────────────────────────────────
750// Character class helpers
751// ─────────────────────────────────────────────────────────────────────────────
752
753/// `\d`: ASCII decimal digits 0–9.
754fn digit_set() -> ByteSet {
755    let mut s = ByteSet::empty();
756    for b in b'0'..=b'9' {
757        s.insert(b);
758    }
759    s
760}
761
762/// `\w`: word characters: `[A-Za-z0-9_]`.
763fn word_set() -> ByteSet {
764    let mut s = ByteSet::empty();
765    for b in b'A'..=b'Z' {
766        s.insert(b);
767    }
768    for b in b'a'..=b'z' {
769        s.insert(b);
770    }
771    for b in b'0'..=b'9' {
772        s.insert(b);
773    }
774    s.insert(b'_');
775    s
776}
777
778/// `\s`: whitespace characters: space, `\t`, `\n`, `\r`, `\x0B`, `\x0C`.
779fn space_set() -> ByteSet {
780    let mut s = ByteSet::empty();
781    s.insert(b' ');
782    s.insert(b'\t');
783    s.insert(b'\n');
784    s.insert(b'\r');
785    s.insert(0x0B); // vertical tab
786    s.insert(0x0C); // form feed
787    s
788}
789
790/// Return true if `b` is a metacharacter that may be escaped with `\`.
791fn is_meta_escapable(b: u8) -> bool {
792    matches!(
793        b,
794        b'.' | b'\\'
795            | b'['
796            | b']'
797            | b'('
798            | b')'
799            | b'*'
800            | b'+'
801            | b'?'
802            | b'{'
803            | b'}'
804            | b'|'
805            | b'^'
806            | b'$'
807            | b'0'
808    )
809}
810
811// ─────────────────────────────────────────────────────────────────────────────
812// Thompson NFA construction from RegexNode
813// ─────────────────────────────────────────────────────────────────────────────
814
815/// Recursively build an NFA fragment for the given [`RegexNode`].
816fn build_nfa_frag(nfa: &mut Nfa, node: &RegexNode) -> Result<NfaFrag, RegexCompileError> {
817    match node {
818        RegexNode::Empty => {
819            // Empty match: single state that is both start and end.
820            let s = nfa.alloc()?;
821            Ok(NfaFrag { start: s, end: s })
822        }
823        RegexNode::ByteClass(set) => nfa.build_byte_set(set.clone()),
824        RegexNode::Concat(nodes) => {
825            if nodes.is_empty() {
826                let s = nfa.alloc()?;
827                return Ok(NfaFrag { start: s, end: s });
828            }
829            let mut frag = build_nfa_frag(nfa, &nodes[0])?;
830            for node in &nodes[1..] {
831                let next = build_nfa_frag(nfa, node)?;
832                frag = nfa.build_concat(frag, next);
833            }
834            Ok(frag)
835        }
836        RegexNode::Alternation(nodes) => {
837            if nodes.is_empty() {
838                let s = nfa.alloc()?;
839                return Ok(NfaFrag { start: s, end: s });
840            }
841            let mut frag = build_nfa_frag(nfa, &nodes[0])?;
842            for node in &nodes[1..] {
843                let next = build_nfa_frag(nfa, node)?;
844                frag = nfa.build_alternation(frag, next)?;
845            }
846            Ok(frag)
847        }
848        RegexNode::Star(inner) => {
849            let inner_frag = build_nfa_frag(nfa, inner)?;
850            nfa.build_star(inner_frag)
851        }
852        RegexNode::Plus(inner) => {
853            let inner_frag = build_nfa_frag(nfa, inner)?;
854            nfa.build_plus(inner_frag)
855        }
856        RegexNode::Optional(inner) => {
857            let inner_frag = build_nfa_frag(nfa, inner)?;
858            nfa.build_optional(inner_frag)
859        }
860        RegexNode::CountedExact(inner, n) => {
861            // Expand to n concatenated copies.
862            if *n == 0 {
863                let s = nfa.alloc()?;
864                return Ok(NfaFrag { start: s, end: s });
865            }
866            let first = build_nfa_frag(nfa, inner)?;
867            let mut frag = first;
868            for _ in 1..*n {
869                let next = build_nfa_frag(nfa, inner)?;
870                frag = nfa.build_concat(frag, next);
871            }
872            Ok(frag)
873        }
874        RegexNode::CountedRange(inner, n, m_opt) => {
875            // Mandatory part: n copies.
876            // Optional part: (m-n) optional copies (or unlimited if m = None).
877            if let Some(m) = m_opt {
878                // Build n mandatory copies.
879                let mandatory = if *n == 0 {
880                    let s = nfa.alloc()?;
881                    NfaFrag { start: s, end: s }
882                } else {
883                    let first = build_nfa_frag(nfa, inner)?;
884                    let mut frag = first;
885                    for _ in 1..*n {
886                        let next = build_nfa_frag(nfa, inner)?;
887                        frag = nfa.build_concat(frag, next);
888                    }
889                    frag
890                };
891
892                // Build (m - n) optional copies.
893                if *m == *n {
894                    return Ok(mandatory);
895                }
896                let optional_count = m - n;
897                let first_opt = build_nfa_frag(nfa, inner)?;
898                let mut opt_frag = nfa.build_optional(first_opt)?;
899                for _ in 1..optional_count {
900                    let next = build_nfa_frag(nfa, inner)?;
901                    let next_opt = nfa.build_optional(next)?;
902                    opt_frag = nfa.build_concat(opt_frag, next_opt);
903                }
904                Ok(nfa.build_concat(mandatory, opt_frag))
905            } else {
906                // `{n,}`: n mandatory copies followed by `*`.
907                let mandatory = if *n == 0 {
908                    let s = nfa.alloc()?;
909                    NfaFrag { start: s, end: s }
910                } else {
911                    let first = build_nfa_frag(nfa, inner)?;
912                    let mut frag = first;
913                    for _ in 1..*n {
914                        let next = build_nfa_frag(nfa, inner)?;
915                        frag = nfa.build_concat(frag, next);
916                    }
917                    frag
918                };
919                let star_inner = build_nfa_frag(nfa, inner)?;
920                let star_frag = nfa.build_star(star_inner)?;
921                Ok(nfa.build_concat(mandatory, star_frag))
922            }
923        }
924    }
925}
926
927// ─────────────────────────────────────────────────────────────────────────────
928// Subset DFA construction
929// ─────────────────────────────────────────────────────────────────────────────
930
931/// One state in the DFA.
932struct DfaState {
933    /// Transitions: byte value → DFA state id.
934    transitions: HashMap<u8, usize>,
935    /// Whether this DFA state is an accepting state.
936    is_accept: bool,
937}
938
939/// Construct the subset DFA from the NFA.
940///
941/// Returns `(dfa_states, start_state_id)`.
942fn build_dfa(
943    nfa: &Nfa,
944    nfa_accept: usize,
945    nfa_start: usize,
946) -> Result<(Vec<DfaState>, usize), RegexCompileError> {
947    // Powerset construction: each DFA state is a frozenset (BTreeSet) of NFA state ids.
948    let start_closure = nfa.epsilon_closure([nfa_start]);
949    let start_is_accept = start_closure.contains(&nfa_accept);
950
951    let mut dfa_states: Vec<DfaState> = Vec::new();
952    // Map from NFA state set → DFA state index.
953    let mut set_to_dfa: HashMap<BTreeSet<usize>, usize> = HashMap::new();
954    let mut worklist: VecDeque<(BTreeSet<usize>, usize)> = VecDeque::new();
955
956    let start_idx = 0usize;
957    dfa_states.push(DfaState {
958        transitions: HashMap::new(),
959        is_accept: start_is_accept,
960    });
961    set_to_dfa.insert(start_closure.clone(), start_idx);
962    worklist.push_back((start_closure, start_idx));
963
964    while let Some((nfa_set, dfa_id)) = worklist.pop_front() {
965        // Collect all distinct byte transitions from this NFA state set.
966        // Build a mapping: byte → set of target NFA states.
967        let mut byte_targets: HashMap<u8, BTreeSet<usize>> = HashMap::new();
968
969        for &nfa_state in &nfa_set {
970            for (label, target) in &nfa.states[nfa_state].transitions {
971                for b in label.iter() {
972                    byte_targets.entry(b).or_default().insert(*target);
973                }
974            }
975        }
976
977        // For each unique byte b, compute ε-closure of the target set.
978        for (b, targets) in byte_targets {
979            let closure = nfa.epsilon_closure(targets);
980            if closure.is_empty() {
981                continue;
982            }
983
984            let next_dfa_id = if let Some(&existing) = set_to_dfa.get(&closure) {
985                existing
986            } else {
987                // Allocate new DFA state.
988                if dfa_states.len() >= MAX_DFA_STATES {
989                    return Err(RegexCompileError::DepthExceeded {
990                        limit: MAX_DFA_STATES,
991                    });
992                }
993                let new_id = dfa_states.len();
994                let is_accept = closure.contains(&nfa_accept);
995                dfa_states.push(DfaState {
996                    transitions: HashMap::new(),
997                    is_accept,
998                });
999                set_to_dfa.insert(closure.clone(), new_id);
1000                worklist.push_back((closure, new_id));
1001                new_id
1002            };
1003
1004            dfa_states[dfa_id].transitions.insert(b, next_dfa_id);
1005        }
1006    }
1007
1008    Ok((dfa_states, start_idx))
1009}
1010
1011// ─────────────────────────────────────────────────────────────────────────────
1012// DFA → Grammar
1013// ─────────────────────────────────────────────────────────────────────────────
1014
1015/// Convert a DFA into a [`Grammar`].
1016///
1017/// Each DFA state maps to a non-terminal `__regex_s{i}`.
1018/// For each transition `s →[b]→ t`, we emit:
1019///   `<__regex_s{s}> ::= Terminal([b]) <__regex_s{t}>`
1020/// For each accept state `s`, we emit an ε-production:
1021///   `<__regex_s{s}> ::=`
1022fn dfa_to_grammar(dfa_states: &[DfaState], start_idx: usize) -> Result<Grammar, RegexCompileError> {
1023    let num_states = dfa_states.len();
1024
1025    // We must pre-allocate all NT ids first, then set the start.
1026    // Grammar::new(start) takes a start id, but we build the NTs via alloc_nt.
1027    // To work around this: create grammar with a placeholder start=0, then
1028    // alloc NTs (which assigns ids 0, 1, ..., num_states-1), then set start.
1029    let mut grammar = Grammar::new(0);
1030
1031    // Allocate one NT per DFA state.
1032    let mut nt_ids: Vec<NonTerminalId> = Vec::with_capacity(num_states);
1033    for i in 0..num_states {
1034        let nt = grammar.alloc_nt(format!("__regex_s{i}"));
1035        nt_ids.push(nt);
1036    }
1037
1038    // Set the actual start symbol.
1039    grammar.start = nt_ids[start_idx];
1040
1041    // Emit rules.
1042    for (state_idx, dfa_state) in dfa_states.iter().enumerate() {
1043        let lhs_nt = nt_ids[state_idx];
1044
1045        // ε-production for accept states.
1046        if dfa_state.is_accept {
1047            grammar.add_rule(Rule::new(lhs_nt, vec![]));
1048        }
1049
1050        // Byte transition rules.
1051        // Group transitions by target state to consolidate, but since Grammar
1052        // only supports single-byte terminals, we emit one rule per byte.
1053        for (&byte_val, &target_idx) in &dfa_state.transitions {
1054            let target_nt = nt_ids[target_idx];
1055            grammar.add_rule(Rule::new(
1056                lhs_nt,
1057                vec![
1058                    Symbol::Terminal(vec![byte_val]),
1059                    Symbol::NonTerminal(target_nt),
1060                ],
1061            ));
1062        }
1063    }
1064
1065    Ok(grammar)
1066}
1067
1068// ─────────────────────────────────────────────────────────────────────────────
1069// Public API
1070// ─────────────────────────────────────────────────────────────────────────────
1071
1072/// Compile a regex pattern string into a [`Grammar`].
1073///
1074/// The returned grammar is ready to be passed to
1075/// [`GrammarConstraint::new`](super::constraint::GrammarConstraint::new) for
1076/// constrained token generation with the given regex.
1077///
1078/// # Errors
1079///
1080/// - [`RegexCompileError::EmptyPattern`] — the pattern is the empty string
1081/// - [`RegexCompileError::InvalidSyntax`] — the regex has a syntax error
1082/// - [`RegexCompileError::UnsupportedFeature`] — a feature not supported by
1083///   this compiler was used (backreferences, lookahead, etc.)
1084/// - [`RegexCompileError::DepthExceeded`] — the DFA exceeded 2048 states or
1085///   a counted quantifier exceeded 64
1086/// - [`RegexCompileError::InvalidUtf8`] — internal (should not occur for
1087///   well-formed Rust strings)
1088///
1089/// # Example
1090///
1091/// ```rust
1092/// use oxibonsai_runtime::grammar::compile_regex;
1093///
1094/// let grammar = compile_regex(r"\d{4}-\d{2}-\d{2}").expect("valid regex");
1095/// assert!(!grammar.rules.is_empty());
1096/// ```
1097pub fn compile_regex(pattern: &str) -> Result<Grammar, RegexCompileError> {
1098    if pattern.is_empty() {
1099        return Err(RegexCompileError::EmptyPattern);
1100    }
1101
1102    // ── Step 1: Parse regex → AST ────────────────────────────────────────────
1103    let mut parser = RegexParser::new(pattern);
1104    let ast = parser.parse_alternation()?;
1105    if !parser.is_at_end() {
1106        return Err(RegexCompileError::InvalidSyntax(format!(
1107            "unexpected character '{}' at position {} (unmatched ')'?)",
1108            parser.input[parser.pos] as char, parser.pos
1109        )));
1110    }
1111
1112    // ── Step 2: Build Thompson NFA ────────────────────────────────────────────
1113    let mut nfa = Nfa::new();
1114    let frag = build_nfa_frag(&mut nfa, &ast)?;
1115
1116    // Mark the NFA accept state.
1117    nfa.states[frag.end].is_accept = true;
1118
1119    // ── Step 3: Subset DFA construction ──────────────────────────────────────
1120    let (dfa_states, start_idx) = build_dfa(&nfa, frag.end, frag.start)?;
1121
1122    // ── Step 4: DFA → Grammar ────────────────────────────────────────────────
1123    let grammar = dfa_to_grammar(&dfa_states, start_idx)?;
1124
1125    Ok(grammar)
1126}
1127
1128// ─────────────────────────────────────────────────────────────────────────────
1129// Internal unit tests
1130// ─────────────────────────────────────────────────────────────────────────────
1131
1132#[cfg(test)]
1133mod tests {
1134    use super::*;
1135
1136    #[test]
1137    fn byte_set_insert_contains() {
1138        let mut s = ByteSet::empty();
1139        s.insert(b'A');
1140        assert!(s.contains(b'A'));
1141        assert!(!s.contains(b'B'));
1142    }
1143
1144    #[test]
1145    fn byte_set_complement() {
1146        let mut s = ByteSet::empty();
1147        s.insert(b'a');
1148        let c = s.complement();
1149        assert!(!c.contains(b'a'));
1150        assert!(c.contains(b'b'));
1151    }
1152
1153    #[test]
1154    fn byte_set_union() {
1155        let mut a = ByteSet::empty();
1156        a.insert(b'x');
1157        let mut b = ByteSet::empty();
1158        b.insert(b'y');
1159        let u = a.union(&b);
1160        assert!(u.contains(b'x'));
1161        assert!(u.contains(b'y'));
1162        assert!(!u.contains(b'z'));
1163    }
1164
1165    #[test]
1166    fn byte_set_any_except_newline_has_255_bytes() {
1167        let s = ByteSet::any_except_newline();
1168        let count = s.iter().count();
1169        assert_eq!(count, 255);
1170        assert!(!s.contains(b'\n'));
1171    }
1172
1173    #[test]
1174    fn digit_set_is_ten_bytes() {
1175        let s = digit_set();
1176        let count = s.iter().count();
1177        assert_eq!(count, 10);
1178        for d in b'0'..=b'9' {
1179            assert!(s.contains(d));
1180        }
1181    }
1182
1183    #[test]
1184    fn word_set_contains_alnum_underscore() {
1185        let s = word_set();
1186        assert!(s.contains(b'A'));
1187        assert!(s.contains(b'z'));
1188        assert!(s.contains(b'5'));
1189        assert!(s.contains(b'_'));
1190        assert!(!s.contains(b'!'));
1191        assert!(!s.contains(b' '));
1192    }
1193
1194    #[test]
1195    fn space_set_contains_whitespace() {
1196        let s = space_set();
1197        assert!(s.contains(b' '));
1198        assert!(s.contains(b'\t'));
1199        assert!(s.contains(b'\n'));
1200        assert!(s.contains(b'\r'));
1201        assert!(!s.contains(b'a'));
1202    }
1203
1204    #[test]
1205    fn parser_literal_parses() {
1206        let mut p = RegexParser::new("abc");
1207        let node = p.parse_alternation().unwrap();
1208        assert!(matches!(node, RegexNode::Concat(_)));
1209    }
1210
1211    #[test]
1212    fn parser_alternation_parses() {
1213        let mut p = RegexParser::new("a|b");
1214        let node = p.parse_alternation().unwrap();
1215        assert!(matches!(node, RegexNode::Alternation(_)));
1216    }
1217
1218    #[test]
1219    fn parser_counted_exact_parses() {
1220        let mut p = RegexParser::new("a{3}");
1221        let node = p.parse_alternation().unwrap();
1222        assert!(matches!(node, RegexNode::CountedExact(_, 3)));
1223    }
1224
1225    #[test]
1226    fn parser_counted_range_parses() {
1227        let mut p = RegexParser::new("a{2,5}");
1228        let node = p.parse_alternation().unwrap();
1229        assert!(matches!(node, RegexNode::CountedRange(_, 2, Some(5))));
1230    }
1231
1232    #[test]
1233    fn parser_unmatched_paren_fails() {
1234        let result = compile_regex("(abc");
1235        assert!(matches!(result, Err(RegexCompileError::InvalidSyntax(_))));
1236    }
1237}