use crate::{
crate_prelude::*,
hir::HirNode,
ty::{Type, TypeKind},
ParamEnv, ParamEnvBinding,
};
use bit_vec::BitVec;
use itertools::Itertools;
use num::{BigInt, BigRational, Integer, One, ToPrimitive, Zero};
pub type Value<'t> = &'t ValueData<'t>;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct ValueData<'t> {
pub ty: Type<'t>,
pub kind: ValueKind<'t>,
}
impl<'t> ValueData<'t> {
pub fn is_error(&self) -> bool {
self.ty.is_error() || self.kind.is_error()
}
pub fn is_true(&self) -> bool {
!self.is_false()
}
pub fn is_false(&self) -> bool {
match self.kind {
ValueKind::Void => true,
ValueKind::Int(ref v, ..) => v.is_zero(),
ValueKind::Time(ref v) => v.is_zero(),
ValueKind::StructOrArray(_) => false,
ValueKind::Error => true,
}
}
pub fn get_int(&self) -> Option<&BigInt> {
match self.kind {
ValueKind::Int(ref v, ..) => Some(v),
_ => None,
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum ValueKind<'t> {
Void,
Int(BigInt, BitVec, BitVec),
Time(BigRational),
StructOrArray(Vec<Value<'t>>),
Error,
}
impl<'t> ValueKind<'t> {
pub fn is_error(&self) -> bool {
match self {
ValueKind::Error => true,
_ => false,
}
}
}
impl std::fmt::Display for ValueKind<'_> {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
ValueKind::Void => write!(f, "void"),
ValueKind::Int(v, ..) => write!(f, "{}", v),
ValueKind::Time(v) => write!(f, "{}", v),
ValueKind::StructOrArray(v) => {
write!(f, "{{ {} }}", v.iter().map(|v| &v.kind).format(", "))
}
ValueKind::Error => write!(f, "<error>"),
}
}
}
pub fn make_error(ty: Type) -> ValueData {
ValueData {
ty,
kind: ValueKind::Error,
}
}
pub fn make_int(ty: Type, value: BigInt) -> ValueData {
let w = ty.width();
make_int_special(
ty,
value,
BitVec::from_elem(w, false),
BitVec::from_elem(w, false),
)
}
pub fn make_int_special(
ty: Type,
mut value: BigInt,
special_bits: BitVec,
x_bits: BitVec,
) -> ValueData {
match *ty.resolve_name() {
TypeKind::Int(width, _)
| TypeKind::BitVector {
range: ty::Range { size: width, .. },
..
} => {
value = value % (BigInt::from(1) << width);
}
TypeKind::Bit(_) | TypeKind::BitScalar { .. } => {
value = value % 2;
}
_ => panic!("create int value `{}` with non-int type {:?}", value, ty),
}
ValueData {
ty: ty,
kind: ValueKind::Int(value, special_bits, x_bits),
}
}
pub fn make_time(value: BigRational) -> ValueData<'static> {
ValueData {
ty: &ty::TIME_TYPE,
kind: ValueKind::Time(value),
}
}
pub fn make_struct<'t>(ty: Type<'t>, fields: Vec<Value<'t>>) -> ValueData<'t> {
assert!(ty.is_struct());
ValueData {
ty: ty,
kind: ValueKind::StructOrArray(fields),
}
}
pub fn make_array<'t>(ty: Type<'t>, elements: Vec<Value<'t>>) -> ValueData<'t> {
assert!(ty.is_array());
ValueData {
ty: ty,
kind: ValueKind::StructOrArray(elements),
}
}
pub(crate) fn constant_value_of<'gcx>(
cx: &impl Context<'gcx>,
node_id: NodeId,
env: ParamEnv,
) -> Result<Value<'gcx>> {
let v = const_node(cx, node_id, env);
if cx.sess().has_verbosity(Verbosity::CONSTS) {
let vp = v
.as_ref()
.map(|v| format!("{}, {}", v.ty, v.kind))
.unwrap_or_else(|_| format!("<error>"));
let span = cx.span(node_id);
let ext = span.extract();
let line = span.begin().human_line();
println!("{}: const({}) = {}", line, ext, vp);
}
v
}
fn const_node<'gcx>(
cx: &impl Context<'gcx>,
node_id: NodeId,
env: ParamEnv,
) -> Result<Value<'gcx>> {
let hir = cx.hir_of(node_id)?;
match hir {
HirNode::Expr(expr) => {
let mir = cx.mir_rvalue(expr.id, env);
Ok(cx.const_mir_rvalue(mir.into()))
}
HirNode::ValueParam(param) => {
let env_data = cx.param_env_data(env);
match env_data.find_value(node_id) {
Some(ParamEnvBinding::Indirect(assigned_id)) => {
return cx.constant_value_of(assigned_id.id(), assigned_id.env())
}
Some(ParamEnvBinding::Direct(v)) => return Ok(v),
_ => (),
}
if let Some(default) = param.default {
return cx.constant_value_of(default, env);
}
let d = DiagBuilder2::error(format!(
"{} not assigned and has no default",
param.desc_full(),
));
let contexts = cx.param_env_contexts(env);
for &context in &contexts {
cx.emit(
d.clone()
.span(cx.span(context))
.add_note("Parameter declared here:")
.span(param.human_span()),
);
}
if contexts.is_empty() {
cx.emit(d.span(param.human_span()));
}
Err(())
}
HirNode::GenvarDecl(decl) => {
let env_data = cx.param_env_data(env);
match env_data.find_value(node_id) {
Some(ParamEnvBinding::Indirect(assigned_id)) => {
return cx.constant_value_of(assigned_id.id(), assigned_id.env())
}
Some(ParamEnvBinding::Direct(v)) => return Ok(v),
_ => (),
}
if let Some(init) = decl.init {
return cx.constant_value_of(init, env);
}
cx.emit(
DiagBuilder2::error(format!("{} not initialized", decl.desc_full()))
.span(decl.human_span()),
);
Err(())
}
HirNode::VarDecl(_) => {
cx.emit(
DiagBuilder2::error(format!("{} has no constant value", hir.desc_full()))
.span(hir.human_span()),
);
Err(())
}
HirNode::EnumVariant(var) => match var.value {
Some(v) => cx.constant_value_of(v, env),
None => Ok(cx.intern_value(make_int(cx.type_of(node_id, env)?, var.index.into()))),
},
_ => cx.unimp_msg("constant value computation of", &hir),
}
}
pub(crate) fn const_mir_rvalue_query<'gcx>(
cx: &impl Context<'gcx>,
mir: Ref<'gcx, mir::Rvalue<'gcx>>,
) -> Value<'gcx> {
const_mir_rvalue(cx, *mir)
}
fn const_mir_rvalue<'gcx>(cx: &impl Context<'gcx>, mir: &'gcx mir::Rvalue<'gcx>) -> Value<'gcx> {
let v = const_mir_rvalue_inner(cx, mir);
if cx.sess().has_verbosity(Verbosity::CONSTS) {
let ext = mir.span.extract();
let line = mir.span.begin().human_line();
println!("{}: const_mir({}) = {}, {}", line, ext, v.ty, v.kind);
}
v
}
fn const_mir_rvalue_inner<'gcx>(
cx: &impl Context<'gcx>,
mir: &'gcx mir::Rvalue<'gcx>,
) -> Value<'gcx> {
if mir.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match mir.kind {
mir::RvalueKind::CastValueDomain { value, .. }
| mir::RvalueKind::CastVectorToAtom { value, .. }
| mir::RvalueKind::CastAtomToVector { value, .. }
| mir::RvalueKind::CastSign(_, value)
| mir::RvalueKind::Truncate(_, value)
| mir::RvalueKind::ZeroExtend(_, value)
| mir::RvalueKind::SignExtend(_, value) => {
cx.emit(
DiagBuilder2::warning("cast ignored during constant evaluation")
.span(mir.span)
.add_note(format!(
"Casts `{}` from `{}` to `{}`",
value.span.extract(),
value.ty,
mir.ty
))
.span(value.span),
);
let v = cx.const_mir_rvalue(value.into());
cx.intern_value(ValueData {
ty: mir.ty,
kind: v.kind.clone(),
})
}
mir::RvalueKind::CastToBool(value) => {
let value = cx.const_mir_rvalue(value.into());
if value.is_error() {
return cx.intern_value(make_error(mir.ty));
}
cx.intern_value(make_int(mir.ty, (value.is_true() as usize).into()))
}
mir::RvalueKind::ConstructArray(ref values) => cx.intern_value(make_array(
mir.ty,
(0..values.len())
.map(|index| cx.const_mir_rvalue(values[&index].into()))
.collect(),
)),
mir::RvalueKind::ConstructStruct(ref values) => cx.intern_value(make_struct(
mir.ty,
values
.iter()
.map(|&value| cx.const_mir_rvalue(value.into()))
.collect(),
)),
mir::RvalueKind::Const(value) => value,
mir::RvalueKind::UnaryBitwise { op, arg } => {
let arg_val = cx.const_mir_rvalue(arg.into());
if arg_val.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match arg_val.kind {
ValueKind::Int(ref arg_int, ..) => cx.intern_value(make_int(
mir.ty,
const_unary_bitwise_int(cx, arg.ty, op, arg_int),
)),
_ => unreachable!(),
}
}
mir::RvalueKind::BinaryBitwise { op, lhs, rhs } => {
let lhs_val = cx.const_mir_rvalue(lhs.into());
let rhs_val = cx.const_mir_rvalue(rhs.into());
if lhs_val.is_error() || rhs_val.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match (&lhs_val.kind, &rhs_val.kind) {
(ValueKind::Int(lhs_int, ..), ValueKind::Int(rhs_int, ..)) => {
cx.intern_value(make_int(
mir.ty,
const_binary_bitwise_int(cx, lhs.ty, op, lhs_int, rhs_int),
))
}
_ => unreachable!(),
}
}
mir::RvalueKind::IntUnaryArith { op, arg, .. } => {
let arg_val = cx.const_mir_rvalue(arg.into());
if arg_val.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match arg_val.kind {
ValueKind::Int(ref arg_int, ..) => cx.intern_value(make_int(
mir.ty,
const_unary_arith_int(cx, arg.ty, op, arg_int),
)),
_ => unreachable!(),
}
}
mir::RvalueKind::IntBinaryArith { op, lhs, rhs, .. } => {
let lhs_val = cx.const_mir_rvalue(lhs.into());
let rhs_val = cx.const_mir_rvalue(rhs.into());
if lhs_val.is_error() || rhs_val.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match (&lhs_val.kind, &rhs_val.kind) {
(ValueKind::Int(lhs_int, ..), ValueKind::Int(rhs_int, ..)) => {
cx.intern_value(make_int(
mir.ty,
const_binary_arith_int(cx, lhs.ty, op, lhs_int, rhs_int),
))
}
_ => unreachable!(),
}
}
mir::RvalueKind::IntComp { op, lhs, rhs, .. } => {
let lhs_val = cx.const_mir_rvalue(lhs.into());
let rhs_val = cx.const_mir_rvalue(rhs.into());
if lhs_val.is_error() || rhs_val.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match (&lhs_val.kind, &rhs_val.kind) {
(ValueKind::Int(lhs_int, ..), ValueKind::Int(rhs_int, ..)) => cx.intern_value(
make_int(mir.ty, const_comp_int(cx, lhs.ty, op, lhs_int, rhs_int)),
),
_ => unreachable!(),
}
}
mir::RvalueKind::Concat(ref values) => {
let mut result = BigInt::zero();
for &value in values {
result <<= value.ty.width();
result |= cx
.const_mir_rvalue(value.into())
.get_int()
.expect("concat non-integer");
}
cx.intern_value(make_int(mir.ty, result))
}
mir::RvalueKind::Repeat(count, value) => {
let value_const = cx.const_mir_rvalue(value.into());
if value_const.is_error() {
return cx.intern_value(make_error(mir.ty));
}
let mut result = BigInt::zero();
for _ in 0..count {
result <<= value.ty.width();
result |= value_const.get_int().expect("repeat non-integer");
}
cx.intern_value(make_int(mir.ty, result))
}
mir::RvalueKind::Assignment { .. } | mir::RvalueKind::Var(_) | mir::RvalueKind::Port(_) => {
cx.emit(DiagBuilder2::error("value is not constant").span(mir.span));
cx.intern_value(make_error(mir.ty))
}
mir::RvalueKind::Member { value, field } => {
let value_const = cx.const_mir_rvalue(value.into());
if value_const.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match value_const.kind {
ValueKind::StructOrArray(ref fields) => fields[field],
_ => unreachable!("member access on non-struct should be caught in typeck"),
}
}
mir::RvalueKind::Ternary {
cond,
true_value,
false_value,
} => {
let cond_val = cx.const_mir_rvalue(cond.into());
let true_val = cx.const_mir_rvalue(true_value.into());
let false_val = cx.const_mir_rvalue(false_value.into());
match cond_val.is_true() {
true => true_val,
false => false_val,
}
}
mir::RvalueKind::Shift {
op,
arith,
value,
amount,
..
} => {
let value_val = cx.const_mir_rvalue(value.into());
let amount_val = cx.const_mir_rvalue(amount.into());
if value_val.is_error() || amount_val.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match (&value_val.kind, &amount_val.kind) {
(ValueKind::Int(value_int, ..), ValueKind::Int(amount_int, ..)) => {
cx.intern_value(make_int(
mir.ty,
const_shift_int(cx, value.ty, op, arith, value_int, amount_int),
))
}
_ => unreachable!(),
}
}
mir::RvalueKind::Reduction { op, arg } => {
let arg_val = cx.const_mir_rvalue(arg.into());
if arg_val.is_error() {
return cx.intern_value(make_error(mir.ty));
}
match arg_val.kind {
ValueKind::Int(ref arg_int, ..) => cx.intern_value(make_int(
mir.ty,
const_reduction_int(cx, arg.ty, op, arg_int),
)),
_ => unreachable!(),
}
}
mir::RvalueKind::Index {
..
} => {
bug_span!(mir.span, cx, "constant folding of slices not implemented");
}
mir::RvalueKind::Error => cx.intern_value(make_error(mir.ty)),
}
}
fn const_unary_bitwise_int<'gcx>(
_cx: &impl Context<'gcx>,
ty: Type,
op: mir::UnaryBitwiseOp,
arg: &BigInt,
) -> BigInt {
match op {
mir::UnaryBitwiseOp::Not => (BigInt::one() << ty.width()) - 1 - arg,
}
}
fn const_binary_bitwise_int<'gcx>(
_cx: &impl Context<'gcx>,
_ty: Type,
op: mir::BinaryBitwiseOp,
lhs: &BigInt,
rhs: &BigInt,
) -> BigInt {
match op {
mir::BinaryBitwiseOp::And => lhs & rhs,
mir::BinaryBitwiseOp::Or => lhs | rhs,
mir::BinaryBitwiseOp::Xor => lhs ^ rhs,
}
}
fn const_unary_arith_int<'gcx>(
_cx: &impl Context<'gcx>,
_ty: Type,
op: mir::IntUnaryArithOp,
arg: &BigInt,
) -> BigInt {
match op {
mir::IntUnaryArithOp::Neg => -arg,
}
}
fn const_binary_arith_int<'gcx>(
_cx: &impl Context<'gcx>,
_ty: Type,
op: mir::IntBinaryArithOp,
lhs: &BigInt,
rhs: &BigInt,
) -> BigInt {
match op {
mir::IntBinaryArithOp::Add => lhs + rhs,
mir::IntBinaryArithOp::Sub => lhs - rhs,
mir::IntBinaryArithOp::Mul => lhs * rhs,
mir::IntBinaryArithOp::Div => lhs / rhs,
mir::IntBinaryArithOp::Mod => lhs % rhs,
mir::IntBinaryArithOp::Pow => {
let mut result = num::one();
let mut cnt = rhs.clone();
while !cnt.is_zero() {
result = result * lhs;
cnt = cnt - 1;
}
result
}
}
}
fn const_comp_int<'gcx>(
_cx: &impl Context<'gcx>,
_ty: Type,
op: mir::IntCompOp,
lhs: &BigInt,
rhs: &BigInt,
) -> BigInt {
match op {
mir::IntCompOp::Eq => ((lhs == rhs) as usize).into(),
mir::IntCompOp::Neq => ((lhs != rhs) as usize).into(),
mir::IntCompOp::Lt => ((lhs < rhs) as usize).into(),
mir::IntCompOp::Leq => ((lhs <= rhs) as usize).into(),
mir::IntCompOp::Gt => ((lhs > rhs) as usize).into(),
mir::IntCompOp::Geq => ((lhs >= rhs) as usize).into(),
}
}
fn const_shift_int<'gcx>(
_cx: &impl Context<'gcx>,
_ty: Type,
op: mir::ShiftOp,
_arith: bool,
value: &BigInt,
amount: &BigInt,
) -> BigInt {
match op {
mir::ShiftOp::Left => match amount.to_isize() {
Some(sh) if sh < 0 => value >> -sh as usize,
Some(sh) => value << sh as usize,
None => num::zero(),
},
mir::ShiftOp::Right => match amount.to_isize() {
Some(sh) if sh < 0 => value << -sh as usize,
Some(sh) => value >> sh as usize,
None => num::zero(),
},
}
}
fn const_reduction_int<'gcx>(
_cx: &impl Context<'gcx>,
ty: Type,
op: mir::BinaryBitwiseOp,
arg: &BigInt,
) -> BigInt {
match op {
mir::BinaryBitwiseOp::And => {
((arg == &((BigInt::one() << ty.width()) - 1)) as usize).into()
}
mir::BinaryBitwiseOp::Or => ((!arg.is_zero()) as usize).into(),
mir::BinaryBitwiseOp::Xor => (arg
.to_bytes_le()
.1
.into_iter()
.map(|v| v.count_ones())
.sum::<u32>()
.is_odd() as usize)
.into(),
}
}
pub(crate) fn is_constant<'gcx>(cx: &impl Context<'gcx>, node_id: NodeId) -> Result<bool> {
let hir = cx.hir_of(node_id)?;
Ok(match hir {
HirNode::ValueParam(_) => true,
HirNode::GenvarDecl(_) => true,
HirNode::EnumVariant(_) => true,
_ => false,
})
}
pub(crate) fn type_default_value<'gcx>(cx: &impl Context<'gcx>, ty: Type<'gcx>) -> Value<'gcx> {
match *ty {
TypeKind::Error => cx.intern_value(ValueData {
ty: &ty::ERROR_TYPE,
kind: ValueKind::Error,
}),
TypeKind::Void => cx.intern_value(ValueData {
ty: &ty::VOID_TYPE,
kind: ValueKind::Void,
}),
TypeKind::Time => cx.intern_value(make_time(Zero::zero())),
TypeKind::Bit(..)
| TypeKind::Int(..)
| TypeKind::BitVector { .. }
| TypeKind::BitScalar { .. } => cx.intern_value(make_int(ty, Zero::zero())),
TypeKind::Named(_, _, ty) => type_default_value(cx, ty),
TypeKind::Struct(id) => {
let def = cx.struct_def(id).unwrap();
let fields = def
.fields
.iter()
.map(|field| {
type_default_value(
cx,
cx.map_to_type(field.ty, cx.default_param_env()).unwrap(),
)
})
.collect();
cx.intern_value(make_struct(ty, fields))
}
TypeKind::PackedArray(length, elem_ty) => cx.intern_value(make_array(
ty.clone(),
std::iter::repeat(cx.type_default_value(elem_ty.clone()))
.take(length)
.collect(),
)),
}
}