use crate::convert::apply_rewrites;
use crate::convert::ChangedOp;
use crate::convert::Pass;
use crate::convert::Rewrite;
use crate::convert::RewriteResult;
use crate::dialect;
use crate::dialect::arith;
use crate::dialect::func::Call;
use crate::dialect::func::Func;
use crate::dialect::llvm;
use crate::dialect::llvm::PointerType;
use crate::ir::APInt;
use crate::ir::Block;
use crate::ir::BlockArgumentName;
use crate::ir::IntegerAttr;
use crate::ir::IntegerType;
use crate::ir::Op;
use crate::ir::OpOperand;
use crate::ir::Operation;
use crate::ir::StringAttr;
use crate::ir::Value;
use crate::shared::Shared;
use crate::shared::SharedExt;
use anyhow::Result;
use dialect::experimental::PrintfOp;
use std::str::FromStr;
use std::sync::Arc;
struct PrintLowering;
impl PrintLowering {
fn text_constant(parent: &Shared<Block>, op: &PrintfOp) -> (Shared<dyn Op>, usize) {
let mut const_operation = Operation::default();
const_operation.set_parent(Some(parent.clone()));
let text = op.text().clone();
let text = text.c_string();
let len = text.len();
let typ = llvm::ArrayType::for_bytes(&text);
let typ = Shared::new(typ.into());
let result = const_operation.add_new_op_result(typ);
let const_op = llvm::ConstantOp::from_operation(const_operation);
const_op.set_value(Arc::new(StringAttr::new(text)));
let const_op = Shared::new(const_op.into());
result.set_defining_op(Some(const_op.clone()));
(const_op, len)
}
fn len_specifier(parent: &Shared<Block>, len: usize) -> Shared<dyn Op> {
let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
let typ = IntegerType::from_str("i16").unwrap();
let result_type = Shared::new(typ.into());
let result = operation.add_new_op_result(result_type);
let op = arith::ConstantOp::from_operation(operation);
let len = APInt::from_str("i16", &len.to_string());
op.set_value(Arc::new(IntegerAttr::new(typ, len)));
let op = Shared::new(op.into());
result.set_defining_op(Some(op.clone()));
op
}
fn alloca_op(parent: &Shared<Block>, len: Shared<dyn Op>) -> Shared<dyn Op> {
let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
let typ = llvm::PointerType::new();
let result_type = Shared::new(typ.into());
let result = operation.add_new_op_result(result_type);
let array_size = len.rd().result(0);
let array_size = OpOperand::new(array_size);
let array_size = Shared::new(array_size.into());
operation.set_operand(0, array_size);
let mut op = llvm::AllocaOp::from_operation(operation);
op.set_element_type("i8".to_string());
let op = Shared::new(op.into());
result.set_defining_op(Some(op.clone()));
op
}
fn store_op(
parent: &Shared<Block>,
text: Shared<dyn Op>,
alloca: Shared<dyn Op>,
) -> Shared<dyn Op> {
let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
let mut op = llvm::StoreOp::from_operation(operation);
let value = text.rd().result(0);
let value = OpOperand::new(value);
op.set_value(Shared::new(value.into()));
let addr = alloca.rd().result(0);
let addr = OpOperand::new(addr);
op.set_addr(Shared::new(addr.into()));
Shared::new(op.into())
}
fn call_op(
parent: &Shared<Block>,
op: &PrintfOp,
alloca: Shared<dyn Op>,
set_varargs: bool,
) -> Shared<dyn Op> {
let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
{
let text_addr = alloca.rd().result(0);
let text_addr = OpOperand::new(text_addr);
let text_addr = Shared::new(text_addr.into());
operation.set_operand(0, text_addr);
}
if set_varargs {
let var = op.operation().rd().operand(1);
let var = var.expect("expected vararg");
operation.set_operand(1, var);
}
let typ = IntegerType::from_str("i32").unwrap();
let result_type = Shared::new(typ.into());
let result = operation.add_new_op_result(result_type);
let mut op = llvm::CallOp::from_operation(operation);
op.set_identifier("@printf".to_string());
if set_varargs {
let varargs = "!llvm.func<i32 (!llvm.ptr, ...)>";
let varargs = llvm::FunctionType::from_str(varargs).unwrap();
let varargs = Shared::new(varargs.into());
op.set_varargs(Some(varargs));
}
let op = Shared::new(op.into());
result.set_defining_op(Some(op.clone()));
op
}
fn top_level_op(op: Shared<dyn Op>) -> Shared<dyn Op> {
let mut out = op.clone();
for i in 0..1000 {
let parent_op = out.rd().parent_op();
match parent_op {
Some(parent_op) => out = parent_op,
None => break,
}
if i == 999 {
panic!("infinite loop");
}
}
out
}
fn contains_printf(top_level_op: Shared<dyn Op>) -> bool {
let ops = top_level_op.rd().ops();
for op in ops {
let op = op.rd();
if op.is_func() {
let func = match op.as_any().downcast_ref::<llvm::FuncOp>() {
Some(func) => func,
None => continue,
};
if func.identifier() == Some("@printf".to_string()) {
return true;
}
}
}
false
}
fn printf_func_def(parent: Shared<Block>, set_varargs: bool) -> Result<Shared<dyn Op>> {
let mut operation = Operation::default();
operation.set_parent(Some(parent.clone()));
let result_type = IntegerType::from_str("i32").unwrap();
let result_type = Shared::new(result_type.into());
operation.set_anonymous_result(result_type)?;
let mut op = llvm::FuncOp::from_operation(operation);
op.set_identifier("@printf".to_string());
op.set_sym_visibility(Some("private".to_string()));
{
let arg_type = PointerType::new();
let arg_type = Shared::new(arg_type.into());
let name = BlockArgumentName::Anonymous;
let argument = crate::ir::BlockArgument::new(name, arg_type);
let value = Value::BlockArgument(argument);
let value = Shared::new(value.into());
let operation = op.operation();
operation.wr().set_argument(0, value);
}
if set_varargs {
let value = Value::Variadic;
let value = Shared::new(value.into());
op.operation().wr().set_argument(1, value);
}
let op = Shared::new(op.into());
Ok(op)
}
fn define_printf(op: Shared<dyn Op>, set_varargs: bool) -> Result<()> {
let top_level_op = Self::top_level_op(op.clone());
if !Self::contains_printf(top_level_op.clone()) {
let ops = top_level_op.rd().ops();
let op = ops[0].clone();
let parent = op.rd().operation().rd().parent().unwrap();
op.rd()
.insert_before(Self::printf_func_def(parent, set_varargs)?);
}
Ok(())
}
}
impl Rewrite for PrintLowering {
fn name(&self) -> &'static str {
"experimental_to_mlir::PrintLowering"
}
fn parallelizable(&self) -> bool {
false
}
fn rewrite(&self, op: Shared<dyn Op>) -> Result<RewriteResult> {
let op_rd = op.clone();
let op_rd = op_rd.rd();
let op_rd = match op_rd
.as_any()
.downcast_ref::<dialect::experimental::PrintfOp>()
{
Some(op_rd) => op_rd,
None => return Ok(RewriteResult::Unchanged),
};
let set_varargs = 1 < op.rd().operation().rd().operands().vec().rd().len();
let parent = op_rd.operation().rd().parent();
let parent = parent.expect("no parent");
let (text, len) = PrintLowering::text_constant(&parent, op_rd);
op_rd.insert_before(text.clone());
let len = PrintLowering::len_specifier(&parent, len);
op_rd.insert_before(len.clone());
let alloca = PrintLowering::alloca_op(&parent, len);
op_rd.insert_before(alloca.clone());
let store = PrintLowering::store_op(&parent, text.clone(), alloca.clone());
op_rd.insert_before(store);
PrintLowering::define_printf(op, set_varargs)?;
let call = PrintLowering::call_op(&parent, op_rd, alloca, set_varargs);
op_rd.insert_before(call.clone());
op_rd.remove();
Ok(RewriteResult::Changed(ChangedOp::new(text)))
}
}
pub struct ConvertExperimentalToMLIR;
impl Pass for ConvertExperimentalToMLIR {
const NAME: &'static str = "convert-experimental-to-mlir";
fn convert(op: Shared<dyn Op>) -> Result<RewriteResult> {
let rewrites: Vec<&dyn Rewrite> = vec![&PrintLowering];
apply_rewrites(op, &rewrites)
}
}