#[cfg(feature = "mlir")]
use super::MlirError;
#[cfg(feature = "mlir")]
use crate::ast::{BinaryOp, UnaryOp};
#[cfg(feature = "mlir")]
use melior::{
dialect::{arith, func, scf},
ir::{
attribute::{IntegerAttribute, StringAttribute, TypeAttribute},
operation::OperationBuilder,
r#type::{FunctionType, IntegerType},
Block, Location, Operation, Region, Type, Value, ValueLike,
},
Context,
};
#[cfg(feature = "mlir")]
pub struct OpBuilder<'c> {
context: &'c Context,
}
#[cfg(feature = "mlir")]
impl<'c> OpBuilder<'c> {
pub fn new(context: &'c Context) -> Self {
Self { context }
}
pub fn build_binary_arith(
&self,
op: BinaryOp,
lhs: Value<'c, '_>,
rhs: Value<'c, '_>,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
match op {
BinaryOp::Add => Ok(arith::addi(lhs, rhs, location).into()),
BinaryOp::Sub => Ok(arith::subi(lhs, rhs, location).into()),
BinaryOp::Mul => Ok(arith::muli(lhs, rhs, location).into()),
BinaryOp::Div => Ok(arith::divsi(lhs, rhs, location).into()),
BinaryOp::Mod => Ok(arith::remsi(lhs, rhs, location).into()),
BinaryOp::Eq => {
Ok(arith::cmpi(self.context, arith::CmpiPredicate::Eq, lhs, rhs, location).into())
}
BinaryOp::Ne => {
Ok(arith::cmpi(self.context, arith::CmpiPredicate::Ne, lhs, rhs, location).into())
}
BinaryOp::Lt => {
Ok(arith::cmpi(self.context, arith::CmpiPredicate::Slt, lhs, rhs, location).into())
}
BinaryOp::Le => {
Ok(arith::cmpi(self.context, arith::CmpiPredicate::Sle, lhs, rhs, location).into())
}
BinaryOp::Gt => {
Ok(arith::cmpi(self.context, arith::CmpiPredicate::Sgt, lhs, rhs, location).into())
}
BinaryOp::Ge => {
Ok(arith::cmpi(self.context, arith::CmpiPredicate::Sge, lhs, rhs, location).into())
}
BinaryOp::And => Ok(arith::andi(lhs, rhs, location).into()),
BinaryOp::Or => Ok(arith::ori(lhs, rhs, location).into()),
BinaryOp::Pow => Err(MlirError::new(
"exponentiation (^) requires math dialect or custom implementation",
)),
BinaryOp::Pipe | BinaryOp::Compose | BinaryOp::Apply | BinaryOp::Bind => {
Err(MlirError::new(format!(
"functional operator {:?} not supported in MLIR lowering",
op
)))
}
BinaryOp::Member => Err(MlirError::new(
"member access requires struct/object lowering",
)),
BinaryOp::Map | BinaryOp::Ap => Err(MlirError::new(format!(
"functor operator {:?} requires custom dialect",
op
))),
}
}
pub fn build_binary_float(
&self,
op: BinaryOp,
lhs: Value<'c, '_>,
rhs: Value<'c, '_>,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
match op {
BinaryOp::Add => Ok(arith::addf(lhs, rhs, location).into()),
BinaryOp::Sub => Ok(arith::subf(lhs, rhs, location).into()),
BinaryOp::Mul => Ok(arith::mulf(lhs, rhs, location).into()),
BinaryOp::Div => Ok(arith::divf(lhs, rhs, location).into()),
BinaryOp::Mod => Ok(arith::remf(lhs, rhs, location).into()),
BinaryOp::Eq => {
Ok(arith::cmpf(self.context, arith::CmpfPredicate::Oeq, lhs, rhs, location).into())
}
BinaryOp::Ne => {
Ok(arith::cmpf(self.context, arith::CmpfPredicate::One, lhs, rhs, location).into())
}
BinaryOp::Lt => {
Ok(arith::cmpf(self.context, arith::CmpfPredicate::Olt, lhs, rhs, location).into())
}
BinaryOp::Le => {
Ok(arith::cmpf(self.context, arith::CmpfPredicate::Ole, lhs, rhs, location).into())
}
BinaryOp::Gt => {
Ok(arith::cmpf(self.context, arith::CmpfPredicate::Ogt, lhs, rhs, location).into())
}
BinaryOp::Ge => {
Ok(arith::cmpf(self.context, arith::CmpfPredicate::Oge, lhs, rhs, location).into())
}
_ => Err(MlirError::new(format!(
"operator {:?} not supported for floating-point",
op
))),
}
}
pub fn build_unary(
&self,
op: UnaryOp,
operand: Value<'c, '_>,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
match op {
UnaryOp::Neg => {
let zero = self.build_constant_i64(0, location)?;
let zero_value = zero.result(0)?.into();
Ok(arith::subi(zero_value, operand, location).into())
}
UnaryOp::Not => {
let one = self.build_constant_i1(true, location)?;
let one_value = one.result(0)?.into();
Ok(arith::xori(operand, one_value, location).into())
}
UnaryOp::Quote | UnaryOp::Reflect => Err(MlirError::new(format!(
"metaprogramming operator {:?} not supported in MLIR lowering",
op
))),
}
}
pub fn build_constant_i64(
&self,
value: i64,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
let r#type = IntegerType::new(self.context, 64).into();
Ok(arith::constant(
self.context,
IntegerAttribute::new(value.into(), r#type).into(),
location,
)
.into())
}
pub fn build_constant_i1(
&self,
value: bool,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
let r#type = IntegerType::new(self.context, 1).into();
let int_value = if value { 1i64 } else { 0i64 };
Ok(arith::constant(
self.context,
IntegerAttribute::new(int_value.into(), r#type).into(),
location,
)
.into())
}
pub fn build_if(
&self,
condition: Value<'c, '_>,
result_types: &[Type<'c>],
then_region: Region<'c>,
else_region: Option<Region<'c>>,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
let regions = if let Some(else_r) = else_region {
vec![then_region, else_r]
} else {
vec![then_region]
};
let mut op_builder = OperationBuilder::new("scf.if", location)
.add_operands(&[condition])
.add_results(result_types);
for region in regions {
op_builder = op_builder.add_regions([region]);
}
op_builder
.build()
.map_err(|e| MlirError::new(format!("failed to create scf.if operation: {}", e)))
}
pub fn build_for(
&self,
lower_bound: Value<'c, '_>,
upper_bound: Value<'c, '_>,
step: Value<'c, '_>,
init_args: &[Value<'c, '_>],
body_region: Region<'c>,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
let mut operands = vec![lower_bound, upper_bound, step];
operands.extend_from_slice(init_args);
let result_types: Vec<Type> = init_args.iter().map(|v| v.r#type()).collect();
OperationBuilder::new("scf.for", location)
.add_operands(&operands)
.add_results(&result_types)
.add_regions([body_region])
.build()
.map_err(|e| MlirError::new(format!("failed to create scf.for operation: {}", e)))
}
pub fn build_while(
&self,
init_args: &[Value<'c, '_>],
before_region: Region<'c>,
after_region: Region<'c>,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
let result_types: Vec<Type> = init_args.iter().map(|v| v.r#type()).collect();
OperationBuilder::new("scf.while", location)
.add_operands(init_args)
.add_results(&result_types)
.add_regions([before_region, after_region])
.build()
.map_err(|e| MlirError::new(format!("failed to create scf.while operation: {}", e)))
}
pub fn build_func(
&self,
name: &str,
function_type: FunctionType<'c>,
body_region: Region<'c>,
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
OperationBuilder::new("func.func", location)
.add_attributes(&[
(
StringAttribute::new(self.context, "sym_name").into(),
StringAttribute::new(self.context, name).into(),
),
(
StringAttribute::new(self.context, "function_type").into(),
TypeAttribute::new(function_type.into()).into(),
),
])
.add_regions([body_region])
.build()
.map_err(|e| MlirError::new(format!("failed to create func.func operation: {}", e)))
}
pub fn build_call(
&self,
callee: &str,
arguments: &[Value<'c, '_>],
result_types: &[Type<'c>],
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
OperationBuilder::new("func.call", location)
.add_attributes(&[(
StringAttribute::new(self.context, "callee").into(),
StringAttribute::new(self.context, callee).into(),
)])
.add_operands(arguments)
.add_results(result_types)
.build()
.map_err(|e| MlirError::new(format!("failed to create func.call operation: {}", e)))
}
pub fn build_return(
&self,
operands: &[Value<'c, '_>],
location: Location<'c>,
) -> Result<Operation<'c>, MlirError> {
OperationBuilder::new("func.return", location)
.add_operands(operands)
.build()
.map_err(|e| MlirError::new(format!("failed to create func.return operation: {}", e)))
}
}
#[cfg(all(test, feature = "mlir"))]
mod tests {
use super::*;
use melior::ir::Location;
#[test]
fn test_opbuilder_creation() {
let context = Context::new();
let _builder = OpBuilder::new(&context);
}
#[test]
fn test_build_constant_i64() {
let context = Context::new();
let builder = OpBuilder::new(&context);
let location = Location::unknown(&context);
let const_op = builder.build_constant_i64(42, location);
assert!(const_op.is_ok());
}
#[test]
fn test_build_constant_i1() {
let context = Context::new();
let builder = OpBuilder::new(&context);
let location = Location::unknown(&context);
let const_op = builder.build_constant_i1(true, location);
assert!(const_op.is_ok());
}
#[test]
fn test_unsupported_operations() {
let context = Context::new();
let builder = OpBuilder::new(&context);
let location = Location::unknown(&context);
let const_op = builder.build_constant_i64(1, location).unwrap();
let dummy_value = const_op.result(0).unwrap().into();
let result = builder.build_binary_arith(BinaryOp::Pow, dummy_value, dummy_value, location);
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("exponentiation"));
let result = builder.build_binary_arith(BinaryOp::Pipe, dummy_value, dummy_value, location);
assert!(result.is_err());
assert!(result.unwrap_err().message.contains("functional operator"));
let result = builder.build_unary(UnaryOp::Quote, dummy_value, location);
assert!(result.is_err());
assert!(result
.unwrap_err()
.message
.contains("metaprogramming operator"));
}
}