use crate::HashMap;
use anyhow::{bail, ensure, Result};
use derivre::RegexAst;
use std::{ops::RangeInclusive, sync::atomic::AtomicU32};
use crate::api::{
GenGrammarOptions, GenOptions, GrammarWithLexer, Node, NodeId, NodeProps, RegexId, RegexNode,
RegexSpec, TopLevelGrammar,
};
#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)]
pub struct NodeRef {
idx: usize,
grammar_id: u32,
}
const K: usize = 4;
pub struct GrammarBuilder {
pub top_grammar: TopLevelGrammar,
placeholder: Node,
strings: HashMap<String, NodeRef>,
curr_grammar_id: u32,
node_refs: HashMap<String, NodeRef>,
nodes: Vec<Node>,
pub regex: RegexBuilder,
at_most_cache: HashMap<(NodeRef, usize), NodeRef>,
repeat_exact_cache: HashMap<(NodeRef, usize), NodeRef>,
}
pub struct RegexBuilder {
node_ids: HashMap<RegexNode, RegexId>,
nodes: Vec<RegexNode>,
}
impl RegexBuilder {
pub fn new() -> Self {
Self {
nodes: vec![],
node_ids: HashMap::default(),
}
}
pub fn add_ast(&mut self, ast: RegexAst) -> Result<RegexId> {
let id = match ast {
RegexAst::And(asts) => {
let ids = self.add_asts(asts)?;
self.and(ids)
}
RegexAst::Or(asts) => {
let ids = self.add_asts(asts)?;
self.add_node(RegexNode::Or(ids))
}
RegexAst::Concat(asts) => {
let ids = self.add_asts(asts)?;
self.concat(ids)
}
RegexAst::LookAhead(ast) => {
let id = self.add_ast(*ast)?;
self.add_node(RegexNode::LookAhead(id))
}
RegexAst::Not(ast) => {
let id = self.add_ast(*ast)?;
self.not(id)
}
RegexAst::Repeat(ast, min, max) => {
let id = self.add_ast(*ast)?;
self.repeat(id, min, Some(max))
}
RegexAst::EmptyString => self.add_node(RegexNode::EmptyString),
RegexAst::NoMatch => self.add_node(RegexNode::NoMatch),
RegexAst::Regex(rx) => self.regex(rx),
RegexAst::Literal(s) => self.literal(s),
RegexAst::ByteLiteral(bytes) => self.add_node(RegexNode::ByteLiteral(bytes)),
RegexAst::Byte(b) => self.add_node(RegexNode::Byte(b)),
RegexAst::ByteSet(bs) => self.add_node(RegexNode::ByteSet(bs)),
RegexAst::JsonQuote(ast, opts) => {
let regex = self.add_ast(*ast)?;
self.add_node(RegexNode::JsonQuote {
regex,
raw_mode: opts.raw_mode,
allowed_escapes: Some(opts.allowed_escapes.clone()),
})
}
RegexAst::MultipleOf(d, s) => self.add_node(RegexNode::MultipleOf(d, s)),
RegexAst::ExprRef(_) => {
bail!("ExprRef not supported")
}
};
Ok(id)
}
fn add_asts(&mut self, asts: Vec<RegexAst>) -> Result<Vec<RegexId>> {
asts.into_iter().map(|ast| self.add_ast(ast)).collect()
}
pub fn add_node(&mut self, node: RegexNode) -> RegexId {
if let Some(id) = self.node_ids.get(&node) {
return *id;
}
let id = RegexId(self.nodes.len());
self.nodes.push(node.clone());
self.node_ids.insert(node, id);
id
}
pub fn regex(&mut self, rx: String) -> RegexId {
self.add_node(RegexNode::Regex(rx))
}
pub fn literal(&mut self, s: String) -> RegexId {
self.add_node(RegexNode::Literal(s))
}
pub fn concat(&mut self, nodes: Vec<RegexId>) -> RegexId {
if nodes.len() == 1 {
return nodes[0];
}
if nodes.len() == 0 {
return self.add_node(RegexNode::NoMatch);
}
self.add_node(RegexNode::Concat(nodes))
}
pub fn select(&mut self, nodes: Vec<RegexId>) -> RegexId {
if nodes.len() == 1 {
return nodes[0];
}
if nodes.len() == 0 {
return self.add_node(RegexNode::NoMatch);
}
self.add_node(RegexNode::Or(nodes))
}
pub fn zero_or_more(&mut self, node: RegexId) -> RegexId {
self.repeat(node, 0, None)
}
pub fn one_or_more(&mut self, node: RegexId) -> RegexId {
self.repeat(node, 1, None)
}
pub fn optional(&mut self, node: RegexId) -> RegexId {
self.repeat(node, 0, Some(1))
}
pub fn repeat(&mut self, node: RegexId, min: u32, max: Option<u32>) -> RegexId {
self.add_node(RegexNode::Repeat(node, min, max))
}
pub fn not(&mut self, node: RegexId) -> RegexId {
self.add_node(RegexNode::Not(node))
}
pub fn and(&mut self, nodes: Vec<RegexId>) -> RegexId {
self.add_node(RegexNode::And(nodes))
}
pub fn or(&mut self, nodes: Vec<RegexId>) -> RegexId {
self.add_node(RegexNode::Or(nodes))
}
pub fn finalize(&mut self) -> Vec<RegexNode> {
let r = std::mem::take(&mut self.nodes);
*self = Self::new();
r
}
}
impl GrammarBuilder {
pub fn new() -> Self {
Self {
top_grammar: TopLevelGrammar {
grammars: vec![],
max_tokens: None,
test_trace: false,
},
placeholder: Node::String {
literal: "__placeholder__: do not use this string in grammars".to_string(),
props: NodeProps {
max_tokens: Some(usize::MAX - 108),
capture_name: Some("$$$placeholder$$$".to_string()),
..NodeProps::default()
},
},
strings: HashMap::default(),
curr_grammar_id: 0,
node_refs: HashMap::default(),
nodes: vec![],
regex: RegexBuilder::new(),
at_most_cache: HashMap::default(),
repeat_exact_cache: HashMap::default(),
}
}
fn shift_nodes(&mut self) {
if self.top_grammar.grammars.len() == 0 {
assert!(self.nodes.is_empty(), "nodes added before add_grammar()");
} else {
let nodes = std::mem::take(&mut self.nodes);
assert!(
nodes.len() > 0,
"no nodes added before add_grammar() or finalize()"
);
self.top_grammar.grammars.last_mut().unwrap().nodes = nodes;
self.top_grammar.grammars.last_mut().unwrap().rx_nodes = self.regex.finalize();
}
}
fn next_grammar_id(&mut self) {
static COUNTER: AtomicU32 = AtomicU32::new(1);
self.curr_grammar_id = COUNTER.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
}
pub fn add_grammar(&mut self, grammar: GrammarWithLexer) {
assert!(grammar.nodes.is_empty(), "Grammar already has nodes");
self.shift_nodes();
self.next_grammar_id();
self.top_grammar.grammars.push(grammar);
self.strings.clear();
let id = self.placeholder();
assert!(id.idx == 0);
}
pub fn add_node(&mut self, node: Node) -> NodeRef {
let key = (node != self.placeholder)
.then(|| serde_json::to_string(&node).ok())
.flatten();
if let Some(ref key) = key {
if let Some(node_ref) = self.node_refs.get(key) {
return *node_ref;
}
}
let r = NodeRef {
idx: self.nodes.len(),
grammar_id: self.curr_grammar_id,
};
self.nodes.push(node);
if let Some(key) = key {
self.node_refs.insert(key, r);
}
r
}
pub fn string(&mut self, s: &str) -> NodeRef {
if let Some(r) = self.strings.get(s) {
return *r;
}
let r = self.add_node(Node::String {
literal: s.to_string(),
props: NodeProps::default(),
});
self.strings.insert(s.to_string(), r);
r
}
pub fn token_ranges(&mut self, token_ranges: Vec<RangeInclusive<u32>>) -> NodeRef {
self.add_node(Node::TokenRanges {
token_ranges,
props: NodeProps::default(),
})
}
pub fn special_token(&mut self, name: &str) -> NodeRef {
self.add_node(Node::SpecialToken {
token: name.to_string(),
props: NodeProps::default(),
})
}
pub fn gen_grammar(&mut self, data: GenGrammarOptions, props: NodeProps) -> NodeRef {
self.add_node(Node::GenGrammar { data, props })
}
pub fn gen_rx(&mut self, regex: &str, stop_regex: &str) -> NodeRef {
self.gen(
GenOptions {
body_rx: RegexSpec::Regex(regex.to_string()),
stop_rx: RegexSpec::Regex(stop_regex.to_string()),
..Default::default()
},
NodeProps::default(),
)
}
pub fn gen(&mut self, data: GenOptions, props: NodeProps) -> NodeRef {
self.add_node(Node::Gen { data, props })
}
pub fn lexeme(&mut self, rx: RegexSpec) -> NodeRef {
self.add_node(Node::Lexeme {
rx,
contextual: None,
temperature: None,
json_string: None,
json_raw: None,
json_allowed_escapes: None,
props: NodeProps::default(),
})
}
fn child_nodes(&mut self, options: &[NodeRef]) -> Vec<NodeId> {
options
.iter()
.map(|e| {
assert!(e.grammar_id == self.curr_grammar_id);
NodeId(e.idx)
})
.collect()
}
pub fn select(&mut self, options: &[NodeRef]) -> NodeRef {
let ch = self.child_nodes(&options);
self.add_node(Node::Select {
among: ch,
props: NodeProps::default(),
})
}
pub fn max_tokens(&mut self, node: NodeRef, max_tokens: usize) -> NodeRef {
self.join_props(
&[node],
NodeProps {
max_tokens: Some(max_tokens),
..Default::default()
},
)
}
pub fn join(&mut self, values: &[NodeRef]) -> NodeRef {
self.join_props(values, NodeProps::default())
}
pub fn join_props(&mut self, values: &[NodeRef], props: NodeProps) -> NodeRef {
let mut ch = self.child_nodes(&values);
let empty = NodeId(self.empty().idx);
ch.retain(|&n| n != empty);
if ch.len() == 0 {
return self.empty();
}
if ch.len() == 1 && props == NodeProps::default() {
return NodeRef {
idx: ch[0].0,
grammar_id: self.curr_grammar_id,
};
}
self.add_node(Node::Join {
sequence: ch,
props,
})
}
pub fn empty(&mut self) -> NodeRef {
self.string("")
}
pub fn optional(&mut self, value: NodeRef) -> NodeRef {
let empty = self.empty();
self.select(&[value, empty])
}
pub fn one_or_more(&mut self, elt: NodeRef) -> NodeRef {
let p = self.placeholder();
let p_elt = self.join(&[p, elt]);
let inner = self.select(&[elt, p_elt]);
self.set_placeholder(p, inner);
p
}
pub fn zero_or_more(&mut self, elt: NodeRef) -> NodeRef {
let p = self.placeholder();
let empty = self.empty();
let p_elt = self.join(&[p, elt]);
let inner = self.select(&[empty, p_elt]);
self.set_placeholder(p, inner);
p
}
fn at_most(&mut self, elt: NodeRef, n: usize) -> NodeRef {
if let Some(r) = self.at_most_cache.get(&(elt, n)) {
return *r;
}
let r = if n == 0 {
self.empty()
} else if n == 1 {
self.optional(elt)
} else if n < 3 * K {
let options = (0..=n)
.map(|k| self.simple_repeat(elt, k))
.collect::<Vec<_>>();
self.select(&options)
} else {
let elt_k = self.simple_repeat(elt, K);
let elt_max_nk = self.at_most(elt_k, (n / K) - 1);
let elt_max_k = self.at_most(elt, K - 1);
let elt_max_nk = self.join(&[elt_max_nk, elt_max_k]);
let elt_nk = self.repeat_exact(elt_k, n / K);
let left = self.at_most(elt, n % K);
let elt_n = self.join(&[elt_nk, left]);
self.select(&[elt_n, elt_max_nk])
};
self.at_most_cache.insert((elt, n), r);
r
}
fn simple_repeat(&mut self, elt: NodeRef, n: usize) -> NodeRef {
let elt_n = (0..n).map(|_| elt).collect::<Vec<_>>();
self.join(&elt_n)
}
fn repeat_exact(&mut self, elt: NodeRef, n: usize) -> NodeRef {
if let Some(r) = self.repeat_exact_cache.get(&(elt, n)) {
return *r;
}
let r = if n > 2 * K {
let elt_k = self.simple_repeat(elt, K);
let inner = self.repeat_exact(elt_k, n / K);
let left = n % K;
let mut elt_left = (0..left).map(|_| elt).collect::<Vec<_>>();
elt_left.push(inner);
self.join(&elt_left)
} else {
self.simple_repeat(elt, n)
};
self.repeat_exact_cache.insert((elt, n), r);
r
}
fn at_least(&mut self, elt: NodeRef, n: usize) -> NodeRef {
let z = self.zero_or_more(elt);
if n == 0 {
z
} else {
let r = self.repeat_exact(elt, n);
self.join(&[r, z])
}
}
pub fn repeat(&mut self, elt: NodeRef, min: usize, max: Option<usize>) -> NodeRef {
if max.is_none() {
return self.at_least(elt, min);
}
let max = max.unwrap();
assert!(min <= max);
if min == max {
self.repeat_exact(elt, min)
} else if min == 0 {
self.at_most(elt, max)
} else {
let d = max - min;
let common = self.repeat_exact(elt, min);
let extra = self.at_most(elt, d);
self.join(&[common, extra])
}
}
pub fn placeholder(&mut self) -> NodeRef {
self.add_node(self.placeholder.clone())
}
pub fn is_placeholder(&self, node: NodeRef) -> bool {
assert!(node.grammar_id == self.curr_grammar_id);
self.nodes[node.idx] == self.placeholder
}
pub fn set_placeholder(&mut self, placeholder: NodeRef, node: NodeRef) {
let ch = self.child_nodes(&[placeholder, node]); if !self.is_placeholder(placeholder) {
panic!(
"placeholder already set at {} to {:?}",
placeholder.idx, self.nodes[placeholder.idx]
);
}
if self.is_placeholder(node) {
self.nodes[placeholder.idx] = Node::Join {
sequence: vec![ch[1]],
props: NodeProps::default(),
};
} else {
let prev_placeholder_link = Node::Join {
sequence: vec![ch[0]],
props: NodeProps::default(),
};
self.nodes[placeholder.idx] =
std::mem::replace(&mut self.nodes[node.idx], prev_placeholder_link);
}
}
pub fn set_start_node(&mut self, node: NodeRef) {
self.set_placeholder(
NodeRef {
idx: 0,
grammar_id: self.curr_grammar_id,
},
node,
);
}
pub fn finalize(mut self) -> Result<TopLevelGrammar> {
ensure!(
self.top_grammar.grammars.len() > 0,
"No grammars added to the top level grammar"
);
self.shift_nodes();
for grammar in &self.top_grammar.grammars {
for node in &grammar.nodes {
ensure!(node != &self.placeholder, "Unresolved placeholder");
}
}
Ok(self.top_grammar.clone())
}
}