use eqv::{EqvRelation, equiv};
use haloumi_core::eqv::SymbolicEqv;
use haloumi_lowering::{
Lowering, Result as LoweringResult,
lowerable::{LowerableExpr, LowerableStmt},
};
use crate::{
error::Error,
expr::IRBexpr,
stmt::IRStmt,
traits::{Canonicalize, ConstantFolding},
};
#[derive(Clone, PartialEq)]
pub struct CondBlock<T> {
cond: IRBexpr<T>,
body: Box<IRStmt<T>>,
}
impl<T> CondBlock<T> {
pub fn new(cond: IRBexpr<T>, body: IRStmt<T>) -> Self {
Self {
cond,
body: Box::new(body),
}
}
pub fn cond(&self) -> &IRBexpr<T> {
&self.cond
}
pub fn cond_mut(&mut self) -> &mut IRBexpr<T> {
&mut self.cond
}
pub fn body(&self) -> &IRStmt<T> {
&self.body
}
pub fn body_mut(&mut self) -> &mut IRStmt<T> {
&mut self.body
}
pub fn map<O>(self, f: &mut impl FnMut(T) -> O) -> CondBlock<O> {
CondBlock {
cond: IRBexpr::map(self.cond, f),
body: Box::new(self.body.map(f)),
}
}
pub fn map_into<O>(&self, f: &mut impl FnMut(&T) -> O) -> CondBlock<O> {
CondBlock {
cond: IRBexpr::map_into(&self.cond, f),
body: Box::new(self.body.map_into(f)),
}
}
pub fn try_map<O, E>(self, f: &mut impl FnMut(T) -> Result<O, E>) -> Result<CondBlock<O>, E> {
Ok(CondBlock {
cond: IRBexpr::try_map(self.cond, f)?,
body: Box::new(self.body.try_map(f)?),
})
}
pub fn map_inplace(&mut self, f: &mut impl FnMut(&mut T)) {
IRBexpr::map_inplace(&mut self.cond, f);
self.body.map_inplace(f);
}
pub fn try_map_inplace<E>(
&mut self,
f: &mut impl FnMut(&mut T) -> Result<(), E>,
) -> Result<(), E> {
IRBexpr::try_map_inplace(&mut self.cond, f)?;
self.body.try_map_inplace(f)
}
pub fn constant_fold(&mut self) -> Result<Option<IRStmt<T>>, Error>
where
IRBexpr<T>: ConstantFolding<T = bool>,
IRStmt<T>: ConstantFolding<Error = Error>,
Error: From<<IRBexpr<T> as ConstantFolding>::Error>,
{
self.body.constant_fold()?;
self.cond.constant_fold()?;
Ok(match self.cond.const_value() {
Some(false) => Some(IRStmt::empty()),
_ => None,
})
}
pub fn canonicalize(&mut self)
where
IRBexpr<T>: Canonicalize,
IRStmt<T>: Canonicalize,
{
self.cond.canonicalize();
self.body.canonicalize();
}
}
impl<L, R> EqvRelation<CondBlock<L>, CondBlock<R>> for SymbolicEqv
where
SymbolicEqv: EqvRelation<L, R>,
{
fn equivalent(lhs: &CondBlock<L>, rhs: &CondBlock<R>) -> bool {
let cond = equiv! { SymbolicEqv | lhs.cond(), rhs.cond() };
let body = equiv! { SymbolicEqv | lhs.body(), rhs.body() };
cond && body
}
}
impl<T: std::fmt::Debug> std::fmt::Debug for CondBlock<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "emit-if ")?;
std::fmt::Debug::fmt(self.cond(), f)?;
writeln!(f, " {{")?;
std::fmt::Debug::fmt(self.body(), f)?;
writeln!(f, " }}")
}
}
impl<T: LowerableExpr> LowerableStmt for CondBlock<T>
where
IRBexpr<T>: ConstantFolding<T = bool>,
{
fn lower<L>(self, l: &L) -> LoweringResult<()>
where
L: Lowering + ?Sized,
{
match self.cond.const_value() {
Some(false) => Ok(()),
_ => self.body.lower(l),
}
}
}
#[cfg(test)]
mod tests {
use haloumi_core::{felt::Felt, slot::arg::ArgNo};
use rstest::{fixture, rstest};
use crate::expr::IRAexpr;
use super::*;
type Stmt = IRStmt<IRAexpr>;
type Bexpr = IRBexpr<IRAexpr>;
fn true_bexpr() -> Bexpr {
true.into()
}
fn false_bexpr() -> Bexpr {
false.into()
}
fn truthy_const_bexpr1() -> Bexpr {
let one = IRAexpr::constant(Felt::new(BabyBear::from(1)));
let ten = IRAexpr::constant(Felt::new(BabyBear::from(10)));
let eleven = IRAexpr::constant(Felt::new(BabyBear::from(11)));
Bexpr::eq(one + ten, eleven)
}
fn falshy_const_bexpr1() -> Bexpr {
let one = IRAexpr::constant(Felt::new(BabyBear::from(1)));
let ten = IRAexpr::constant(Felt::new(BabyBear::from(10)));
let twelve = IRAexpr::constant(Felt::new(BabyBear::from(12)));
Bexpr::eq(one + ten, twelve)
}
fn non_const_bexpr1() -> Bexpr {
let one = IRAexpr::constant(Felt::new(BabyBear::from(1)));
let arg0 = IRAexpr::slot(ArgNo::from(0));
let eleven = IRAexpr::constant(Felt::new(BabyBear::from(11)));
Bexpr::eq(one + arg0, eleven)
}
fn test_stmt() -> Stmt {
Stmt::comment("There is an assertion below").then(Stmt::assert(truthy_const_bexpr1()))
}
#[fixture]
fn stmt() -> Stmt {
test_stmt()
}
#[rstest]
#[case(true_bexpr())]
#[case(false_bexpr())]
#[case(truthy_const_bexpr1())]
#[case(falshy_const_bexpr1())]
#[should_panic(expected = "NonConstIRBexprError")]
#[case(non_const_bexpr1())]
fn const_block_validation(stmt: Stmt, #[case] cond: Bexpr) {
let cb = CondBlock::new(cond.try_into().unwrap(), stmt);
let _ = cb.validate().unwrap();
}
#[rstest]
#[case(truthy_const_bexpr1())]
#[case(falshy_const_bexpr1())]
fn const_block_validation_after_map(stmt: Stmt, #[case] cond: Bexpr) {
let mut cb = CondBlock::new(cond.try_into().unwrap(), stmt);
let _ = cb.validate().unwrap();
cb.map_inplace(&mut |e| {
*e = IRAexpr::slot(ArgNo::from(1));
});
let res = cb.validate();
assert!(res.is_err());
}
#[rstest]
#[case(true_bexpr(), test_stmt())]
#[case(false_bexpr(), Stmt::empty())]
#[case(truthy_const_bexpr1(), test_stmt())]
#[case(falshy_const_bexpr1(), Stmt::empty())]
fn const_block_folding(stmt: Stmt, #[case] cond: Bexpr, #[case] expected: Stmt) {
let mut cb = CondBlock::new(cond.try_into().unwrap(), stmt);
let _ = cb.validate().unwrap();
let folded = cb.constant_fold().unwrap().unwrap();
assert_eq!(folded, expected);
}
use ff::PrimeField;
#[derive(PrimeField)]
#[PrimeFieldModulus = "2013265921"]
#[PrimeFieldGenerator = "31"]
#[PrimeFieldReprEndianness = "little"]
pub struct BabyBear([u64; 1]);
}