use crate::{
attributes::NamedAttribute,
builder::{OpBuilder, OpBuilderLike as _},
dialect::r#struct::StructType,
error::Error,
macros::llzk_op_type,
symbol_ref::{SymbolRefAttrLike, SymbolRefAttribute},
};
use llzk_sys::{
llzkFunction_CallOpBuild, llzkFunction_CallOpCalleeIsCompute,
llzkFunction_CallOpCalleeIsConstrain, llzkFunction_CallOpCalleeIsStructCompute,
llzkFunction_CallOpCalleeIsStructConstrain, llzkFunction_CallOpGetSelfValueFromCompute,
llzkFunction_CallOpGetSelfValueFromConstrain, llzkFunction_FuncDefOpCreateWithAttrsAndArgAttrs,
llzkFunction_FuncDefOpGetFullyQualifiedName, llzkFunction_FuncDefOpGetSelfValueFromCompute,
llzkFunction_FuncDefOpGetSelfValueFromConstrain,
llzkFunction_FuncDefOpGetSingleResultTypeOfCompute,
llzkFunction_FuncDefOpHasAllowConstraintAttr,
llzkFunction_FuncDefOpHasAllowNonNativeFieldOpsAttr, llzkFunction_FuncDefOpHasAllowWitnessAttr,
llzkFunction_FuncDefOpHasArgPublicAttr, llzkFunction_FuncDefOpIsInStruct,
llzkFunction_FuncDefOpIsStructCompute, llzkFunction_FuncDefOpIsStructConstrain,
llzkFunction_FuncDefOpNameIsCompute, llzkFunction_FuncDefOpNameIsConstrain,
llzkFunction_FuncDefOpSetAllowConstraintAttr,
llzkFunction_FuncDefOpSetAllowNonNativeFieldOpsAttr, llzkFunction_FuncDefOpSetAllowWitnessAttr,
llzkOperationIsA_Function_CallOp, llzkOperationIsA_Function_FuncDefOp,
};
use melior::{
Context, StringRef,
ir::{
Attribute, AttributeLike, BlockLike as _, Location, Operation, RegionLike as _, Type,
TypeLike, Value,
attribute::{ArrayAttribute, TypeAttribute},
block::BlockArgument,
operation::{OperationBuilder, OperationLike, OperationMutLike},
r#type::FunctionType,
},
};
use mlir_sys::{MlirAttribute, MlirNamedAttribute, mlirDictionaryAttrGet, mlirNamedAttributeGet};
use std::ptr::null;
fn create_out_of_bounds_error<'c: 'a, 'a>(
func: &(impl FuncDefOpLike<'c, 'a> + ?Sized),
idx: usize,
) -> Error {
match SymbolRefAttribute::try_from(func.fully_qualified_name()) {
Ok(fqn) => Error::OutOfBoundsArgument(Some(fqn.to_string()), idx),
Err(err) => err.into(),
}
}
pub trait FuncDefOpLike<'c: 'a, 'a>: OperationLike<'c, 'a> {
fn has_allow_constraint_attr(&self) -> bool {
unsafe { llzkFunction_FuncDefOpHasAllowConstraintAttr(self.to_raw()) }
}
fn set_allow_constraint_attr(&self, value: bool) {
unsafe { llzkFunction_FuncDefOpSetAllowConstraintAttr(self.to_raw(), value) }
}
fn has_allow_witness_attr(&self) -> bool {
unsafe { llzkFunction_FuncDefOpHasAllowWitnessAttr(self.to_raw()) }
}
fn set_allow_witness_attr(&self, value: bool) {
unsafe { llzkFunction_FuncDefOpSetAllowWitnessAttr(self.to_raw(), value) }
}
fn has_allow_non_native_field_ops_attr(&self) -> bool {
unsafe { llzkFunction_FuncDefOpHasAllowNonNativeFieldOpsAttr(self.to_raw()) }
}
fn set_allow_non_native_field_ops_attr(&self, value: bool) {
unsafe { llzkFunction_FuncDefOpSetAllowNonNativeFieldOpsAttr(self.to_raw(), value) }
}
fn arg_is_pub(&self, idx: u32) -> bool {
unsafe { llzkFunction_FuncDefOpHasArgPublicAttr(self.to_raw(), idx) }
}
fn fully_qualified_name(&self) -> Attribute<'c> {
unsafe {
Attribute::from_raw(llzkFunction_FuncDefOpGetFullyQualifiedName(
self.to_raw(),
false,
))
}
}
fn name_is_compute(&self) -> bool {
unsafe { llzkFunction_FuncDefOpNameIsCompute(self.to_raw()) }
}
fn name_is_constrain(&self) -> bool {
unsafe { llzkFunction_FuncDefOpNameIsConstrain(self.to_raw()) }
}
fn is_in_struct(&self) -> bool {
unsafe { llzkFunction_FuncDefOpIsInStruct(self.to_raw()) }
}
fn is_struct_compute(&self) -> bool {
unsafe { llzkFunction_FuncDefOpIsStructCompute(self.to_raw()) }
}
fn is_struct_constrain(&self) -> bool {
unsafe { llzkFunction_FuncDefOpIsStructConstrain(self.to_raw()) }
}
fn self_value_of_compute(&self) -> Result<Value<'c, 'a>, Error> {
if self.name_is_compute() {
Ok(unsafe {
Value::from_raw(llzkFunction_FuncDefOpGetSelfValueFromCompute(self.to_raw()))
})
} else {
Err(Error::ExpectedFunctionName(&llzk_sys::FUNC_NAME_COMPUTE))
}
}
fn self_value_of_constrain(&self) -> Result<Value<'c, 'a>, Error> {
if self.name_is_constrain() {
Ok(unsafe {
Value::from_raw(llzkFunction_FuncDefOpGetSelfValueFromConstrain(
self.to_raw(),
))
})
} else {
Err(Error::ExpectedFunctionName(&llzk_sys::FUNC_NAME_CONSTRAIN))
}
}
fn result_type_of_compute(&self) -> StructType<'c> {
unsafe {
Type::from_raw(llzkFunction_FuncDefOpGetSingleResultTypeOfCompute(
self.to_raw(),
))
}
.try_into()
.expect("struct type")
}
fn argument(&self, idx: usize) -> Result<BlockArgument<'c, 'a>, Error> {
self.region(0)
.map_err(Into::into)
.and_then(|region| {
region
.first_block()
.ok_or(create_out_of_bounds_error(self, idx))
})
.and_then(|block| block.argument(idx).map_err(Into::into))
}
fn argument_attr(&self, idx: usize, name: &str) -> Result<Attribute<'c>, Error> {
let arg_attrs: ArrayAttribute = self.attribute("arg_attrs")?.try_into()?;
let arg = arg_attrs.element(idx)?;
let name_ref = StringRef::new(name);
unsafe {
Attribute::from_option_raw(mlir_sys::mlirDictionaryAttrGetElementByName(
arg.to_raw(),
name_ref.to_raw(),
))
}
.ok_or_else(|| Error::AttributeNotFound(name.to_string()))
}
fn get_function_type_attribute(&self) -> Result<FunctionType<'c>, Error> {
let attr = self.attribute("function_type")?;
let type_attr: TypeAttribute<'c> = attr.try_into()?;
let func_type: FunctionType<'c> = type_attr.value().try_into()?;
Ok(func_type)
}
}
pub trait FuncDefOpMutLike<'c: 'a, 'a>: FuncDefOpLike<'c, 'a> + OperationMutLike<'c, 'a> {}
llzk_op_type!(
FuncDefOp,
llzkOperationIsA_Function_FuncDefOp,
"function.def"
);
impl<'a, 'c: 'a> FuncDefOpLike<'c, 'a> for FuncDefOp<'c> {}
impl<'a, 'c: 'a> FuncDefOpLike<'c, 'a> for FuncDefOpRef<'c, 'a> {}
impl<'a, 'c: 'a> FuncDefOpLike<'c, 'a> for FuncDefOpRefMut<'c, 'a> {}
impl<'a, 'c: 'a> FuncDefOpMutLike<'c, 'a> for FuncDefOp<'c> {}
impl<'a, 'c: 'a> FuncDefOpMutLike<'c, 'a> for FuncDefOpRefMut<'c, 'a> {}
pub trait CallOpLike<'c: 'a, 'a>: OperationLike<'c, 'a> {
fn callee_is_compute(&self) -> bool {
unsafe { llzkFunction_CallOpCalleeIsCompute(self.to_raw()) }
}
fn callee_is_constrain(&self) -> bool {
unsafe { llzkFunction_CallOpCalleeIsConstrain(self.to_raw()) }
}
fn callee_is_struct_compute(&self) -> bool {
unsafe { llzkFunction_CallOpCalleeIsStructCompute(self.to_raw()) }
}
fn callee_is_struct_constrain(&self) -> bool {
unsafe { llzkFunction_CallOpCalleeIsStructConstrain(self.to_raw()) }
}
fn self_value_of_compute(&self) -> Result<Value<'c, 'a>, Error> {
if self.callee_is_compute() {
Ok(unsafe {
Value::from_raw(llzkFunction_CallOpGetSelfValueFromCompute(self.to_raw()))
})
} else {
Err(Error::ExpectedFunctionName(&llzk_sys::FUNC_NAME_COMPUTE))
}
}
fn self_value_of_constrain(&self) -> Result<Value<'c, 'a>, Error> {
if self.callee_is_constrain() {
Ok(unsafe {
Value::from_raw(llzkFunction_CallOpGetSelfValueFromConstrain(self.to_raw()))
})
} else {
Err(Error::ExpectedFunctionName(&llzk_sys::FUNC_NAME_CONSTRAIN))
}
}
}
llzk_op_type!(CallOp, llzkOperationIsA_Function_CallOp, "function.call");
impl<'a, 'c: 'a> CallOpLike<'c, 'a> for CallOp<'c> {}
impl<'a, 'c: 'a> CallOpLike<'c, 'a> for CallOpRef<'c, 'a> {}
impl<'a, 'c: 'a> CallOpLike<'c, 'a> for CallOpRefMut<'c, 'a> {}
fn tuple_to_named_attr((name, attr): &NamedAttribute) -> MlirNamedAttribute {
unsafe { mlirNamedAttributeGet(name.to_raw(), attr.to_raw()) }
}
fn prepare_arg_attrs<'c>(
arg_attrs: Option<&[Vec<NamedAttribute<'c>>]>,
input_count: usize,
ctx: &'c Context,
) -> Vec<MlirAttribute> {
log::debug!("prepare_arg_attrs(\n{arg_attrs:?},\n{input_count},\n{ctx:?})");
let Some(arg_attrs) = arg_attrs else {
return vec![unsafe { mlirDictionaryAttrGet(ctx.to_raw(), 0, null()) }; input_count];
};
assert_eq!(arg_attrs.len(), input_count);
arg_attrs
.iter()
.map(|arg_attr| {
let named_attrs = Vec::from_iter(arg_attr.iter().map(tuple_to_named_attr));
unsafe {
mlirDictionaryAttrGet(
ctx.to_raw(),
named_attrs.len() as isize,
named_attrs.as_ptr(),
)
}
})
.collect()
}
pub fn def<'c>(
location: Location<'c>,
name: &str,
r#type: FunctionType<'c>,
attrs: &[NamedAttribute<'c>],
arg_attrs: Option<&[Vec<NamedAttribute<'c>>]>,
) -> Result<FuncDefOp<'c>, Error> {
let ctx = location.context();
let name = StringRef::new(name);
let attrs: Vec<_> = attrs.iter().map(tuple_to_named_attr).collect();
let arg_attrs = prepare_arg_attrs(arg_attrs, r#type.input_count(), unsafe { ctx.to_ref() });
unsafe {
Operation::from_raw(llzkFunction_FuncDefOpCreateWithAttrsAndArgAttrs(
location.to_raw(),
name.to_raw(),
r#type.to_raw(),
attrs.len() as isize,
attrs.as_ptr(),
arg_attrs.len() as isize,
arg_attrs.as_ptr(),
))
}
.try_into()
}
#[inline]
pub fn is_func_def<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "function.def")
}
pub fn call<'c>(
builder: &OpBuilder<'c>,
location: Location<'c>,
name: impl SymbolRefAttrLike<'c>,
args: &[Value<'c, '_>],
return_types: &[impl TypeLike<'c>],
) -> Result<CallOp<'c>, Error> {
unsafe {
Operation::from_raw(llzkFunction_CallOpBuild(
builder.to_raw(),
location.to_raw(),
return_types.len() as isize,
return_types.as_ptr() as *const _,
name.to_raw(),
args.len() as isize,
args.as_ptr() as *const _,
))
}
.try_into()
}
#[inline]
pub fn is_func_call<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "function.call")
}
pub fn r#return<'c>(location: Location<'c>, values: &[Value<'c, '_>]) -> Operation<'c> {
OperationBuilder::new("function.return", location)
.add_operands(values)
.build()
.unwrap()
}
#[inline]
pub fn is_func_return<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "function.return")
}