use crate::ir_eval::canonical_f32;
use crate::ir_inner::model::types::{BinOp, UnOp};
use std::collections::HashMap;
use std::fmt;
use std::sync::Arc;
pub type OpId = Arc<str>;
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct NodeId(pub u32);
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct VarId(pub u32);
#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash, Ord, PartialOrd)]
pub struct RegionId(pub u32);
#[non_exhaustive]
#[derive(Clone, Copy, Debug, PartialEq)]
pub enum Value {
U32(u32),
U64(u64),
I32(i32),
F32(f32),
Bool(bool),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EvalError {
message: String,
}
impl EvalError {
#[must_use]
pub fn new(message: impl Into<String>) -> Self {
Self {
message: message.into(),
}
}
#[must_use]
pub fn message(&self) -> &str {
&self.message
}
}
impl fmt::Display for EvalError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(&self.message)
}
}
impl std::error::Error for EvalError {}
#[derive(Debug, Default)]
pub struct InterpCtx {
values: HashMap<NodeId, Value>,
operands: Vec<NodeId>,
regions: HashMap<RegionId, Vec<u8>>,
}
impl InterpCtx {
pub fn set(&mut self, id: NodeId, value: Value) {
self.values.insert(id, value);
}
pub fn get(&self, id: NodeId) -> Result<Value, EvalError> {
self.values.get(&id).copied().ok_or_else(|| {
EvalError::new(format!(
"missing interpreter value for node {}. Fix: topologically sort the program before interpretation and ensure every operand node runs before its users.",
id.0
))
})
}
pub fn set_operands<I>(&mut self, operands: I)
where
I: IntoIterator<Item = NodeId>,
{
self.operands.clear();
self.operands.extend(operands);
}
#[must_use]
pub fn operands(&self) -> &[NodeId] {
&self.operands
}
pub fn operand(&self, index: usize) -> Result<Value, EvalError> {
let id = self.operands.get(index).copied().ok_or_else(|| {
EvalError::new(format!(
"missing operand {index}. Fix: bind the primitive with the arity declared by its metadata before interpretation."
))
})?;
self.get(id)
}
pub fn set_region(&mut self, id: RegionId, bytes: Vec<u8>) {
self.regions.insert(id, bytes);
}
pub fn region(&self, id: RegionId) -> Result<&[u8], EvalError> {
self.regions.get(&id).map(Vec::as_slice).ok_or_else(|| {
EvalError::new(format!(
"missing interpreter region {}. Fix: initialize every primitive input region before reference execution.",
id.0
))
})
}
pub fn region_mut(&mut self, id: RegionId) -> Result<&mut Vec<u8>, EvalError> {
self.regions.get_mut(&id).ok_or_else(|| {
EvalError::new(format!(
"missing mutable interpreter region {}. Fix: allocate primitive output regions before reference execution.",
id.0
))
})
}
}
#[derive(Debug, Clone)]
pub enum NodeStorage {
LitU32(u32),
LitI32(i32),
LitF32(f32),
LitBool(bool),
BinOp {
op: BinOp,
left: NodeId,
right: NodeId,
},
UnOp {
op: UnOp,
operand: NodeId,
},
Extern {
op_id: OpId,
operands: Arc<[NodeId]>,
payload: Arc<[u8]>,
},
}
impl NodeStorage {
#[must_use]
pub fn input_ids(&self) -> Vec<NodeId> {
match self {
Self::BinOp { left, right, .. } => vec![*left, *right],
Self::UnOp { operand, .. } => vec![*operand],
Self::Extern { operands, .. } => operands.iter().copied().collect(),
Self::LitU32(_) | Self::LitI32(_) | Self::LitF32(_) | Self::LitBool(_) => Vec::new(),
}
}
pub fn interpret(&self, ctx: &mut InterpCtx) -> Result<Value, EvalError> {
match self {
Self::LitU32(value) => Ok(Value::U32(*value)),
Self::LitI32(value) => Ok(Value::I32(*value)),
Self::LitF32(value) => Ok(Value::F32(*value)),
Self::LitBool(value) => Ok(Value::Bool(*value)),
Self::BinOp { op, left, right } => {
interpret_bin_op(op, ctx.get(*left)?, ctx.get(*right)?)
}
Self::UnOp { op, operand } => interpret_un_op(op, ctx.get(*operand)?),
Self::Extern { op_id, .. } => Err(EvalError::new(format!(
"extern node `{op_id}` has no linked interpreter. Fix: link the primitive crate that registered this op or lower it to a hot NodeStorage variant before reference execution."
))),
}
}
}
fn interpret_bin_op(op: &BinOp, left: Value, right: Value) -> Result<Value, EvalError> {
match (left, right) {
(Value::U32(left), Value::U32(right)) => match op {
BinOp::Add => Ok(Value::U32(left.wrapping_add(right))),
BinOp::Sub => Ok(Value::U32(left.wrapping_sub(right))),
BinOp::Mul => Ok(Value::U32(left.wrapping_mul(right))),
BinOp::Div => {
if right == 0 {
Ok(Value::U32(u32::MAX))
} else {
Ok(Value::U32(left / right))
}
}
BinOp::Mod => {
if right == 0 {
Ok(Value::U32(0))
} else {
Ok(Value::U32(left % right))
}
}
BinOp::BitAnd => Ok(Value::U32(left & right)),
BinOp::BitOr => Ok(Value::U32(left | right)),
BinOp::BitXor => Ok(Value::U32(left ^ right)),
BinOp::Shl => Ok(Value::U32(left.wrapping_shl(right & 31))),
BinOp::Shr => Ok(Value::U32(left.wrapping_shr(right & 31))),
BinOp::Eq => Ok(Value::Bool(left == right)),
BinOp::Ne => Ok(Value::Bool(left != right)),
BinOp::Lt => Ok(Value::Bool(left < right)),
BinOp::Le => Ok(Value::Bool(left <= right)),
BinOp::Gt => Ok(Value::Bool(left > right)),
BinOp::Ge => Ok(Value::Bool(left >= right)),
BinOp::Min => Ok(Value::U32(left.min(right))),
BinOp::Max => Ok(Value::U32(left.max(right))),
BinOp::SaturatingAdd => Ok(Value::U32(left.saturating_add(right))),
BinOp::SaturatingSub => Ok(Value::U32(left.saturating_sub(right))),
BinOp::SaturatingMul => Ok(Value::U32(left.saturating_mul(right))),
BinOp::AbsDiff => Ok(Value::U32(left.abs_diff(right))),
BinOp::RotateLeft => Ok(Value::U32(left.rotate_left(right & 31))),
BinOp::RotateRight => Ok(Value::U32(left.rotate_right(right & 31))),
BinOp::MulHigh => Ok(Value::U32(((left as u64).wrapping_mul(right as u64) >> 32) as u32)),
BinOp::And => Ok(Value::Bool(left != 0 && right != 0)),
BinOp::Or => Ok(Value::Bool(left != 0 || right != 0)),
_ => Err(EvalError::new(format!(
"unsupported u32 binary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
},
(Value::Bool(left), Value::Bool(right)) => match op {
BinOp::And => Ok(Value::Bool(left && right)),
BinOp::Or => Ok(Value::Bool(left || right)),
BinOp::Eq => Ok(Value::Bool(left == right)),
BinOp::Ne => Ok(Value::Bool(left != right)),
_ => Err(EvalError::new(format!(
"unsupported bool binary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
},
(Value::F32(left), Value::F32(right)) => {
let left = canonical_f32(left);
let right = canonical_f32(right);
match op {
BinOp::Add => Ok(Value::F32(canonical_f32(left + right))),
BinOp::Sub => Ok(Value::F32(canonical_f32(left - right))),
BinOp::Mul => Ok(Value::F32(canonical_f32(left * right))),
BinOp::Div => Ok(Value::F32(canonical_f32(left / right))),
BinOp::Eq => Ok(Value::Bool(left == right)),
BinOp::Ne => Ok(Value::Bool(left != right)),
BinOp::Lt => Ok(Value::Bool(left < right)),
BinOp::Le => Ok(Value::Bool(left <= right)),
BinOp::Gt => Ok(Value::Bool(left > right)),
BinOp::Ge => Ok(Value::Bool(left >= right)),
BinOp::Min => Ok(Value::F32(canonical_f32(left.min(right)))),
BinOp::Max => Ok(Value::F32(canonical_f32(left.max(right)))),
_ => Err(EvalError::new(format!(
"unsupported f32 binary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
}
},
(Value::I32(left), Value::I32(right)) => match op {
BinOp::Add => Ok(Value::I32(left.wrapping_add(right))),
BinOp::Sub => Ok(Value::I32(left.wrapping_sub(right))),
BinOp::Mul => Ok(Value::I32(left.wrapping_mul(right))),
BinOp::Div => {
if right == 0 || (left == i32::MIN && right == -1) {
Err(undefined_i32_division("division", left, right))
} else {
Ok(Value::I32(left / right))
}
}
BinOp::Mod => {
if right == 0 || (left == i32::MIN && right == -1) {
Err(undefined_i32_division("remainder", left, right))
} else {
Ok(Value::I32(left % right))
}
}
BinOp::BitAnd => Ok(Value::I32(left & right)),
BinOp::BitOr => Ok(Value::I32(left | right)),
BinOp::BitXor => Ok(Value::I32(left ^ right)),
BinOp::Shl => Ok(Value::I32(left.wrapping_shl((right as u32) & 31))),
BinOp::Shr => Ok(Value::I32(left.wrapping_shr((right as u32) & 31))),
BinOp::Eq => Ok(Value::Bool(left == right)),
BinOp::Ne => Ok(Value::Bool(left != right)),
BinOp::Lt => Ok(Value::Bool(left < right)),
BinOp::Le => Ok(Value::Bool(left <= right)),
BinOp::Gt => Ok(Value::Bool(left > right)),
BinOp::Ge => Ok(Value::Bool(left >= right)),
BinOp::Min => Ok(Value::I32(left.min(right))),
BinOp::Max => Ok(Value::I32(left.max(right))),
BinOp::SaturatingAdd => Ok(Value::I32(left.saturating_add(right))),
BinOp::SaturatingSub => Ok(Value::I32(left.saturating_sub(right))),
BinOp::SaturatingMul => Ok(Value::I32(left.saturating_mul(right))),
_ => Err(EvalError::new(format!(
"unsupported i32 binary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
},
_ => Err(EvalError::new(
"type mismatch in binary operation. Fix: validate operand types before interpretation.",
)),
}
}
fn undefined_i32_division(kind: &str, left: i32, right: i32) -> EvalError {
EvalError::new(format!(
"i32 {kind} `{left} / {right}` has undefined target-text semantics. Fix: guard the signed divisor/overflow case before interpretation, or use unsigned operands when zero-divisor semantics must be total."
))
}
fn interpret_un_op(op: &UnOp, operand: Value) -> Result<Value, EvalError> {
match operand {
Value::U32(value) => match op {
UnOp::BitNot => Ok(Value::U32(!value)),
UnOp::LogicalNot => Ok(Value::Bool(value == 0)),
UnOp::Popcount => Ok(Value::U32(value.count_ones())),
UnOp::Clz => Ok(Value::U32(value.leading_zeros())),
UnOp::Ctz => Ok(Value::U32(value.trailing_zeros())),
UnOp::ReverseBits => Ok(Value::U32(value.reverse_bits())),
_ => Err(EvalError::new(format!(
"unsupported u32 unary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
},
Value::Bool(value) => match op {
UnOp::LogicalNot => Ok(Value::Bool(!value)),
_ => Err(EvalError::new(format!(
"unsupported bool unary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
},
Value::F32(value) => match op {
UnOp::Negate => Ok(Value::F32(canonical_f32(-canonical_f32(value)))),
UnOp::InverseSqrt => {
let value = canonical_f32(value);
Ok(Value::F32(canonical_f32(1.0 / value.sqrt())))
}
UnOp::Reciprocal => {
let value = canonical_f32(value);
Ok(Value::F32(canonical_f32(1.0 / value)))
}
_ => Err(EvalError::new(format!(
"unsupported f32 unary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
},
Value::I32(value) => match op {
UnOp::Negate => Ok(Value::I32(value.wrapping_neg())),
_ => Err(EvalError::new(format!(
"unsupported i32 unary operation {op:?}. Fix: add interpreter semantics before registering this operation."
))),
},
Value::U64(value) => match op {
UnOp::BitNot => Ok(Value::U64(!value)),
UnOp::Popcount => Ok(Value::U64(value.count_ones() as u64)),
UnOp::Clz => Ok(Value::U64(value.leading_zeros() as u64)),
UnOp::Ctz => Ok(Value::U64(value.trailing_zeros() as u64)),
UnOp::ReverseBits => Ok(Value::U64(value.reverse_bits())),
_ => Err(EvalError::new(format!(
"unsupported u64 unary operation {op:?}. Fix: register explicit u64 semantics before interpreting this operation."
))),
},
}
}
#[cfg(test)]
mod tests {
use super::*;
fn eval_binary(op: BinOp, left: Value, right: Value) -> Result<Value, EvalError> {
let mut ctx = InterpCtx::default();
ctx.set(NodeId(0), left);
ctx.set(NodeId(1), right);
NodeStorage::BinOp {
op,
left: NodeId(0),
right: NodeId(1),
}
.interpret(&mut ctx)
}
#[test]
fn unsigned_zero_division_matches_reference_total_contract() {
assert_eq!(
eval_binary(BinOp::Div, Value::U32(9), Value::U32(0)).unwrap(),
Value::U32(u32::MAX)
);
assert_eq!(
eval_binary(BinOp::Mod, Value::U32(9), Value::U32(0)).unwrap(),
Value::U32(0)
);
}
#[test]
fn signed_undefined_division_returns_errors() {
for (op, left, right) in [
(BinOp::Div, i32::MIN, -1),
(BinOp::Mod, i32::MIN, -1),
(BinOp::Div, 1, 0),
(BinOp::Mod, 1, 0),
] {
let error = eval_binary(op, Value::I32(left), Value::I32(right))
.unwrap_err()
.to_string();
assert!(
error.contains("undefined target-text semantics"),
"unexpected error for {op:?}({left}, {right}): {error}"
);
}
}
#[test]
fn f32_subnormal_results_are_canonicalized() {
let result =
eval_binary(BinOp::Div, Value::F32(f32::MIN_POSITIVE), Value::F32(2.0)).unwrap();
assert!(matches!(result, Value::F32(value) if value.to_bits() == 0.0f32.to_bits()));
}
}