use std::collections::HashSet;
#[derive(Debug, Clone, Default)]
pub enum Grammar {
Json(JsonGrammar),
Regex(RegexGrammar),
Gbnf(GbnfGrammar),
Choice(Vec<String>),
#[default]
None,
}
#[derive(Debug, Clone)]
pub struct JsonGrammar {
pub schema: Option<String>,
pub allow_any: bool,
pub required_fields: Vec<String>,
}
impl Default for JsonGrammar {
fn default() -> Self {
Self {
schema: None,
allow_any: true,
required_fields: Vec::new(),
}
}
}
impl JsonGrammar {
pub fn any() -> Self {
Self::default()
}
pub fn with_schema(schema: impl Into<String>) -> Self {
Self {
schema: Some(schema.into()),
allow_any: false,
required_fields: Vec::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct RegexGrammar {
pub pattern: String,
state: RegexState,
}
#[derive(Debug, Clone, Default)]
struct RegexState {
position: usize,
_in_class: bool,
_min_remaining: usize,
}
impl RegexGrammar {
pub fn new(pattern: impl Into<String>) -> Self {
let pattern = pattern.into();
Self {
pattern,
state: RegexState::default(),
}
}
pub fn allows_char(&self, c: char) -> bool {
if self.state.position >= self.pattern.len() {
return false;
}
let pattern_chars: Vec<char> = self.pattern.chars().collect();
let current = pattern_chars.get(self.state.position);
match current {
Some('.') => true, Some('\\') => {
if let Some(&next) = pattern_chars.get(self.state.position + 1) {
match next {
'd' => c.is_ascii_digit(),
'w' => c.is_alphanumeric() || c == '_',
's' => c.is_whitespace(),
_ => c == next,
}
} else {
false
}
}
Some('[') => {
true
}
Some(&pc) if pc == c => true,
Some('*') | Some('+') | Some('?') => true, _ => false,
}
}
pub fn advance(&mut self, _c: char) {
self.state.position += 1;
}
pub fn reset(&mut self) {
self.state = RegexState::default();
}
}
#[derive(Debug, Clone)]
pub struct GbnfGrammar {
pub rules: Vec<GbnfRule>,
pub root: String,
state: GbnfState,
}
#[derive(Debug, Clone)]
pub struct GbnfRule {
pub name: String,
pub alternatives: Vec<GbnfAlternative>,
}
#[derive(Debug, Clone)]
pub struct GbnfAlternative {
pub elements: Vec<GbnfElement>,
}
#[derive(Debug, Clone)]
pub enum GbnfElement {
Literal(String),
RuleRef(String),
CharRange(char, char),
CharClass(Vec<char>),
Optional(Box<GbnfElement>),
ZeroOrMore(Box<GbnfElement>),
OneOrMore(Box<GbnfElement>),
}
#[derive(Debug, Clone, Default)]
struct GbnfState {
_stack: Vec<(String, usize, usize)>, }
impl GbnfGrammar {
pub fn new(rules: Vec<GbnfRule>, root: impl Into<String>) -> Self {
Self {
rules,
root: root.into(),
state: GbnfState::default(),
}
}
pub fn parse(input: &str) -> Result<Self, String> {
let mut rules = Vec::new();
let mut root = String::new();
for line in input.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some(pos) = line.find("::=") {
let name = line[..pos].trim().to_string();
let body = line[pos + 3..].trim();
if root.is_empty() {
root = name.clone();
}
let alternatives = Self::parse_alternatives(body)?;
rules.push(GbnfRule { name, alternatives });
}
}
if rules.is_empty() {
return Err("No rules found in grammar".to_string());
}
Ok(Self::new(rules, root))
}
fn parse_alternatives(body: &str) -> Result<Vec<GbnfAlternative>, String> {
let mut alternatives = Vec::new();
for alt in body.split('|') {
let elements = Self::parse_elements(alt.trim())?;
alternatives.push(GbnfAlternative { elements });
}
Ok(alternatives)
}
fn parse_elements(body: &str) -> Result<Vec<GbnfElement>, String> {
let mut elements = Vec::new();
let mut chars = body.chars().peekable();
while let Some(c) = chars.next() {
match c {
'"' => {
let mut literal = String::new();
while let Some(&next) = chars.peek() {
if next == '"' {
chars.next();
break;
}
if next == '\\' {
chars.next();
if let Some(escaped) = chars.next() {
literal.push(escaped);
}
} else {
literal.push(chars.next().unwrap());
}
}
elements.push(GbnfElement::Literal(literal));
}
'[' => {
let mut class_chars = Vec::new();
while let Some(&next) = chars.peek() {
if next == ']' {
chars.next();
break;
}
class_chars.push(chars.next().unwrap());
}
if class_chars.len() == 3 && class_chars[1] == '-' {
elements.push(GbnfElement::CharRange(class_chars[0], class_chars[2]));
} else {
elements.push(GbnfElement::CharClass(class_chars));
}
}
' ' | '\t' => {
}
_ if c.is_alphabetic() || c == '_' => {
let mut name = String::from(c);
while let Some(&next) = chars.peek() {
if next.is_alphanumeric() || next == '_' || next == '-' {
name.push(chars.next().unwrap());
} else {
break;
}
}
elements.push(GbnfElement::RuleRef(name));
}
_ => {}
}
}
Ok(elements)
}
pub fn allowed_chars(&self) -> HashSet<char> {
let mut allowed = HashSet::new();
for c in ' '..='~' {
allowed.insert(c);
}
allowed
}
pub fn reset(&mut self) {
self.state = GbnfState::default();
}
}
#[derive(Debug)]
pub struct GrammarSampler {
grammar: Grammar,
generated: String,
vocab: Vec<String>,
}
impl GrammarSampler {
pub fn new(grammar: Grammar, vocab: Vec<String>) -> Self {
Self {
grammar,
generated: String::new(),
vocab,
}
}
pub fn get_token_mask(&self) -> Vec<bool> {
let mut mask = vec![true; self.vocab.len()];
match &self.grammar {
Grammar::None => {
}
Grammar::Json(_) => {
self.filter_json_tokens(&mut mask);
}
Grammar::Regex(regex) => {
self.filter_regex_tokens(&mut mask, regex);
}
Grammar::Gbnf(gbnf) => {
self.filter_gbnf_tokens(&mut mask, gbnf);
}
Grammar::Choice(choices) => {
self.filter_choice_tokens(&mut mask, choices);
}
}
mask
}
fn filter_json_tokens(&self, mask: &mut [bool]) {
let current = &self.generated;
let depth = current.chars().filter(|&c| c == '{' || c == '[').count() as i32
- current.chars().filter(|&c| c == '}' || c == ']').count() as i32;
for (i, token) in self.vocab.iter().enumerate() {
let would_be = format!("{}{}", current, token);
let valid = if current.is_empty() {
token.trim_start().starts_with('{')
|| token.trim_start().starts_with('[')
|| token.trim().is_empty()
} else if depth <= 0 && !current.trim().is_empty() {
token.trim().is_empty()
} else {
let new_depth = would_be.chars().filter(|&c| c == '{' || c == '[').count() as i32
- would_be.chars().filter(|&c| c == '}' || c == ']').count() as i32;
new_depth >= 0
};
mask[i] = valid;
}
}
fn filter_regex_tokens(&self, mask: &mut [bool], regex: &RegexGrammar) {
for (i, token) in self.vocab.iter().enumerate() {
let mut allowed = true;
for c in token.chars() {
if !regex.allows_char(c) {
allowed = false;
break;
}
}
mask[i] = allowed;
}
}
fn filter_gbnf_tokens(&self, mask: &mut [bool], gbnf: &GbnfGrammar) {
let allowed_chars = gbnf.allowed_chars();
for (i, token) in self.vocab.iter().enumerate() {
let all_allowed = token.chars().all(|c| allowed_chars.contains(&c));
mask[i] = all_allowed;
}
}
fn filter_choice_tokens(&self, mask: &mut [bool], choices: &[String]) {
for (i, token) in self.vocab.iter().enumerate() {
let would_be = format!("{}{}", self.generated, token);
let could_match = choices
.iter()
.any(|choice| choice.starts_with(&would_be) || would_be.starts_with(choice));
mask[i] = could_match;
}
}
pub fn apply_mask(&self, logits: &mut [f32]) {
let mask = self.get_token_mask();
for (i, &allowed) in mask.iter().enumerate() {
if !allowed && i < logits.len() {
logits[i] = f32::NEG_INFINITY;
}
}
}
pub fn record_token(&mut self, token: &str) {
self.generated.push_str(token);
}
pub fn reset(&mut self) {
self.generated.clear();
match &mut self.grammar {
Grammar::Regex(r) => r.reset(),
Grammar::Gbnf(g) => g.reset(),
_ => {}
}
}
pub fn is_complete(&self) -> bool {
match &self.grammar {
Grammar::None => false,
Grammar::Json(_) => {
let trimmed = self.generated.trim();
(trimmed.starts_with('{') && trimmed.ends_with('}'))
|| (trimmed.starts_with('[') && trimmed.ends_with(']'))
}
Grammar::Choice(choices) => choices.iter().any(|c| c == &self.generated),
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_json_grammar() {
let grammar = JsonGrammar::any();
assert!(grammar.allow_any);
}
#[test]
fn test_regex_grammar() {
let regex = RegexGrammar::new(r"\d+");
assert!(regex.allows_char('5'));
}
#[test]
fn test_gbnf_parse() {
let grammar_str = r#"
root ::= "hello" | "world"
"#;
let grammar = GbnfGrammar::parse(grammar_str).unwrap();
assert_eq!(grammar.root, "root");
assert_eq!(grammar.rules.len(), 1);
}
#[test]
fn test_grammar_sampler_json() {
let grammar = Grammar::Json(JsonGrammar::any());
let vocab = vec!["{".to_string(), "}".to_string(), "hello".to_string()];
let sampler = GrammarSampler::new(grammar, vocab);
let mask = sampler.get_token_mask();
assert!(mask[0]); }
#[test]
fn test_grammar_sampler_choice() {
let grammar = Grammar::Choice(vec!["yes".to_string(), "no".to_string()]);
let vocab = vec!["y".to_string(), "n".to_string(), "x".to_string()];
let sampler = GrammarSampler::new(grammar, vocab);
let mask = sampler.get_token_mask();
assert!(mask[0]); assert!(mask[1]); assert!(!mask[2]); }
}