use melior::{
Context,
ir::{
Attribute, Block, BlockLike as _, Identifier, Location, RegionLike as _, Type,
operation::OperationLike as _, r#type::FunctionType,
},
};
use crate::{
attributes::NamedAttribute,
builder::{EntryPoint, OpBuilder},
dialect,
error::Error,
prelude::{
FUNC_NAME_COMPUTE, FUNC_NAME_CONSTRAIN, FUNC_NAME_PRODUCT, FeltType, FuncDefOp,
FuncDefOpLike as _, StructDefOp,
},
};
use super::r#type::StructType;
pub fn compute_fn<'c>(
loc: Location<'c>,
struct_type: StructType<'c>,
inputs: &[(Type<'c>, Location<'c>)],
arg_attrs: Option<&[Vec<NamedAttribute<'c>>]>,
) -> Result<FuncDefOp<'c>, Error> {
let context = loc.context();
let input_types: Vec<Type<'c>> = inputs.iter().map(|(t, _)| *t).collect();
dialect::function::def(
loc,
FUNC_NAME_COMPUTE.as_ref(),
FunctionType::new(
unsafe { context.to_ref() },
&input_types,
&[struct_type.into()],
),
&[],
arg_attrs,
)
.and_then(|f| {
let block = Block::new(inputs);
let new_struct = block.append_operation(super::new(loc, struct_type));
block.append_operation(dialect::function::r#return(
loc,
&[new_struct.result(0)?.into()],
));
f.set_allow_witness_attr(true);
f.set_allow_non_native_field_ops_attr(true);
f.region(0)?.append_block(block);
Ok(f)
})
}
pub fn constrain_fn<'c>(
loc: Location<'c>,
struct_type: StructType<'c>,
inputs: &[(Type<'c>, Location<'c>)],
arg_attrs: Option<&[Vec<NamedAttribute<'c>>]>,
) -> Result<FuncDefOp<'c>, Error> {
let context = loc.context();
let mut input_types: Vec<Type<'c>> = Vec::with_capacity(inputs.len() + 1);
input_types.push(struct_type.into());
input_types.extend(inputs.iter().map(|(t, _)| *t));
let mut all_inputs = vec![(struct_type.into(), loc)];
all_inputs.extend(inputs);
let all_arg_attrs = arg_attrs.map(|original| {
let mut result: Vec<Vec<(Identifier<'_>, Attribute<'_>)>> = vec![vec![]];
result.extend(original.iter().cloned());
result
});
dialect::function::def(
loc,
FUNC_NAME_CONSTRAIN.as_ref(),
FunctionType::new(unsafe { context.to_ref() }, &input_types, &[]),
&[],
all_arg_attrs.as_deref(),
)
.and_then(|f| {
let block = Block::new(&all_inputs);
block.append_operation(dialect::function::r#return(loc, &[]));
f.set_allow_constraint_attr(true);
f.set_allow_non_native_field_ops_attr(true);
f.region(0)?.append_block(block);
Ok(f)
})
}
pub fn product_fn<'c>(
loc: Location<'c>,
struct_type: StructType<'c>,
inputs: &[(Type<'c>, Location<'c>)],
arg_attrs: Option<&[Vec<NamedAttribute<'c>>]>,
) -> Result<FuncDefOp<'c>, Error> {
let context = loc.context();
let input_types: Vec<Type<'c>> = inputs.iter().map(|(t, _)| *t).collect();
dialect::function::def(
loc,
FUNC_NAME_PRODUCT.as_ref(),
FunctionType::new(
unsafe { context.to_ref() },
&input_types,
&[struct_type.into()],
),
&[],
arg_attrs,
)
.and_then(|f| {
let block = Block::new(inputs);
let new_struct = block.append_operation(super::new(loc, struct_type));
block.append_operation(dialect::function::r#return(
loc,
&[new_struct.result(0)?.into()],
));
f.set_allow_constraint_attr(true);
f.set_allow_witness_attr(true);
f.set_allow_non_native_field_ops_attr(true);
f.region(0)?.append_block(block);
Ok(f)
})
}
#[deprecated]
pub fn define_signal_struct<'c>(context: &'c Context) -> Result<StructDefOp<'c>, Error> {
let loc = Location::new(context, "Signal struct", 0, 0);
let typ = StructType::from_str(context, "Signal");
let reg = "reg";
super::def(loc, "Signal", {
[
super::member(loc, reg, FeltType::new(context), false, true).map(Into::into),
compute_fn(loc, typ, &[(FeltType::new(context).into(), loc)], None)
.and_then(|compute| {
let block = compute
.region(0)?
.first_block()
.ok_or(Error::BlockExpected(0))?;
let fst = block.first_operation().ok_or(Error::EmptyBlock)?;
if fst.name() != Identifier::new(context, "struct.new") {
return Err(Error::OperationExpected(
"struct.new",
fst.name().as_string_ref().as_str()?.to_owned(),
));
}
block.insert_operation_after(
fst,
super::writem(loc, fst.result(0)?.into(), reg, block.argument(0)?.into())?,
);
Ok(compute)
})
.map(Into::into),
constrain_fn(loc, typ, &[(FeltType::new(context).into(), loc)], None)
.and_then(|constrain| {
let block = constrain
.region(0)?
.first_block()
.ok_or(Error::BlockExpected(0))?;
let fst = block.first_operation().ok_or(Error::EmptyBlock)?;
if fst.name() != Identifier::new(context, "function.return") {
return Err(Error::OperationExpected(
"function.return",
fst.name().as_string_ref().as_str()?.to_owned(),
));
}
let reg = block.insert_operation_before(
fst,
super::readm(
&OpBuilder::new(context, EntryPoint::Before(fst)),
loc,
FeltType::new(context).into(),
block.argument(0)?.into(),
"reg",
)?,
);
block.insert_operation_after(
reg,
dialect::constrain::eq(
loc,
reg.result(0)?.into(),
block.argument(1)?.into(),
),
);
Ok(constrain)
})
.map(Into::into),
]
})
}