Skip to main content

oxibonsai_runtime/constrained_decoding/
regex.rs

1//! Regex-based [`TokenConstraint`] implementation backed by a minimal NFA
2//! engine.
3//!
4//! This sub-module hosts the NFA compiler/simulator (`NfaState`, `RegexNfa`,
5//! `Fragment`) plus the public [`RegexConstraint`] type.
6
7use super::error_trait::{ConstraintError, TokenConstraint};
8
9// ─────────────────────────────────────────────────────────────────────────────
10// Minimal NFA-based regex engine
11// ─────────────────────────────────────────────────────────────────────────────
12
13/// One NFA state.
14#[derive(Debug, Clone)]
15pub(super) enum NfaState {
16    /// Matches a specific character then transitions to `next`.
17    Literal(char, usize),
18    /// Matches any character then transitions to `next`.
19    Any(usize),
20    /// ε-transition fork (used for `|`, `?`, `*`, `+`).
21    Split(usize, usize),
22    /// Character class `[...]`.  `negated` inverts the match.
23    Class {
24        chars: Vec<char>,
25        ranges: Vec<(char, char)>,
26        negated: bool,
27        next: usize,
28    },
29    /// The accepting state.
30    Accept,
31}
32
33/// Simple NFA compiled from a regex pattern.
34#[derive(Debug, Clone)]
35pub(super) struct RegexNfa {
36    states: Vec<NfaState>,
37    start: usize,
38    accept_state: usize,
39}
40
41/// A fragment of NFA states returned by the compiler — holds start index and
42/// a list of "dangling" out-arrows that must be patched to the next fragment.
43pub(super) struct Fragment {
44    start: usize,
45    /// Indices of states whose outgoing arrow is "open" (needs patching).
46    outs: Vec<usize>,
47}
48
49impl RegexNfa {
50    /// Build an NFA from a regex pattern.
51    pub(super) fn from_pattern(pattern: &str) -> Result<Self, ConstraintError> {
52        let mut nfa = RegexNfa {
53            states: Vec::new(),
54            start: 0,
55            accept_state: 0,
56        };
57        let chars: Vec<char> = pattern.chars().collect();
58        let frag = nfa
59            .compile(&chars, 0)
60            .map_err(ConstraintError::InvalidPattern)?;
61        // Add accept state.
62        let accept = nfa.push(NfaState::Accept);
63        nfa.accept_state = accept;
64        nfa.patch(&frag.outs, accept);
65        nfa.start = frag.start;
66        Ok(nfa)
67    }
68
69    fn push(&mut self, state: NfaState) -> usize {
70        let idx = self.states.len();
71        self.states.push(state);
72        idx
73    }
74
75    /// Patch all dangling out-arrows in `outs` to point to `target`.
76    fn patch(&mut self, outs: &[usize], target: usize) {
77        for &idx in outs {
78            match &mut self.states[idx] {
79                NfaState::Literal(_, ref mut n)
80                | NfaState::Any(ref mut n)
81                | NfaState::Class {
82                    next: ref mut n, ..
83                } => *n = target,
84                NfaState::Split(ref mut a, ref mut b) => {
85                    // Patch every open slot (usize::MAX means "unset").
86                    if *a == usize::MAX {
87                        *a = target;
88                    }
89                    if *b == usize::MAX {
90                        *b = target;
91                    }
92                }
93                NfaState::Accept => {}
94            }
95        }
96    }
97
98    /// Recursive-descent compiler; returns a Fragment.
99    fn compile(&mut self, chars: &[char], mut pos: usize) -> Result<Fragment, String> {
100        // Parse a sequence of alternation alternatives: e1 | e2 | ...
101        let mut alt_frags: Vec<Fragment> = Vec::new();
102        let mut cur_frags: Vec<Fragment> = Vec::new();
103
104        while pos < chars.len() {
105            let ch = chars[pos];
106
107            // Handle alternation `|`
108            if ch == '|' {
109                let seq = Self::concat_fragments(&mut self.states, cur_frags);
110                alt_frags.push(seq);
111                cur_frags = Vec::new();
112                pos += 1;
113                continue;
114            }
115
116            // End of group
117            if ch == ')' {
118                break;
119            }
120
121            // Parse one atom (possibly followed by a quantifier)
122            let (atom, new_pos) = self.parse_atom(chars, pos)?;
123            pos = new_pos;
124
125            // Check for quantifier
126            let quantified = if pos < chars.len() {
127                match chars[pos] {
128                    '?' => {
129                        pos += 1;
130                        self.quantifier_optional(atom)
131                    }
132                    '*' => {
133                        pos += 1;
134                        self.quantifier_star(atom)
135                    }
136                    '+' => {
137                        pos += 1;
138                        self.quantifier_plus(atom)
139                    }
140                    _ => atom,
141                }
142            } else {
143                atom
144            };
145
146            cur_frags.push(quantified);
147        }
148
149        // Concatenate remaining sequence
150        let seq = Self::concat_fragments(&mut self.states, cur_frags);
151        alt_frags.push(seq);
152
153        // Build alternation if needed
154        let result = if alt_frags.len() == 1 {
155            alt_frags.remove(0)
156        } else {
157            self.alternation(alt_frags)
158        };
159
160        Ok(result)
161    }
162
163    /// Parse one atom starting at `pos`, return (Fragment, new_pos).
164    fn parse_atom(&mut self, chars: &[char], pos: usize) -> Result<(Fragment, usize), String> {
165        if pos >= chars.len() {
166            return Err("Unexpected end of pattern".to_string());
167        }
168        let ch = chars[pos];
169        match ch {
170            '(' => {
171                // Grouped sub-expression
172                let inner = self.compile(chars, pos + 1)?;
173                // Find matching ')'
174                let mut depth = 1usize;
175                let mut i = pos + 1;
176                while i < chars.len() {
177                    match chars[i] {
178                        '(' => depth += 1,
179                        ')' => {
180                            depth -= 1;
181                            if depth == 0 {
182                                break;
183                            }
184                        }
185                        '\\' => {
186                            i += 1;
187                        } // skip escaped
188                        _ => {}
189                    }
190                    i += 1;
191                }
192                let new_pos = if i < chars.len() && chars[i] == ')' {
193                    i + 1
194                } else {
195                    i
196                };
197                Ok((inner, new_pos))
198            }
199            '[' => {
200                let (frag, new_pos) = self.parse_class(chars, pos)?;
201                Ok((frag, new_pos))
202            }
203            '.' => {
204                let idx = self.push(NfaState::Any(usize::MAX));
205                Ok((
206                    Fragment {
207                        start: idx,
208                        outs: vec![idx],
209                    },
210                    pos + 1,
211                ))
212            }
213            '\\' => {
214                let (frag, new_pos) = self.parse_escape(chars, pos)?;
215                Ok((frag, new_pos))
216            }
217            _ if ch == '*' || ch == '+' || ch == '?' => {
218                Err(format!("Unexpected quantifier '{ch}' at position {pos}"))
219            }
220            _ => {
221                let idx = self.push(NfaState::Literal(ch, usize::MAX));
222                Ok((
223                    Fragment {
224                        start: idx,
225                        outs: vec![idx],
226                    },
227                    pos + 1,
228                ))
229            }
230        }
231    }
232
233    /// Parse a character class `[...]`.
234    fn parse_class(&mut self, chars: &[char], start: usize) -> Result<(Fragment, usize), String> {
235        // start points to '['
236        let mut pos = start + 1;
237        let negated = if pos < chars.len() && chars[pos] == '^' {
238            pos += 1;
239            true
240        } else {
241            false
242        };
243
244        let mut class_chars: Vec<char> = Vec::new();
245        let mut ranges: Vec<(char, char)> = Vec::new();
246
247        while pos < chars.len() && chars[pos] != ']' {
248            if chars[pos] == '\\' && pos + 1 < chars.len() {
249                // Escape inside class
250                let escaped = chars[pos + 1];
251                match escaped {
252                    'd' => ranges.push(('0', '9')),
253                    'w' => {
254                        ranges.push(('a', 'z'));
255                        ranges.push(('A', 'Z'));
256                        ranges.push(('0', '9'));
257                        class_chars.push('_');
258                    }
259                    's' => {
260                        class_chars.extend_from_slice(&[' ', '\t', '\n', '\r']);
261                    }
262                    _ => class_chars.push(escaped),
263                }
264                pos += 2;
265            } else if pos + 2 < chars.len() && chars[pos + 1] == '-' && chars[pos + 2] != ']' {
266                ranges.push((chars[pos], chars[pos + 2]));
267                pos += 3;
268            } else {
269                class_chars.push(chars[pos]);
270                pos += 1;
271            }
272        }
273
274        let new_pos = if pos < chars.len() && chars[pos] == ']' {
275            pos + 1
276        } else {
277            pos
278        };
279
280        let idx = self.push(NfaState::Class {
281            chars: class_chars,
282            ranges,
283            negated,
284            next: usize::MAX,
285        });
286        Ok((
287            Fragment {
288                start: idx,
289                outs: vec![idx],
290            },
291            new_pos,
292        ))
293    }
294
295    /// Parse a backslash escape at `pos` (e.g., `\d`, `\w`, `\s`).
296    fn parse_escape(&mut self, chars: &[char], pos: usize) -> Result<(Fragment, usize), String> {
297        if pos + 1 >= chars.len() {
298            return Err("Trailing backslash in pattern".to_string());
299        }
300        let escaped = chars[pos + 1];
301        let (class_chars, ranges): (Vec<char>, Vec<(char, char)>) = match escaped {
302            'd' => (vec![], vec![('0', '9')]),
303            'D' => {
304                // non-digit — represented as negated class [^0-9]
305                let idx = self.push(NfaState::Class {
306                    chars: vec![],
307                    ranges: vec![('0', '9')],
308                    negated: true,
309                    next: usize::MAX,
310                });
311                return Ok((
312                    Fragment {
313                        start: idx,
314                        outs: vec![idx],
315                    },
316                    pos + 2,
317                ));
318            }
319            'w' => (vec!['_'], vec![('a', 'z'), ('A', 'Z'), ('0', '9')]),
320            'W' => {
321                let idx = self.push(NfaState::Class {
322                    chars: vec!['_'],
323                    ranges: vec![('a', 'z'), ('A', 'Z'), ('0', '9')],
324                    negated: true,
325                    next: usize::MAX,
326                });
327                return Ok((
328                    Fragment {
329                        start: idx,
330                        outs: vec![idx],
331                    },
332                    pos + 2,
333                ));
334            }
335            's' => (vec![' ', '\t', '\n', '\r'], vec![]),
336            'S' => {
337                let idx = self.push(NfaState::Class {
338                    chars: vec![' ', '\t', '\n', '\r'],
339                    ranges: vec![],
340                    negated: true,
341                    next: usize::MAX,
342                });
343                return Ok((
344                    Fragment {
345                        start: idx,
346                        outs: vec![idx],
347                    },
348                    pos + 2,
349                ));
350            }
351            'n' => {
352                let idx = self.push(NfaState::Literal('\n', usize::MAX));
353                return Ok((
354                    Fragment {
355                        start: idx,
356                        outs: vec![idx],
357                    },
358                    pos + 2,
359                ));
360            }
361            'r' => {
362                let idx = self.push(NfaState::Literal('\r', usize::MAX));
363                return Ok((
364                    Fragment {
365                        start: idx,
366                        outs: vec![idx],
367                    },
368                    pos + 2,
369                ));
370            }
371            't' => {
372                let idx = self.push(NfaState::Literal('\t', usize::MAX));
373                return Ok((
374                    Fragment {
375                        start: idx,
376                        outs: vec![idx],
377                    },
378                    pos + 2,
379                ));
380            }
381            _ => {
382                // Treat as literal escape (e.g., `\.`)
383                let idx = self.push(NfaState::Literal(escaped, usize::MAX));
384                return Ok((
385                    Fragment {
386                        start: idx,
387                        outs: vec![idx],
388                    },
389                    pos + 2,
390                ));
391            }
392        };
393        let idx = self.push(NfaState::Class {
394            chars: class_chars,
395            ranges,
396            negated: false,
397            next: usize::MAX,
398        });
399        Ok((
400            Fragment {
401                start: idx,
402                outs: vec![idx],
403            },
404            pos + 2,
405        ))
406    }
407
408    // ── Quantifiers ──────────────────────────────────────────────────────────
409
410    /// `e?` — zero or one.
411    fn quantifier_optional(&mut self, frag: Fragment) -> Fragment {
412        let split = self.push(NfaState::Split(frag.start, usize::MAX));
413        let mut outs = frag.outs;
414        outs.push(split); // the second arm of Split is still open
415        Fragment { start: split, outs }
416    }
417
418    /// `e*` — zero or more.
419    fn quantifier_star(&mut self, frag: Fragment) -> Fragment {
420        let split = self.push(NfaState::Split(frag.start, usize::MAX));
421        // Patch all fragment outs back to the split (loop).
422        self.patch(&frag.outs, split);
423        Fragment {
424            start: split,
425            outs: vec![split],
426        }
427    }
428
429    /// `e+` — one or more.
430    fn quantifier_plus(&mut self, frag: Fragment) -> Fragment {
431        let split = self.push(NfaState::Split(frag.start, usize::MAX));
432        self.patch(&frag.outs, split);
433        Fragment {
434            start: frag.start,
435            outs: vec![split],
436        }
437    }
438
439    /// Build alternation from multiple fragments (`e1 | e2 | ...`).
440    fn alternation(&mut self, frags: Vec<Fragment>) -> Fragment {
441        if frags.is_empty() {
442            let split = self.push(NfaState::Split(usize::MAX, usize::MAX));
443            return Fragment {
444                start: split,
445                outs: vec![split],
446            };
447        }
448        let mut iter = frags.into_iter();
449        let mut current = iter.next().expect("non-empty checked above");
450        for next_frag in iter {
451            let split = self.push(NfaState::Split(current.start, next_frag.start));
452            let mut outs = current.outs;
453            outs.extend(next_frag.outs);
454            current = Fragment { start: split, outs };
455        }
456        current
457    }
458
459    /// Concatenate a sequence of fragments into one.
460    fn concat_fragments(states: &mut Vec<NfaState>, frags: Vec<Fragment>) -> Fragment {
461        if frags.is_empty() {
462            // ε-fragment: a split pointing nowhere used as a placeholder
463            let idx = states.len();
464            states.push(NfaState::Split(usize::MAX, usize::MAX));
465            return Fragment {
466                start: idx,
467                outs: vec![idx],
468            };
469        }
470        let mut iter = frags.into_iter();
471        let first = iter.next().expect("non-empty checked above");
472        iter.fold(first, |acc, next| {
473            // Patch all open outs of acc to point to start of next
474            for &idx in &acc.outs {
475                match &mut states[idx] {
476                    NfaState::Literal(_, ref mut n)
477                    | NfaState::Any(ref mut n)
478                    | NfaState::Class {
479                        next: ref mut n, ..
480                    } => {
481                        if *n == usize::MAX {
482                            *n = next.start;
483                        }
484                    }
485                    NfaState::Split(ref mut a, ref mut b) => {
486                        if *a == usize::MAX {
487                            *a = next.start;
488                        } else if *b == usize::MAX {
489                            *b = next.start;
490                        }
491                    }
492                    NfaState::Accept => {}
493                }
494            }
495            Fragment {
496                start: acc.start,
497                outs: next.outs,
498            }
499        })
500    }
501
502    // ── Simulation ───────────────────────────────────────────────────────────
503
504    /// Compute the ε-closure of a set of states.
505    fn epsilon_closure(&self, states: Vec<usize>) -> Vec<usize> {
506        let mut closure: Vec<usize> = Vec::new();
507        let mut stack = states;
508        let mut visited = std::collections::HashSet::new();
509        while let Some(s) = stack.pop() {
510            if s == usize::MAX || !visited.insert(s) {
511                continue;
512            }
513            closure.push(s);
514            if let Some(NfaState::Split(a, b)) = self.states.get(s) {
515                if *a != usize::MAX {
516                    stack.push(*a);
517                }
518                if *b != usize::MAX {
519                    stack.push(*b);
520                }
521            }
522        }
523        closure
524    }
525
526    /// Advance the NFA by consuming character `ch` from state set `states`.
527    fn step(&self, states: &[usize], ch: char) -> Vec<usize> {
528        let mut next = Vec::new();
529        for &s in states {
530            if s == usize::MAX {
531                continue;
532            }
533            if let Some(state) = self.states.get(s) {
534                match state {
535                    NfaState::Literal(c, n) => {
536                        if *c == ch && *n != usize::MAX {
537                            next.push(*n);
538                        }
539                    }
540                    NfaState::Any(n) => {
541                        if *n != usize::MAX {
542                            next.push(*n);
543                        }
544                    }
545                    NfaState::Class {
546                        chars,
547                        ranges,
548                        negated,
549                        next: n,
550                    } => {
551                        let matched = chars.contains(&ch)
552                            || ranges.iter().any(|&(lo, hi)| ch >= lo && ch <= hi);
553                        let effective = if *negated { !matched } else { matched };
554                        if effective && *n != usize::MAX {
555                            next.push(*n);
556                        }
557                    }
558                    NfaState::Split(_, _) | NfaState::Accept => {}
559                }
560            }
561        }
562        self.epsilon_closure(next)
563    }
564
565    /// Returns `true` if any of `states` is the accept state.
566    fn is_accepting(&self, states: &[usize]) -> bool {
567        states.contains(&self.accept_state)
568    }
569
570    /// Check whether `text` is fully matched by the NFA.
571    fn is_full_match(&self, text: &str) -> bool {
572        let initial = self.epsilon_closure(vec![self.start]);
573        let final_states = text.chars().fold(initial, |s, ch| self.step(&s, ch));
574        self.is_accepting(&final_states)
575    }
576}
577
578// ─────────────────────────────────────────────────────────────────────────────
579// RegexConstraint
580// ─────────────────────────────────────────────────────────────────────────────
581
582/// Constrains generation to strings that match a regular expression.
583///
584/// Uses a minimal NFA engine (no external crate). Supported syntax:
585/// - Literals, `.` (any char), `*`, `+`, `?`
586/// - Alternation `|`
587/// - Grouping `(...)`
588/// - Character classes `[abc]`, `[a-z]`, `[^x]`
589/// - Escapes: `\d`, `\D`, `\w`, `\W`, `\s`, `\S`, `\n`, `\r`, `\t`
590pub struct RegexConstraint {
591    pattern: String,
592    nfa: RegexNfa,
593    current_states: Vec<usize>,
594    matched_so_far: String,
595}
596
597impl RegexConstraint {
598    /// Build a new constraint from `pattern`.
599    pub fn new(pattern: &str) -> Result<Self, ConstraintError> {
600        let nfa = RegexNfa::from_pattern(pattern)?;
601        let current_states = nfa.epsilon_closure(vec![nfa.start]);
602        Ok(Self {
603            pattern: pattern.to_string(),
604            nfa,
605            current_states,
606            matched_so_far: String::new(),
607        })
608    }
609
610    /// Test whether `text` fully matches `pattern`.
611    pub fn is_match(pattern: &str, text: &str) -> bool {
612        match RegexNfa::from_pattern(pattern) {
613            Ok(nfa) => nfa.is_full_match(text),
614            Err(_) => false,
615        }
616    }
617
618    /// The text matched so far.
619    pub fn current_partial(&self) -> &str {
620        &self.matched_so_far
621    }
622
623    /// Check whether character `ch` would keep the NFA in a live (non-dead) state.
624    pub fn char_is_valid(&self, ch: char) -> bool {
625        let next = self.nfa.step(&self.current_states, ch);
626        !next.is_empty()
627    }
628}
629
630impl TokenConstraint for RegexConstraint {
631    fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
632        // If already in a dead state, nothing is allowed.
633        if self.current_states.is_empty() {
634            return Some(vec![false; vocab_size]);
635        }
636        // We cannot map token ids to characters without a real vocabulary table,
637        // so we return None (allow all) as a safe conservative choice.
638        // The constraint is enforced via `advance` which rejects invalid tokens.
639        None
640    }
641
642    fn advance(&mut self, token: u32) -> bool {
643        // Treat the token id as a codepoint for demonstration purposes.
644        // In a real integration the caller would pass token bytes/text.
645        let ch = char::from_u32(token).unwrap_or('\u{FFFD}');
646        let next = self.nfa.step(&self.current_states, ch);
647        if next.is_empty() {
648            return false;
649        }
650        self.current_states = next;
651        self.matched_so_far.push(ch);
652        true
653    }
654
655    fn is_complete(&self) -> bool {
656        self.nfa.is_accepting(&self.current_states)
657    }
658
659    fn reset(&mut self) {
660        self.current_states = self.nfa.epsilon_closure(vec![self.nfa.start]);
661        self.matched_so_far.clear();
662    }
663
664    fn name(&self) -> &str {
665        &self.pattern
666    }
667}
668
669#[cfg(test)]
670mod tests {
671    use super::*;
672
673    // ── RegexNfa ─────────────────────────────────────────────────────────────
674
675    #[test]
676    fn regex_nfa_literal_match() {
677        let nfa = RegexNfa::from_pattern("abc").expect("valid pattern");
678        assert!(nfa.is_full_match("abc"));
679        assert!(!nfa.is_full_match("ab"));
680        assert!(!nfa.is_full_match("abcd"));
681    }
682
683    #[test]
684    fn regex_nfa_dot_match() {
685        let nfa = RegexNfa::from_pattern("a.c").expect("valid pattern");
686        assert!(nfa.is_full_match("abc"));
687        assert!(nfa.is_full_match("axc"));
688        assert!(!nfa.is_full_match("ac"));
689    }
690
691    #[test]
692    fn regex_nfa_star_quantifier() {
693        let nfa = RegexNfa::from_pattern("ab*c").expect("valid pattern");
694        assert!(nfa.is_full_match("ac"));
695        assert!(nfa.is_full_match("abc"));
696        assert!(nfa.is_full_match("abbc"));
697        assert!(!nfa.is_full_match("xbc"));
698    }
699
700    #[test]
701    fn regex_nfa_alternation() {
702        let nfa = RegexNfa::from_pattern("cat|dog").expect("valid pattern");
703        assert!(nfa.is_full_match("cat"));
704        assert!(nfa.is_full_match("dog"));
705        assert!(!nfa.is_full_match("cow"));
706    }
707
708    // ── RegexConstraint ──────────────────────────────────────────────────────
709
710    #[test]
711    fn regex_constraint_is_match() {
712        assert!(RegexConstraint::is_match("he+llo", "hello"));
713        assert!(RegexConstraint::is_match("he+llo", "heeeello"));
714        assert!(!RegexConstraint::is_match("he+llo", "hllo"));
715    }
716
717    #[test]
718    fn regex_constraint_allows_valid_chars() {
719        let rc = RegexConstraint::new("abc").expect("valid");
720        // 'a' (97) should be valid as first char
721        assert!(rc.char_is_valid('a'));
722        assert!(!rc.char_is_valid('b')); // 'b' is not valid before 'a'
723    }
724}