use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ConstrainedError {
#[error("Invalid grammar: {0}")]
InvalidGrammar(String),
#[error("Grammar compilation failed: {0}")]
CompilationError(String),
#[error("Token mask error: {0}")]
MaskError(String),
#[error("Trie construction error: {0}")]
TrieError(String),
#[error("JSON schema error: {0}")]
JsonSchemaError(String),
#[error("Regex error: {0}")]
RegexError(String),
#[error("State machine error: {0}")]
StateError(String),
#[error("No valid tokens: grammar has no accepting continuation")]
NoValidTokens,
}
pub type Result<T> = std::result::Result<T, ConstrainedError>;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct TokenMask {
bits: Vec<u64>,
len: usize,
}
impl TokenMask {
#[inline]
pub fn new(vocab_size: usize) -> Self {
let num_words = vocab_size.div_ceil(64);
Self {
bits: vec![0u64; num_words],
len: vocab_size,
}
}
#[inline]
pub fn new_all_allowed(vocab_size: usize) -> Self {
let num_words = vocab_size.div_ceil(64);
let mut bits = vec![!0u64; num_words];
let excess = vocab_size % 64;
if excess != 0 && !bits.is_empty() {
let last_idx = bits.len() - 1;
bits[last_idx] = (1u64 << excess) - 1;
}
Self {
bits,
len: vocab_size,
}
}
#[inline]
pub fn set(&mut self, idx: usize, allowed: bool) {
debug_assert!(
idx < self.len,
"Token index {} out of bounds (vocab_size={})",
idx,
self.len
);
let word_idx = idx / 64;
let bit_idx = idx % 64;
let bit_mask = 1u64 << bit_idx;
if allowed {
self.bits[word_idx] |= bit_mask;
} else {
self.bits[word_idx] &= !bit_mask;
}
}
#[inline]
pub fn is_allowed(&self, idx: usize) -> bool {
if idx >= self.len {
return false;
}
let word_idx = idx / 64;
let bit_idx = idx % 64;
let bit_mask = 1u64 << bit_idx;
(self.bits[word_idx] & bit_mask) != 0
}
#[inline]
pub fn as_slice(&self) -> &[u64] {
&self.bits
}
#[inline]
pub fn as_mut_slice(&mut self) -> &mut [u64] {
&mut self.bits
}
#[inline]
pub fn count_allowed(&self) -> usize {
self.bits.iter().map(|w| w.count_ones() as usize).sum()
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.len
}
#[inline]
pub fn has_any_allowed(&self) -> bool {
self.bits.iter().any(|&w| w != 0)
}
#[inline]
pub fn all_allowed(&self) -> bool {
self.count_allowed() == self.len
}
#[inline]
pub fn intersect(&self, other: &Self) -> Self {
debug_assert_eq!(self.len, other.len, "Mask sizes must match");
let bits: Vec<u64> = self
.bits
.iter()
.zip(other.bits.iter())
.map(|(&a, &b)| a & b)
.collect();
Self {
bits,
len: self.len,
}
}
#[inline]
pub fn union(&self, other: &Self) -> Self {
debug_assert_eq!(self.len, other.len, "Mask sizes must match");
let bits: Vec<u64> = self
.bits
.iter()
.zip(other.bits.iter())
.map(|(&a, &b)| a | b)
.collect();
Self {
bits,
len: self.len,
}
}
#[inline]
pub fn invert(&self) -> Self {
let mut bits: Vec<u64> = self.bits.iter().map(|&w| !w).collect();
let excess = self.len % 64;
if excess != 0 && !bits.is_empty() {
let last_idx = bits.len() - 1;
bits[last_idx] &= (1u64 << excess) - 1;
}
Self {
bits,
len: self.len,
}
}
#[inline]
pub fn clear(&mut self) {
self.bits.fill(0);
}
#[inline]
pub fn allow_all(&mut self) {
self.bits.fill(!0u64);
let excess = self.len % 64;
if excess != 0 && !self.bits.is_empty() {
let last_idx = self.bits.len() - 1;
self.bits[last_idx] = (1u64 << excess) - 1;
}
}
#[inline]
pub fn allowed_tokens(&self) -> impl Iterator<Item = usize> + '_ {
let len = self.len;
self.bits
.iter()
.enumerate()
.flat_map(move |(word_idx, &word)| {
(0..64).filter_map(move |bit_idx| {
let token_idx = word_idx * 64 + bit_idx;
if token_idx < len && (word & (1u64 << bit_idx)) != 0 {
Some(token_idx)
} else {
None
}
})
})
}
}
impl Default for TokenMask {
fn default() -> Self {
Self::new(0)
}
}
#[derive(Debug, Clone, Default)]
struct TrieNode {
children: HashMap<u8, usize>,
token_id: Option<usize>,
is_terminal: bool,
}
#[derive(Debug, Clone)]
pub struct TokTrie {
nodes: Vec<TrieNode>,
vocab_size: usize,
token_bytes: HashMap<usize, Vec<u8>>,
}
impl TokTrie {
pub fn from_vocab<I, B>(vocab: I) -> Self
where
I: IntoIterator<Item = (usize, B)>,
B: AsRef<[u8]>,
{
let mut nodes = vec![TrieNode::default()]; let mut token_bytes = HashMap::new();
let mut max_token_id = 0;
for (token_id, bytes) in vocab {
let bytes = bytes.as_ref();
token_bytes.insert(token_id, bytes.to_vec());
max_token_id = max_token_id.max(token_id);
let mut current = 0; for &byte in bytes {
let next = nodes[current].children.get(&byte).copied();
current = match next {
Some(idx) => idx,
None => {
let new_idx = nodes.len();
nodes.push(TrieNode::default());
nodes[current].children.insert(byte, new_idx);
new_idx
}
};
}
nodes[current].token_id = Some(token_id);
nodes[current].is_terminal = true;
}
Self {
nodes,
vocab_size: max_token_id + 1,
token_bytes,
}
}
#[inline]
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
pub fn find_tokens_with_prefix(&self, prefix: &[u8]) -> Vec<usize> {
let mut current = 0;
for &byte in prefix {
match self.nodes[current].children.get(&byte) {
Some(&next) => current = next,
None => return Vec::new(), }
}
let mut result = Vec::new();
self.collect_tokens(current, &mut result);
result
}
fn collect_tokens(&self, node_idx: usize, result: &mut Vec<usize>) {
let node = &self.nodes[node_idx];
if let Some(token_id) = node.token_id {
result.push(token_id);
}
for &child_idx in node.children.values() {
self.collect_tokens(child_idx, result);
}
}
#[inline]
pub fn get_token_bytes(&self, token_id: usize) -> Option<&[u8]> {
self.token_bytes.get(&token_id).map(|v| v.as_slice())
}
pub fn is_valid_token(&self, bytes: &[u8]) -> Option<usize> {
let mut current = 0;
for &byte in bytes {
match self.nodes[current].children.get(&byte) {
Some(&next) => current = next,
None => return None,
}
}
self.nodes[current].token_id
}
pub fn valid_next_bytes(&self, prefix: &[u8]) -> Vec<u8> {
let mut current = 0;
for &byte in prefix {
match self.nodes[current].children.get(&byte) {
Some(&next) => current = next,
None => return Vec::new(),
}
}
self.nodes[current].children.keys().copied().collect()
}
}
impl Default for TokTrie {
fn default() -> Self {
Self {
nodes: vec![TrieNode::default()],
vocab_size: 0,
token_bytes: HashMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum GrammarType {
JsonSchema(serde_json::Value),
Regex(String),
#[serde(rename = "gbnf")]
GBNF(String),
#[serde(rename = "cfg")]
CFG(String),
}
impl GrammarType {
pub fn description(&self) -> &'static str {
match self {
Self::JsonSchema(_) => "JSON Schema",
Self::Regex(_) => "Regular Expression",
Self::GBNF(_) => "GBNF Grammar",
Self::CFG(_) => "Context-Free Grammar",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConstraintConfig {
pub grammar_type: GrammarType,
pub max_tokens: Option<usize>,
pub strict_mode: bool,
pub allow_partial: bool,
}
impl ConstraintConfig {
pub fn new(grammar_type: GrammarType) -> Self {
Self {
grammar_type,
max_tokens: None,
strict_mode: true,
allow_partial: false,
}
}
pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
self.max_tokens = Some(max_tokens);
self
}
pub fn with_strict_mode(mut self, strict: bool) -> Self {
self.strict_mode = strict;
self
}
pub fn with_allow_partial(mut self, allow: bool) -> Self {
self.allow_partial = allow;
self
}
}
#[derive(Debug, Clone, Default)]
struct GrammarState {
position: usize,
is_accepting: bool,
#[allow(dead_code)]
derivative_cache: HashMap<u8, Box<GrammarState>>,
}
#[derive(Debug)]
pub struct ConstrainedGenerator {
trie: TokTrie,
state: GrammarState,
config: ConstraintConfig,
generated_bytes: Vec<u8>,
token_count: usize,
#[allow(dead_code)]
compiled_regex: Option<regex::Regex>,
}
impl ConstrainedGenerator {
pub fn new(config: ConstraintConfig) -> Result<Self> {
let compiled_regex = match &config.grammar_type {
GrammarType::Regex(pattern) => {
let re = regex::Regex::new(pattern)
.map_err(|e| ConstrainedError::RegexError(e.to_string()))?;
Some(re)
}
_ => None,
};
Ok(Self {
trie: TokTrie::default(),
state: GrammarState::default(),
config,
generated_bytes: Vec::new(),
token_count: 0,
compiled_regex,
})
}
pub fn set_vocabulary<I, B>(&mut self, vocab: I)
where
I: IntoIterator<Item = (usize, B)>,
B: AsRef<[u8]>,
{
self.trie = TokTrie::from_vocab(vocab);
}
pub fn compute_mask(&self, _current_tokens: &[usize]) -> Result<TokenMask> {
let vocab_size = self.trie.vocab_size();
if vocab_size == 0 {
return Err(ConstrainedError::MaskError(
"Vocabulary not set - call set_vocabulary() first".to_string(),
));
}
let mut mask = TokenMask::new(vocab_size);
for (token_id, token_bytes) in &self.trie.token_bytes {
if self.is_valid_continuation(token_bytes) {
mask.set(*token_id, true);
}
}
if !mask.has_any_allowed() && !self.state.is_accepting {
return Err(ConstrainedError::NoValidTokens);
}
Ok(mask)
}
fn is_valid_continuation(&self, bytes: &[u8]) -> bool {
let mut test_bytes = self.generated_bytes.clone();
test_bytes.extend_from_slice(bytes);
match &self.config.grammar_type {
GrammarType::JsonSchema(_schema) => {
self.is_valid_json_prefix(&test_bytes)
}
GrammarType::Regex(pattern) => {
self.is_valid_regex_prefix(&test_bytes, pattern)
}
GrammarType::GBNF(_grammar) => {
true }
GrammarType::CFG(_grammar) => {
true }
}
}
fn is_valid_json_prefix(&self, bytes: &[u8]) -> bool {
let s = match std::str::from_utf8(bytes) {
Ok(s) => s,
Err(_) => return false, };
if s.is_empty() {
return true;
}
if serde_json::from_str::<serde_json::Value>(s).is_ok() {
return true;
}
self.is_plausible_json_prefix(s)
}
fn is_plausible_json_prefix(&self, s: &str) -> bool {
let trimmed = s.trim_start();
if trimmed.is_empty() {
return true;
}
let first = trimmed.chars().next().unwrap();
match first {
'{' | '[' | '"' | 't' | 'f' | 'n' | '-' | '0'..='9' => {}
_ => return false,
}
let mut brace_depth = 0i32;
let mut bracket_depth = 0i32;
let mut in_string = false;
let mut escape_next = false;
for c in trimmed.chars() {
if escape_next {
escape_next = false;
continue;
}
match c {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' if !in_string => brace_depth += 1,
'}' if !in_string => brace_depth -= 1,
'[' if !in_string => bracket_depth += 1,
']' if !in_string => bracket_depth -= 1,
_ => {}
}
if brace_depth < 0 || bracket_depth < 0 {
return false;
}
}
true
}
fn is_valid_regex_prefix(&self, bytes: &[u8], _pattern: &str) -> bool {
let s = match std::str::from_utf8(bytes) {
Ok(s) => s,
Err(_) => return false,
};
if let Some(re) = &self.compiled_regex {
re.is_match(s) || s.is_empty() || re.find(s).is_some()
} else {
true
}
}
pub fn advance(&mut self, token_id: usize) -> Result<()> {
if let Some(max) = self.config.max_tokens {
if self.token_count >= max {
return Err(ConstrainedError::StateError(
"Maximum token limit reached".to_string(),
));
}
}
let bytes = self
.trie
.get_token_bytes(token_id)
.ok_or_else(|| ConstrainedError::StateError(format!("Unknown token ID: {}", token_id)))?
.to_vec();
self.generated_bytes.extend_from_slice(&bytes);
self.token_count += 1;
self.update_state(&bytes)?;
Ok(())
}
fn update_state(&mut self, bytes: &[u8]) -> Result<()> {
self.state.position += bytes.len();
self.state.is_accepting = self.check_accepting();
Ok(())
}
fn check_accepting(&self) -> bool {
match &self.config.grammar_type {
GrammarType::JsonSchema(schema) => {
let s = match std::str::from_utf8(&self.generated_bytes) {
Ok(s) => s,
Err(_) => return false,
};
if let Ok(value) = serde_json::from_str::<serde_json::Value>(s) {
self.validate_json_schema(&value, schema)
} else {
false
}
}
GrammarType::Regex(_) => {
if let Some(re) = &self.compiled_regex {
let s = std::str::from_utf8(&self.generated_bytes).unwrap_or("");
re.is_match(s)
} else {
false
}
}
GrammarType::GBNF(_) | GrammarType::CFG(_) => {
false
}
}
}
fn validate_json_schema(&self, value: &serde_json::Value, schema: &serde_json::Value) -> bool {
if let Some(type_constraint) = schema.get("type") {
let type_str = type_constraint.as_str().unwrap_or("");
let matches = match type_str {
"object" => value.is_object(),
"array" => value.is_array(),
"string" => value.is_string(),
"number" => value.is_number(),
"integer" => value.is_i64() || value.is_u64(),
"boolean" => value.is_boolean(),
"null" => value.is_null(),
_ => true,
};
if !matches {
return false;
}
}
if let (Some(obj), Some(required)) = (value.as_object(), schema.get("required")) {
if let Some(required_arr) = required.as_array() {
for req in required_arr {
if let Some(key) = req.as_str() {
if !obj.contains_key(key) {
return false;
}
}
}
}
}
true
}
#[inline]
pub fn is_complete(&self) -> bool {
self.state.is_accepting
}
#[inline]
pub fn generated_bytes(&self) -> &[u8] {
&self.generated_bytes
}
#[inline]
pub fn token_count(&self) -> usize {
self.token_count
}
pub fn reset(&mut self) {
self.state = GrammarState::default();
self.generated_bytes.clear();
self.token_count = 0;
}
#[inline]
pub fn config(&self) -> &ConstraintConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_mask_new() {
let mask = TokenMask::new(1000);
assert_eq!(mask.vocab_size(), 1000);
assert_eq!(mask.count_allowed(), 0);
assert!(!mask.has_any_allowed());
}
#[test]
fn test_token_mask_all_allowed() {
let mask = TokenMask::new_all_allowed(100);
assert_eq!(mask.count_allowed(), 100);
assert!(mask.all_allowed());
}
#[test]
fn test_token_mask_set_get() {
let mut mask = TokenMask::new(1000);
mask.set(0, true);
mask.set(42, true);
mask.set(999, true);
assert!(mask.is_allowed(0));
assert!(mask.is_allowed(42));
assert!(mask.is_allowed(999));
assert!(!mask.is_allowed(1));
assert!(!mask.is_allowed(500));
assert_eq!(mask.count_allowed(), 3);
}
#[test]
fn test_token_mask_out_of_bounds() {
let mask = TokenMask::new(100);
assert!(!mask.is_allowed(100));
assert!(!mask.is_allowed(1000));
}
#[test]
fn test_token_mask_intersect() {
let mut mask1 = TokenMask::new(100);
let mut mask2 = TokenMask::new(100);
mask1.set(1, true);
mask1.set(2, true);
mask1.set(3, true);
mask2.set(2, true);
mask2.set(3, true);
mask2.set(4, true);
let intersection = mask1.intersect(&mask2);
assert!(!intersection.is_allowed(1));
assert!(intersection.is_allowed(2));
assert!(intersection.is_allowed(3));
assert!(!intersection.is_allowed(4));
}
#[test]
fn test_token_mask_union() {
let mut mask1 = TokenMask::new(100);
let mut mask2 = TokenMask::new(100);
mask1.set(1, true);
mask2.set(2, true);
let union = mask1.union(&mask2);
assert!(union.is_allowed(1));
assert!(union.is_allowed(2));
assert_eq!(union.count_allowed(), 2);
}
#[test]
fn test_token_mask_invert() {
let mut mask = TokenMask::new(100);
mask.set(0, true);
mask.set(50, true);
let inverted = mask.invert();
assert!(!inverted.is_allowed(0));
assert!(!inverted.is_allowed(50));
assert!(inverted.is_allowed(1));
assert!(inverted.is_allowed(99));
assert_eq!(inverted.count_allowed(), 98);
}
#[test]
fn test_token_mask_iterator() {
let mut mask = TokenMask::new(200);
mask.set(5, true);
mask.set(64, true);
mask.set(128, true);
let allowed: Vec<_> = mask.allowed_tokens().collect();
assert_eq!(allowed, vec![5, 64, 128]);
}
#[test]
fn test_tok_trie_construction() {
let vocab = vec![
(0, b"hello".to_vec()),
(1, b"world".to_vec()),
(2, b"hel".to_vec()),
];
let trie = TokTrie::from_vocab(vocab);
assert_eq!(trie.vocab_size(), 3);
}
#[test]
fn test_tok_trie_prefix_search() {
let vocab = vec![
(0, b"hello".to_vec()),
(1, b"help".to_vec()),
(2, b"world".to_vec()),
(3, b"hel".to_vec()),
];
let trie = TokTrie::from_vocab(vocab);
let matches = trie.find_tokens_with_prefix(b"hel");
assert!(matches.contains(&0)); assert!(matches.contains(&1)); assert!(matches.contains(&3)); assert!(!matches.contains(&2)); }
#[test]
fn test_tok_trie_token_lookup() {
let vocab = vec![(42, b"test".to_vec())];
let trie = TokTrie::from_vocab(vocab);
assert_eq!(trie.get_token_bytes(42), Some(b"test".as_slice()));
assert_eq!(trie.get_token_bytes(0), None);
}
#[test]
fn test_constraint_config_builder() {
let config = ConstraintConfig::new(GrammarType::Regex(r"\d+".to_string()))
.with_max_tokens(100)
.with_strict_mode(false);
assert_eq!(config.max_tokens, Some(100));
assert!(!config.strict_mode);
}
#[test]
fn test_constrained_generator_regex() {
let config = ConstraintConfig::new(GrammarType::Regex(r"[0-9]+".to_string()));
let generator = ConstrainedGenerator::new(config);
assert!(generator.is_ok());
}
#[test]
fn test_constrained_generator_json_schema() {
let schema = serde_json::json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
let config = ConstraintConfig::new(GrammarType::JsonSchema(schema));
let generator = ConstrainedGenerator::new(config);
assert!(generator.is_ok());
}
#[test]
fn test_json_prefix_validation() {
let schema = serde_json::json!({"type": "object"});
let config = ConstraintConfig::new(GrammarType::JsonSchema(schema));
let generator = ConstrainedGenerator::new(config).unwrap();
assert!(generator.is_valid_json_prefix(b""));
assert!(generator.is_valid_json_prefix(b"{"));
assert!(generator.is_valid_json_prefix(b"{\""));
assert!(generator.is_valid_json_prefix(b"{\"key\":"));
assert!(!generator.is_valid_json_prefix(b"}"));
assert!(!generator.is_valid_json_prefix(b"}{"));
}
#[test]
fn test_compute_mask_no_vocab() {
let config = ConstraintConfig::new(GrammarType::Regex(r".+".to_string()));
let generator = ConstrainedGenerator::new(config).unwrap();
let result = generator.compute_mask(&[]);
assert!(matches!(result, Err(ConstrainedError::MaskError(_))));
}
}