1use super::error_trait::{ConstraintError, TokenConstraint};
8
9#[derive(Debug, Clone)]
15pub(super) enum NfaState {
16 Literal(char, usize),
18 Any(usize),
20 Split(usize, usize),
22 Class {
24 chars: Vec<char>,
25 ranges: Vec<(char, char)>,
26 negated: bool,
27 next: usize,
28 },
29 Accept,
31}
32
33#[derive(Debug, Clone)]
35pub(super) struct RegexNfa {
36 states: Vec<NfaState>,
37 start: usize,
38 accept_state: usize,
39}
40
41pub(super) struct Fragment {
44 start: usize,
45 outs: Vec<usize>,
47}
48
49impl RegexNfa {
50 pub(super) fn from_pattern(pattern: &str) -> Result<Self, ConstraintError> {
52 let mut nfa = RegexNfa {
53 states: Vec::new(),
54 start: 0,
55 accept_state: 0,
56 };
57 let chars: Vec<char> = pattern.chars().collect();
58 let frag = nfa
59 .compile(&chars, 0)
60 .map_err(ConstraintError::InvalidPattern)?;
61 let accept = nfa.push(NfaState::Accept);
63 nfa.accept_state = accept;
64 nfa.patch(&frag.outs, accept);
65 nfa.start = frag.start;
66 Ok(nfa)
67 }
68
69 fn push(&mut self, state: NfaState) -> usize {
70 let idx = self.states.len();
71 self.states.push(state);
72 idx
73 }
74
75 fn patch(&mut self, outs: &[usize], target: usize) {
77 for &idx in outs {
78 match &mut self.states[idx] {
79 NfaState::Literal(_, ref mut n)
80 | NfaState::Any(ref mut n)
81 | NfaState::Class {
82 next: ref mut n, ..
83 } => *n = target,
84 NfaState::Split(ref mut a, ref mut b) => {
85 if *a == usize::MAX {
87 *a = target;
88 }
89 if *b == usize::MAX {
90 *b = target;
91 }
92 }
93 NfaState::Accept => {}
94 }
95 }
96 }
97
98 fn compile(&mut self, chars: &[char], mut pos: usize) -> Result<Fragment, String> {
100 let mut alt_frags: Vec<Fragment> = Vec::new();
102 let mut cur_frags: Vec<Fragment> = Vec::new();
103
104 while pos < chars.len() {
105 let ch = chars[pos];
106
107 if ch == '|' {
109 let seq = Self::concat_fragments(&mut self.states, cur_frags);
110 alt_frags.push(seq);
111 cur_frags = Vec::new();
112 pos += 1;
113 continue;
114 }
115
116 if ch == ')' {
118 break;
119 }
120
121 let (atom, new_pos) = self.parse_atom(chars, pos)?;
123 pos = new_pos;
124
125 let quantified = if pos < chars.len() {
127 match chars[pos] {
128 '?' => {
129 pos += 1;
130 self.quantifier_optional(atom)
131 }
132 '*' => {
133 pos += 1;
134 self.quantifier_star(atom)
135 }
136 '+' => {
137 pos += 1;
138 self.quantifier_plus(atom)
139 }
140 _ => atom,
141 }
142 } else {
143 atom
144 };
145
146 cur_frags.push(quantified);
147 }
148
149 let seq = Self::concat_fragments(&mut self.states, cur_frags);
151 alt_frags.push(seq);
152
153 let result = if alt_frags.len() == 1 {
155 alt_frags.remove(0)
156 } else {
157 self.alternation(alt_frags)
158 };
159
160 Ok(result)
161 }
162
163 fn parse_atom(&mut self, chars: &[char], pos: usize) -> Result<(Fragment, usize), String> {
165 if pos >= chars.len() {
166 return Err("Unexpected end of pattern".to_string());
167 }
168 let ch = chars[pos];
169 match ch {
170 '(' => {
171 let inner = self.compile(chars, pos + 1)?;
173 let mut depth = 1usize;
175 let mut i = pos + 1;
176 while i < chars.len() {
177 match chars[i] {
178 '(' => depth += 1,
179 ')' => {
180 depth -= 1;
181 if depth == 0 {
182 break;
183 }
184 }
185 '\\' => {
186 i += 1;
187 } _ => {}
189 }
190 i += 1;
191 }
192 let new_pos = if i < chars.len() && chars[i] == ')' {
193 i + 1
194 } else {
195 i
196 };
197 Ok((inner, new_pos))
198 }
199 '[' => {
200 let (frag, new_pos) = self.parse_class(chars, pos)?;
201 Ok((frag, new_pos))
202 }
203 '.' => {
204 let idx = self.push(NfaState::Any(usize::MAX));
205 Ok((
206 Fragment {
207 start: idx,
208 outs: vec![idx],
209 },
210 pos + 1,
211 ))
212 }
213 '\\' => {
214 let (frag, new_pos) = self.parse_escape(chars, pos)?;
215 Ok((frag, new_pos))
216 }
217 _ if ch == '*' || ch == '+' || ch == '?' => {
218 Err(format!("Unexpected quantifier '{ch}' at position {pos}"))
219 }
220 _ => {
221 let idx = self.push(NfaState::Literal(ch, usize::MAX));
222 Ok((
223 Fragment {
224 start: idx,
225 outs: vec![idx],
226 },
227 pos + 1,
228 ))
229 }
230 }
231 }
232
233 fn parse_class(&mut self, chars: &[char], start: usize) -> Result<(Fragment, usize), String> {
235 let mut pos = start + 1;
237 let negated = if pos < chars.len() && chars[pos] == '^' {
238 pos += 1;
239 true
240 } else {
241 false
242 };
243
244 let mut class_chars: Vec<char> = Vec::new();
245 let mut ranges: Vec<(char, char)> = Vec::new();
246
247 while pos < chars.len() && chars[pos] != ']' {
248 if chars[pos] == '\\' && pos + 1 < chars.len() {
249 let escaped = chars[pos + 1];
251 match escaped {
252 'd' => ranges.push(('0', '9')),
253 'w' => {
254 ranges.push(('a', 'z'));
255 ranges.push(('A', 'Z'));
256 ranges.push(('0', '9'));
257 class_chars.push('_');
258 }
259 's' => {
260 class_chars.extend_from_slice(&[' ', '\t', '\n', '\r']);
261 }
262 _ => class_chars.push(escaped),
263 }
264 pos += 2;
265 } else if pos + 2 < chars.len() && chars[pos + 1] == '-' && chars[pos + 2] != ']' {
266 ranges.push((chars[pos], chars[pos + 2]));
267 pos += 3;
268 } else {
269 class_chars.push(chars[pos]);
270 pos += 1;
271 }
272 }
273
274 let new_pos = if pos < chars.len() && chars[pos] == ']' {
275 pos + 1
276 } else {
277 pos
278 };
279
280 let idx = self.push(NfaState::Class {
281 chars: class_chars,
282 ranges,
283 negated,
284 next: usize::MAX,
285 });
286 Ok((
287 Fragment {
288 start: idx,
289 outs: vec![idx],
290 },
291 new_pos,
292 ))
293 }
294
295 fn parse_escape(&mut self, chars: &[char], pos: usize) -> Result<(Fragment, usize), String> {
297 if pos + 1 >= chars.len() {
298 return Err("Trailing backslash in pattern".to_string());
299 }
300 let escaped = chars[pos + 1];
301 let (class_chars, ranges): (Vec<char>, Vec<(char, char)>) = match escaped {
302 'd' => (vec![], vec![('0', '9')]),
303 'D' => {
304 let idx = self.push(NfaState::Class {
306 chars: vec![],
307 ranges: vec![('0', '9')],
308 negated: true,
309 next: usize::MAX,
310 });
311 return Ok((
312 Fragment {
313 start: idx,
314 outs: vec![idx],
315 },
316 pos + 2,
317 ));
318 }
319 'w' => (vec!['_'], vec![('a', 'z'), ('A', 'Z'), ('0', '9')]),
320 'W' => {
321 let idx = self.push(NfaState::Class {
322 chars: vec!['_'],
323 ranges: vec![('a', 'z'), ('A', 'Z'), ('0', '9')],
324 negated: true,
325 next: usize::MAX,
326 });
327 return Ok((
328 Fragment {
329 start: idx,
330 outs: vec![idx],
331 },
332 pos + 2,
333 ));
334 }
335 's' => (vec![' ', '\t', '\n', '\r'], vec![]),
336 'S' => {
337 let idx = self.push(NfaState::Class {
338 chars: vec![' ', '\t', '\n', '\r'],
339 ranges: vec![],
340 negated: true,
341 next: usize::MAX,
342 });
343 return Ok((
344 Fragment {
345 start: idx,
346 outs: vec![idx],
347 },
348 pos + 2,
349 ));
350 }
351 'n' => {
352 let idx = self.push(NfaState::Literal('\n', usize::MAX));
353 return Ok((
354 Fragment {
355 start: idx,
356 outs: vec![idx],
357 },
358 pos + 2,
359 ));
360 }
361 'r' => {
362 let idx = self.push(NfaState::Literal('\r', usize::MAX));
363 return Ok((
364 Fragment {
365 start: idx,
366 outs: vec![idx],
367 },
368 pos + 2,
369 ));
370 }
371 't' => {
372 let idx = self.push(NfaState::Literal('\t', usize::MAX));
373 return Ok((
374 Fragment {
375 start: idx,
376 outs: vec![idx],
377 },
378 pos + 2,
379 ));
380 }
381 _ => {
382 let idx = self.push(NfaState::Literal(escaped, usize::MAX));
384 return Ok((
385 Fragment {
386 start: idx,
387 outs: vec![idx],
388 },
389 pos + 2,
390 ));
391 }
392 };
393 let idx = self.push(NfaState::Class {
394 chars: class_chars,
395 ranges,
396 negated: false,
397 next: usize::MAX,
398 });
399 Ok((
400 Fragment {
401 start: idx,
402 outs: vec![idx],
403 },
404 pos + 2,
405 ))
406 }
407
408 fn quantifier_optional(&mut self, frag: Fragment) -> Fragment {
412 let split = self.push(NfaState::Split(frag.start, usize::MAX));
413 let mut outs = frag.outs;
414 outs.push(split); Fragment { start: split, outs }
416 }
417
418 fn quantifier_star(&mut self, frag: Fragment) -> Fragment {
420 let split = self.push(NfaState::Split(frag.start, usize::MAX));
421 self.patch(&frag.outs, split);
423 Fragment {
424 start: split,
425 outs: vec![split],
426 }
427 }
428
429 fn quantifier_plus(&mut self, frag: Fragment) -> Fragment {
431 let split = self.push(NfaState::Split(frag.start, usize::MAX));
432 self.patch(&frag.outs, split);
433 Fragment {
434 start: frag.start,
435 outs: vec![split],
436 }
437 }
438
439 fn alternation(&mut self, frags: Vec<Fragment>) -> Fragment {
441 if frags.is_empty() {
442 let split = self.push(NfaState::Split(usize::MAX, usize::MAX));
443 return Fragment {
444 start: split,
445 outs: vec![split],
446 };
447 }
448 let mut iter = frags.into_iter();
449 let mut current = iter.next().expect("non-empty checked above");
450 for next_frag in iter {
451 let split = self.push(NfaState::Split(current.start, next_frag.start));
452 let mut outs = current.outs;
453 outs.extend(next_frag.outs);
454 current = Fragment { start: split, outs };
455 }
456 current
457 }
458
459 fn concat_fragments(states: &mut Vec<NfaState>, frags: Vec<Fragment>) -> Fragment {
461 if frags.is_empty() {
462 let idx = states.len();
464 states.push(NfaState::Split(usize::MAX, usize::MAX));
465 return Fragment {
466 start: idx,
467 outs: vec![idx],
468 };
469 }
470 let mut iter = frags.into_iter();
471 let first = iter.next().expect("non-empty checked above");
472 iter.fold(first, |acc, next| {
473 for &idx in &acc.outs {
475 match &mut states[idx] {
476 NfaState::Literal(_, ref mut n)
477 | NfaState::Any(ref mut n)
478 | NfaState::Class {
479 next: ref mut n, ..
480 } => {
481 if *n == usize::MAX {
482 *n = next.start;
483 }
484 }
485 NfaState::Split(ref mut a, ref mut b) => {
486 if *a == usize::MAX {
487 *a = next.start;
488 } else if *b == usize::MAX {
489 *b = next.start;
490 }
491 }
492 NfaState::Accept => {}
493 }
494 }
495 Fragment {
496 start: acc.start,
497 outs: next.outs,
498 }
499 })
500 }
501
502 fn epsilon_closure(&self, states: Vec<usize>) -> Vec<usize> {
506 let mut closure: Vec<usize> = Vec::new();
507 let mut stack = states;
508 let mut visited = std::collections::HashSet::new();
509 while let Some(s) = stack.pop() {
510 if s == usize::MAX || !visited.insert(s) {
511 continue;
512 }
513 closure.push(s);
514 if let Some(NfaState::Split(a, b)) = self.states.get(s) {
515 if *a != usize::MAX {
516 stack.push(*a);
517 }
518 if *b != usize::MAX {
519 stack.push(*b);
520 }
521 }
522 }
523 closure
524 }
525
526 fn step(&self, states: &[usize], ch: char) -> Vec<usize> {
528 let mut next = Vec::new();
529 for &s in states {
530 if s == usize::MAX {
531 continue;
532 }
533 if let Some(state) = self.states.get(s) {
534 match state {
535 NfaState::Literal(c, n) => {
536 if *c == ch && *n != usize::MAX {
537 next.push(*n);
538 }
539 }
540 NfaState::Any(n) => {
541 if *n != usize::MAX {
542 next.push(*n);
543 }
544 }
545 NfaState::Class {
546 chars,
547 ranges,
548 negated,
549 next: n,
550 } => {
551 let matched = chars.contains(&ch)
552 || ranges.iter().any(|&(lo, hi)| ch >= lo && ch <= hi);
553 let effective = if *negated { !matched } else { matched };
554 if effective && *n != usize::MAX {
555 next.push(*n);
556 }
557 }
558 NfaState::Split(_, _) | NfaState::Accept => {}
559 }
560 }
561 }
562 self.epsilon_closure(next)
563 }
564
565 fn is_accepting(&self, states: &[usize]) -> bool {
567 states.contains(&self.accept_state)
568 }
569
570 fn is_full_match(&self, text: &str) -> bool {
572 let initial = self.epsilon_closure(vec![self.start]);
573 let final_states = text.chars().fold(initial, |s, ch| self.step(&s, ch));
574 self.is_accepting(&final_states)
575 }
576}
577
578pub struct RegexConstraint {
591 pattern: String,
592 nfa: RegexNfa,
593 current_states: Vec<usize>,
594 matched_so_far: String,
595}
596
597impl RegexConstraint {
598 pub fn new(pattern: &str) -> Result<Self, ConstraintError> {
600 let nfa = RegexNfa::from_pattern(pattern)?;
601 let current_states = nfa.epsilon_closure(vec![nfa.start]);
602 Ok(Self {
603 pattern: pattern.to_string(),
604 nfa,
605 current_states,
606 matched_so_far: String::new(),
607 })
608 }
609
610 pub fn is_match(pattern: &str, text: &str) -> bool {
612 match RegexNfa::from_pattern(pattern) {
613 Ok(nfa) => nfa.is_full_match(text),
614 Err(_) => false,
615 }
616 }
617
618 pub fn current_partial(&self) -> &str {
620 &self.matched_so_far
621 }
622
623 pub fn char_is_valid(&self, ch: char) -> bool {
625 let next = self.nfa.step(&self.current_states, ch);
626 !next.is_empty()
627 }
628}
629
630impl TokenConstraint for RegexConstraint {
631 fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
632 if self.current_states.is_empty() {
634 return Some(vec![false; vocab_size]);
635 }
636 None
640 }
641
642 fn advance(&mut self, token: u32) -> bool {
643 let ch = char::from_u32(token).unwrap_or('\u{FFFD}');
646 let next = self.nfa.step(&self.current_states, ch);
647 if next.is_empty() {
648 return false;
649 }
650 self.current_states = next;
651 self.matched_so_far.push(ch);
652 true
653 }
654
655 fn is_complete(&self) -> bool {
656 self.nfa.is_accepting(&self.current_states)
657 }
658
659 fn reset(&mut self) {
660 self.current_states = self.nfa.epsilon_closure(vec![self.nfa.start]);
661 self.matched_so_far.clear();
662 }
663
664 fn name(&self) -> &str {
665 &self.pattern
666 }
667}
668
669#[cfg(test)]
670mod tests {
671 use super::*;
672
673 #[test]
676 fn regex_nfa_literal_match() {
677 let nfa = RegexNfa::from_pattern("abc").expect("valid pattern");
678 assert!(nfa.is_full_match("abc"));
679 assert!(!nfa.is_full_match("ab"));
680 assert!(!nfa.is_full_match("abcd"));
681 }
682
683 #[test]
684 fn regex_nfa_dot_match() {
685 let nfa = RegexNfa::from_pattern("a.c").expect("valid pattern");
686 assert!(nfa.is_full_match("abc"));
687 assert!(nfa.is_full_match("axc"));
688 assert!(!nfa.is_full_match("ac"));
689 }
690
691 #[test]
692 fn regex_nfa_star_quantifier() {
693 let nfa = RegexNfa::from_pattern("ab*c").expect("valid pattern");
694 assert!(nfa.is_full_match("ac"));
695 assert!(nfa.is_full_match("abc"));
696 assert!(nfa.is_full_match("abbc"));
697 assert!(!nfa.is_full_match("xbc"));
698 }
699
700 #[test]
701 fn regex_nfa_alternation() {
702 let nfa = RegexNfa::from_pattern("cat|dog").expect("valid pattern");
703 assert!(nfa.is_full_match("cat"));
704 assert!(nfa.is_full_match("dog"));
705 assert!(!nfa.is_full_match("cow"));
706 }
707
708 #[test]
711 fn regex_constraint_is_match() {
712 assert!(RegexConstraint::is_match("he+llo", "hello"));
713 assert!(RegexConstraint::is_match("he+llo", "heeeello"));
714 assert!(!RegexConstraint::is_match("he+llo", "hllo"));
715 }
716
717 #[test]
718 fn regex_constraint_allows_valid_chars() {
719 let rc = RegexConstraint::new("abc").expect("valid");
720 assert!(rc.char_is_valid('a'));
722 assert!(!rc.char_is_valid('b')); }
724}