use super::pcfg::{Symbol, WeightedCFG};
use std::collections::{HashMap, HashSet};
#[derive(Debug, Clone)]
pub struct ConstrainedDecodingConfig {
pub max_lookahead: usize,
pub cache_states: bool,
pub min_rule_probability: f64,
pub allow_partial: bool,
}
impl Default for ConstrainedDecodingConfig {
fn default() -> Self {
Self {
max_lookahead: 3,
cache_states: true,
min_rule_probability: 1e-10,
allow_partial: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EarleyState {
pub rule_idx: usize,
pub dot_pos: usize,
pub start_pos: usize,
}
impl EarleyState {
pub fn new(rule_idx: usize, dot_pos: usize, start_pos: usize) -> Self {
Self {
rule_idx,
dot_pos,
start_pos,
}
}
pub fn is_complete(&self, rhs_len: usize) -> bool {
self.dot_pos >= rhs_len
}
}
pub struct EarleyParser {
grammar: WeightedCFG,
rules_by_lhs: HashMap<String, Vec<usize>>,
rules: Vec<(String, Vec<Symbol>, f64)>,
}
impl EarleyParser {
pub fn new(grammar: WeightedCFG) -> Self {
let mut rules_by_lhs: HashMap<String, Vec<usize>> = HashMap::new();
let mut rules = Vec::new();
for (production, &weight) in grammar.iter_rules() {
let rule_idx = rules.len();
rules_by_lhs
.entry(production.lhs.clone())
.or_default()
.push(rule_idx);
rules.push((production.lhs.clone(), production.rhs.clone(), weight));
}
Self {
grammar,
rules_by_lhs,
rules,
}
}
pub fn start_symbol(&self) -> &str {
self.grammar.start_symbol()
}
pub fn rules_for(&self, nt: &str) -> impl Iterator<Item = usize> + '_ {
self.rules_by_lhs
.get(nt)
.map(|v| v.iter().copied())
.into_iter()
.flatten()
}
pub fn rule(&self, idx: usize) -> Option<(&str, &[Symbol], f64)> {
self.rules
.get(idx)
.map(|(lhs, rhs, w)| (lhs.as_str(), rhs.as_slice(), *w))
}
}
#[derive(Debug, Clone, Default)]
pub struct EarleyChart {
states: Vec<Vec<EarleyState>>,
seen: Vec<HashSet<EarleyState>>,
}
impl EarleyChart {
pub fn new(size: usize) -> Self {
Self {
states: vec![Vec::new(); size + 1],
seen: vec![HashSet::new(); size + 1],
}
}
pub fn add(&mut self, pos: usize, state: EarleyState) -> bool {
if pos < self.states.len() {
if self.seen[pos].insert(state.clone()) {
self.states[pos].push(state);
true
} else {
false
}
} else {
false
}
}
pub fn states_at(&self, pos: usize) -> impl Iterator<Item = &EarleyState> {
self.states.get(pos).map(|s| s.iter()).into_iter().flatten()
}
pub fn state_count_at(&self, pos: usize) -> usize {
self.states.get(pos).map(|s| s.len()).unwrap_or(0)
}
pub fn state_at_index(&self, pos: usize, idx: usize) -> Option<&EarleyState> {
self.states.get(pos).and_then(|s| s.get(idx))
}
pub fn len(&self) -> usize {
self.states.len()
}
pub fn is_empty(&self) -> bool {
self.states.is_empty()
}
}
pub struct GrammarConstraint {
parser: EarleyParser,
config: ConstrainedDecodingConfig,
chart: EarleyChart,
position: usize,
valid_tokens_cache: Option<HashSet<String>>,
}
impl GrammarConstraint {
pub fn new(grammar: WeightedCFG, config: ConstrainedDecodingConfig) -> Self {
let mut constraint = Self {
parser: EarleyParser::new(grammar),
config,
chart: EarleyChart::new(0),
position: 0,
valid_tokens_cache: None,
};
constraint.initialize();
constraint
}
pub fn with_default_config(grammar: WeightedCFG) -> Self {
Self::new(grammar, ConstrainedDecodingConfig::default())
}
pub fn reset(&mut self) {
self.chart = EarleyChart::new(0);
self.position = 0;
self.valid_tokens_cache = None;
self.initialize();
}
pub fn initialize(&mut self) {
self.chart = EarleyChart::new(self.config.max_lookahead * 2);
let start = self.parser.start_symbol().to_string();
for rule_idx in self.parser.rules_for(&start) {
self.chart.add(0, EarleyState::new(rule_idx, 0, 0));
}
self.complete_chart(0);
}
fn complete_chart(&mut self, pos: usize) {
let mut state_idx = 0;
while state_idx < self.chart.state_count_at(pos) {
let (rule_idx, dot_pos, start_pos) = {
let state = match self.chart.state_at_index(pos, state_idx) {
Some(s) => s,
None => {
state_idx += 1;
continue;
}
};
(state.rule_idx, state.dot_pos, state.start_pos)
};
if let Some((_, rhs, _)) = self.parser.rule(rule_idx) {
if dot_pos < rhs.len() {
if let Symbol::NonTerminal(nt) = &rhs[dot_pos] {
for pred_rule_idx in self.parser.rules_for(nt) {
self.chart.add(pos, EarleyState::new(pred_rule_idx, 0, pos));
}
}
} else {
let (lhs, _, _) = self.parser.rule(rule_idx).expect("rule exists");
let completed_nt = lhs.to_string();
let waiting_count = self.chart.state_count_at(start_pos);
for waiting_idx in 0..waiting_count {
let (w_rule_idx, w_dot_pos, w_start_pos) = {
let waiting_state =
match self.chart.state_at_index(start_pos, waiting_idx) {
Some(s) => s,
None => continue,
};
(
waiting_state.rule_idx,
waiting_state.dot_pos,
waiting_state.start_pos,
)
};
if let Some((_, w_rhs, _)) = self.parser.rule(w_rule_idx) {
if w_dot_pos < w_rhs.len() {
if let Symbol::NonTerminal(nt) = &w_rhs[w_dot_pos] {
if *nt == completed_nt {
let new_state = EarleyState::new(
w_rule_idx,
w_dot_pos + 1,
w_start_pos,
);
self.chart.add(pos, new_state);
}
}
}
}
}
}
}
state_idx += 1;
}
}
pub fn is_valid_token(&self, token: &str) -> bool {
if let Some(valid) = &self.valid_tokens_cache {
return valid.contains(token);
}
for state in self.chart.states_at(self.position) {
if let Some((_, rhs, _)) = self.parser.rule(state.rule_idx) {
if state.dot_pos < rhs.len() {
if let Symbol::Terminal(t) = &rhs[state.dot_pos] {
if t == token {
return true;
}
}
}
}
}
false
}
pub fn valid_tokens(&mut self) -> HashSet<String> {
if let Some(cached) = &self.valid_tokens_cache {
return cached.clone();
}
let mut valid = HashSet::new();
for state in self.chart.states_at(self.position) {
if let Some((_, rhs, _)) = self.parser.rule(state.rule_idx) {
if state.dot_pos < rhs.len() {
if let Symbol::Terminal(t) = &rhs[state.dot_pos] {
valid.insert(t.clone());
}
}
}
}
if self.config.cache_states {
self.valid_tokens_cache = Some(valid.clone());
}
valid
}
pub fn advance(&mut self, token: &str) -> bool {
if !self.is_valid_token(token) {
return false;
}
let current_pos = self.position;
let next_pos = current_pos + 1;
if next_pos >= self.chart.len() {
return false;
}
let state_count = self.chart.state_count_at(current_pos);
for state_idx in 0..state_count {
let (rule_idx, dot_pos, start_pos) = {
let state = match self.chart.state_at_index(current_pos, state_idx) {
Some(s) => s,
None => continue,
};
(state.rule_idx, state.dot_pos, state.start_pos)
};
if let Some((_, rhs, _)) = self.parser.rule(rule_idx) {
if dot_pos < rhs.len() {
if let Symbol::Terminal(t) = &rhs[dot_pos] {
if t == token {
let new_state = EarleyState::new(rule_idx, dot_pos + 1, start_pos);
self.chart.add(next_pos, new_state);
}
}
}
}
}
self.position = next_pos;
self.valid_tokens_cache = None;
self.complete_chart(next_pos);
true
}
pub fn can_complete(&self) -> bool {
let start = self.parser.start_symbol();
for state in self.chart.states_at(self.position) {
if let Some((lhs, rhs, _)) = self.parser.rule(state.rule_idx) {
if lhs == start && state.start_pos == 0 && state.dot_pos >= rhs.len() {
return true;
}
}
}
false
}
pub fn position(&self) -> usize {
self.position
}
pub fn grammar(&self) -> &WeightedCFG {
&self.parser.grammar
}
}
#[derive(Debug, Clone)]
pub struct TokenMask {
allowed: HashSet<usize>,
vocab_size: usize,
}
impl TokenMask {
pub fn allow_all(vocab_size: usize) -> Self {
Self {
allowed: (0..vocab_size).collect(),
vocab_size,
}
}
pub fn from_allowed(allowed: HashSet<usize>, vocab_size: usize) -> Self {
Self {
allowed,
vocab_size,
}
}
pub fn is_allowed(&self, idx: usize) -> bool {
self.allowed.contains(&idx)
}
pub fn allowed_indices(&self) -> impl Iterator<Item = usize> + '_ {
self.allowed.iter().copied()
}
pub fn count_allowed(&self) -> usize {
self.allowed.len()
}
pub fn to_bool_vec(&self) -> Vec<bool> {
let mut v = vec![false; self.vocab_size];
for &idx in &self.allowed {
if idx < self.vocab_size {
v[idx] = true;
}
}
v
}
pub fn apply_to_logits(&self, logits: &mut [f32]) {
for (i, logit) in logits.iter_mut().enumerate() {
if !self.allowed.contains(&i) {
*logit = f32::NEG_INFINITY;
}
}
}
}
pub struct DecodingVocabulary {
token_to_idx: HashMap<String, usize>,
idx_to_token: Vec<String>,
}
impl DecodingVocabulary {
pub fn new() -> Self {
Self {
token_to_idx: HashMap::new(),
idx_to_token: Vec::new(),
}
}
pub fn add_token(&mut self, token: impl Into<String>) -> usize {
let token = token.into();
if let Some(&idx) = self.token_to_idx.get(&token) {
return idx;
}
let idx = self.idx_to_token.len();
self.idx_to_token.push(token.clone());
self.token_to_idx.insert(token, idx);
idx
}
pub fn get_idx(&self, token: &str) -> Option<usize> {
self.token_to_idx.get(token).copied()
}
pub fn get_token(&self, idx: usize) -> Option<&str> {
self.idx_to_token.get(idx).map(|s| s.as_str())
}
pub fn len(&self) -> usize {
self.idx_to_token.len()
}
pub fn is_empty(&self) -> bool {
self.idx_to_token.is_empty()
}
pub fn create_mask(&self, valid_tokens: &HashSet<String>) -> TokenMask {
let allowed: HashSet<usize> = valid_tokens
.iter()
.filter_map(|t| self.get_idx(t))
.collect();
TokenMask::from_allowed(allowed, self.len())
}
}
impl Default for DecodingVocabulary {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::code::pcfg::Production;
fn create_simple_grammar() -> WeightedCFG {
let mut cfg = WeightedCFG::new("S");
cfg.add_rule(
Production::new(
"S",
vec![
Symbol::NonTerminal("A".to_string()),
Symbol::NonTerminal("B".to_string()),
],
),
1.0,
);
cfg.add_rule(
Production::new("A", vec![Symbol::Terminal("a".to_string())]),
1.0,
);
cfg.add_rule(
Production::new("B", vec![Symbol::Terminal("b".to_string())]),
1.0,
);
cfg
}
#[test]
fn test_earley_parser_creation() {
let grammar = create_simple_grammar();
let parser = EarleyParser::new(grammar);
assert_eq!(parser.start_symbol(), "S");
assert_eq!(parser.rules_for("S").count(), 1);
assert_eq!(parser.rules_for("A").count(), 1);
assert_eq!(parser.rules_for("B").count(), 1);
}
#[test]
fn test_grammar_constraint_basic() {
let grammar = create_simple_grammar();
let mut constraint = GrammarConstraint::with_default_config(grammar);
constraint.reset();
let valid = constraint.valid_tokens();
assert!(valid.contains("a"));
assert!(!valid.contains("b"));
assert!(constraint.advance("a"));
let valid = constraint.valid_tokens();
assert!(valid.contains("b"));
assert!(!valid.contains("a"));
assert!(constraint.advance("b"));
assert!(constraint.can_complete());
}
#[test]
fn test_token_mask() {
let mut allowed = HashSet::new();
allowed.insert(1);
allowed.insert(3);
allowed.insert(5);
let mask = TokenMask::from_allowed(allowed, 10);
assert!(mask.is_allowed(1));
assert!(mask.is_allowed(3));
assert!(mask.is_allowed(5));
assert!(!mask.is_allowed(0));
assert!(!mask.is_allowed(2));
let bool_vec = mask.to_bool_vec();
assert_eq!(bool_vec.len(), 10);
assert!(bool_vec[1]);
assert!(!bool_vec[0]);
}
#[test]
fn test_decoding_vocabulary() {
let mut vocab = DecodingVocabulary::new();
let idx_a = vocab.add_token("a");
let idx_b = vocab.add_token("b");
let idx_a2 = vocab.add_token("a");
assert_eq!(idx_a, idx_a2);
assert_ne!(idx_a, idx_b);
assert_eq!(vocab.get_token(idx_a), Some("a"));
assert_eq!(vocab.get_idx("b"), Some(idx_b));
assert_eq!(vocab.len(), 2);
}
}