use crate::sampler::PossibleTokensResult;
use crate::sampler::Sampler;
use crate::trie::TerminalsTrie;
use crate::trie::TrieNodeID;
use crate::utils;
use crate::utils::NonterminalID;
use crate::utils::U8ArrayWrapper;
use crate::vocabulary::Vocabulary;
use anyhow::{anyhow, ensure, Error};
use bit_set::BitSet;
use bnf::Production;
use bnf::Term;
use itertools::Itertools;
use memchr::memmem;
use regex::Regex;
use rustc_hash::FxHashMap;
use rustc_hash::FxHashSet;
use std::sync::Arc;
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub(crate) enum U8Term {
Terminal(Vec<u8>),
Nonterminal(String),
}
#[derive(Clone, Debug)]
pub struct Grammar {
pub(crate) nonterminal_id_to_expression: FxHashMap<NonterminalID, SimplifiedExpressions>,
pub(crate) nonterminal_to_terminal_id: FxHashMap<String, NonterminalID>,
pub(crate) terminals_trie: TerminalsTrie,
pub(crate) nonterminal_to_token_ids: FxHashMap<NonterminalID, BitSet<u32>>,
}
#[derive(Clone, Debug)]
pub(crate) enum SimplifiedExpressions {
Expressions(FxHashSet<Vec<U8Term>>),
Terminals(TrieNodeID),
}
impl Grammar {
pub fn new(
input: &str,
vocabulary: Arc<Vocabulary>,
stack_arena_capacity: usize,
) -> Result<Arc<Self>, Error> {
let except_present = utils::EXCEPTS_REGEX.is_match(input);
let any_present = input.contains(&format!("<{}>", utils::ANY_NONTERMINAL_NAME));
let mut grammar: bnf::Grammar = input.parse()?;
if any_present {
let mut any_prod = Production::new();
any_prod.lhs = Term::Nonterminal(utils::ANY_NONTERMINAL_NAME.to_string());
grammar.add_production(any_prod);
}
let mut nonterminal_to_token_ids: FxHashMap<NonterminalID, BitSet<u32>> =
FxHashMap::default();
let mut excepts: FxHashSet<String> = FxHashSet::default();
if except_present {
for i in utils::EXCEPT_LITERAL_REGEX.find_iter(input) {
let temp = i.as_str().to_string();
let mut any_prod = Production::new();
excepts.insert(temp.clone());
any_prod.lhs = Term::Nonterminal(temp);
grammar.add_production(any_prod);
}
for i in utils::EXCEPT_NONTERMINAL_REGEX.find_iter(input) {
let temp = i.as_str().to_string();
excepts.insert(temp);
}
}
let mut simplified_grammar: FxHashMap<String, FxHashSet<Vec<U8Term>>> =
FxHashMap::default();
for i in grammar.productions_iter() {
let key = match &i.lhs {
Term::Terminal(x) => x,
Term::Nonterminal(x) => x,
};
simplified_grammar
.entry(key.clone())
.or_default()
.extend(i.rhs_iter().map(|x| {
let mut temp_vec: Vec<U8Term> = vec![];
let mut temp_string: Option<String> = None;
for i in x.terms_iter() {
match i {
Term::Terminal(x) => match temp_string {
Some(value) => temp_string = Some(value + x),
None => temp_string = Some(x.clone()),
},
Term::Nonterminal(nonterminal) => {
if let Some(value) = temp_string {
temp_vec.push(U8Term::Terminal(utils::fix_utf8_escape(&value)));
temp_string = None;
}
temp_vec.push(U8Term::Nonterminal(nonterminal.clone()));
}
}
}
if let Some(value) = temp_string {
temp_vec.push(U8Term::Terminal(utils::fix_utf8_escape(&value)));
}
temp_vec
}));
}
let nonterminal_to_terminal_id: FxHashMap<String, NonterminalID> = simplified_grammar
.iter()
.enumerate()
.map(|(i, (key, _))| (key.clone(), NonterminalID(i)))
.collect();
let mut terminals_arena = TerminalsTrie::new();
let add_tokens = |simplified_grammar: &mut FxHashMap<String, FxHashSet<Vec<U8Term>>>,
terminals_arena: &mut TerminalsTrie,
nonterminal_to_terminal_id: &FxHashMap<String, NonterminalID>,
nonterminal_to_token_ids: &mut FxHashMap<NonterminalID, BitSet>,
nonterminal: &str,
excepted_literal: Option<&Vec<&[u8]>>| {
simplified_grammar.remove(nonterminal);
let predicate = |haystack: &&U8ArrayWrapper| {
excepted_literal.is_none()
|| excepted_literal.is_some_and(|x| {
x.iter().all(|x| {
&haystack.0[..] != *x && memmem::find(&haystack.0, x).is_none()
})
})
};
match excepted_literal {
Some(_) => {
simplified_grammar.insert(
nonterminal.to_string(),
vocabulary
.token_to_id
.keys()
.filter(predicate)
.map(|k| vec![U8Term::Terminal(k.0.to_vec())])
.collect(),
);
for (key, _) in vocabulary.token_to_id.iter() {
terminals_arena.add(&key.0, nonterminal_to_terminal_id[nonterminal], false)
}
let mut bit_set = BitSet::new();
bit_set.extend(vocabulary.token_to_id.iter().filter_map(|(k, token_id)| {
if predicate(&k) {
Some(*(token_id) as usize)
} else {
None
}
}));
nonterminal_to_token_ids
.insert(nonterminal_to_terminal_id[nonterminal], bit_set);
}
None => {
simplified_grammar.insert(
nonterminal.to_string(),
vocabulary
.token_to_id
.keys()
.map(|k| vec![U8Term::Terminal(k.0.to_vec())])
.collect(),
);
let mut bit_set = BitSet::new();
for (key, token_id) in vocabulary.token_to_id.iter() {
bit_set.insert((*token_id) as usize);
terminals_arena.add(&key.0, nonterminal_to_terminal_id[nonterminal], false)
}
nonterminal_to_token_ids
.insert(nonterminal_to_terminal_id[nonterminal], bit_set);
}
}
};
if any_present {
add_tokens(
&mut simplified_grammar,
&mut terminals_arena,
&nonterminal_to_terminal_id,
&mut nonterminal_to_token_ids,
utils::ANY_NONTERMINAL_NAME,
None,
);
}
fn process_valid_excepts<F: FnOnce(&str) -> Result<(), Error>>(
regex: &Regex,
nonterminal: &str,
process: F,
) -> Result<(), Error> {
let extracted = utils::extract_excepted(regex, nonterminal);
if let Some(extracted) = extracted {
if extracted.is_empty() {
return Err(anyhow::anyhow!("{nonterminal} is invalid except!() nonterminal because the brackets contain nothing."));
}
process(extracted)?;
}
Ok(())
}
if except_present {
for nonterminal in excepts.iter() {
process_valid_excepts(&utils::EXCEPT_LITERAL_REGEX, nonterminal, |extracted| {
let bytes = utils::fix_utf8_escape(extracted);
println!("{:?}", bytes);
add_tokens(
&mut simplified_grammar,
&mut terminals_arena,
&nonterminal_to_terminal_id,
&mut nonterminal_to_token_ids,
nonterminal,
Some(&vec![&bytes]),
);
terminals_arena.except_literal(&bytes, nonterminal_to_terminal_id[nonterminal]);
Ok(())
})?;
}
}
fn convert_u8terms_to_simplified_expressions(
k: &str,
v: FxHashSet<Vec<U8Term>>,
terminals_arena: &mut TerminalsTrie,
nonterminal_to_terminal_id: &FxHashMap<String, NonterminalID>,
) -> (String, SimplifiedExpressions) {
for i in v.into_iter() {
let value = match i.last().unwrap() {
U8Term::Terminal(value) => value,
_ => panic!("There should only be terminals."),
};
terminals_arena.add(value, nonterminal_to_terminal_id[k], true);
}
let v = SimplifiedExpressions::Terminals(
terminals_arena.roots[&nonterminal_to_terminal_id[k]],
);
(k.to_string(), v)
}
let mut new_simplified_grammar: FxHashMap<String, SimplifiedExpressions> =
simplified_grammar
.iter()
.map(|(k, v)| {
if v.iter().all(|terms| {
terms.len() == 1
&& match terms.last().unwrap() {
U8Term::Terminal(_) => true,
U8Term::Nonterminal(_) => false,
}
}) {
convert_u8terms_to_simplified_expressions(
k,
v.clone(),
&mut terminals_arena,
&nonterminal_to_terminal_id,
)
} else {
(k.clone(), SimplifiedExpressions::Expressions(v.clone()))
}
})
.collect();
if any_present {
new_simplified_grammar.insert(
utils::ANY_NONTERMINAL_NAME.to_string(),
SimplifiedExpressions::Terminals(
terminals_arena.roots[&nonterminal_to_terminal_id[utils::ANY_NONTERMINAL_NAME]],
),
);
}
if except_present {
for nonterminal in excepts.iter() {
if utils::EXCEPT_LITERAL_REGEX.is_match(nonterminal) {
new_simplified_grammar.insert(
nonterminal.to_string(),
SimplifiedExpressions::Terminals(
terminals_arena.roots[&nonterminal_to_terminal_id[nonterminal]],
),
);
}
}
}
let nonterminal_id_to_expression: FxHashMap<NonterminalID, SimplifiedExpressions> =
new_simplified_grammar
.iter()
.map(|(key, value)| (nonterminal_to_terminal_id[key], value.clone()))
.collect();
let grammar = Arc::new(Grammar {
nonterminal_to_terminal_id,
nonterminal_id_to_expression,
terminals_trie: terminals_arena,
nonterminal_to_token_ids,
});
let mut_grammar = unsafe { &mut *(Arc::as_ptr(&grammar) as *mut Grammar) };
if except_present {
for nonterminal in excepts.iter() {
process_valid_excepts(
&utils::EXCEPT_NONTERMINAL_REGEX,
nonterminal,
|extracted| {
ensure!(
mut_grammar
.nonterminal_to_terminal_id
.contains_key(extracted),
"except!([{extracted}]) is invalid because [{extracted}] is not a valid nonterminal."
);
mut_grammar.nonterminal_to_terminal_id.insert(
nonterminal.to_string(),
NonterminalID(grammar.nonterminal_id_to_expression.len()),
);
let mut temp_machine = Sampler::new(
grammar.clone(),
extracted.to_string(),
vocabulary.clone(),
stack_arena_capacity,
false,
)?;
let mut simplified_grammar: FxHashMap<String, FxHashSet<Vec<U8Term>>> =
FxHashMap::default();
match temp_machine.all_possible_next_tokens(None)? {
PossibleTokensResult::Continue(tokens) => {
let iter = vocabulary
.get_token_from_token_ids(tokens)
.collect_vec();
add_tokens(
&mut simplified_grammar,
&mut mut_grammar.terminals_trie,
&mut_grammar.nonterminal_to_terminal_id,
&mut mut_grammar.nonterminal_to_token_ids,
nonterminal,
Some(&iter),
);
for token in iter {
mut_grammar.terminals_trie.except_literal(
token,
mut_grammar.nonterminal_to_terminal_id[nonterminal],
);
}
let (new_k, new_v) = {
let (new_k, new_v) = convert_u8terms_to_simplified_expressions(
nonterminal,
simplified_grammar[nonterminal].clone(),
&mut mut_grammar.terminals_trie,
&grammar.nonterminal_to_terminal_id,
);
(grammar.nonterminal_to_terminal_id[&new_k], new_v)
};
mut_grammar
.nonterminal_id_to_expression
.insert(new_k, new_v);
simplified_grammar.clear();
},
_ => return Err(anyhow!("except!([{extracted}]) is invalid because [{extracted}] does not produce valid terminals.")),
}
Ok(())
},
)?;
}
}
for (_, v) in grammar.nonterminal_id_to_expression.iter() {
if let SimplifiedExpressions::Expressions(expressions) = v {
for terms in expressions {
for term in terms {
if let U8Term::Nonterminal(nonterminal) = term {
grammar.nonterminal_to_terminal_id.get(nonterminal).ok_or(
anyhow::anyhow!(
"Nonterminal string <{nonterminal}> is not defined."
),
)?;
}
}
}
}
}
Ok(grammar)
}
}