use super::{ConflictSolution, Node, Symbol};
use crate::{
indices::{ProdIdxRaw, SymbolIdx},
structures::Map,
};
pub type PrecedenceLevel = u16;
#[derive(Clone, Debug, Default)]
pub(crate) struct PrecedenceLevels {
pub(crate) left: Vec<Option<PrecedenceLevel>>,
pub(crate) right: Vec<Option<PrecedenceLevel>>,
}
#[derive(Clone, Debug)]
pub struct Production {
pub symbols: Box<[Symbol]>,
pub(crate) precedence: PrecedenceLevels,
pub(crate) forbidden_derivations: Vec<(SymbolIdx, ProdIdxRaw)>,
}
impl Production {
pub fn set_left_precedence(
&mut self,
precedence_family: PrecedenceFamilyToken,
level: PrecedenceLevel,
) -> &mut Self {
while self.precedence.left.len() <= precedence_family.0 {
self.precedence.left.push(None);
}
self.precedence.left[precedence_family.0] = Some(level);
self
}
pub fn set_right_precedence(
&mut self,
precedence_family: PrecedenceFamilyToken,
level: PrecedenceLevel,
) -> &mut Self {
while self.precedence.right.len() <= precedence_family.0 {
self.precedence.right.push(None);
}
self.precedence.right[precedence_family.0] = Some(level);
self
}
pub fn forbid_derivation(
&mut self,
symbol_index: usize,
forbidden_production: ProdIdx,
) -> Result<&mut Self, ForbidDerivationError> {
match self.symbols.get(symbol_index) {
None | Some(Symbol::Token(_)) => Err(ForbidDerivationError),
Some(Symbol::Node(n)) => {
if *n == forbidden_production.lhs {
self.forbidden_derivations
.push((symbol_index as SymbolIdx, forbidden_production.index));
Ok(self)
} else {
Err(ForbidDerivationError)
}
}
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct PrecedenceFamilyToken(usize);
#[derive(Clone, Debug)]
pub struct Grammar {
productions: Map<Node, Vec<Production>>,
next_precedence_family: usize,
pub(crate) conflict_solutions: Vec<ConflictSolution>,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ProdIdx {
pub lhs: Node,
pub index: ProdIdxRaw,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum AddProductionError {
RhsTooLong,
TooManyProductions,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct ForbidDerivationError;
impl Grammar {
pub fn new() -> Self {
Self {
productions: Map::default(),
next_precedence_family: 0,
conflict_solutions: Vec::new(),
}
}
pub fn add_production(
&mut self,
lhs: Node,
rhs: Vec<Symbol>,
) -> Result<ProdIdx, AddProductionError> {
let productions = self.productions.entry(lhs).or_default();
if rhs.len() > SymbolIdx::MAX as usize {
return Err(AddProductionError::RhsTooLong);
}
if productions.len() == ProdIdxRaw::MAX as usize {
return Err(AddProductionError::TooManyProductions);
}
let index = productions.len() as ProdIdxRaw;
productions.push(Production {
symbols: rhs.into(),
precedence: PrecedenceLevels::default(),
forbidden_derivations: Vec::new(),
});
Ok(ProdIdx { lhs, index })
}
pub fn add_conflict_solution(&mut self, conflict_solution: ConflictSolution) {
self.conflict_solutions.push(conflict_solution);
}
pub fn add_precedence_family(&mut self) -> PrecedenceFamilyToken {
let family = PrecedenceFamilyToken(self.next_precedence_family);
self.next_precedence_family += 1;
family
}
pub fn get_all_nodes(&self) -> impl Iterator<Item = Node> + '_ {
self.productions.keys().copied()
}
pub fn get_conflict_solutions(&self) -> &[ConflictSolution] {
&self.conflict_solutions
}
pub fn get_conflict_solutions_mut(&mut self) -> &mut [ConflictSolution] {
&mut self.conflict_solutions
}
pub fn get_production(&self, prod_idx: ProdIdx) -> Option<&Production> {
self.productions
.get(&prod_idx.lhs)?
.get(prod_idx.index as usize)
}
pub fn get_production_mut(&mut self, prod_idx: ProdIdx) -> Option<&mut Production> {
self.productions
.get_mut(&prod_idx.lhs)?
.get_mut(prod_idx.index as usize)
}
pub fn get_rhs(&self, prod_idx: ProdIdx) -> Option<&[Symbol]> {
Some(
&self
.productions
.get(&prod_idx.lhs)?
.get(prod_idx.index as usize)?
.symbols as &[_],
)
}
pub fn get_all_productions(&self) -> impl Iterator<Item = (ProdIdx, &'_ [Symbol])> + '_ {
self.productions.iter().flat_map(|(&n, productions)| {
productions.iter().enumerate().map(move |(index, b)| {
(
ProdIdx {
lhs: n,
index: index as ProdIdxRaw,
},
(&b.symbols as &[_]),
)
})
})
}
pub fn get_node_productions(
&self,
node: Node,
) -> impl Iterator<Item = (ProdIdx, &'_ Production)> + '_ {
self.productions
.get(&node)
.map(|productions| {
productions
.iter()
.enumerate()
.map(move |(index, production)| {
let prod_idx = ProdIdx {
lhs: node,
index: index as ProdIdxRaw,
};
(prod_idx, production)
})
})
.into_iter()
.flatten()
}
}
impl Default for Grammar {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::test_common::{nodes::START, tokens::T1};
#[test]
fn grammar_errors() {
let mut grammar = Grammar::default();
assert_eq!(
grammar
.add_production(START, vec![Symbol::Token(T1); SymbolIdx::MAX as usize + 1])
.unwrap_err(),
AddProductionError::RhsTooLong
);
for _ in 0..ProdIdxRaw::MAX {
assert!(grammar.add_production(START, vec![]).is_ok());
}
assert_eq!(
grammar.add_production(START, vec![]).unwrap_err(),
AddProductionError::TooManyProductions
);
}
}