use crate::{
builder::OpBuilderLike,
error::Error,
ident,
macros::llzk_op_type,
value_ext::{OwningValueRange, ValueRange},
};
use llzk_sys::{
llzkOperationIsA_Poly_TemplateExprOp, llzkOperationIsA_Poly_TemplateOp,
llzkOperationIsA_Poly_TemplateParamOp, llzkOperationIsA_Poly_YieldOp,
llzkPoly_ApplyMapOpBuildWithAffineMap, llzkPoly_TemplateExprOpBuild,
llzkPoly_TemplateExprOpGetInitializerRegion, llzkPoly_TemplateExprOpGetType,
llzkPoly_TemplateOpBuild, llzkPoly_TemplateOpGetBody, llzkPoly_TemplateOpGetBodyRegion,
llzkPoly_TemplateOpGetConstExprNames, llzkPoly_TemplateOpGetConstParamNames,
llzkPoly_TemplateOpHasConstExprNamed, llzkPoly_TemplateOpHasConstExprOps,
llzkPoly_TemplateOpHasConstParamNamed, llzkPoly_TemplateOpHasConstParamOps,
llzkPoly_TemplateOpNumConstExprOps, llzkPoly_TemplateOpNumConstParamOps,
llzkPoly_TemplateParamOpBuild, llzkPoly_TemplateParamOpGetTypeOpt,
llzkPoly_TemplateParamOpSetTypeOpt, llzkPoly_YieldOpBuild,
};
use melior::ir::{
Attribute, AttributeLike, Block, BlockLike as _, BlockRef, Identifier, Location, Operation,
OperationRef, RegionLike as _, RegionRef, Type, Value, ValueLike as _,
attribute::{FlatSymbolRefAttribute, StringAttribute, TypeAttribute},
operation::{OperationBuilder, OperationLike},
};
use mlir_sys::MlirAttribute;
fn collect_flat_sym_ref_attrs<'c>(
count: usize,
fill: impl FnOnce(*mut MlirAttribute),
) -> Vec<FlatSymbolRefAttribute<'c>> {
let mut raw_attrs: Vec<MlirAttribute> = Vec::with_capacity(count);
fill(raw_attrs.as_mut_ptr());
unsafe { raw_attrs.set_len(count) };
raw_attrs
.into_iter()
.map(|attr| FlatSymbolRefAttribute::try_from(unsafe { Attribute::from_raw(attr) }).unwrap())
.collect()
}
pub trait TemplateOpLike<'c: 'a, 'a>: OperationLike<'c, 'a> {
fn body_region(&self) -> RegionRef<'c, 'a> {
unsafe { RegionRef::from_raw(llzkPoly_TemplateOpGetBodyRegion(self.to_raw())) }
}
fn body(&self) -> BlockRef<'c, 'a> {
unsafe { BlockRef::from_raw(llzkPoly_TemplateOpGetBody(self.to_raw())) }
}
fn has_const_param_ops(&self) -> bool {
unsafe { llzkPoly_TemplateOpHasConstParamOps(self.to_raw()) }
}
fn has_const_expr_ops(&self) -> bool {
unsafe { llzkPoly_TemplateOpHasConstExprOps(self.to_raw()) }
}
fn const_param_names(&self) -> Vec<FlatSymbolRefAttribute<'c>> {
let num_attrs =
usize::try_from(unsafe { llzkPoly_TemplateOpNumConstParamOps(self.to_raw()) }).unwrap();
let raw_op = self.to_raw();
collect_flat_sym_ref_attrs(num_attrs, |ptr| unsafe {
llzkPoly_TemplateOpGetConstParamNames(raw_op, ptr);
})
}
fn const_expr_names(&self) -> Vec<FlatSymbolRefAttribute<'c>> {
let num_attrs =
usize::try_from(unsafe { llzkPoly_TemplateOpNumConstExprOps(self.to_raw()) }).unwrap();
let raw_op = self.to_raw();
collect_flat_sym_ref_attrs(num_attrs, |ptr| unsafe {
llzkPoly_TemplateOpGetConstExprNames(raw_op, ptr);
})
}
fn has_const_param_named(&self, find: &str) -> bool {
let find = melior::StringRef::new(find);
unsafe { llzkPoly_TemplateOpHasConstParamNamed(self.to_raw(), find.to_raw()) }
}
fn has_const_expr_named(&self, find: &str) -> bool {
let find = melior::StringRef::new(find);
unsafe { llzkPoly_TemplateOpHasConstExprNamed(self.to_raw(), find.to_raw()) }
}
fn const_binding_ops(&self) -> Vec<TemplateSymbolBindingOpRef<'c, 'a>> {
let num_ops = usize::try_from(unsafe {
llzkPoly_TemplateOpNumConstParamOps(self.to_raw())
+ llzkPoly_TemplateOpNumConstExprOps(self.to_raw())
})
.unwrap();
let mut ops = Vec::with_capacity(num_ops);
let mut op = self.body().first_operation();
while let Some(cur) = op {
let raw = cur.to_raw();
if unsafe { llzkOperationIsA_Poly_TemplateParamOp(raw) } {
ops.push(TemplateSymbolBindingOpRef::Param(unsafe {
TemplateParamOpRef::from_raw(raw)
}));
} else if unsafe { llzkOperationIsA_Poly_TemplateExprOp(raw) } {
ops.push(TemplateSymbolBindingOpRef::Expr(unsafe {
TemplateExprOpRef::from_raw(raw)
}));
}
op = cur.next_in_block();
}
ops
}
}
llzk_op_type!(
TemplateOp,
llzkOperationIsA_Poly_TemplateOp,
"poly.template"
);
impl<'c: 'a, 'a> TemplateOpLike<'c, 'a> for TemplateOp<'c> {}
impl<'c: 'a, 'a> TemplateOpLike<'c, 'a> for TemplateOpRef<'c, 'a> {}
impl<'c: 'a, 'a> TemplateOpLike<'c, 'a> for TemplateOpRefMut<'c, 'a> {}
pub fn template<'c, 'a, B>(
builder: &B,
location: Location<'c>,
name: &str,
fill: impl FnOnce(&B) -> Result<(), Error>,
) -> Result<TemplateOpRef<'c, 'a>, Error>
where
B: OpBuilderLike<'c>,
{
let ctx = location.context();
let op = unsafe {
OperationRef::from_raw(llzkPoly_TemplateOpBuild(
builder.to_raw(),
location.to_raw(),
Identifier::new(ctx.to_ref(), name).to_raw(),
))
};
let op: TemplateOpRef<'c, 'a> = op.try_into()?;
let region = op.body_region();
let block = region
.first_block()
.unwrap_or_else(|| region.append_block(Block::new(&[])));
let saved = builder.save_insertion_point();
builder.set_insertion_point_at_start(block);
let res = fill(builder);
builder.restore_insertion_point(saved);
res.map(|_| op)
}
#[inline]
pub fn is_template_op<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "poly.template")
}
pub trait TemplateSymbolBindingOpLike<'c: 'a, 'a>: OperationLike<'c, 'a> {
fn sym_name_attr(&self) -> StringAttribute<'c> {
self.attribute("sym_name")
.and_then(StringAttribute::try_from)
.unwrap()
}
#[inline]
fn sym_name(&self) -> &'c str {
self.sym_name_attr().value()
}
fn type_opt(&self) -> Option<Type<'c>>;
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum TemplateSymbolBindingOp<'c> {
Param(TemplateParamOp<'c>),
Expr(TemplateExprOp<'c>),
}
impl<'c> TemplateSymbolBindingOp<'c> {
pub fn as_ref<'a>(&'a self) -> TemplateSymbolBindingOpRef<'c, 'a> {
match self {
Self::Param(op) => TemplateSymbolBindingOpRef::Param(unsafe {
TemplateParamOpRef::from_raw(op.to_raw())
}),
Self::Expr(op) => TemplateSymbolBindingOpRef::Expr(unsafe {
TemplateExprOpRef::from_raw(op.to_raw())
}),
}
}
}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateSymbolBindingOp<'c> {
fn type_opt(&self) -> Option<Type<'c>> {
match self {
Self::Param(op) => op.type_restriction(),
Self::Expr(op) => Some(op.expr_type()),
}
}
}
impl std::fmt::Display for TemplateSymbolBindingOp<'_> {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Param(op) => std::fmt::Display::fmt(op, formatter),
Self::Expr(op) => std::fmt::Display::fmt(op, formatter),
}
}
}
impl<'c: 'a, 'a> OperationLike<'c, 'a> for TemplateSymbolBindingOp<'c> {
fn to_raw(&self) -> mlir_sys::MlirOperation {
match self {
Self::Param(op) => op.to_raw(),
Self::Expr(op) => op.to_raw(),
}
}
}
impl<'c> From<TemplateParamOp<'c>> for TemplateSymbolBindingOp<'c> {
fn from(op: TemplateParamOp<'c>) -> Self {
Self::Param(op)
}
}
impl<'c> From<TemplateExprOp<'c>> for TemplateSymbolBindingOp<'c> {
fn from(op: TemplateExprOp<'c>) -> Self {
Self::Expr(op)
}
}
impl<'c> From<TemplateSymbolBindingOp<'c>> for Operation<'c> {
fn from(op: TemplateSymbolBindingOp<'c>) -> Self {
match op {
TemplateSymbolBindingOp::Param(inner) => inner.into(),
TemplateSymbolBindingOp::Expr(inner) => inner.into(),
}
}
}
impl<'c> TryFrom<Operation<'c>> for TemplateSymbolBindingOp<'c> {
type Error = crate::error::Error;
fn try_from(op: Operation<'c>) -> Result<Self, Self::Error> {
if is_param_op(&op) {
TemplateParamOp::try_from(op).map(Self::Param)
} else if is_expr_op(&op) {
TemplateExprOp::try_from(op).map(Self::Expr)
} else {
Err(Error::OperationExpected(
"poly.param or poly.expr",
op.to_string(),
))
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum TemplateSymbolBindingOpRef<'c, 'a> {
Param(TemplateParamOpRef<'c, 'a>),
Expr(TemplateExprOpRef<'c, 'a>),
}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateSymbolBindingOpRef<'c, 'a> {
fn type_opt(&self) -> Option<Type<'c>> {
match self {
Self::Param(op) => op.type_restriction(),
Self::Expr(op) => Some(op.expr_type()),
}
}
}
impl std::fmt::Display for TemplateSymbolBindingOpRef<'_, '_> {
fn fmt(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
match self {
Self::Param(op) => std::fmt::Display::fmt(op, formatter),
Self::Expr(op) => std::fmt::Display::fmt(op, formatter),
}
}
}
impl<'c: 'a, 'a> OperationLike<'c, 'a> for TemplateSymbolBindingOpRef<'c, 'a> {
fn to_raw(&self) -> mlir_sys::MlirOperation {
match self {
Self::Param(op) => op.to_raw(),
Self::Expr(op) => op.to_raw(),
}
}
}
impl<'c, 'a> From<TemplateParamOpRef<'c, 'a>> for TemplateSymbolBindingOpRef<'c, 'a> {
fn from(op: TemplateParamOpRef<'c, 'a>) -> Self {
Self::Param(op)
}
}
impl<'c, 'a> From<TemplateExprOpRef<'c, 'a>> for TemplateSymbolBindingOpRef<'c, 'a> {
fn from(op: TemplateExprOpRef<'c, 'a>) -> Self {
Self::Expr(op)
}
}
impl<'c, 'a> From<TemplateSymbolBindingOpRef<'c, 'a>> for OperationRef<'c, 'a> {
fn from(op: TemplateSymbolBindingOpRef<'c, 'a>) -> Self {
match op {
TemplateSymbolBindingOpRef::Param(inner) => inner.into(),
TemplateSymbolBindingOpRef::Expr(inner) => inner.into(),
}
}
}
impl<'c, 'a> TryFrom<OperationRef<'c, 'a>> for TemplateSymbolBindingOpRef<'c, 'a> {
type Error = crate::error::Error;
fn try_from(op: OperationRef<'c, 'a>) -> Result<Self, Self::Error> {
if is_param_op(&op) {
TemplateParamOpRef::try_from(op).map(Self::Param)
} else if is_expr_op(&op) {
TemplateExprOpRef::try_from(op).map(Self::Expr)
} else {
Err(Error::OperationExpected(
"poly.param or poly.expr",
op.to_string(),
))
}
}
}
pub trait TemplateExprOpLike<'c: 'a, 'a>: OperationLike<'c, 'a> {
fn initializer_region(&self) -> RegionRef<'c, 'a> {
unsafe { RegionRef::from_raw(llzkPoly_TemplateExprOpGetInitializerRegion(self.to_raw())) }
}
fn expr_type(&self) -> Type<'c> {
unsafe { Type::from_raw(llzkPoly_TemplateExprOpGetType(self.to_raw())) }
}
}
pub trait TemplateParamOpLike<'c: 'a, 'a>: OperationLike<'c, 'a> {
fn type_restriction(&self) -> Option<Type<'c>> {
let raw_attr = unsafe { llzkPoly_TemplateParamOpGetTypeOpt(self.to_raw()) };
if raw_attr.ptr.is_null() {
None
} else {
let attr = unsafe { Attribute::from_raw(raw_attr) };
let type_attr = TypeAttribute::try_from(attr)
.expect("malformed poly.param type restriction attribute");
Some(type_attr.value())
}
}
fn set_type_restriction(&self, type_opt: Option<Type<'c>>) {
let raw_attr = match type_opt {
Some(t) => TypeAttribute::new(t).to_raw(),
None => MlirAttribute {
ptr: std::ptr::null_mut(),
},
};
unsafe { llzkPoly_TemplateParamOpSetTypeOpt(self.to_raw(), raw_attr) }
}
}
llzk_op_type!(
TemplateExprOp,
llzkOperationIsA_Poly_TemplateExprOp,
"poly.expr"
);
llzk_op_type!(
TemplateParamOp,
llzkOperationIsA_Poly_TemplateParamOp,
"poly.param"
);
impl<'c: 'a, 'a> TemplateExprOpLike<'c, 'a> for TemplateExprOp<'c> {}
impl<'c: 'a, 'a> TemplateExprOpLike<'c, 'a> for TemplateExprOpRef<'c, 'a> {}
impl<'c: 'a, 'a> TemplateExprOpLike<'c, 'a> for TemplateExprOpRefMut<'c, 'a> {}
impl<'c: 'a, 'a> TemplateParamOpLike<'c, 'a> for TemplateParamOp<'c> {}
impl<'c: 'a, 'a> TemplateParamOpLike<'c, 'a> for TemplateParamOpRef<'c, 'a> {}
impl<'c: 'a, 'a> TemplateParamOpLike<'c, 'a> for TemplateParamOpRefMut<'c, 'a> {}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateParamOp<'c> {
fn type_opt(&self) -> Option<Type<'c>> {
self.type_restriction()
}
}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateParamOpRef<'c, 'a> {
fn type_opt(&self) -> Option<Type<'c>> {
self.type_restriction()
}
}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateParamOpRefMut<'c, 'a> {
fn type_opt(&self) -> Option<Type<'c>> {
self.type_restriction()
}
}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateExprOp<'c> {
fn type_opt(&self) -> Option<Type<'c>> {
Some(self.expr_type())
}
}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateExprOpRef<'c, 'a> {
fn type_opt(&self) -> Option<Type<'c>> {
Some(self.expr_type())
}
}
impl<'c: 'a, 'a> TemplateSymbolBindingOpLike<'c, 'a> for TemplateExprOpRefMut<'c, 'a> {
fn type_opt(&self) -> Option<Type<'c>> {
Some(self.expr_type())
}
}
pub fn param<'c, 'a>(
builder: &impl OpBuilderLike<'c>,
location: Location<'c>,
name: &str,
type_opt: Option<Type<'c>>,
) -> Result<TemplateParamOpRef<'c, 'a>, Error> {
let ctx = location.context();
let raw_type = match type_opt {
Some(t) => TypeAttribute::new(t).to_raw(),
None => MlirAttribute {
ptr: std::ptr::null_mut(),
},
};
unsafe {
OperationRef::from_raw(llzkPoly_TemplateParamOpBuild(
builder.to_raw(),
location.to_raw(),
Identifier::new(ctx.to_ref(), name).to_raw(),
raw_type,
))
}
.try_into()
}
#[inline]
pub fn is_param_op<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "poly.param")
}
pub fn expr<'c, 'a, B>(
builder: &B,
location: Location<'c>,
name: &str,
fill_cb: impl FnOnce(&B) -> Result<(), Error>,
) -> Result<TemplateExprOpRef<'c, 'a>, Error>
where
B: OpBuilderLike<'c>,
{
let ctx = location.context();
let op = unsafe {
OperationRef::from_raw(llzkPoly_TemplateExprOpBuild(
builder.to_raw(),
location.to_raw(),
Identifier::new(ctx.to_ref(), name).to_raw(),
))
};
let op: TemplateExprOpRef<'c, 'a> = op.try_into()?;
let region = op.initializer_region();
let block = region
.first_block()
.unwrap_or_else(|| region.append_block(Block::new(&[])));
let prev = builder.save_insertion_point();
builder.set_insertion_point_at_start(block);
let res = fill_cb(builder);
builder.restore_insertion_point(prev);
res.map(|_| op)
}
#[inline]
pub fn is_expr_op<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "poly.expr")
}
llzk_op_type!(YieldOp, llzkOperationIsA_Poly_YieldOp, "poly.yield");
pub fn r#yield<'c, 'a>(
builder: &impl OpBuilderLike<'c>,
location: Location<'c>,
val: Value<'c, '_>,
) -> Result<YieldOpRef<'c, 'a>, Error> {
unsafe {
OperationRef::from_raw(llzkPoly_YieldOpBuild(
builder.to_raw(),
location.to_raw(),
val.to_raw(),
))
}
.try_into()
}
#[inline]
pub fn is_yield_op<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "poly.yield")
}
pub fn read_const<'c>(location: Location<'c>, symbol: &str, result: Type<'c>) -> Operation<'c> {
let ctx = location.context();
OperationBuilder::new("poly.read_const", location)
.add_attributes(&[(
ident!(ctx, "const_name"),
FlatSymbolRefAttribute::new(unsafe { ctx.to_ref() }, symbol).into(),
)])
.add_results(&[result])
.build()
.expect("valid operation")
}
#[inline]
pub fn is_read_const_op<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "poly.read_const")
}
pub fn unifiable_cast<'c>(
location: Location<'c>,
input: Value<'c, '_>,
result: Type<'c>,
) -> Operation<'c> {
OperationBuilder::new("poly.unifiable_cast", location)
.add_operands(&[input])
.add_results(&[result])
.build()
.expect("valid operation")
}
#[inline]
pub fn is_unifiable_cast_op<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "poly.unifiable_cast")
}
pub fn applymap<'c, 'a>(
builder: &'c impl OpBuilderLike<'c>,
location: Location<'c>,
map: Attribute<'c>,
map_operands: &[Value<'c, '_>],
) -> OperationRef<'c, 'a> {
let value_range = OwningValueRange::from(map_operands);
assert!(unsafe { mlir_sys::mlirAttributeIsAAffineMap(map.to_raw()) });
let op = unsafe {
Operation::from_raw(llzkPoly_ApplyMapOpBuildWithAffineMap(
builder.to_raw(),
location.to_raw(),
mlir_sys::mlirAffineMapAttrGetValue(map.to_raw()),
ValueRange::try_from(&value_range).unwrap().to_raw(),
))
};
builder.insert(location, move |_, _| op)
}
#[inline]
pub fn is_applymap_op<'c: 'a, 'a>(op: &impl OperationLike<'c, 'a>) -> bool {
crate::operation::isa(op, "poly.applymap")
}