use super::expr::IRBexpr;
use crate::{
diagnostics::{Diagnostic, Validation},
error::Error,
expr::{ExprProperties, IRAexpr, IRConstBexpr},
meta::{HasMeta, Meta},
printer::IRPrintable,
traits::{Canonicalize, ConstantFolding, Evaluate, Validatable},
};
use eqv::{EqvRelation, equiv};
use haloumi_core::{cmp::CmpOp, eqv::SymbolicEqv, slot::Slot};
use haloumi_lowering::{
Lowering,
lowerable::{LowerableExpr, LowerableStmt},
lowering_err,
};
use std::fmt::Write;
mod assert;
mod assume_determ;
mod call;
mod comment;
mod cond_block;
mod constraint;
mod post_cond;
mod seq;
use assert::Assert;
use assume_determ::AssumeDeterministic;
use call::Call;
use comment::Comment;
use cond_block::CondBlock;
use constraint::Constraint;
use post_cond::PostCond;
use seq::Seq;
mod sealed {
pub trait EmitIfSealed {}
}
pub trait EmitIf<T>: sealed::EmitIfSealed {
fn emit_if(self, cond: IRConstBexpr<T>) -> IRStmt<T>;
}
impl<T, I> EmitIf<T> for I
where
I: IntoIterator<Item = IRStmt<T>>,
{
fn emit_if(self, cond: IRConstBexpr<T>) -> IRStmt<T> {
CondBlock::new(cond, self.into_iter().collect()).into()
}
}
impl<T, I> sealed::EmitIfSealed for I where I: IntoIterator<Item = IRStmt<T>> {}
pub struct IRStmt<T>(IRStmtImpl<T>, Meta);
enum IRStmtImpl<T> {
ConstraintCall(Call<T>),
Constraint(Constraint<T>),
Comment(Comment),
AssumeDeterministic(AssumeDeterministic),
Assert(Assert<T>),
Seq(Seq<T>),
PostCond(PostCond<T>),
CondBlock(CondBlock<T>),
}
impl<T> HasMeta for IRStmt<T> {
fn meta(&self) -> &Meta {
&self.1
}
fn meta_mut(&mut self) -> &mut Meta {
&mut self.1
}
}
impl<T: PartialEq> PartialEq for IRStmt<T> {
fn eq(&self, other: &Self) -> bool {
std::iter::zip(self.iter(), other.iter()).all(|(lhs, rhs)| match (&lhs.0, &rhs.0) {
(IRStmtImpl::ConstraintCall(lhs), IRStmtImpl::ConstraintCall(rhs)) => lhs.eq(rhs),
(IRStmtImpl::Constraint(lhs), IRStmtImpl::Constraint(rhs)) => lhs.eq(rhs),
(IRStmtImpl::Comment(lhs), IRStmtImpl::Comment(rhs)) => lhs.eq(rhs),
(IRStmtImpl::AssumeDeterministic(lhs), IRStmtImpl::AssumeDeterministic(rhs)) => {
lhs.eq(rhs)
}
(IRStmtImpl::Assert(lhs), IRStmtImpl::Assert(rhs)) => lhs.eq(rhs),
(IRStmtImpl::PostCond(lhs), IRStmtImpl::PostCond(rhs)) => lhs.eq(rhs),
(IRStmtImpl::CondBlock(lhs), IRStmtImpl::CondBlock(rhs)) => lhs.eq(rhs),
(IRStmtImpl::Seq(_), _) | (_, IRStmtImpl::Seq(_)) => unreachable!(),
_ => false,
})
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for IRStmt<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match &self.0 {
IRStmtImpl::ConstraintCall(call) => write!(f, "{call:?}"),
IRStmtImpl::Constraint(constraint) => write!(f, "{constraint:?}"),
IRStmtImpl::Comment(comment) => write!(f, "{comment:?}"),
IRStmtImpl::AssumeDeterministic(assume_deterministic) => {
write!(f, "{assume_deterministic:?}")
}
IRStmtImpl::Assert(assert) => write!(f, "{assert:?}"),
IRStmtImpl::PostCond(pc) => write!(f, "{pc:?}"),
IRStmtImpl::CondBlock(cb) => write!(f, "{cb:?}"),
IRStmtImpl::Seq(seq) => write!(f, "{seq:?}"),
}
}
}
impl<T> IRStmt<T> {
pub fn call(
callee: impl AsRef<str>,
inputs: impl IntoIterator<Item = T>,
outputs: impl IntoIterator<Item = Slot>,
) -> Self {
Call::new(callee, inputs, outputs).into()
}
pub fn post_cond(cond: IRBexpr<T>) -> Self {
PostCond::new(cond).into()
}
pub fn constraint(op: CmpOp, lhs: T, rhs: T) -> Self {
Constraint::new(op, lhs, rhs).into()
}
#[inline]
pub fn eq(lhs: T, rhs: T) -> Self {
Self::constraint(CmpOp::Eq, lhs, rhs)
}
#[inline]
pub fn lt(lhs: T, rhs: T) -> Self {
Self::constraint(CmpOp::Lt, lhs, rhs)
}
#[inline]
pub fn le(lhs: T, rhs: T) -> Self {
Self::constraint(CmpOp::Le, lhs, rhs)
}
#[inline]
pub fn gt(lhs: T, rhs: T) -> Self {
Self::constraint(CmpOp::Gt, lhs, rhs)
}
#[inline]
pub fn ge(lhs: T, rhs: T) -> Self {
Self::constraint(CmpOp::Ge, lhs, rhs)
}
pub fn comment(s: impl AsRef<str>) -> Self {
Comment::new(s).into()
}
pub fn assume_deterministic(f: impl Into<Slot>) -> Self {
AssumeDeterministic::new(f.into()).into()
}
pub fn assert(cond: IRBexpr<T>) -> Self {
Assert::new(cond).into()
}
pub fn seq<I>(stmts: impl IntoIterator<Item = IRStmt<I>>) -> Self
where
I: Into<T>,
{
Seq::new(stmts).into()
}
pub fn empty() -> Self {
Seq::empty().into()
}
pub fn is_empty(&self) -> bool {
match &self.0 {
IRStmtImpl::Seq(s) => s.is_empty(),
_ => false,
}
}
pub fn map<O>(self, f: &mut impl FnMut(T) -> O) -> IRStmt<O> {
match self.0 {
IRStmtImpl::ConstraintCall(call) => call.map(f).into(),
IRStmtImpl::Constraint(constraint) => constraint.map(f).into(),
IRStmtImpl::Comment(comment) => Comment::new(comment.value()).into(),
IRStmtImpl::AssumeDeterministic(ad) => AssumeDeterministic::new(ad.value()).into(),
IRStmtImpl::Assert(assert) => assert.map(f).into(),
IRStmtImpl::PostCond(pc) => pc.map(f).into(),
IRStmtImpl::CondBlock(cb) => cb.map(f).into(),
IRStmtImpl::Seq(seq) => Seq::new(seq.into_iter().map(|s| s.map(f))).into(),
}
}
pub fn with<O>(self, other: O) -> IRStmt<(O, T)>
where
O: Clone,
{
self.map(&mut |t| (other.clone(), t))
}
pub fn with_fn<O>(self, other: impl Fn() -> O) -> IRStmt<(O, T)> {
self.map(&mut |t| (other(), t))
}
pub fn into<O>(self) -> IRStmt<O>
where
O: From<T> + Evaluate<ExprProperties>,
{
self.map(&mut Into::into)
}
pub fn from<O>(value: IRStmt<O>) -> Self
where
O: Into<T>,
{
value.map(&mut Into::into)
}
pub fn then(self, other: impl Into<Self>) -> Self {
match self.0 {
IRStmtImpl::Seq(mut seq) => {
seq.push(other.into());
seq.into()
}
this => Seq::new([Self(this, self.1), other.into()]).into(),
}
}
pub fn map_into<O>(&self, f: &mut impl FnMut(&T) -> O) -> IRStmt<O> {
match &self.0 {
IRStmtImpl::ConstraintCall(call) => call.map_into(f).into(),
IRStmtImpl::Constraint(constraint) => constraint.map_into(f).into(),
IRStmtImpl::Comment(comment) => Comment::new(comment.value()).into(),
IRStmtImpl::AssumeDeterministic(ad) => AssumeDeterministic::new(ad.value()).into(),
IRStmtImpl::Assert(assert) => assert.map_into(f).into(),
IRStmtImpl::PostCond(pc) => pc.map_into(f).into(),
IRStmtImpl::CondBlock(cb) => cb.map_into(f).into(),
IRStmtImpl::Seq(seq) => Seq::new(seq.iter().map(|s| s.map_into(f))).into(),
}
}
pub fn try_map<O, E>(self, f: &mut impl FnMut(T) -> Result<O, E>) -> Result<IRStmt<O>, E> {
Ok(match self.0 {
IRStmtImpl::ConstraintCall(call) => call.try_map(f)?.into(),
IRStmtImpl::Constraint(constraint) => constraint.try_map(f)?.into(),
IRStmtImpl::Comment(comment) => Comment::new(comment.value()).into(),
IRStmtImpl::AssumeDeterministic(ad) => AssumeDeterministic::new(ad.value()).into(),
IRStmtImpl::Assert(assert) => assert.try_map(f)?.into(),
IRStmtImpl::PostCond(pc) => pc.try_map(f)?.into(),
IRStmtImpl::CondBlock(cb) => cb.try_map(f)?.into(),
IRStmtImpl::Seq(seq) => Seq::new(
seq.into_iter()
.map(|s| s.try_map(f))
.collect::<Result<Vec<_>, _>>()?,
)
.into(),
})
}
pub fn map_inplace(&mut self, f: &mut impl FnMut(&mut T)) {
match &mut self.0 {
IRStmtImpl::ConstraintCall(call) => call.map_inplace(f),
IRStmtImpl::Constraint(constraint) => constraint.map_inplace(f),
IRStmtImpl::Assert(assert) => assert.map_inplace(f),
IRStmtImpl::PostCond(pc) => pc.map_inplace(f),
IRStmtImpl::CondBlock(cb) => cb.map_inplace(f),
IRStmtImpl::Seq(seq) => seq.iter_mut().for_each(|stmt| stmt.map_inplace(f)),
_ => {}
}
}
pub fn try_map_inplace<E>(
&mut self,
f: &mut impl FnMut(&mut T) -> Result<(), E>,
) -> Result<(), E> {
match &mut self.0 {
IRStmtImpl::ConstraintCall(call) => call.try_map_inplace(f),
IRStmtImpl::Constraint(constraint) => constraint.try_map_inplace(f),
IRStmtImpl::Assert(assert) => assert.try_map_inplace(f),
IRStmtImpl::PostCond(pc) => pc.try_map_inplace(f),
IRStmtImpl::CondBlock(cb) => cb.try_map_inplace(f),
IRStmtImpl::Seq(seq) => seq.iter_mut().try_for_each(|stmt| stmt.try_map_inplace(f)),
_ => Ok(()),
}
}
pub fn map_slot_inplace(&mut self, f: &mut impl FnMut(&mut Slot)) {
match &mut self.0 {
IRStmtImpl::ConstraintCall(call) => call.outputs_mut().iter_mut().for_each(f),
IRStmtImpl::AssumeDeterministic(det) => f(det.value_mut()),
IRStmtImpl::Seq(seq) => seq.iter_mut().for_each(|stmt| stmt.map_slot_inplace(f)),
_ => {}
}
}
pub fn try_map_slot_inplace<E>(
&mut self,
f: &mut impl FnMut(&mut Slot) -> Result<(), E>,
) -> Result<(), E> {
match &mut self.0 {
IRStmtImpl::ConstraintCall(call) => call.outputs_mut().iter_mut().try_for_each(f),
IRStmtImpl::AssumeDeterministic(det) => f(det.value_mut()),
IRStmtImpl::Seq(seq) => seq
.iter_mut()
.try_for_each(|stmt| stmt.try_map_slot_inplace(f)),
_ => Ok(()),
}
}
pub fn iter(&self) -> IRStmtRefIter<'_, T> {
IRStmtRefIter { stack: vec![self] }
}
pub fn iter_mut(&mut self) -> IRStmtRefMutIter<'_, T> {
IRStmtRefMutIter { stack: vec![self] }
}
pub fn propagate_meta(&mut self) {
match &mut self.0 {
IRStmtImpl::Seq(s) => {
for stmt in s.iter_mut() {
stmt.meta_mut().complete_with(self.1);
}
}
_ => {}
}
}
}
impl<T> ConstantFolding for IRStmt<T>
where
T: ConstantFolding + std::fmt::Debug + Clone,
Error: From<T::Error>,
T::T: Eq + Ord,
{
type Error = Error;
type T = ();
fn constant_fold(&mut self) -> Result<(), Error> {
match &mut self.0 {
IRStmtImpl::ConstraintCall(call) => call.constant_fold()?,
IRStmtImpl::Constraint(constraint) => {
if let Some(replacement) = constraint.constant_fold(self.1)? {
*self = replacement;
}
}
IRStmtImpl::Comment(_) => {}
IRStmtImpl::AssumeDeterministic(_) => {}
IRStmtImpl::Assert(assert) => {
if let Some(replacement) = assert.constant_fold(self.1)? {
*self = replacement;
}
}
IRStmtImpl::PostCond(pc) => {
if let Some(replacement) = pc.constant_fold(self.1)? {
*self = replacement;
}
}
IRStmtImpl::CondBlock(cb) => {
if let Some(replacement) = cb.constant_fold()? {
*self = replacement;
}
}
IRStmtImpl::Seq(seq) => seq.constant_fold()?,
}
Ok(())
}
}
impl Canonicalize for IRStmt<IRAexpr> {
fn canonicalize(&mut self) {
match &mut self.0 {
IRStmtImpl::ConstraintCall(call) => call.canonicalize(),
IRStmtImpl::Constraint(constraint) => constraint.canonicalize(),
IRStmtImpl::Comment(_) => {}
IRStmtImpl::AssumeDeterministic(_) => {}
IRStmtImpl::Assert(assert) => assert.canonicalize(),
IRStmtImpl::PostCond(pc) => pc.canonicalize(),
IRStmtImpl::CondBlock(cb) => cb.canonicalize(),
IRStmtImpl::Seq(seq) => seq.canonicalize(),
}
}
}
impl<T, D> Validatable for IRStmt<T>
where
IRConstBexpr<T>: Validatable<Diagnostic = D, Context = ()>,
D: Diagnostic,
{
type Diagnostic = D;
type Context = ();
fn validate_with_context(
&self,
_: &Self::Context,
) -> Result<Vec<Self::Diagnostic>, Vec<Self::Diagnostic>> {
match &self.0 {
IRStmtImpl::Seq(seq) => {
let mut validation = Validation::new();
for stmt in seq.iter() {
validation.append_from_result(stmt.validate(), "");
}
validation.into()
}
IRStmtImpl::CondBlock(cond_block) => cond_block.validate(),
_ => Validation::new().into(),
}
}
}
impl<L, R> EqvRelation<IRStmt<L>, IRStmt<R>> for SymbolicEqv
where
SymbolicEqv: EqvRelation<L, R> + EqvRelation<Slot, Slot>,
{
fn equivalent(lhs: &IRStmt<L>, rhs: &IRStmt<R>) -> bool {
std::iter::zip(lhs.iter(), rhs.iter()).all(|(lhs, rhs)| match (&lhs.0, &rhs.0) {
(IRStmtImpl::ConstraintCall(lhs), IRStmtImpl::ConstraintCall(rhs)) => {
equiv! { SymbolicEqv | lhs, rhs }
}
(IRStmtImpl::Constraint(lhs), IRStmtImpl::Constraint(rhs)) => {
equiv! { SymbolicEqv | lhs, rhs }
}
(IRStmtImpl::Comment(_), IRStmtImpl::Comment(_)) => true,
(IRStmtImpl::AssumeDeterministic(lhs), IRStmtImpl::AssumeDeterministic(rhs)) => {
equiv! { SymbolicEqv | lhs, rhs }
}
(IRStmtImpl::Assert(lhs), IRStmtImpl::Assert(rhs)) => {
equiv! { SymbolicEqv | lhs, rhs }
}
(IRStmtImpl::PostCond(lhs), IRStmtImpl::PostCond(rhs)) => {
equiv! { SymbolicEqv | lhs, rhs }
}
(IRStmtImpl::CondBlock(lhs), IRStmtImpl::CondBlock(rhs)) => {
equiv! { SymbolicEqv | lhs, rhs }
}
(IRStmtImpl::Seq(_), _) | (_, IRStmtImpl::Seq(_)) => unreachable!(),
_ => false,
})
}
}
#[derive(Debug)]
pub struct IRStmtRefIter<'a, T> {
stack: Vec<&'a IRStmt<T>>,
}
impl<'a, T> Iterator for IRStmtRefIter<'a, T> {
type Item = &'a IRStmt<T>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.stack.pop() {
match &node.0 {
IRStmtImpl::Seq(children) => {
self.stack.extend(children.iter().rev());
}
_ => return Some(node),
}
}
None
}
}
#[derive(Debug)]
pub struct IRStmtRefMutIter<'a, T> {
stack: Vec<&'a mut IRStmt<T>>,
}
impl<'a, T> Iterator for IRStmtRefMutIter<'a, T> {
type Item = &'a mut IRStmt<T>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.stack.pop() {
if let IRStmt(IRStmtImpl::Seq(children), _) = node {
self.stack.extend(children.iter_mut().rev());
} else {
return Some(node);
}
}
None
}
}
impl<T> Default for IRStmt<T> {
fn default() -> Self {
Self::empty()
}
}
#[derive(Debug)]
pub struct IRStmtIter<T> {
stack: Vec<IRStmt<T>>,
}
impl<T> Iterator for IRStmtIter<T> {
type Item = IRStmt<T>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(node) = self.stack.pop() {
match node {
IRStmt(IRStmtImpl::Seq(children), _) => {
self.stack.extend(children.into_iter().rev());
}
stmt => return Some(stmt),
}
}
None
}
}
impl<T> IntoIterator for IRStmt<T> {
type Item = Self;
type IntoIter = IRStmtIter<T>;
fn into_iter(self) -> Self::IntoIter {
IRStmtIter { stack: vec![self] }
}
}
impl<'a, T> IntoIterator for &'a IRStmt<T> {
type Item = Self;
type IntoIter = IRStmtRefIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
impl<'a, T> IntoIterator for &'a mut IRStmt<T> {
type Item = Self;
type IntoIter = IRStmtRefMutIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter_mut()
}
}
impl<I> FromIterator<IRStmt<I>> for IRStmt<I> {
fn from_iter<T: IntoIterator<Item = IRStmt<I>>>(iter: T) -> Self {
Self::seq(iter)
}
}
impl<T> From<Call<T>> for IRStmt<T> {
fn from(value: Call<T>) -> Self {
Self(IRStmtImpl::ConstraintCall(value), Default::default())
}
}
impl<T> From<Constraint<T>> for IRStmt<T> {
fn from(value: Constraint<T>) -> Self {
Self(IRStmtImpl::Constraint(value), Default::default())
}
}
impl<T> From<Comment> for IRStmt<T> {
fn from(value: Comment) -> Self {
Self(IRStmtImpl::Comment(value), Default::default())
}
}
impl<T> From<AssumeDeterministic> for IRStmt<T> {
fn from(value: AssumeDeterministic) -> Self {
Self(IRStmtImpl::AssumeDeterministic(value), Default::default())
}
}
impl<T> From<Assert<T>> for IRStmt<T> {
fn from(value: Assert<T>) -> Self {
Self(IRStmtImpl::Assert(value), Default::default())
}
}
impl<T> From<PostCond<T>> for IRStmt<T> {
fn from(value: PostCond<T>) -> Self {
Self(IRStmtImpl::PostCond(value), Default::default())
}
}
impl<T> From<CondBlock<T>> for IRStmt<T> {
fn from(value: CondBlock<T>) -> Self {
Self(IRStmtImpl::CondBlock(value), Default::default())
}
}
impl<T> From<Seq<T>> for IRStmt<T> {
fn from(value: Seq<T>) -> Self {
Self(IRStmtImpl::Seq(value), Default::default())
}
}
#[derive(Debug)]
pub struct UnresolvedCondBlockError;
impl std::fmt::Display for UnresolvedCondBlockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"attempted to lower an unresolved conditionally emitted block"
)
}
}
impl std::error::Error for UnresolvedCondBlockError where Self: std::fmt::Debug {}
impl<T: LowerableExpr> LowerableStmt for IRStmt<T> {
fn lower<L>(self, l: &L) -> haloumi_lowering::Result<()>
where
L: Lowering + ?Sized,
{
match self.0 {
IRStmtImpl::ConstraintCall(call) => call.lower(l),
IRStmtImpl::Constraint(constraint) => constraint.lower(l),
IRStmtImpl::Comment(comment) => comment.lower(l),
IRStmtImpl::AssumeDeterministic(ad) => ad.lower(l),
IRStmtImpl::Assert(assert) => assert.lower(l),
IRStmtImpl::PostCond(pc) => pc.lower(l),
IRStmtImpl::CondBlock(_) => Err(lowering_err!(UnresolvedCondBlockError)),
IRStmtImpl::Seq(seq) => seq.lower(l),
}
}
}
impl<T: Clone> Clone for IRStmt<T> {
fn clone(&self) -> Self {
match &self.0 {
IRStmtImpl::ConstraintCall(call) => call.clone().into(),
IRStmtImpl::Constraint(c) => c.clone().into(),
IRStmtImpl::Comment(c) => c.clone().into(),
IRStmtImpl::AssumeDeterministic(func_io) => func_io.clone().into(),
IRStmtImpl::Assert(e) => e.clone().into(),
IRStmtImpl::PostCond(e) => e.clone().into(),
IRStmtImpl::CondBlock(e) => e.clone().into(),
IRStmtImpl::Seq(stmts) => stmts.clone().into(),
}
}
}
impl<T: IRPrintable> IRPrintable for IRStmt<T> {
fn fmt(&self, ctx: &mut crate::printer::IRPrinterCtx<'_, '_>) -> crate::printer::Result {
match &self.0 {
IRStmtImpl::ConstraintCall(call) => {
ctx.fmt_call(call.callee(), call.inputs(), call.outputs(), None)
}
IRStmtImpl::Constraint(constraint) => {
ctx.block(format!("assert/{}", constraint.op()).as_str(), |ctx| {
if constraint.lhs().depth() > 1 {
ctx.nl()?;
}
constraint.lhs().fmt(ctx)?;
if constraint.lhs().depth() > 1 || constraint.rhs().depth() > 1 {
ctx.nl()?;
}
constraint.rhs().fmt(ctx)
})
}
IRStmtImpl::Comment(comment) => {
ctx.nl()?;
writeln!(ctx, "; {}", comment.value())
}
IRStmtImpl::AssumeDeterministic(assume_deterministic) => ctx
.list_nl("assume-deterministic", |ctx| {
assume_deterministic.value().fmt(ctx)
}),
IRStmtImpl::Assert(assert) => ctx.block("assert", |ctx| assert.cond().fmt(ctx)),
IRStmtImpl::Seq(seq) => {
for stmt in seq.iter() {
stmt.fmt(ctx)?;
}
Ok(())
}
IRStmtImpl::CondBlock(cb) => ctx.block("emit-if", |ctx| {
cb.cond().fmt(ctx)?;
ctx.nl()?;
cb.body().fmt(ctx)
}),
IRStmtImpl::PostCond(post_cond) => {
ctx.block("post-cond", |ctx| post_cond.cond().fmt(ctx))
}
}
}
}
#[derive(Debug, thiserror::Error)]
#[error(transparent)]
pub struct TryMapError(#[from] Box<dyn std::error::Error>);
#[cfg(test)]
mod test;