hwcalc_lib 0.2.0

Backend for the hwcalc calculator
Documentation
use num_traits::sign::Signed;
use num_traits::Pow;

use crate::ir;
use crate::Span;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Error(pub ErrorKind, pub Span);

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ErrorKind {
    CastWidthOutOfRange,
    DivideByZero,
    NameNotFound,
    OperandNonInteger,
    OperandTooLarge,
    RecursionLimitExceeded,
    ValueTooLarge,
}

impl From<Error> for super::Error {
    fn from(e: Error) -> Self {
        Self {
            kind: super::ErrorKind::Eval(e.0),
            span: e.1,
        }
    }
}

impl std::fmt::Display for ErrorKind {
    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
        match self {
            Self::CastWidthOutOfRange => write!(f, "casting to out of range width"),
            Self::DivideByZero => write!(f, "divide by zero"),
            Self::NameNotFound => write!(f, "name not found"),
            Self::OperandNonInteger => write!(f, "operand non-integer"),
            Self::OperandTooLarge => write!(f, "operand too large"),
            Self::RecursionLimitExceeded => write!(f, "recursion limit exceeded"),
            Self::ValueTooLarge => write!(f, "value too large"),
        }
    }
}

pub struct Evaluator<'a> {
    /// Current recursion depth, track to avoid stack overflow
    depth: usize,
    ns: &'a mut super::Namespace,
}

impl<'a> Evaluator<'a> {
    pub(crate) fn new(ns: &'a mut super::Namespace) -> Self {
        Self { depth: 0, ns }
    }

    pub(crate) fn eval(&mut self, expr: ir::Expr) -> Result<ir::Num, Error> {
        self.depth += 1;

        let res = if self.depth <= super::RECURSION_LIMIT {
            match expr.kind {
                ir::ExprKind::Num(i) => Ok(i),
                ir::ExprKind::Ident(s) => self.ns.get(&s).cloned().ok_or(ErrorKind::NameNotFound),
                ir::ExprKind::Un(op, e) => self.eval(*e)?.un_op(op),
                ir::ExprKind::Bin(op, e0, e1) => self.eval(*e0)?.bin_op(&self.eval(*e1)?, op),
                ir::ExprKind::Cast { ty, expr } => self
                    .eval(*expr)?
                    .cast(ty)
                    .ok_or(ErrorKind::CastWidthOutOfRange),
                ir::ExprKind::Assign(s, e) => {
                    let n = self.eval(*e)?;
                    self.ns.insert(s, n.clone());
                    Ok(n)
                }
            }
        } else {
            Err(ErrorKind::RecursionLimitExceeded)
        }
        .map_err(|kind| Error(kind, expr.span));

        self.depth -= 1;

        if self.depth == 0 {
            if let Ok(num) = &res {
                self.ns.insert(super::Sym::UNDERSCORE, num.clone());
            }
        }

        res
    }
}

impl ir::Ty {
    /// Least common denominator of two types.
    fn max(self, rhs: Self) -> Self {
        Self {
            signed: zip_opt(self.signed, rhs.signed, |a, b| a || b),
            width: zip_opt(self.width, rhs.width, ir::W::max),
            width_frac: zip_opt(self.width_frac, rhs.width_frac, ir::W::max),
        }
    }
}

fn zip_opt<T, F: Fn(T, T) -> T>(a: Option<T>, b: Option<T>, f: F) -> Option<T> {
    match [a, b] {
        [None, None] => None,
        [Some(a), Some(b)] => Some(f(a, b)),
        [w, None] | [None, w] => w,
    }
}

impl ir::Num {
    fn un_op(self, op: ir::UnOp) -> Result<Self, ErrorKind> {
        let x = self.val;
        let val = match op {
            ir::UnOp::Neg => -x,
            ir::UnOp::Not => {
                if let Some(w) = self.ty.width_frac {
                    map_int(x, w, |x| !x)
                } else {
                    // infinite precision rounds up
                    // e.g. 2.25 = !10.01(0).. = ..(1)01.10(1)... = ..(1)01.11 = -2.25
                    -x
                }
            }
        };
        Self::new(val)
            .ok_or(ErrorKind::ValueTooLarge)
            .map(|v| v.with_ty(self.ty))
    }

    fn bin_op(&self, rhs: &Self, op: ir::BinOp) -> Result<Self, ErrorKind> {
        let (a, b) = (&self.val, &rhs.val);
        let val = match op {
            ir::BinOp::Add => a + b,
            ir::BinOp::Sub => a - b,
            ir::BinOp::Mul => a * b,
            ir::BinOp::Div | ir::BinOp::Rem => {
                if b == &ir::Val::default() {
                    return Err(ErrorKind::DivideByZero);
                } else if matches!(op, ir::BinOp::Div) {
                    a / b
                } else {
                    a % b
                }
            }
            ir::BinOp::And => map2_int(a, b, |a, b| a & b),
            ir::BinOp::Or => map2_int(a, b, |a, b| a | b),
            ir::BinOp::Xor => map2_int(a, b, |a, b| a ^ b),
            ir::BinOp::Shl | ir::BinOp::Shr => {
                if !rhs.integer() {
                    return Err(ErrorKind::OperandNonInteger);
                } else if b.to_integer().abs() > ir::W::MAX.into() {
                    return Err(ErrorKind::OperandTooLarge);
                }

                let m = ir::Val::from_integer(2.into()).pow(b.to_integer());
                if matches!(op, ir::BinOp::Shl) {
                    a * m
                } else {
                    a / m
                }
            }
            ir::BinOp::Pow => {
                if !rhs.integer() {
                    return Err(ErrorKind::OperandNonInteger);
                } else if a.to_integer().abs() > ir::W::MAX.into()
                    || b.to_integer().abs() > ir::W::MAX.into()
                {
                    return Err(ErrorKind::OperandTooLarge);
                } else if a == &ir::Val::default() && b < &ir::Val::default() {
                    return Err(ErrorKind::DivideByZero);
                }
                num_traits::Pow::pow(a, b.to_integer())
            }
        };

        let ty = match op {
            ir::BinOp::Shl | ir::BinOp::Shr | ir::BinOp::Pow => self.ty,
            _ => self.ty.max(rhs.ty),
        };

        Self::new(val)
            .ok_or(ErrorKind::ValueTooLarge)
            .and_then(|v| v.cast(ty).ok_or(ErrorKind::CastWidthOutOfRange))
    }
}

/// Apply a function to a value as if it was an integer, i.e. without a fractional point.
fn map_int<F: Fn(ir::Int) -> ir::Int>(val: ir::Val, width_frac: ir::W, f: F) -> ir::Val {
    let denom = ir::Int::from(1) << width_frac;
    ir::Val::from_integer(f((val * ir::Val::from_integer(denom.clone())).to_integer())) / denom
}

/// Apply a binary function as if they both operands were integers.
fn map2_int<F: Fn(ir::Int, ir::Int) -> ir::Int>(a: &ir::Val, b: &ir::Val, f: F) -> ir::Val {
    let denom = ir::Int::max(a.denom().clone(), b.denom().clone());
    ir::Val::from_integer(f(
        (a * ir::Val::from_integer(denom.clone())).to_integer(),
        (b * ir::Val::from_integer(denom.clone())).to_integer(),
    )) / denom
}

#[cfg(test)]
mod test {
    impl super::ErrorKind {
        pub fn span(self, start: usize, end: usize) -> super::Error {
            super::Error(self, super::Span::new(start, end))
        }
    }
}