use std::collections::{HashMap, HashSet};
use crate::{
CompilationError, SourceSpan,
ast::{
self, ASTVisitor, Add, ArithmeticExpression, Binding, Constant,
CustomDice, DiceExpression, Div, DropHighest, DropLowest, Exp,
Expression, Group, Mod, Mul, Neg, Range, StandardDice, Sub, Variable
}
};
#[derive(Copy, Clone, Debug, Default)]
pub struct Validator;
impl Validator
{
#[inline]
pub const fn new() -> Self { Self }
pub fn validate<'src>(
ast: &ast::Function<'src>
) -> Result<(), CompilationError<'src>>
{
check_duplicate_parameters(ast)?;
let bindings = collect_bindings_and_check_collisions(ast)?;
check_use_before_bind(&ast.body, &bindings)
}
}
fn check_duplicate_parameters<'src>(
ast: &ast::Function<'src>
) -> Result<(), CompilationError<'src>>
{
if let Some(ref parameters) = ast.parameters
{
let mut seen: HashMap<&'src str, SourceSpan> =
HashMap::with_capacity(parameters.len());
for param in parameters
{
if let Some(&first) = seen.get(param.name)
{
return Err(CompilationError::DuplicateParameter {
name: param.name,
first,
duplicate: param.span
});
}
seen.insert(param.name, param.span);
}
}
Ok(())
}
fn collect_bindings_and_check_collisions<'src>(
ast: &ast::Function<'src>
) -> Result<HashMap<&'src str, SourceSpan>, CompilationError<'src>>
{
let parameter_spans = match ast.parameters
{
Some(ref parameters) => parameters
.iter()
.map(|p| (p.name, p.span))
.collect::<HashMap<_, _>>(),
None => HashMap::new()
};
let mut bindings: HashMap<&'src str, SourceSpan> = HashMap::new();
gather_bindings(&ast.body, ¶meter_spans, &mut bindings)?;
Ok(bindings)
}
fn gather_bindings<'src>(
expr: &Expression<'src>,
parameters: &HashMap<&'src str, SourceSpan>,
bindings: &mut HashMap<&'src str, SourceSpan>
) -> Result<(), CompilationError<'src>>
{
match expr
{
Expression::Binding(b) =>
{
if let Some(¶meter) = parameters.get(b.name)
{
return Err(CompilationError::BindingCollidesWithParameter {
name: b.name,
parameter,
binding: b.name_span
});
}
if let Some(&first) = bindings.get(b.name)
{
return Err(CompilationError::DuplicateBinding {
name: b.name,
first,
duplicate: b.name_span
});
}
gather_bindings(&b.expression, parameters, bindings)?;
bindings.insert(b.name, b.name_span);
},
Expression::Group(g) =>
{
gather_bindings(&g.expression, parameters, bindings)?
},
Expression::Range(r) =>
{
gather_bindings(&r.start, parameters, bindings)?;
gather_bindings(&r.end, parameters, bindings)?;
},
Expression::Dice(d) => gather_dice_bindings(d, parameters, bindings)?,
Expression::Arithmetic(a) =>
{
gather_arithmetic_bindings(a, parameters, bindings)?
},
Expression::Variable(_) | Expression::Constant(_) =>
{}
}
Ok(())
}
fn gather_dice_bindings<'src>(
dice: &DiceExpression<'src>,
parameters: &HashMap<&'src str, SourceSpan>,
bindings: &mut HashMap<&'src str, SourceSpan>
) -> Result<(), CompilationError<'src>>
{
match dice
{
DiceExpression::Standard(d) =>
{
gather_bindings(&d.count, parameters, bindings)?;
gather_bindings(&d.faces, parameters, bindings)?;
},
DiceExpression::Custom(d) =>
{
gather_bindings(&d.count, parameters, bindings)?;
},
DiceExpression::DropLowest(d) =>
{
gather_dice_bindings(&d.dice, parameters, bindings)?;
if let Some(ref drop) = d.drop
{
gather_bindings(drop, parameters, bindings)?;
}
},
DiceExpression::DropHighest(d) =>
{
gather_dice_bindings(&d.dice, parameters, bindings)?;
if let Some(ref drop) = d.drop
{
gather_bindings(drop, parameters, bindings)?;
}
}
}
Ok(())
}
fn gather_arithmetic_bindings<'src>(
arith: &ArithmeticExpression<'src>,
parameters: &HashMap<&'src str, SourceSpan>,
bindings: &mut HashMap<&'src str, SourceSpan>
) -> Result<(), CompilationError<'src>>
{
match arith
{
ArithmeticExpression::Add(a) =>
{
gather_bindings(&a.left, parameters, bindings)?;
gather_bindings(&a.right, parameters, bindings)?;
},
ArithmeticExpression::Sub(s) =>
{
gather_bindings(&s.left, parameters, bindings)?;
gather_bindings(&s.right, parameters, bindings)?;
},
ArithmeticExpression::Mul(m) =>
{
gather_bindings(&m.left, parameters, bindings)?;
gather_bindings(&m.right, parameters, bindings)?;
},
ArithmeticExpression::Div(d) =>
{
gather_bindings(&d.left, parameters, bindings)?;
gather_bindings(&d.right, parameters, bindings)?;
},
ArithmeticExpression::Mod(m) =>
{
gather_bindings(&m.left, parameters, bindings)?;
gather_bindings(&m.right, parameters, bindings)?;
},
ArithmeticExpression::Exp(e) =>
{
gather_bindings(&e.left, parameters, bindings)?;
gather_bindings(&e.right, parameters, bindings)?;
},
ArithmeticExpression::Neg(n) =>
{
gather_bindings(&n.operand, parameters, bindings)?
},
}
Ok(())
}
fn check_use_before_bind<'src>(
body: &Expression<'src>,
bindings: &HashMap<&'src str, SourceSpan>
) -> Result<(), CompilationError<'src>>
{
let mut seen: HashSet<&'src str> = HashSet::new();
walk_use_before_bind(body, bindings, &mut seen)
}
fn walk_use_before_bind<'src>(
expr: &Expression<'src>,
bindings: &HashMap<&'src str, SourceSpan>,
seen: &mut HashSet<&'src str>
) -> Result<(), CompilationError<'src>>
{
match expr
{
Expression::Variable(v) =>
{
if let Some(&binding_span) = bindings.get(v.name)
&& !seen.contains(v.name)
{
return Err(CompilationError::UseBeforeBind {
name: v.name,
reference: v.span,
binding: binding_span
});
}
},
Expression::Binding(b) =>
{
walk_use_before_bind(&b.expression, bindings, seen)?;
seen.insert(b.name);
},
Expression::Group(g) =>
{
walk_use_before_bind(&g.expression, bindings, seen)?
},
Expression::Range(r) =>
{
walk_use_before_bind(&r.start, bindings, seen)?;
walk_use_before_bind(&r.end, bindings, seen)?;
},
Expression::Dice(d) => walk_dice_use_before_bind(d, bindings, seen)?,
Expression::Arithmetic(a) =>
{
walk_arithmetic_use_before_bind(a, bindings, seen)?
},
Expression::Constant(_) =>
{}
}
Ok(())
}
fn walk_dice_use_before_bind<'src>(
dice: &DiceExpression<'src>,
bindings: &HashMap<&'src str, SourceSpan>,
seen: &mut HashSet<&'src str>
) -> Result<(), CompilationError<'src>>
{
match dice
{
DiceExpression::Standard(d) =>
{
walk_use_before_bind(&d.count, bindings, seen)?;
walk_use_before_bind(&d.faces, bindings, seen)?;
},
DiceExpression::Custom(d) =>
{
walk_use_before_bind(&d.count, bindings, seen)?;
},
DiceExpression::DropLowest(d) =>
{
walk_dice_use_before_bind(&d.dice, bindings, seen)?;
if let Some(ref drop) = d.drop
{
walk_use_before_bind(drop, bindings, seen)?;
}
},
DiceExpression::DropHighest(d) =>
{
walk_dice_use_before_bind(&d.dice, bindings, seen)?;
if let Some(ref drop) = d.drop
{
walk_use_before_bind(drop, bindings, seen)?;
}
}
}
Ok(())
}
fn walk_arithmetic_use_before_bind<'src>(
arith: &ArithmeticExpression<'src>,
bindings: &HashMap<&'src str, SourceSpan>,
seen: &mut HashSet<&'src str>
) -> Result<(), CompilationError<'src>>
{
match arith
{
ArithmeticExpression::Add(a) =>
{
walk_use_before_bind(&a.left, bindings, seen)?;
walk_use_before_bind(&a.right, bindings, seen)?;
},
ArithmeticExpression::Sub(s) =>
{
walk_use_before_bind(&s.left, bindings, seen)?;
walk_use_before_bind(&s.right, bindings, seen)?;
},
ArithmeticExpression::Mul(m) =>
{
walk_use_before_bind(&m.left, bindings, seen)?;
walk_use_before_bind(&m.right, bindings, seen)?;
},
ArithmeticExpression::Div(d) =>
{
walk_use_before_bind(&d.left, bindings, seen)?;
walk_use_before_bind(&d.right, bindings, seen)?;
},
ArithmeticExpression::Mod(m) =>
{
walk_use_before_bind(&m.left, bindings, seen)?;
walk_use_before_bind(&m.right, bindings, seen)?;
},
ArithmeticExpression::Exp(e) =>
{
walk_use_before_bind(&e.left, bindings, seen)?;
walk_use_before_bind(&e.right, bindings, seen)?;
},
ArithmeticExpression::Neg(n) =>
{
walk_use_before_bind(&n.operand, bindings, seen)?
},
}
Ok(())
}
impl<'src> ASTVisitor<'src> for Validator
{
type Error = CompilationError<'src>;
type Output = ();
fn visit_function(
&mut self,
node: &'src ast::Function<'src>
) -> Result<(), Self::Error>
{
check_duplicate_parameters(node)
}
fn visit_group(
&mut self,
_node: &'src Group<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_constant(&mut self, _node: &Constant) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_variable(
&mut self,
_node: &'src Variable<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_binding(
&mut self,
_node: &'src Binding<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_range(
&mut self,
_node: &'src Range<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_standard_dice(
&mut self,
_node: &'src StandardDice<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_custom_dice(
&mut self,
_node: &'src CustomDice<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_drop_lowest(
&mut self,
_node: &'src DropLowest<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_drop_highest(
&mut self,
_node: &'src DropHighest<'src>
) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_add(&mut self, _node: &'src Add<'src>) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_sub(&mut self, _node: &'src Sub<'src>) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_mul(&mut self, _node: &'src Mul<'src>) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_div(&mut self, _node: &'src Div<'src>) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_mod(&mut self, _node: &'src Mod<'src>) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_exp(&mut self, _node: &'src Exp<'src>) -> Result<(), Self::Error>
{
Ok(())
}
fn visit_neg(&mut self, _node: &'src Neg<'src>) -> Result<(), Self::Error>
{
Ok(())
}
}