use crate::{hir, ty::Gcx};
use alloy_primitives::U256;
use solar_ast::LitKind;
use solar_interface::{Span, diagnostics::ErrorGuaranteed};
use std::fmt;
const RECURSION_LIMIT: usize = 64;
pub struct ConstantEvaluator<'gcx> {
pub gcx: Gcx<'gcx>,
depth: usize,
}
type EvalResult<'gcx> = Result<IntScalar, EvalError>;
impl<'gcx> ConstantEvaluator<'gcx> {
pub fn new(gcx: Gcx<'gcx>) -> Self {
Self { gcx, depth: 0 }
}
pub fn eval(&mut self, expr: &hir::Expr<'_>) -> Result<IntScalar, ErrorGuaranteed> {
self.try_eval(expr).map_err(|err| self.emit_eval_error(expr, err))
}
pub fn try_eval(&mut self, expr: &hir::Expr<'_>) -> EvalResult<'gcx> {
self.depth += 1;
if self.depth > RECURSION_LIMIT {
return Err(EE::RecursionLimitReached.spanned(expr.span));
}
let mut res = self.eval_expr(expr);
if let Err(e) = &mut res
&& e.span.is_dummy()
{
e.span = expr.span;
}
self.depth = self.depth.checked_sub(1).unwrap();
res
}
pub fn emit_eval_error(&self, expr: &hir::Expr<'_>, err: EvalError) -> ErrorGuaranteed {
match err.kind {
EE::AlreadyEmitted(guar) => guar,
_ => {
let msg = format!("failed to evaluate constant: {}", err.kind.msg());
let label = "evaluation of constant value failed here";
self.gcx.dcx().err(msg).span(expr.span).span_label(err.span, label).emit()
}
}
}
fn eval_expr(&mut self, expr: &hir::Expr<'_>) -> EvalResult<'gcx> {
let expr = expr.peel_parens();
match expr.kind {
hir::ExprKind::Binary(l, bin_op, r) => {
let l = self.try_eval(l)?;
let r = self.try_eval(r)?;
l.binop(&r, bin_op.kind).map_err(Into::into)
}
hir::ExprKind::Ident(res) => {
let Some(v) = res.iter().find_map(|res| res.as_variable()) else {
return Err(EE::NonConstantVar.into());
};
let v = self.gcx.hir.variable(v);
if v.mutability != Some(hir::VarMut::Constant) {
return Err(EE::NonConstantVar.into());
}
self.try_eval(v.initializer.expect("constant variable has no initializer"))
}
hir::ExprKind::Lit(lit) => self.eval_lit(lit),
hir::ExprKind::Unary(un_op, v) => {
let v = self.try_eval(v)?;
v.unop(un_op.kind).map_err(Into::into)
}
hir::ExprKind::Err(guar) => Err(EE::AlreadyEmitted(guar).into()),
_ => Err(EE::UnsupportedExpr.into()),
}
}
fn eval_lit(&mut self, lit: &hir::Lit<'_>) -> EvalResult<'gcx> {
match lit.kind {
LitKind::Number(n) => Ok(IntScalar::new(n)),
LitKind::Address(address) => Ok(IntScalar::from_be_bytes(address.as_slice())),
LitKind::Bool(bool) => Ok(IntScalar::from_be_bytes(&[bool as u8])),
LitKind::Err(guar) => Err(EE::AlreadyEmitted(guar).into()),
_ => Err(EE::UnsupportedLiteral.into()),
}
}
}
pub struct IntScalar {
pub data: U256,
}
impl IntScalar {
pub fn new(data: U256) -> Self {
Self { data }
}
pub fn from_bool(value: bool) -> Self {
Self { data: U256::from(value as u8) }
}
pub fn from_be_bytes(bytes: &[u8]) -> Self {
Self { data: U256::from_be_slice(bytes) }
}
pub fn to_bool(&self) -> bool {
!self.data.is_zero()
}
pub fn unop(&self, op: hir::UnOpKind) -> Result<Self, EE> {
Ok(match op {
hir::UnOpKind::PreInc
| hir::UnOpKind::PreDec
| hir::UnOpKind::PostInc
| hir::UnOpKind::PostDec => return Err(EE::UnsupportedUnaryOp),
hir::UnOpKind::Not | hir::UnOpKind::BitNot => Self::new(!self.data),
hir::UnOpKind::Neg => Self::new(self.data.wrapping_neg()),
})
}
pub fn binop(&self, r: &Self, op: hir::BinOpKind) -> Result<Self, EE> {
let l = self;
Ok(match op {
hir::BinOpKind::BitOr => Self::new(l.data | r.data),
hir::BinOpKind::BitAnd => Self::new(l.data & r.data),
hir::BinOpKind::BitXor => Self::new(l.data ^ r.data),
hir::BinOpKind::Shr => {
Self::new(l.data.wrapping_shr(r.data.try_into().unwrap_or(usize::MAX)))
}
hir::BinOpKind::Shl => {
Self::new(l.data.wrapping_shl(r.data.try_into().unwrap_or(usize::MAX)))
}
hir::BinOpKind::Sar => {
Self::new(l.data.arithmetic_shr(r.data.try_into().unwrap_or(usize::MAX)))
}
hir::BinOpKind::Add => {
Self::new(l.data.checked_add(r.data).ok_or(EE::ArithmeticOverflow)?)
}
hir::BinOpKind::Sub => {
Self::new(l.data.checked_sub(r.data).ok_or(EE::ArithmeticOverflow)?)
}
hir::BinOpKind::Pow => {
Self::new(l.data.checked_pow(r.data).ok_or(EE::ArithmeticOverflow)?)
}
hir::BinOpKind::Mul => {
Self::new(l.data.checked_mul(r.data).ok_or(EE::ArithmeticOverflow)?)
}
hir::BinOpKind::Div => Self::new(l.data.checked_div(r.data).ok_or(EE::DivisionByZero)?),
hir::BinOpKind::Rem => Self::new(l.data.checked_rem(r.data).ok_or(EE::DivisionByZero)?),
hir::BinOpKind::Lt
| hir::BinOpKind::Le
| hir::BinOpKind::Gt
| hir::BinOpKind::Ge
| hir::BinOpKind::Eq
| hir::BinOpKind::Ne
| hir::BinOpKind::Or
| hir::BinOpKind::And => return Err(EE::UnsupportedBinaryOp),
})
}
}
#[derive(Debug)]
pub enum EvalErrorKind {
RecursionLimitReached,
ArithmeticOverflow,
DivisionByZero,
UnsupportedLiteral,
UnsupportedUnaryOp,
UnsupportedBinaryOp,
UnsupportedExpr,
NonConstantVar,
AlreadyEmitted(ErrorGuaranteed),
}
use EvalErrorKind as EE;
impl EvalErrorKind {
pub fn spanned(self, span: Span) -> EvalError {
EvalError { kind: self, span }
}
fn msg(&self) -> &'static str {
match self {
Self::RecursionLimitReached => "recursion limit reached",
Self::ArithmeticOverflow => "arithmetic overflow",
Self::DivisionByZero => "attempted to divide by zero",
Self::UnsupportedLiteral => "unsupported literal",
Self::UnsupportedUnaryOp => "unsupported unary operation",
Self::UnsupportedBinaryOp => "unsupported binary operation",
Self::UnsupportedExpr => "unsupported expression",
Self::NonConstantVar => "only constant variables are allowed",
Self::AlreadyEmitted(_) => unreachable!(),
}
}
}
#[derive(Debug)]
pub struct EvalError {
pub span: Span,
pub kind: EvalErrorKind,
}
impl From<EE> for EvalError {
fn from(value: EE) -> Self {
Self { kind: value, span: Span::DUMMY }
}
}
impl fmt::Display for EvalError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.kind.msg().fmt(f)
}
}
impl std::error::Error for EvalError {}