fn add_schema_rules(grammar: &mut Grammar, rule_name: &str, schema: &JsonSchemaType) {
match schema {
JsonSchemaType::String => {
grammar.add_rule(GrammarRule::new(
rule_name,
vec![GrammarAlternative::new(vec![
GrammarElement::Char('"'),
GrammarElement::RuleRef("string_content".to_string()),
GrammarElement::Char('"'),
])],
));
},
JsonSchemaType::Integer => {
grammar.add_rule(GrammarRule::new(
rule_name,
vec![
GrammarAlternative::new(vec![GrammarElement::RuleRef("digits".to_string())]),
GrammarAlternative::new(vec![
GrammarElement::Char('-'),
GrammarElement::RuleRef("digits".to_string()),
]),
],
));
},
JsonSchemaType::Number => {
grammar.add_rule(GrammarRule::new(
rule_name,
vec![
GrammarAlternative::new(vec![GrammarElement::RuleRef("digits".to_string())]),
GrammarAlternative::new(vec![
GrammarElement::Char('-'),
GrammarElement::RuleRef("digits".to_string()),
]),
GrammarAlternative::new(vec![
GrammarElement::RuleRef("digits".to_string()),
GrammarElement::Char('.'),
GrammarElement::RuleRef("digits".to_string()),
]),
GrammarAlternative::new(vec![
GrammarElement::Char('-'),
GrammarElement::RuleRef("digits".to_string()),
GrammarElement::Char('.'),
GrammarElement::RuleRef("digits".to_string()),
]),
],
));
},
JsonSchemaType::Boolean => {
grammar.add_rule(GrammarRule::new(
rule_name,
vec![
GrammarAlternative::new(vec![
GrammarElement::Char('t'),
GrammarElement::Char('r'),
GrammarElement::Char('u'),
GrammarElement::Char('e'),
]),
GrammarAlternative::new(vec![
GrammarElement::Char('f'),
GrammarElement::Char('a'),
GrammarElement::Char('l'),
GrammarElement::Char('s'),
GrammarElement::Char('e'),
]),
],
));
},
JsonSchemaType::Null => {
grammar.add_rule(GrammarRule::new(
rule_name,
vec![GrammarAlternative::new(vec![
GrammarElement::Char('n'),
GrammarElement::Char('u'),
GrammarElement::Char('l'),
GrammarElement::Char('l'),
])],
));
},
JsonSchemaType::Enum(values) => {
let alternatives: Vec<GrammarAlternative> = values
.iter()
.map(|v| {
let mut elements = vec![GrammarElement::Char('"')];
for c in v.chars() {
elements.push(GrammarElement::Char(c));
}
elements.push(GrammarElement::Char('"'));
GrammarAlternative::new(elements)
})
.collect();
grammar.add_rule(GrammarRule::new(rule_name, alternatives));
},
JsonSchemaType::Array(item_schema) => {
add_array_schema_rule(grammar, rule_name, item_schema);
},
JsonSchemaType::Object(properties) => {
add_object_schema_rules(grammar, rule_name, properties);
},
JsonSchemaType::Any => {
add_any_schema_rule(grammar, rule_name);
},
}
}
fn add_object_schema_rules(
grammar: &mut Grammar,
rule_name: &str,
properties: &[(String, JsonSchemaType, bool)],
) {
if properties.is_empty() {
grammar.add_rule(GrammarRule::new(
rule_name,
vec![GrammarAlternative::new(vec![
GrammarElement::Char('{'),
GrammarElement::RuleRef("ws".to_string()),
GrammarElement::Char('}'),
])],
));
return;
}
let mut elements = vec![
GrammarElement::Char('{'),
GrammarElement::RuleRef("ws".to_string()),
];
for (i, (prop_name, prop_type, _required)) in properties.iter().enumerate() {
if i > 0 {
elements.push(GrammarElement::Char(','));
elements.push(GrammarElement::RuleRef("ws".to_string()));
}
elements.push(GrammarElement::Char('"'));
for c in prop_name.chars() {
elements.push(GrammarElement::Char(c));
}
elements.push(GrammarElement::Char('"'));
elements.push(GrammarElement::RuleRef("ws".to_string()));
elements.push(GrammarElement::Char(':'));
elements.push(GrammarElement::RuleRef("ws".to_string()));
let prop_rule = format!("{rule_name}_{prop_name}");
add_schema_rules(grammar, &prop_rule, prop_type);
elements.push(GrammarElement::RuleRef(prop_rule));
}
elements.push(GrammarElement::RuleRef("ws".to_string()));
elements.push(GrammarElement::Char('}'));
grammar.add_rule(GrammarRule::new(
rule_name,
vec![GrammarAlternative::new(elements)],
));
}
#[derive(Debug, Clone)]
pub struct TokenMask {
pub allowed: HashSet<u32>,
pub allow_eos: bool,
}
impl TokenMask {
pub fn allow_all(vocab_size: usize) -> Self {
Self {
allowed: (0..vocab_size as u32).collect(),
allow_eos: true,
}
}
pub fn from_allowed(allowed: HashSet<u32>, allow_eos: bool) -> Self {
Self { allowed, allow_eos }
}
pub fn is_allowed(&self, token_id: u32) -> bool {
self.allowed.contains(&token_id)
}
pub fn apply_to_logits(&self, logits: &mut [f32]) {
for (i, logit) in logits.iter_mut().enumerate() {
if !self.allowed.contains(&(i as u32)) {
*logit = f32::NEG_INFINITY;
}
}
}
pub fn num_allowed(&self) -> usize {
self.allowed.len()
}
}
pub struct GrammarTokenMasker {
state_machine: GrammarStateMachine,
token_strings: HashMap<u32, String>,
eos_token_id: u32,
}
impl GrammarTokenMasker {
pub fn new(
grammar: Grammar,
token_strings: HashMap<u32, String>,
eos_token_id: u32,
) -> Result<Self> {
let state_machine = GrammarStateMachine::new(grammar)?;
Ok(Self {
state_machine,
token_strings,
eos_token_id,
})
}
fn is_token_valid_sequence(&self, token_str: &str) -> bool {
let mut temp_sm = self.state_machine.clone();
token_str.chars().all(|c| temp_sm.advance(c))
}
pub fn get_mask(&self) -> TokenMask {
let valid_chars = self.state_machine.valid_chars();
let mut allowed = HashSet::new();
for (token_id, token_str) in &self.token_strings {
if let Some(first_char) = token_str.chars().next() {
if valid_chars.contains(&first_char)
&& (token_str.len() == 1 || self.is_token_valid_sequence(token_str))
{
allowed.insert(*token_id);
}
}
}
TokenMask::from_allowed(allowed, self.state_machine.is_complete())
}
pub fn advance_token(&mut self, token_id: u32) -> bool {
if let Some(token_str) = self.token_strings.get(&token_id) {
for c in token_str.chars() {
if !self.state_machine.advance(c) {
return false;
}
}
true
} else {
false
}
}
pub fn is_complete(&self) -> bool {
self.state_machine.is_complete()
}
pub fn reset(&mut self) {
self.state_machine.reset();
}
pub fn eos_token_id(&self) -> u32 {
self.eos_token_id
}
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ToolParameterType {
#[default]
String,
Integer,
Number,
Boolean,
Array {
items: Box<ToolParameterType>,
},
Object {
properties: Vec<ToolParameter>,
},
Enum(Vec<String>),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolParameter {
pub name: String,
pub description: String,
#[serde(rename = "type")]
pub param_type: ToolParameterType,
#[serde(default)]
pub required: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub default: Option<String>,
}
impl ToolParameter {
pub fn required_string(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
param_type: ToolParameterType::String,
required: true,
default: None,
}
}
pub fn optional_string(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
param_type: ToolParameterType::String,
required: false,
default: None,
}
}
pub fn required_int(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
param_type: ToolParameterType::Integer,
required: true,
default: None,
}
}
pub fn required_enum(
name: impl Into<String>,
description: impl Into<String>,
values: Vec<String>,
) -> Self {
Self {
name: name.into(),
description: description.into(),
param_type: ToolParameterType::Enum(values),
required: true,
default: None,
}
}
#[must_use]
pub fn with_default(mut self, default: impl Into<String>) -> Self {
self.default = Some(default.into());
self
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameters: Vec<ToolParameter>,
}