use std::sync::Arc;
use super::vars::{NamingConvention, VarKey, VarKeySeed};
use crate::pcl::{ModuleLike as _, expr, stmt};
use haloumi_core::{
cmp::CmpOp,
felt::Felt,
slot::{Slot as FuncIO, arg::ArgNo, output::OutputId as FieldId},
};
use haloumi_ir::Prime;
use haloumi_lowering::{ExprLowering, Lowering, Result};
pub type PicusModuleRef = crate::pcl::ModuleRef<VarKey>;
pub(super) type PicusExpr = crate::pcl::expr::Expr;
#[derive(Clone, Debug)]
pub struct PicusModuleLowering {
module: PicusModuleRef,
naming_convention: NamingConvention,
prime: Prime,
}
impl PicusModuleLowering {
pub(crate) fn new(
module: PicusModuleRef,
naming_convention: NamingConvention,
prime: Prime,
) -> Self {
Self {
module,
naming_convention,
prime,
}
}
}
impl PicusModuleLowering {
pub(crate) fn lower_func_io(&self, func_io: FuncIO) -> PicusExpr {
let seed = VarKeySeed::io(func_io, self.naming_convention);
expr::var(&self.module, seed)
}
}
fn err(e: anyhow::Error) -> haloumi_lowering::error::Error {
let b: Box<dyn std::error::Error> = e.into();
let a: Arc<dyn std::error::Error> = b.into();
haloumi_lowering::error::Error::Backend(a)
}
impl Lowering for PicusModuleLowering {
fn generate_constraint(
&self,
op: CmpOp,
lhs: &Self::CellOutput,
rhs: &Self::CellOutput,
) -> Result<()> {
self.module.borrow_mut().add_constraint(match op {
CmpOp::Eq => expr::eq(lhs, rhs),
CmpOp::Lt => expr::lt(lhs, rhs),
CmpOp::Le => expr::le(lhs, rhs),
CmpOp::Gt => expr::gt(lhs, rhs),
CmpOp::Ge => expr::ge(lhs, rhs),
CmpOp::Ne => unimplemented!(),
});
Ok(())
}
fn num_constraints(&self) -> usize {
self.module.constraints_len()
}
fn generate_comment(&self, s: String) -> Result<()> {
self.module.borrow_mut().add_stmt(stmt::comment(s));
Ok(())
}
fn generate_call(
&self,
name: &str,
inputs: &[Self::CellOutput],
outputs: &[FuncIO],
) -> Result<()> {
let stmt = stmt::call(
name.to_owned(),
inputs.to_vec(),
outputs
.iter()
.copied()
.map(|o| self.lower_func_io(o))
.collect(),
)
.map_err(err)?;
self.module.borrow_mut().add_stmt(stmt);
Ok(())
}
fn generate_assume_deterministic(&self, func_io: FuncIO) -> Result<()> {
let stmt = stmt::assume_deterministic(self.lower_func_io(func_io)).map_err(err)?;
self.module.borrow_mut().add_stmt(stmt);
Ok(())
}
fn generate_assert(&self, expr: &Self::CellOutput) -> Result<()> {
let stmt = stmt::constrain(expr.clone());
self.module.borrow_mut().add_stmt(stmt);
Ok(())
}
fn generate_post_condition(&self, expr: &Self::CellOutput) -> Result<()> {
let stmt = stmt::post_condition(expr.clone());
self.module.borrow_mut().add_stmt(stmt);
Ok(())
}
}
impl ExprLowering for PicusModuleLowering {
type CellOutput = PicusExpr;
fn lower_sum(
&self,
lhs: &Self::CellOutput,
rhs: &Self::CellOutput,
) -> Result<Self::CellOutput> {
Ok(expr::add(lhs, rhs))
}
fn lower_product(
&self,
lhs: &Self::CellOutput,
rhs: &Self::CellOutput,
) -> Result<Self::CellOutput> {
Ok(expr::mul(lhs, rhs))
}
fn lower_neg(&self, expr: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::neg(expr))
}
fn lower_constant(&self, f: Felt) -> Result<Self::CellOutput> {
let expr = expr::r#const(f);
log::debug!(
"[PicusBackend::lower_constant] Constant value {f:?} becomes expression {expr:?}"
);
Ok(expr)
}
fn lower_eq(&self, lhs: &Self::CellOutput, rhs: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::eq(lhs, rhs))
}
fn lower_and(
&self,
lhs: &Self::CellOutput,
rhs: &Self::CellOutput,
) -> Result<Self::CellOutput> {
Ok(expr::and(lhs, rhs))
}
fn lower_or(&self, lhs: &Self::CellOutput, rhs: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::or(lhs, rhs))
}
fn lower_function_input(&self, i: usize) -> FuncIO {
ArgNo::from(i).into()
}
fn lower_function_output(&self, o: usize) -> FuncIO {
FieldId::from(o).into()
}
fn lower_funcio<IO>(&self, io: IO) -> Result<Self::CellOutput>
where
IO: Into<FuncIO>,
{
Ok(self.lower_func_io(io.into()))
}
fn lower_lt(&self, lhs: &Self::CellOutput, rhs: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::lt(lhs, rhs))
}
fn lower_le(&self, lhs: &Self::CellOutput, rhs: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::le(lhs, rhs))
}
fn lower_gt(&self, lhs: &Self::CellOutput, rhs: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::gt(lhs, rhs))
}
fn lower_ge(&self, lhs: &Self::CellOutput, rhs: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::ge(lhs, rhs))
}
fn lower_ne(&self, lhs: &Self::CellOutput, rhs: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::ne(lhs, rhs))
}
fn lower_not(&self, value: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::not(value))
}
fn lower_true(&self) -> Result<Self::CellOutput> {
Ok(expr::eq(
&expr::r#const(self.prime.zero()),
&expr::r#const(self.prime.zero()),
))
}
fn lower_false(&self) -> Result<Self::CellOutput> {
Ok(expr::eq(
&expr::r#const(self.prime.zero()),
&expr::r#const(self.prime.one()),
))
}
fn lower_det(&self, expr: &Self::CellOutput) -> Result<Self::CellOutput> {
Ok(expr::det(expr))
}
fn lower_implies(
&self,
lhs: &Self::CellOutput,
rhs: &Self::CellOutput,
) -> Result<Self::CellOutput> {
Ok(expr::implies(lhs, rhs))
}
fn lower_iff(
&self,
lhs: &Self::CellOutput,
rhs: &Self::CellOutput,
) -> Result<Self::CellOutput> {
Ok(expr::iff(lhs, rhs))
}
}