use num::BigInt;
use serde::{Deserialize, Serialize};
use spade_common::location_info::{Loc, WithLocation};
use spade_types::KnownType;
use crate::{
equation::{TypeVar, TypeVarID},
TypeState,
};
#[derive(Debug, Clone)]
pub enum ConstraintExpr {
Bool(bool),
Integer(BigInt),
String(String),
Var(TypeVarID),
Sum(Box<ConstraintExpr>, Box<ConstraintExpr>),
Difference(Box<ConstraintExpr>, Box<ConstraintExpr>),
Product(Box<ConstraintExpr>, Box<ConstraintExpr>),
Div(Box<ConstraintExpr>, Box<ConstraintExpr>),
Mod(Box<ConstraintExpr>, Box<ConstraintExpr>),
Sub(Box<ConstraintExpr>),
Eq(Box<ConstraintExpr>, Box<ConstraintExpr>),
NotEq(Box<ConstraintExpr>, Box<ConstraintExpr>),
UintBitsToRepresent(Box<ConstraintExpr>),
}
impl ConstraintExpr {
pub fn debug_display(&self, type_state: &TypeState) -> String {
match self {
ConstraintExpr::Bool(b) => format!("{b}"),
ConstraintExpr::Integer(v) => format!("{v}"),
ConstraintExpr::String(v) => format!("{v:?}"),
ConstraintExpr::Var(type_var_id) => {
format!("{}", type_var_id.debug_resolve(type_state))
}
ConstraintExpr::Sum(lhs, rhs) => {
format!(
"({} + {})",
lhs.debug_display(type_state),
rhs.debug_display(type_state)
)
}
ConstraintExpr::Difference(lhs, rhs) => {
format!(
"({} - {})",
lhs.debug_display(type_state),
rhs.debug_display(type_state)
)
}
ConstraintExpr::Product(lhs, rhs) => {
format!(
"({} * {})",
lhs.debug_display(type_state),
rhs.debug_display(type_state)
)
}
ConstraintExpr::Div(lhs, rhs) => {
format!(
"({} / {})",
lhs.debug_display(type_state),
rhs.debug_display(type_state)
)
}
ConstraintExpr::Mod(lhs, rhs) => {
format!(
"({} % {})",
lhs.debug_display(type_state),
rhs.debug_display(type_state)
)
}
ConstraintExpr::Sub(lhs) => {
format!("(-{})", lhs.debug_display(type_state))
}
ConstraintExpr::Eq(lhs, rhs) => {
format!(
"({} == {})",
lhs.debug_display(type_state),
rhs.debug_display(type_state)
)
}
ConstraintExpr::NotEq(lhs, rhs) => {
format!(
"({} != {})",
lhs.debug_display(type_state),
rhs.debug_display(type_state)
)
}
ConstraintExpr::UintBitsToRepresent(c) => {
format!("uint_bits_to_fit({})", c.debug_display(type_state))
}
}
}
}
impl ConstraintExpr {
fn evaluate(&self, type_state: &TypeState) -> ConstraintExpr {
let binop =
|lhs: &ConstraintExpr, rhs: &ConstraintExpr, op: &dyn Fn(BigInt, BigInt) -> BigInt| {
match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
(ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
ConstraintExpr::Integer(op(l, r))
}
_ => self.clone(),
}
};
match self {
ConstraintExpr::Integer(_) => self.clone(),
ConstraintExpr::Bool(_) => self.clone(),
ConstraintExpr::String(_) => self.clone(),
ConstraintExpr::Var(v) => match v.resolve(type_state) {
TypeVar::Known(_, known_type, _) => match known_type {
KnownType::Integer(i) => ConstraintExpr::Integer(i.clone()),
KnownType::Bool(b) => ConstraintExpr::Bool(b.clone()),
KnownType::String(s) => ConstraintExpr::String(s.clone()),
KnownType::Error => self.clone(),
KnownType::Named(_)
| KnownType::Tuple
| KnownType::Array
| KnownType::Wire
| KnownType::Inverted => {
panic!("Inferred non-integer or bool for constraint variable")
}
},
TypeVar::Unknown(_, _, _, _) => self.clone(),
},
ConstraintExpr::Sum(lhs, rhs) => binop(lhs, rhs, &|l, r| l + r),
ConstraintExpr::Difference(lhs, rhs) => binop(lhs, rhs, &|l, r| l - r),
ConstraintExpr::Product(lhs, rhs) => binop(lhs, rhs, &|l, r| l * r),
ConstraintExpr::Div(lhs, rhs) => binop(lhs, rhs, &|l, r| l / r),
ConstraintExpr::Mod(lhs, rhs) => binop(lhs, rhs, &|l, r| l % r),
ConstraintExpr::Sub(inner) => match inner.evaluate(type_state) {
ConstraintExpr::Integer(val) => ConstraintExpr::Integer(-val),
_ => self.clone(),
},
ConstraintExpr::Eq(lhs, rhs) => {
match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
(ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
ConstraintExpr::Bool(l == r)
}
(ConstraintExpr::String(l), ConstraintExpr::String(r)) => {
ConstraintExpr::Bool(l == r)
}
_ => self.clone(),
}
}
ConstraintExpr::NotEq(lhs, rhs) => {
match (lhs.evaluate(type_state), rhs.evaluate(type_state)) {
(ConstraintExpr::Integer(l), ConstraintExpr::Integer(r)) => {
ConstraintExpr::Bool(l != r)
}
(ConstraintExpr::String(l), ConstraintExpr::String(r)) => {
ConstraintExpr::Bool(l != r)
}
_ => self.clone(),
}
}
ConstraintExpr::UintBitsToRepresent(inner) => match inner.evaluate(type_state) {
ConstraintExpr::Integer(val) => ConstraintExpr::Integer(val.bits().into()),
_ => self.clone(),
},
}
}
pub fn with_context(
self,
replaces: &TypeVarID,
inside: &TypeVarID,
source: ConstraintSource,
) -> ConstraintRhs {
ConstraintRhs {
constraint: self,
context: ConstraintContext {
replaces: replaces.clone(),
inside: inside.clone(),
source,
},
}
}
}
impl std::ops::Add for ConstraintExpr {
type Output = ConstraintExpr;
fn add(self, rhs: Self) -> Self::Output {
ConstraintExpr::Sum(Box::new(self), Box::new(rhs))
}
}
impl std::ops::Sub for ConstraintExpr {
type Output = ConstraintExpr;
fn sub(self, rhs: Self) -> Self::Output {
ConstraintExpr::Sum(Box::new(self), Box::new(-rhs))
}
}
impl std::ops::Neg for ConstraintExpr {
type Output = ConstraintExpr;
fn neg(self) -> Self::Output {
ConstraintExpr::Sub(Box::new(self))
}
}
pub fn bits_to_store(inner: ConstraintExpr) -> ConstraintExpr {
ConstraintExpr::UintBitsToRepresent(Box::new(inner))
}
pub fn ce_var(v: &TypeVarID) -> ConstraintExpr {
ConstraintExpr::Var(v.clone())
}
pub fn ce_int(v: BigInt) -> ConstraintExpr {
ConstraintExpr::Integer(v)
}
#[derive(Debug, Clone, PartialEq)]
pub enum ConstraintSource {
AdditionOutput,
MultOutput,
ArrayIndexing,
MemoryIndexing,
Concatenation,
PipelineRegOffset { reg: Loc<()>, total: Loc<()> },
PipelineRegCount { reg: Loc<()>, total: Loc<()> },
PipelineAvailDepth,
RangeIndex,
RangeIndexOutputSize,
ArraySize,
TypeLevelIf,
Where,
}
impl std::fmt::Display for ConstraintSource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ConstraintSource::AdditionOutput => write!(f, "AdditionOutput"),
ConstraintSource::MultOutput => write!(f, "MultiplicationOutput"),
ConstraintSource::ArrayIndexing => write!(f, "ArrayIndexing"),
ConstraintSource::MemoryIndexing => write!(f, "MemoryIndexing"),
ConstraintSource::Concatenation => write!(f, "Concatenation"),
ConstraintSource::Where => write!(f, "Where"),
ConstraintSource::RangeIndex => write!(f, "RangeIndex"),
ConstraintSource::RangeIndexOutputSize => write!(f, "RangeIndexOutputSize"),
ConstraintSource::ArraySize => write!(f, "ArraySize"),
ConstraintSource::PipelineRegOffset { .. } => write!(f, "PipelineRegOffset"),
ConstraintSource::PipelineRegCount { .. } => write!(f, "PipelineRegOffset"),
ConstraintSource::PipelineAvailDepth => write!(f, "PipelineAvailDepth"),
ConstraintSource::TypeLevelIf => write!(f, "TypeLevelIf"),
}
}
}
#[derive(Debug, Clone)]
pub struct ConstraintRhs {
pub constraint: ConstraintExpr,
pub context: ConstraintContext,
}
impl ConstraintRhs {
pub fn debug_display(&self, type_state: &TypeState) -> String {
self.constraint.debug_display(type_state)
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct TypeConstraints {
#[serde(skip)]
pub inner: Vec<(TypeVarID, Loc<ConstraintRhs>)>,
}
impl TypeConstraints {
pub fn new() -> Self {
Self { inner: vec![] }
}
pub fn add_int_constraint(&mut self, lhs: TypeVarID, rhs: Loc<ConstraintRhs>) {
self.inner.push((lhs, rhs));
}
pub fn update_type_level_value_constraints(
self,
type_state: &TypeState,
) -> (
TypeConstraints,
Vec<Loc<(TypeVarID, ConstraintReplacement)>>,
) {
let mut new_known = vec![];
let remaining = self
.inner
.into_iter()
.filter_map(|(expr, rhs)| {
let mut rhs = rhs.clone();
rhs.constraint = rhs.constraint.evaluate(type_state);
match &rhs.constraint {
ConstraintExpr::Integer(val) => {
let replacement = ConstraintReplacement {
val: KnownType::Integer(val.clone()),
context: rhs.context.clone(),
};
new_known
.push(().at_loc(&rhs).map(|_| (expr.clone(), replacement.clone())));
None
}
ConstraintExpr::Bool(val) => {
let replacement = ConstraintReplacement {
val: KnownType::Bool(val.clone()),
context: rhs.context.clone(),
};
new_known
.push(().at_loc(&rhs).map(|_| (expr.clone(), replacement.clone())));
None
}
ConstraintExpr::String(val) => {
let replacement = ConstraintReplacement {
val: KnownType::String(val.clone()),
context: rhs.context.clone(),
};
new_known
.push(().at_loc(&rhs).map(|_| (expr.clone(), replacement.clone())));
None
}
ConstraintExpr::Var(_)
| ConstraintExpr::Sum(_, _)
| ConstraintExpr::Div(_, _)
| ConstraintExpr::Mod(_, _)
| ConstraintExpr::Eq(_, _)
| ConstraintExpr::NotEq(_, _)
| ConstraintExpr::Difference(_, _)
| ConstraintExpr::Product(_, _)
| ConstraintExpr::UintBitsToRepresent(_)
| ConstraintExpr::Sub(_) => Some((expr.clone(), rhs)),
}
})
.collect();
(TypeConstraints { inner: remaining }, new_known)
}
}
#[derive(Clone, Debug)]
pub struct ConstraintReplacement {
pub val: KnownType,
pub context: ConstraintContext,
}
#[derive(Clone, Debug)]
pub struct ConstraintContext {
pub inside: TypeVarID,
pub replaces: TypeVarID,
pub source: ConstraintSource,
}