use std::cell::RefCell;
use llguidance::{Matcher, ParserFactory, api::TopLevelGrammar, toktrie::TokEnv};
use serde_json::Value;
use toktrie_hf_tokenizers::{ByteTokenizer, ByteTokenizerEnv};
use smol_str::format_smolstr;
use crate::{
array::Array,
error::{
Error, LengthMismatchPayload, OutOfRangePayload, ParsePayload, RankMismatchPayload, Result,
},
lm::generate::LogitsProcessor,
ops,
tokenizer::Tokenizer,
};
#[derive(Debug, Clone)]
pub enum GrammarSpec {
JsonSchema(Value),
Regex(String),
Lark(String),
}
impl GrammarSpec {
fn into_top_level(self) -> TopLevelGrammar {
match self {
GrammarSpec::JsonSchema(value) => TopLevelGrammar::from_json_schema(value),
GrammarSpec::Regex(rx) => TopLevelGrammar::from_regex(&rx),
GrammarSpec::Lark(src) => TopLevelGrammar::from_lark(src),
}
}
}
fn tok_env_from_tokenizer(
tokenizer: &Tokenizer,
model_vocab_size: Option<usize>,
) -> Result<TokEnv> {
let json = serde_json::to_vec(tokenizer.hf()).map_err(|e| {
Error::Parse(ParsePayload::new(
"llguidance: serialize HF tokenizer",
"HF tokenizer JSON",
Box::new(e) as Box<dyn std::error::Error + Send + Sync>,
))
})?;
let bt = ByteTokenizer::from_json_bytes(&json).map_err(|e| {
Error::Parse(ParsePayload::new(
"llguidance: build ByteTokenizer",
"HF tokenizer JSON",
std::io::Error::other(e.to_string()),
))
})?;
let configured_eos: Vec<u32> = tokenizer.eos_token_ids_iter().collect();
let mut env = ByteTokenizerEnv::new(bt, model_vocab_size).map_err(|e| {
Error::Parse(ParsePayload::new(
"llguidance: build ByteTokenizerEnv",
"tokenizer environment",
std::io::Error::other(e.to_string()),
))
})?;
if !configured_eos.is_empty() {
let widened_vocab = env.tok_trie.vocab_size();
for &eos in &configured_eos {
if (eos as usize) >= widened_vocab {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"llguidance: configured EOS token id",
"must be < tok_trie vocab bound",
format_smolstr!("{eos} (vocab_bound={widened_vocab})"),
)));
}
}
env.tok_trie = env.tok_trie.with_eos_tokens(&configured_eos);
}
Ok(env.to_env())
}
pub struct LLGuidanceLogitsProcessor {
matcher: RefCell<Matcher>,
is_first_token: RefCell<bool>,
}
impl LLGuidanceLogitsProcessor {
pub fn new(
grammar: GrammarSpec,
tokenizer: &Tokenizer,
model_vocab_size: Option<usize>,
) -> Result<Self> {
let tok_env = tok_env_from_tokenizer(tokenizer, model_vocab_size)?;
let mut factory = ParserFactory::new_simple(&tok_env).map_err(|e| {
Error::Parse(ParsePayload::new(
"llguidance: ParserFactory",
"llguidance grammar factory",
std::io::Error::other(e.to_string()),
))
})?;
factory.set_stderr_log_level(0);
let top = grammar.into_top_level();
let parser = factory.create_parser(top);
let matcher = Matcher::new(parser);
if let Some(err) = matcher.get_error() {
return Err(Error::Parse(ParsePayload::new(
"llguidance: grammar compile",
"llguidance grammar",
std::io::Error::other(err),
)));
}
Ok(Self {
matcher: RefCell::new(matcher),
is_first_token: RefCell::new(true),
})
}
pub fn apply(&self, tokens: &[u32], logits: &Array) -> Result<Array> {
let shape = logits.shape();
let vocab = match shape.as_slice() {
[v] => *v,
[1, v] => *v,
other => {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"LLGuidanceLogitsProcessor: expected logits shape `[V]` or `[1, V]`",
other.len() as u32,
other.to_vec(),
)));
}
};
{
let mut first = self.is_first_token.borrow_mut();
if *first {
*first = false;
} else if let Some(&last) = tokens.last() {
self.matcher.borrow_mut().consume_token(last).map_err(|e| {
Error::Parse(ParsePayload::new(
"llguidance: consume_token",
"previously-sampled token",
std::io::Error::other(format!("token={last}: {e}")),
))
})?;
}
}
let mask = self
.matcher
.borrow_mut()
.compute_mask_or_eos()
.map_err(|e| {
Error::Parse(ParsePayload::new(
"llguidance: compute_mask_or_eos",
"llguidance allowed-mask",
std::io::Error::other(e.to_string()),
))
})?;
if mask.len() < vocab {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"LLGuidanceLogitsProcessor: matcher mask vs logits vocab",
vocab,
mask.len(),
)));
}
let mut disallowed: Vec<bool> = Vec::with_capacity(vocab);
for tok in 0..vocab {
disallowed.push(!mask.is_allowed(tok as u32));
}
let mask_shape: Vec<i32> = match shape.as_slice() {
[v] => vec![*v as i32],
[b, v] => vec![*b as i32, *v as i32],
_ => unreachable!("shape validated above"),
};
let bool_mask_flat = Array::from_slice::<bool>(&disallowed, &(vocab,))?;
let bool_mask = if mask_shape.len() == 1 {
bool_mask_flat
} else {
let dims: &[i32] = &mask_shape;
ops::shape::reshape(&bool_mask_flat, &dims)?
};
let neg_inf_f32 = Array::full::<f32>(&(1,), f32::NEG_INFINITY)?;
let neg_inf = ops::misc::astype(&neg_inf_f32, logits.dtype()?)?;
ops::logical::select(&bool_mask, &neg_inf, logits)
}
pub fn reset(&self) -> Result<()> {
self.matcher.borrow_mut().reset().map_err(|e| {
Error::Parse(ParsePayload::new(
"llguidance: reset",
"llguidance matcher state",
std::io::Error::other(e.to_string()),
))
})?;
*self.is_first_token.borrow_mut() = true;
Ok(())
}
pub fn into_logits_processor(self) -> LogitsProcessor {
LogitsProcessor::Custom(Box::new(move |tokens: &[u32], logits: &Array| {
self.apply(tokens, logits)
}))
}
}
pub fn build_json_schema_logits_processor(
schema: Value,
tokenizer: &Tokenizer,
model_vocab_size: Option<usize>,
) -> Result<LLGuidanceLogitsProcessor> {
LLGuidanceLogitsProcessor::new(GrammarSpec::JsonSchema(schema), tokenizer, model_vocab_size)
}
#[cfg(test)]
mod tests;