use num_traits::sign::Signed;
use num_traits::Pow;
use crate::ir;
use crate::Span;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Error(pub ErrorKind, pub Span);
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorKind {
CastWidthOutOfRange,
DivideByZero,
NameNotFound,
OperandNonInteger,
OperandTooLarge,
RecursionLimitExceeded,
ValueTooLarge,
}
impl From<Error> for super::Error {
fn from(e: Error) -> Self {
Self {
kind: super::ErrorKind::Eval(e.0),
span: e.1,
}
}
}
impl std::fmt::Display for ErrorKind {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::CastWidthOutOfRange => write!(f, "casting to out of range width"),
Self::DivideByZero => write!(f, "divide by zero"),
Self::NameNotFound => write!(f, "name not found"),
Self::OperandNonInteger => write!(f, "operand non-integer"),
Self::OperandTooLarge => write!(f, "operand too large"),
Self::RecursionLimitExceeded => write!(f, "recursion limit exceeded"),
Self::ValueTooLarge => write!(f, "value too large"),
}
}
}
pub struct Evaluator<'a> {
depth: usize,
ns: &'a mut super::Namespace,
}
impl<'a> Evaluator<'a> {
pub(crate) fn new(ns: &'a mut super::Namespace) -> Self {
Self { depth: 0, ns }
}
pub(crate) fn eval(&mut self, expr: ir::Expr) -> Result<ir::Num, Error> {
self.depth += 1;
let res = if self.depth <= super::RECURSION_LIMIT {
match expr.kind {
ir::ExprKind::Num(i) => Ok(i),
ir::ExprKind::Ident(s) => self.ns.get(&s).cloned().ok_or(ErrorKind::NameNotFound),
ir::ExprKind::Un(op, e) => self.eval(*e)?.un_op(op),
ir::ExprKind::Bin(op, e0, e1) => self.eval(*e0)?.bin_op(&self.eval(*e1)?, op),
ir::ExprKind::Cast { ty, expr } => self
.eval(*expr)?
.cast(ty)
.ok_or(ErrorKind::CastWidthOutOfRange),
ir::ExprKind::Assign(s, e) => {
let n = self.eval(*e)?;
self.ns.insert(s, n.clone());
Ok(n)
}
}
} else {
Err(ErrorKind::RecursionLimitExceeded)
}
.map_err(|kind| Error(kind, expr.span));
self.depth -= 1;
if self.depth == 0 {
if let Ok(num) = &res {
self.ns.insert(super::Sym::UNDERSCORE, num.clone());
}
}
res
}
}
impl ir::Ty {
fn max(self, rhs: Self) -> Self {
Self {
signed: zip_opt(self.signed, rhs.signed, |a, b| a || b),
width: zip_opt(self.width, rhs.width, ir::W::max),
width_frac: zip_opt(self.width_frac, rhs.width_frac, ir::W::max),
}
}
}
fn zip_opt<T, F: Fn(T, T) -> T>(a: Option<T>, b: Option<T>, f: F) -> Option<T> {
match [a, b] {
[None, None] => None,
[Some(a), Some(b)] => Some(f(a, b)),
[w, None] | [None, w] => w,
}
}
impl ir::Num {
fn un_op(self, op: ir::UnOp) -> Result<Self, ErrorKind> {
let x = self.val;
let val = match op {
ir::UnOp::Neg => -x,
ir::UnOp::Not => {
if let Some(w) = self.ty.width_frac {
map_int(x, w, |x| !x)
} else {
-x
}
}
};
Self::new(val)
.ok_or(ErrorKind::ValueTooLarge)
.map(|v| v.with_ty(self.ty))
}
fn bin_op(&self, rhs: &Self, op: ir::BinOp) -> Result<Self, ErrorKind> {
let (a, b) = (&self.val, &rhs.val);
let val = match op {
ir::BinOp::Add => a + b,
ir::BinOp::Sub => a - b,
ir::BinOp::Mul => a * b,
ir::BinOp::Div | ir::BinOp::Rem => {
if b == &ir::Val::default() {
return Err(ErrorKind::DivideByZero);
} else if matches!(op, ir::BinOp::Div) {
a / b
} else {
a % b
}
}
ir::BinOp::And => map2_int(a, b, |a, b| a & b),
ir::BinOp::Or => map2_int(a, b, |a, b| a | b),
ir::BinOp::Xor => map2_int(a, b, |a, b| a ^ b),
ir::BinOp::Shl | ir::BinOp::Shr => {
if !rhs.integer() {
return Err(ErrorKind::OperandNonInteger);
} else if b.to_integer().abs() > ir::W::MAX.into() {
return Err(ErrorKind::OperandTooLarge);
}
let m = ir::Val::from_integer(2.into()).pow(b.to_integer());
if matches!(op, ir::BinOp::Shl) {
a * m
} else {
a / m
}
}
ir::BinOp::Pow => {
if !rhs.integer() {
return Err(ErrorKind::OperandNonInteger);
} else if a.to_integer().abs() > ir::W::MAX.into()
|| b.to_integer().abs() > ir::W::MAX.into()
{
return Err(ErrorKind::OperandTooLarge);
} else if a == &ir::Val::default() && b < &ir::Val::default() {
return Err(ErrorKind::DivideByZero);
}
num_traits::Pow::pow(a, b.to_integer())
}
};
let ty = match op {
ir::BinOp::Shl | ir::BinOp::Shr | ir::BinOp::Pow => self.ty,
_ => self.ty.max(rhs.ty),
};
Self::new(val)
.ok_or(ErrorKind::ValueTooLarge)
.and_then(|v| v.cast(ty).ok_or(ErrorKind::CastWidthOutOfRange))
}
}
fn map_int<F: Fn(ir::Int) -> ir::Int>(val: ir::Val, width_frac: ir::W, f: F) -> ir::Val {
let denom = ir::Int::from(1) << width_frac;
ir::Val::from_integer(f((val * ir::Val::from_integer(denom.clone())).to_integer())) / denom
}
fn map2_int<F: Fn(ir::Int, ir::Int) -> ir::Int>(a: &ir::Val, b: &ir::Val, f: F) -> ir::Val {
let denom = ir::Int::max(a.denom().clone(), b.denom().clone());
ir::Val::from_integer(f(
(a * ir::Val::from_integer(denom.clone())).to_integer(),
(b * ir::Val::from_integer(denom.clone())).to_integer(),
)) / denom
}
#[cfg(test)]
mod test {
impl super::ErrorKind {
pub fn span(self, start: usize, end: usize) -> super::Error {
super::Error(self, super::Span::new(start, end))
}
}
}