use anyhow::Context;
use crate::typesys::{ConstExpr, Type, Variable};
pub fn partially_erase_cg<T: Variable, C: Variable>(
t: &Type<T, C>,
erase: impl Fn(&C) -> bool + Copy,
) -> Type<T, C> {
let recurse = |v| partially_erase_cg(v, erase);
match t {
Type::Nothing => Type::Nothing,
Type::Any => Type::Any,
Type::Var(v) => Type::Var(v.clone()),
Type::NatRange(n, m) => Type::NatRange(
if n.cvars().iter().any(erase) {
0.into()
} else {
n.clone()
},
if m.cvars().iter().any(erase) {
0.into()
} else {
m.clone()
},
),
Type::Vector(v) => Type::Vector(v.iter().map(recurse).collect()),
Type::Vectorof(a, n) => {
if n.cvars().iter().any(erase) {
Type::DynVectorof(recurse(a).into())
} else {
Type::Vectorof(recurse(a).into(), n.clone())
}
}
Type::Struct(s, b) => Type::Struct(*s, b.iter().map(|(a, b)| (*a, recurse(b))).collect()),
Type::Union(t, u) => Type::Union(recurse(t).into(), recurse(u).into()),
Type::DynVectorof(t) => Type::DynVectorof(recurse(t).into()),
Type::Bytes(m) => {
if m.cvars().iter().any(erase) {
Type::DynBytes
} else {
Type::Bytes(m.clone())
}
}
Type::DynBytes => Type::DynBytes,
}
}
pub fn cgify_all_slots<T: Variable, C: Variable>(
t: &Type<T, C>,
gensym: impl Fn() -> C + Copy,
) -> Type<T, C> {
match &t {
Type::Vector(v) => Type::Vector(v.iter().map(|t| cgify_all_slots(t, gensym)).collect()),
Type::Vectorof(v, _) => Type::Vectorof(v.clone(), ConstExpr::Var(gensym())),
Type::Struct(_, _) => todo!(),
Type::Union(_, _) => todo!(),
_ => t.clone(),
}
}
pub fn solve_recurrence<C: Variable>(
initial: ConstExpr<C>,
pre_step: C,
post_step: ConstExpr<C>,
iterations: ConstExpr<C>,
) -> anyhow::Result<ConstExpr<C>> {
let diff = post_step
.checked_sub(&ConstExpr::Var(pre_step))
.context("pre_step cannot be subtracted from post_step")?;
let diff = diff.try_eval().context(format!(
"per-step change in constant-generic variable is not constant: {:?}",
diff
))?;
Ok(ConstExpr::Add(
initial.into(),
ConstExpr::Mul(ConstExpr::Lit(diff).into(), iterations.into()).into(),
))
}