use crate::errors::{Result, TrustformersError};
use regex::Regex;
use std::collections::{HashMap, HashSet};
use super::config::GuidedGenerationConfig;
#[derive(Debug)]
pub struct ConstraintValidator {
regex: Option<Regex>,
choice_list: Option<HashSet<String>>,
json_schema: Option<JsonSchemaValidator>,
grammar: Option<GrammarValidator>,
}
impl ConstraintValidator {
pub fn new(config: &GuidedGenerationConfig) -> Result<Self> {
let regex = if let Some(pattern) = &config.regex_pattern {
Some(Regex::new(pattern).map_err(|e| {
TrustformersError::invalid_input(format!("Invalid regex pattern: {}", e))
})?)
} else {
None
};
let choice_list = config
.choice_list
.as_ref()
.map(|choices| choices.iter().cloned().collect::<HashSet<String>>());
let json_schema = if let Some(schema) = &config.json_schema {
Some(JsonSchemaValidator::new(schema)?)
} else {
None
};
let grammar = if let Some(grammar_config) = &config.grammar {
Some(GrammarValidator::new(grammar_config)?)
} else {
None
};
Ok(Self {
regex,
choice_list,
json_schema,
grammar,
})
}
pub fn validate_token(
&self,
current_text: &str,
new_token: &str,
_tokenizer_fn: Option<&dyn Fn(usize) -> String>,
) -> bool {
let potential_text = format!("{}{}", current_text, new_token);
if let Some(regex) = &self.regex {
if !regex.is_match(&potential_text) && !self.is_partial_match(regex, &potential_text) {
return false;
}
}
if let Some(choices) = &self.choice_list {
if !choices.contains(&potential_text) && !self.is_valid_prefix(&potential_text, choices)
{
return false;
}
}
if let Some(json_validator) = &self.json_schema {
if !json_validator.validate_partial(&potential_text) {
return false;
}
}
if let Some(grammar_validator) = &self.grammar {
if !grammar_validator.validate_partial(&potential_text) {
return false;
}
}
true
}
pub fn is_complete(&self, text: &str) -> bool {
if let Some(regex) = &self.regex {
if !regex.is_match(text) {
return false;
}
}
if let Some(choices) = &self.choice_list {
if !choices.contains(text) {
return false;
}
}
if let Some(json_validator) = &self.json_schema {
if !json_validator.validate_complete(text) {
return false;
}
}
if let Some(grammar_validator) = &self.grammar {
if !grammar_validator.validate_complete(text) {
return false;
}
}
true
}
fn is_partial_match(&self, regex: &Regex, text: &str) -> bool {
if text.is_empty() {
return true;
}
if regex.is_match(text) {
return true;
}
for i in 0..text.len() {
if regex.find(&text[i..]).is_some() {
return true;
}
}
let test_extensions = vec![" ", "\\s", " world", " world"];
for ext in test_extensions {
let test_text = format!("{}{}", text, ext);
if regex.is_match(&test_text) {
return true;
}
}
false
}
fn is_valid_prefix(&self, text: &str, choices: &HashSet<String>) -> bool {
choices.iter().any(|choice| choice.starts_with(text))
}
pub fn filter_valid_tokens(
&self,
current_text: &str,
token_logits: &[(usize, f32)],
tokenizer_fn: &dyn Fn(usize) -> String,
) -> Vec<(usize, f32)> {
token_logits
.iter()
.filter(|(token_id, _)| {
let token_str = tokenizer_fn(*token_id);
self.validate_token(current_text, &token_str, Some(tokenizer_fn))
})
.cloned()
.collect()
}
}
#[derive(Debug)]
pub struct JsonSchemaValidator {
#[allow(dead_code)]
schema: String,
#[allow(dead_code)]
brace_stack: Vec<char>,
}
impl JsonSchemaValidator {
pub fn new(schema: &str) -> Result<Self> {
Ok(Self {
schema: schema.to_string(),
brace_stack: Vec::new(),
})
}
pub fn validate_partial(&self, text: &str) -> bool {
let mut stack = Vec::new();
let mut in_string = false;
let mut escape_next = false;
for ch in text.chars() {
if escape_next {
escape_next = false;
continue;
}
match ch {
'\\' if in_string => escape_next = true,
'"' => in_string = !in_string,
'{' | '[' if !in_string => stack.push(ch),
'}' if !in_string => {
if stack.last() == Some(&'{') {
stack.pop();
} else {
return false;
}
},
']' if !in_string => {
if stack.last() == Some(&'[') {
stack.pop();
} else {
return false;
}
},
_ => {},
}
}
true
}
pub fn validate_complete(&self, text: &str) -> bool {
serde_json::from_str::<serde_json::Value>(text).is_ok()
}
}
#[derive(Debug)]
pub struct GrammarValidator {
#[allow(dead_code)]
rules: HashMap<String, Vec<String>>,
#[allow(dead_code)]
current_state: String,
}
impl GrammarValidator {
pub fn new(grammar: &str) -> Result<Self> {
let mut rules = HashMap::new();
for line in grammar.lines() {
if let Some((lhs, rhs)) = line.split_once("::=") {
let rule_name = lhs.trim().to_string();
let alternatives: Vec<String> =
rhs.split('|').map(|alt| alt.trim().to_string()).collect();
rules.insert(rule_name, alternatives);
}
}
Ok(Self {
rules,
current_state: "start".to_string(),
})
}
pub fn validate_partial(&self, _text: &str) -> bool {
true
}
pub fn validate_complete(&self, _text: &str) -> bool {
true
}
pub fn get_valid_next_tokens(&self, _current_state: &str) -> Vec<String> {
vec![]
}
}