#[derive(Debug, thiserror::Error)]
pub enum ConstraintError {
#[error("Invalid regex pattern: {0}")]
InvalidPattern(String),
#[error("Invalid JSON schema: {0}")]
InvalidSchema(String),
#[error("Constraint violated at token {token}: {reason}")]
Violated { token: u32, reason: String },
}
pub trait TokenConstraint: Send + Sync {
fn allowed_tokens(&self, generated: &[u32], vocab_size: usize) -> Option<Vec<bool>>;
fn advance(&mut self, token: u32) -> bool;
fn is_complete(&self) -> bool;
fn reset(&mut self);
fn name(&self) -> &str;
}
pub struct NoConstraint;
impl TokenConstraint for NoConstraint {
fn allowed_tokens(&self, _generated: &[u32], _vocab_size: usize) -> Option<Vec<bool>> {
None
}
fn advance(&mut self, _token: u32) -> bool {
true
}
fn is_complete(&self) -> bool {
true
}
fn reset(&mut self) {}
fn name(&self) -> &str {
"NoConstraint"
}
}
#[derive(Debug, Clone)]
enum NfaState {
Literal(char, usize),
Any(usize),
Split(usize, usize),
Class {
chars: Vec<char>,
ranges: Vec<(char, char)>,
negated: bool,
next: usize,
},
Accept,
}
#[derive(Debug, Clone)]
struct RegexNfa {
states: Vec<NfaState>,
start: usize,
accept_state: usize,
}
struct Fragment {
start: usize,
outs: Vec<usize>,
}
impl RegexNfa {
fn from_pattern(pattern: &str) -> Result<Self, ConstraintError> {
let mut nfa = RegexNfa {
states: Vec::new(),
start: 0,
accept_state: 0,
};
let chars: Vec<char> = pattern.chars().collect();
let frag = nfa
.compile(&chars, 0)
.map_err(ConstraintError::InvalidPattern)?;
let accept = nfa.push(NfaState::Accept);
nfa.accept_state = accept;
nfa.patch(&frag.outs, accept);
nfa.start = frag.start;
Ok(nfa)
}
fn push(&mut self, state: NfaState) -> usize {
let idx = self.states.len();
self.states.push(state);
idx
}
fn patch(&mut self, outs: &[usize], target: usize) {
for &idx in outs {
match &mut self.states[idx] {
NfaState::Literal(_, ref mut n)
| NfaState::Any(ref mut n)
| NfaState::Class {
next: ref mut n, ..
} => *n = target,
NfaState::Split(ref mut a, ref mut b) => {
if *a == usize::MAX {
*a = target;
}
if *b == usize::MAX {
*b = target;
}
}
NfaState::Accept => {}
}
}
}
fn compile(&mut self, chars: &[char], mut pos: usize) -> Result<Fragment, String> {
let mut alt_frags: Vec<Fragment> = Vec::new();
let mut cur_frags: Vec<Fragment> = Vec::new();
while pos < chars.len() {
let ch = chars[pos];
if ch == '|' {
let seq = Self::concat_fragments(&mut self.states, cur_frags);
alt_frags.push(seq);
cur_frags = Vec::new();
pos += 1;
continue;
}
if ch == ')' {
break;
}
let (atom, new_pos) = self.parse_atom(chars, pos)?;
pos = new_pos;
let quantified = if pos < chars.len() {
match chars[pos] {
'?' => {
pos += 1;
self.quantifier_optional(atom)
}
'*' => {
pos += 1;
self.quantifier_star(atom)
}
'+' => {
pos += 1;
self.quantifier_plus(atom)
}
_ => atom,
}
} else {
atom
};
cur_frags.push(quantified);
}
let seq = Self::concat_fragments(&mut self.states, cur_frags);
alt_frags.push(seq);
let result = if alt_frags.len() == 1 {
alt_frags.remove(0)
} else {
self.alternation(alt_frags)
};
Ok(result)
}
fn parse_atom(&mut self, chars: &[char], pos: usize) -> Result<(Fragment, usize), String> {
if pos >= chars.len() {
return Err("Unexpected end of pattern".to_string());
}
let ch = chars[pos];
match ch {
'(' => {
let inner = self.compile(chars, pos + 1)?;
let mut depth = 1usize;
let mut i = pos + 1;
while i < chars.len() {
match chars[i] {
'(' => depth += 1,
')' => {
depth -= 1;
if depth == 0 {
break;
}
}
'\\' => {
i += 1;
} _ => {}
}
i += 1;
}
let new_pos = if i < chars.len() && chars[i] == ')' {
i + 1
} else {
i
};
Ok((inner, new_pos))
}
'[' => {
let (frag, new_pos) = self.parse_class(chars, pos)?;
Ok((frag, new_pos))
}
'.' => {
let idx = self.push(NfaState::Any(usize::MAX));
Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 1,
))
}
'\\' => {
let (frag, new_pos) = self.parse_escape(chars, pos)?;
Ok((frag, new_pos))
}
_ if ch == '*' || ch == '+' || ch == '?' => {
Err(format!("Unexpected quantifier '{ch}' at position {pos}"))
}
_ => {
let idx = self.push(NfaState::Literal(ch, usize::MAX));
Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 1,
))
}
}
}
fn parse_class(&mut self, chars: &[char], start: usize) -> Result<(Fragment, usize), String> {
let mut pos = start + 1;
let negated = if pos < chars.len() && chars[pos] == '^' {
pos += 1;
true
} else {
false
};
let mut class_chars: Vec<char> = Vec::new();
let mut ranges: Vec<(char, char)> = Vec::new();
while pos < chars.len() && chars[pos] != ']' {
if chars[pos] == '\\' && pos + 1 < chars.len() {
let escaped = chars[pos + 1];
match escaped {
'd' => ranges.push(('0', '9')),
'w' => {
ranges.push(('a', 'z'));
ranges.push(('A', 'Z'));
ranges.push(('0', '9'));
class_chars.push('_');
}
's' => {
class_chars.extend_from_slice(&[' ', '\t', '\n', '\r']);
}
_ => class_chars.push(escaped),
}
pos += 2;
} else if pos + 2 < chars.len() && chars[pos + 1] == '-' && chars[pos + 2] != ']' {
ranges.push((chars[pos], chars[pos + 2]));
pos += 3;
} else {
class_chars.push(chars[pos]);
pos += 1;
}
}
let new_pos = if pos < chars.len() && chars[pos] == ']' {
pos + 1
} else {
pos
};
let idx = self.push(NfaState::Class {
chars: class_chars,
ranges,
negated,
next: usize::MAX,
});
Ok((
Fragment {
start: idx,
outs: vec![idx],
},
new_pos,
))
}
fn parse_escape(&mut self, chars: &[char], pos: usize) -> Result<(Fragment, usize), String> {
if pos + 1 >= chars.len() {
return Err("Trailing backslash in pattern".to_string());
}
let escaped = chars[pos + 1];
let (class_chars, ranges): (Vec<char>, Vec<(char, char)>) = match escaped {
'd' => (vec![], vec![('0', '9')]),
'D' => {
let idx = self.push(NfaState::Class {
chars: vec![],
ranges: vec![('0', '9')],
negated: true,
next: usize::MAX,
});
return Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
));
}
'w' => (vec!['_'], vec![('a', 'z'), ('A', 'Z'), ('0', '9')]),
'W' => {
let idx = self.push(NfaState::Class {
chars: vec!['_'],
ranges: vec![('a', 'z'), ('A', 'Z'), ('0', '9')],
negated: true,
next: usize::MAX,
});
return Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
));
}
's' => (vec![' ', '\t', '\n', '\r'], vec![]),
'S' => {
let idx = self.push(NfaState::Class {
chars: vec![' ', '\t', '\n', '\r'],
ranges: vec![],
negated: true,
next: usize::MAX,
});
return Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
));
}
'n' => {
let idx = self.push(NfaState::Literal('\n', usize::MAX));
return Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
));
}
'r' => {
let idx = self.push(NfaState::Literal('\r', usize::MAX));
return Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
));
}
't' => {
let idx = self.push(NfaState::Literal('\t', usize::MAX));
return Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
));
}
_ => {
let idx = self.push(NfaState::Literal(escaped, usize::MAX));
return Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
));
}
};
let idx = self.push(NfaState::Class {
chars: class_chars,
ranges,
negated: false,
next: usize::MAX,
});
Ok((
Fragment {
start: idx,
outs: vec![idx],
},
pos + 2,
))
}
fn quantifier_optional(&mut self, frag: Fragment) -> Fragment {
let split = self.push(NfaState::Split(frag.start, usize::MAX));
let mut outs = frag.outs;
outs.push(split); Fragment { start: split, outs }
}
fn quantifier_star(&mut self, frag: Fragment) -> Fragment {
let split = self.push(NfaState::Split(frag.start, usize::MAX));
self.patch(&frag.outs, split);
Fragment {
start: split,
outs: vec![split],
}
}
fn quantifier_plus(&mut self, frag: Fragment) -> Fragment {
let split = self.push(NfaState::Split(frag.start, usize::MAX));
self.patch(&frag.outs, split);
Fragment {
start: frag.start,
outs: vec![split],
}
}
fn alternation(&mut self, frags: Vec<Fragment>) -> Fragment {
if frags.is_empty() {
let split = self.push(NfaState::Split(usize::MAX, usize::MAX));
return Fragment {
start: split,
outs: vec![split],
};
}
let mut iter = frags.into_iter();
let mut current = iter.next().expect("non-empty checked above");
for next_frag in iter {
let split = self.push(NfaState::Split(current.start, next_frag.start));
let mut outs = current.outs;
outs.extend(next_frag.outs);
current = Fragment { start: split, outs };
}
current
}
fn concat_fragments(states: &mut Vec<NfaState>, frags: Vec<Fragment>) -> Fragment {
if frags.is_empty() {
let idx = states.len();
states.push(NfaState::Split(usize::MAX, usize::MAX));
return Fragment {
start: idx,
outs: vec![idx],
};
}
let mut iter = frags.into_iter();
let first = iter.next().expect("non-empty checked above");
iter.fold(first, |acc, next| {
for &idx in &acc.outs {
match &mut states[idx] {
NfaState::Literal(_, ref mut n)
| NfaState::Any(ref mut n)
| NfaState::Class {
next: ref mut n, ..
} => {
if *n == usize::MAX {
*n = next.start;
}
}
NfaState::Split(ref mut a, ref mut b) => {
if *a == usize::MAX {
*a = next.start;
} else if *b == usize::MAX {
*b = next.start;
}
}
NfaState::Accept => {}
}
}
Fragment {
start: acc.start,
outs: next.outs,
}
})
}
fn epsilon_closure(&self, states: Vec<usize>) -> Vec<usize> {
let mut closure: Vec<usize> = Vec::new();
let mut stack = states;
let mut visited = std::collections::HashSet::new();
while let Some(s) = stack.pop() {
if s == usize::MAX || !visited.insert(s) {
continue;
}
closure.push(s);
if let Some(NfaState::Split(a, b)) = self.states.get(s) {
if *a != usize::MAX {
stack.push(*a);
}
if *b != usize::MAX {
stack.push(*b);
}
}
}
closure
}
fn step(&self, states: &[usize], ch: char) -> Vec<usize> {
let mut next = Vec::new();
for &s in states {
if s == usize::MAX {
continue;
}
if let Some(state) = self.states.get(s) {
match state {
NfaState::Literal(c, n) => {
if *c == ch && *n != usize::MAX {
next.push(*n);
}
}
NfaState::Any(n) => {
if *n != usize::MAX {
next.push(*n);
}
}
NfaState::Class {
chars,
ranges,
negated,
next: n,
} => {
let matched = chars.contains(&ch)
|| ranges.iter().any(|&(lo, hi)| ch >= lo && ch <= hi);
let effective = if *negated { !matched } else { matched };
if effective && *n != usize::MAX {
next.push(*n);
}
}
NfaState::Split(_, _) | NfaState::Accept => {}
}
}
}
self.epsilon_closure(next)
}
fn is_accepting(&self, states: &[usize]) -> bool {
states.contains(&self.accept_state)
}
fn is_full_match(&self, text: &str) -> bool {
let initial = self.epsilon_closure(vec![self.start]);
let final_states = text.chars().fold(initial, |s, ch| self.step(&s, ch));
self.is_accepting(&final_states)
}
}
pub struct RegexConstraint {
pattern: String,
nfa: RegexNfa,
current_states: Vec<usize>,
matched_so_far: String,
}
impl RegexConstraint {
pub fn new(pattern: &str) -> Result<Self, ConstraintError> {
let nfa = RegexNfa::from_pattern(pattern)?;
let current_states = nfa.epsilon_closure(vec![nfa.start]);
Ok(Self {
pattern: pattern.to_string(),
nfa,
current_states,
matched_so_far: String::new(),
})
}
pub fn is_match(pattern: &str, text: &str) -> bool {
match RegexNfa::from_pattern(pattern) {
Ok(nfa) => nfa.is_full_match(text),
Err(_) => false,
}
}
pub fn current_partial(&self) -> &str {
&self.matched_so_far
}
pub fn char_is_valid(&self, ch: char) -> bool {
let next = self.nfa.step(&self.current_states, ch);
!next.is_empty()
}
}
impl TokenConstraint for RegexConstraint {
fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
if self.current_states.is_empty() {
return Some(vec![false; vocab_size]);
}
None
}
fn advance(&mut self, token: u32) -> bool {
let ch = char::from_u32(token).unwrap_or('\u{FFFD}');
let next = self.nfa.step(&self.current_states, ch);
if next.is_empty() {
return false;
}
self.current_states = next;
self.matched_so_far.push(ch);
true
}
fn is_complete(&self) -> bool {
self.nfa.is_accepting(&self.current_states)
}
fn reset(&mut self) {
self.current_states = self.nfa.epsilon_closure(vec![self.nfa.start]);
self.matched_so_far.clear();
}
fn name(&self) -> &str {
&self.pattern
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum JsonParseState {
Start,
InObject,
InObjectKey,
AfterKey,
InObjectValue,
InArray,
InArrayValue,
InString,
InStringEscape,
InNumber,
InBool,
InNull,
Complete,
Error,
}
pub struct JsonConstraint {
state: JsonParseState,
depth: usize,
buffer: String,
expecting_comma_or_close: bool,
keyword_buf: String,
context_stack: Vec<char>,
}
impl JsonConstraint {
pub fn new() -> Self {
Self {
state: JsonParseState::Start,
depth: 0,
buffer: String::new(),
expecting_comma_or_close: false,
keyword_buf: String::new(),
context_stack: Vec::new(),
}
}
pub fn current_state(&self) -> &JsonParseState {
&self.state
}
pub fn depth(&self) -> usize {
self.depth
}
pub fn is_in_string(&self) -> bool {
matches!(
self.state,
JsonParseState::InString | JsonParseState::InStringEscape
)
}
pub fn valid_next_chars(&self) -> Vec<char> {
match &self.state {
JsonParseState::Start => {
vec![
'{', '[', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't', 'f',
'n', ' ', '\t', '\n',
]
}
JsonParseState::InObject => {
if self.expecting_comma_or_close {
vec![',', '}', ' ', '\t', '\n']
} else {
vec!['"', '}', ' ', '\t', '\n']
}
}
JsonParseState::InObjectKey => {
let mut v: Vec<char> = (0x20u8..0x7fu8)
.filter(|&c| c != b'"')
.map(|c| c as char)
.collect();
v.push('"'); v.push('\\');
v
}
JsonParseState::AfterKey => vec![':', ' ', '\t'],
JsonParseState::InObjectValue
| JsonParseState::InArrayValue
| JsonParseState::InArray => {
if self.expecting_comma_or_close {
if self.context_stack.last() == Some(&'o') {
vec![',', '}', ' ', '\t', '\n']
} else {
vec![',', ']', ' ', '\t', '\n']
}
} else {
vec![
'{', '[', '"', '-', '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 't',
'f', 'n', ' ', '\t', '\n',
]
}
}
JsonParseState::InString => {
let mut v: Vec<char> = (0x20u8..0x7fu8)
.filter(|&c| c != b'"')
.map(|c| c as char)
.collect();
v.push('"');
v.push('\\');
v
}
JsonParseState::InStringEscape => {
vec!['"', '\\', '/', 'b', 'f', 'n', 'r', 't', 'u']
}
JsonParseState::InNumber => {
vec![
'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', 'e', 'E', '+', '-', ',',
'}', ']', ' ', '\t', '\n',
]
}
JsonParseState::InBool | JsonParseState::InNull => {
vec![
'r', 'u', 'e', 'a', 'l', 's', 'i', 'o', 'n', 't', 'f', ',', '}', ']', ' ',
'\t', '\n',
]
}
JsonParseState::Complete => {
vec![' ', '\t', '\n']
}
JsonParseState::Error => vec![],
}
}
fn feed_char(&mut self, ch: char) {
match &self.state.clone() {
JsonParseState::Error | JsonParseState::Complete => {
if self.state == JsonParseState::Complete && !ch.is_whitespace() {
self.state = JsonParseState::Error;
}
return;
}
JsonParseState::Start => {
if ch.is_whitespace() {
return;
}
match ch {
'{' => {
self.depth += 1;
self.context_stack.push('o');
self.state = JsonParseState::InObject;
self.expecting_comma_or_close = false;
}
'[' => {
self.depth += 1;
self.context_stack.push('a');
self.state = JsonParseState::InArray;
self.expecting_comma_or_close = false;
}
'"' => {
self.state = JsonParseState::InString;
}
'-' | '0'..='9' => {
self.state = JsonParseState::InNumber;
self.keyword_buf.clear();
self.keyword_buf.push(ch);
}
't' | 'f' => {
self.state = JsonParseState::InBool;
self.keyword_buf.clear();
self.keyword_buf.push(ch);
}
'n' => {
self.state = JsonParseState::InNull;
self.keyword_buf.clear();
self.keyword_buf.push(ch);
}
_ => {
self.state = JsonParseState::Error;
}
}
}
JsonParseState::InObject => {
if ch.is_whitespace() {
return;
}
if self.expecting_comma_or_close {
match ch {
',' => {
self.expecting_comma_or_close = false;
}
'}' => {
self.close_context();
}
_ => {
self.state = JsonParseState::Error;
}
}
} else {
match ch {
'"' => {
self.state = JsonParseState::InObjectKey;
}
'}' => {
self.close_context();
}
_ => {
self.state = JsonParseState::Error;
}
}
}
}
JsonParseState::InObjectKey => {
match ch {
'"' => {
self.state = JsonParseState::AfterKey;
}
'\\' => {
self.state = JsonParseState::InStringEscape;
}
_ => {} }
}
JsonParseState::AfterKey => {
if ch.is_whitespace() {
return;
}
if ch == ':' {
self.state = JsonParseState::InObjectValue;
self.expecting_comma_or_close = false;
} else {
self.state = JsonParseState::Error;
}
}
JsonParseState::InObjectValue => {
if ch.is_whitespace() {
return;
}
self.start_value(ch, 'o');
}
JsonParseState::InArray => {
if ch.is_whitespace() {
return;
}
if self.expecting_comma_or_close {
match ch {
',' => {
self.expecting_comma_or_close = false;
}
']' => {
self.close_context();
}
_ => {
self.state = JsonParseState::Error;
}
}
} else {
match ch {
']' => {
self.close_context();
}
_ => {
self.start_value(ch, 'a');
}
}
}
}
JsonParseState::InArrayValue => {
if ch.is_whitespace() {
return;
}
if self.expecting_comma_or_close {
if self.context_stack.last() == Some(&'a') {
match ch {
',' => {
self.expecting_comma_or_close = false;
self.state = JsonParseState::InArray;
}
']' => {
self.close_context();
}
_ => {
self.state = JsonParseState::Error;
}
}
} else {
match ch {
',' => {
self.expecting_comma_or_close = false;
self.state = JsonParseState::InObject;
}
'}' => {
self.close_context();
}
_ => {
self.state = JsonParseState::Error;
}
}
}
} else {
self.start_value(ch, *self.context_stack.last().unwrap_or(&'a'));
}
}
JsonParseState::InString => {
match ch {
'"' => {
self.finish_string();
}
'\\' => {
self.state = JsonParseState::InStringEscape;
}
_ => {} }
}
JsonParseState::InStringEscape => {
self.state = JsonParseState::InString;
}
JsonParseState::InNumber => {
match ch {
'0'..='9' | '.' | 'e' | 'E' | '+' | '-' => {
self.keyword_buf.push(ch);
}
_ => {
self.finish_value();
self.feed_char(ch);
}
}
}
JsonParseState::InBool => {
self.keyword_buf.push(ch);
let kb = self.keyword_buf.clone();
if kb == "true" || kb == "false" {
self.keyword_buf.clear();
self.finish_value();
} else if !"true".starts_with(kb.as_str()) && !"false".starts_with(kb.as_str()) {
self.state = JsonParseState::Error;
}
}
JsonParseState::InNull => {
self.keyword_buf.push(ch);
let kb = self.keyword_buf.clone();
if kb == "null" {
self.keyword_buf.clear();
self.finish_value();
} else if !"null".starts_with(kb.as_str()) {
self.state = JsonParseState::Error;
}
}
}
self.buffer.push(ch);
}
fn start_value(&mut self, ch: char, ctx: char) {
match ch {
'{' => {
self.depth += 1;
self.context_stack.push('o');
self.state = JsonParseState::InObject;
self.expecting_comma_or_close = false;
}
'[' => {
self.depth += 1;
self.context_stack.push('a');
self.state = JsonParseState::InArray;
self.expecting_comma_or_close = false;
}
'"' => {
self.state = JsonParseState::InString;
}
'-' | '0'..='9' => {
self.state = JsonParseState::InNumber;
self.keyword_buf.clear();
self.keyword_buf.push(ch);
let _ = ctx; }
't' | 'f' => {
self.state = JsonParseState::InBool;
self.keyword_buf.clear();
self.keyword_buf.push(ch);
}
'n' => {
self.state = JsonParseState::InNull;
self.keyword_buf.clear();
self.keyword_buf.push(ch);
}
_ => {
self.state = JsonParseState::Error;
}
}
}
fn finish_value(&mut self) {
self.expecting_comma_or_close = true;
match self.context_stack.last() {
Some(&'o') => {
self.state = JsonParseState::InObject;
}
Some(&'a') => {
self.state = JsonParseState::InArray;
}
None => {
self.state = JsonParseState::Complete;
}
_ => {
self.state = JsonParseState::Error;
}
}
}
fn finish_string(&mut self) {
match self.context_stack.last() {
Some(&'o') => {
self.state = JsonParseState::InObject;
self.expecting_comma_or_close = true;
}
Some(&'a') => {
self.state = JsonParseState::InArray;
self.expecting_comma_or_close = true;
}
None => {
self.state = JsonParseState::Complete;
}
_ => {
self.state = JsonParseState::Error;
}
}
}
fn close_context(&mut self) {
if let Some(ctx) = self.context_stack.pop() {
if ctx == 'o' || ctx == 'a' {
self.depth = self.depth.saturating_sub(1);
}
}
self.expecting_comma_or_close = true;
match self.context_stack.last() {
Some(&'o') => {
self.state = JsonParseState::InObject;
}
Some(&'a') => {
self.state = JsonParseState::InArray;
}
None => {
self.state = JsonParseState::Complete;
}
_ => {
self.state = JsonParseState::Error;
}
}
}
}
impl Default for JsonConstraint {
fn default() -> Self {
Self::new()
}
}
impl TokenConstraint for JsonConstraint {
fn allowed_tokens(&self, _generated: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
if self.state == JsonParseState::Error {
return Some(vec![false; vocab_size]);
}
let valid = self.valid_next_chars();
let mask: Vec<bool> = (0..vocab_size)
.map(|id| {
let ch = char::from_u32(id as u32).unwrap_or('\u{FFFD}');
ch as u32 > 127 || valid.contains(&ch)
})
.collect();
Some(mask)
}
fn advance(&mut self, token: u32) -> bool {
if self.state == JsonParseState::Error {
return false;
}
if let Some(ch) = char::from_u32(token) {
self.feed_char(ch);
}
self.state != JsonParseState::Error
}
fn is_complete(&self) -> bool {
self.state == JsonParseState::Complete
}
fn reset(&mut self) {
*self = Self::new();
}
fn name(&self) -> &str {
"JsonConstraint"
}
}
pub struct ConstrainedSampler {
inner: crate::sampling_advanced::SamplerChain,
constraint: Box<dyn TokenConstraint>,
generated: Vec<u32>,
vocab_size: usize,
}
impl ConstrainedSampler {
pub fn new(
sampler: crate::sampling_advanced::SamplerChain,
constraint: Box<dyn TokenConstraint>,
vocab_size: usize,
) -> Self {
Self {
inner: sampler,
constraint,
generated: Vec::new(),
vocab_size,
}
}
pub fn sample(&mut self, logits: &mut Vec<f32>) -> u32 {
if let Some(mask) = self
.constraint
.allowed_tokens(&self.generated, self.vocab_size)
{
for (i, allowed) in mask.iter().enumerate() {
if i < logits.len() && !allowed {
logits[i] = -1e9;
}
}
}
let token = self.inner.sample(logits) as u32;
self.constraint.advance(token);
self.generated.push(token);
token
}
pub fn is_complete(&self) -> bool {
self.constraint.is_complete()
}
pub fn reset(&mut self) {
self.generated.clear();
self.constraint.reset();
}
pub fn generated_text_len(&self) -> usize {
self.generated.len()
}
pub fn constraint_name(&self) -> &str {
self.constraint.name()
}
}
pub struct ConstrainedSamplerBuilder {
vocab_size: usize,
seed: u64,
}
impl ConstrainedSamplerBuilder {
pub fn new(vocab_size: usize, seed: u64) -> Self {
Self { vocab_size, seed }
}
fn default_chain(&self) -> crate::sampling_advanced::SamplerChain {
crate::sampling_advanced::SamplerChain::new(self.seed)
}
pub fn with_json_constraint(self) -> ConstrainedSampler {
ConstrainedSampler::new(
self.default_chain(),
Box::new(JsonConstraint::new()),
self.vocab_size,
)
}
pub fn with_regex_constraint(
self,
pattern: &str,
) -> Result<ConstrainedSampler, ConstraintError> {
let constraint = RegexConstraint::new(pattern)?;
let chain = self.default_chain();
Ok(ConstrainedSampler::new(
chain,
Box::new(constraint),
self.vocab_size,
))
}
pub fn unconstrained(self) -> ConstrainedSampler {
ConstrainedSampler::new(
self.default_chain(),
Box::new(NoConstraint),
self.vocab_size,
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn no_constraint_allows_all() {
let nc = NoConstraint;
assert!(nc.allowed_tokens(&[], 10).is_none());
}
#[test]
fn json_constraint_initial_state() {
let jc = JsonConstraint::new();
assert_eq!(*jc.current_state(), JsonParseState::Start);
assert_eq!(jc.depth(), 0);
}
#[test]
fn json_constraint_valid_object_chars() {
let jc = JsonConstraint::new();
let valid = jc.valid_next_chars();
assert!(valid.contains(&'{'));
assert!(valid.contains(&'['));
assert!(valid.contains(&'"'));
}
#[test]
fn json_constraint_tracks_depth() {
let mut jc = JsonConstraint::new();
jc.advance('{' as u32);
assert_eq!(jc.depth(), 1);
jc.advance('"' as u32);
jc.advance('k' as u32);
jc.advance('"' as u32);
jc.advance(':' as u32);
jc.advance('{' as u32);
assert_eq!(jc.depth(), 2);
jc.advance('}' as u32);
assert_eq!(jc.depth(), 1);
}
#[test]
fn json_constraint_detects_completion() {
let mut jc = JsonConstraint::new();
assert!(!jc.is_complete());
jc.advance('{' as u32);
jc.advance('}' as u32);
assert!(jc.is_complete());
}
#[test]
fn json_constraint_in_string_state() {
let mut jc = JsonConstraint::new();
jc.advance('"' as u32);
assert!(jc.is_in_string());
jc.advance('"' as u32);
assert!(!jc.is_in_string());
}
#[test]
fn regex_nfa_literal_match() {
let nfa = RegexNfa::from_pattern("abc").expect("valid pattern");
assert!(nfa.is_full_match("abc"));
assert!(!nfa.is_full_match("ab"));
assert!(!nfa.is_full_match("abcd"));
}
#[test]
fn regex_nfa_dot_match() {
let nfa = RegexNfa::from_pattern("a.c").expect("valid pattern");
assert!(nfa.is_full_match("abc"));
assert!(nfa.is_full_match("axc"));
assert!(!nfa.is_full_match("ac"));
}
#[test]
fn regex_nfa_star_quantifier() {
let nfa = RegexNfa::from_pattern("ab*c").expect("valid pattern");
assert!(nfa.is_full_match("ac"));
assert!(nfa.is_full_match("abc"));
assert!(nfa.is_full_match("abbc"));
assert!(!nfa.is_full_match("xbc"));
}
#[test]
fn regex_nfa_alternation() {
let nfa = RegexNfa::from_pattern("cat|dog").expect("valid pattern");
assert!(nfa.is_full_match("cat"));
assert!(nfa.is_full_match("dog"));
assert!(!nfa.is_full_match("cow"));
}
#[test]
fn regex_constraint_is_match() {
assert!(RegexConstraint::is_match("he+llo", "hello"));
assert!(RegexConstraint::is_match("he+llo", "heeeello"));
assert!(!RegexConstraint::is_match("he+llo", "hllo"));
}
#[test]
fn regex_constraint_allows_valid_chars() {
let rc = RegexConstraint::new("abc").expect("valid");
assert!(rc.char_is_valid('a'));
assert!(!rc.char_is_valid('b')); }
#[test]
fn constrained_sampler_masks_logits() {
struct AllowEvens;
impl TokenConstraint for AllowEvens {
fn allowed_tokens(&self, _: &[u32], vocab_size: usize) -> Option<Vec<bool>> {
Some((0..vocab_size).map(|i| i % 2 == 0).collect())
}
fn advance(&mut self, _: u32) -> bool {
true
}
fn is_complete(&self) -> bool {
true
}
fn reset(&mut self) {}
fn name(&self) -> &str {
"AllowEvens"
}
}
let chain = crate::sampling_advanced::SamplerChain::greedy();
let mut sampler = ConstrainedSampler::new(chain, Box::new(AllowEvens), 4);
let mut logits = vec![2.0_f32, 10.0, 1.0, 0.5];
let tok = sampler.sample(&mut logits);
assert_eq!(tok, 0);
}
#[test]
fn constrained_sampler_greedy_json() {
let chain = crate::sampling_advanced::SamplerChain::greedy();
let mut sampler = ConstrainedSampler::new(chain, Box::new(JsonConstraint::new()), 256);
assert!(!sampler.is_complete());
let mut logits_open = vec![0.0_f32; 256];
logits_open['{' as usize] = 100.0;
sampler.sample(&mut logits_open);
let mut logits_close = vec![0.0_f32; 256];
logits_close['}' as usize] = 100.0;
sampler.sample(&mut logits_close);
assert!(sampler.is_complete());
assert_eq!(sampler.generated_text_len(), 2);
}
#[test]
fn constrained_sampler_reset() {
let chain = crate::sampling_advanced::SamplerChain::greedy();
let mut sampler = ConstrainedSampler::new(chain, Box::new(JsonConstraint::new()), 256);
let mut logits = vec![0.0_f32; 256];
logits['{' as usize] = 100.0;
sampler.sample(&mut logits);
assert_eq!(sampler.generated_text_len(), 1);
sampler.reset();
assert_eq!(sampler.generated_text_len(), 0);
}
#[test]
fn constrained_sampler_builder_json() {
let sampler = ConstrainedSamplerBuilder::new(256, 42).with_json_constraint();
assert_eq!(sampler.constraint_name(), "JsonConstraint");
}
#[test]
fn constrained_sampler_builder_unconstrained() {
let sampler = ConstrainedSamplerBuilder::new(256, 42).unconstrained();
assert_eq!(sampler.constraint_name(), "NoConstraint");
assert!(sampler.is_complete());
}
#[test]
fn constraint_error_display() {
let e = ConstraintError::InvalidPattern("bad".to_string());
assert!(e.to_string().contains("bad"));
let e2 = ConstraintError::Violated {
token: 5,
reason: "oops".to_string(),
};
assert!(e2.to_string().contains("5"));
assert!(e2.to_string().contains("oops"));
}
}