use cubecl_core::{ir::StorageType, prelude::KernelDefinition};
use cubecl_opt::Optimizer;
use tracel_llvm::mlir_rs::{
Context, ExecutionEngine,
ir::{Location, operation::OperationLike},
pass::{self, PassManager},
};
use super::{passes::shared_memories::SharedMemories, visitor::Visitor};
pub(super) struct Module<'a> {
module: tracel_llvm::mlir_rs::ir::Module<'a>,
#[allow(unused)]
name: String,
location: Location<'a>,
context: &'a Context,
}
impl<'a> Module<'a> {
pub(super) fn new(context: &'a Context, name: String) -> Self {
let location = Location::unknown(context);
let module = tracel_llvm::mlir_rs::ir::Module::new(location);
Self {
module,
context,
name,
location,
}
}
pub(super) fn visit_kernel(
&mut self,
kernel: &KernelDefinition,
opt: &Optimizer,
shared_memories: &SharedMemories,
addr_type: StorageType,
) {
Visitor::visit_kernel(
self.context,
self.location,
kernel,
&self.module,
opt,
shared_memories,
addr_type,
)
}
pub(super) fn run_pass(&mut self) {
let pass_manager = PassManager::new(self.context);
pass_manager.enable_verifier(true);
#[cfg(feature = "mlir-dump")]
if let Ok(dir) = std::env::var("CUBECL_DEBUG_MLIR") {
use std::path::PathBuf;
use tracel_llvm::mlir_rs::{
ir::operation::OperationPrintingFlags, pass::PassIrPrintingOptions,
};
let dir = dir.to_string() + "/" + &self.name;
pass_manager.enable_ir_printing(&PassIrPrintingOptions {
before_all: true,
after_all: true,
module_scope: true,
on_change: true,
on_failure: true,
flags: OperationPrintingFlags::new(),
tree_printing_path: PathBuf::from(dir),
});
}
pass_manager.add_pass(pass::transform::create_canonicalizer());
pass_manager.add_pass(pass::conversion::create_math_to_libm());
pass_manager.add_pass(pass::conversion::create_scf_to_control_flow());
pass_manager.add_pass(pass::conversion::create_finalize_mem_ref_to_llvm());
pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts());
let func_passes = pass_manager.nested_under("func.func");
func_passes.add_pass(pass::conversion::create_index_to_llvm());
func_passes.add_pass(pass::conversion::create_math_to_llvm());
func_passes.add_pass(pass::conversion::create_vector_to_llvm());
func_passes.add_pass(pass::conversion::create_arith_to_llvm());
pass_manager.add_pass(pass::conversion::create_control_flow_to_llvm());
pass_manager.add_pass(pass::conversion::create_func_to_llvm());
pass_manager.add_pass(pass::transform::create_inliner());
pass_manager.add_pass(pass::conversion::create_reconcile_unrealized_casts());
pass_manager.add_pass(pass::transform::create_sccp());
pass_manager.add_pass(pass::transform::create_mem_2_reg());
pass_manager.add_pass(pass::transform::create_control_flow_sink());
pass_manager.add_pass(pass::transform::create_cse());
if let Err(err) = pass_manager.run(&mut self.module) {
panic!("{}", err);
}
self.module.as_operation().verify();
}
pub(super) fn into_execution_engine(self) -> ExecutionEngine {
ExecutionEngine::new(&self.module, 0, &[], true)
}
}
#[cfg(test)]
mod tests {
use tracel_llvm::mlir_rs::{
Context,
dialect::{arith, func, index, memref, scf},
ir::{
Block, Region, RegionLike, Type,
attribute::{IntegerAttribute, StringAttribute, TypeAttribute},
block::BlockLike,
operation::OperationLike,
r#type::{FunctionType, IntegerType, MemRefType},
},
utility::{register_all_dialects, register_all_llvm_translations, register_all_passes},
};
#[test]
fn it_lowers_scf_if_memref_result_through_cf_to_llvm() {
register_all_passes();
let registry = tracel_llvm::mlir_rs::dialect::DialectRegistry::new();
register_all_dialects(®istry);
let context = Context::new();
register_all_llvm_translations(&context);
context.enable_multi_threading(false);
context.append_dialect_registry(®istry);
context.load_all_available_dialects();
let mut module = super::Module::new(&context, "scf_if_memref_result".to_string());
let location = module.location;
let memref_ty = MemRefType::new(Type::bfloat16(&context), &[8], None, None);
let i32_ty: Type = IntegerType::new(&context, 32).into();
module.module.body().append_operation(func::func(
&context,
StringAttribute::new(&context, "kernel"),
TypeAttribute::new(FunctionType::new(&context, &[], &[]).into()),
{
let block = Block::new(&[]);
let c0 = block.append_operation(arith::constant(
&context,
IntegerAttribute::new(i32_ty, 0).into(),
location,
));
let c1 = block.append_operation(arith::constant(
&context,
IntegerAttribute::new(i32_ty, 1).into(),
location,
));
let cond = block.append_operation(arith::cmpi(
&context,
arith::CmpiPredicate::Ne,
c0.result(0).unwrap().into(),
c1.result(0).unwrap().into(),
location,
));
let a = block.append_operation(memref::alloca(
&context,
memref_ty,
&[],
&[],
None,
location,
));
let b = block.append_operation(memref::alloca(
&context,
memref_ty,
&[],
&[],
None,
location,
));
let then_region = {
let then_block = Block::new(&[]);
then_block
.append_operation(scf::r#yield(&[a.result(0).unwrap().into()], location));
let region = Region::new();
region.append_block(then_block);
region
};
let else_region = {
let else_block = Block::new(&[]);
else_block
.append_operation(scf::r#yield(&[b.result(0).unwrap().into()], location));
let region = Region::new();
region.append_block(else_block);
region
};
let if_op = block.append_operation(scf::r#if(
cond.result(0).unwrap().into(),
&[memref_ty.into()],
then_region,
else_region,
location,
));
let idx0 = block.append_operation(index::constant(
&context,
IntegerAttribute::new(IntegerType::new(&context, 64).into(), 0),
location,
));
let _ = block.append_operation(memref::load(
if_op.result(0).unwrap().into(),
&[idx0.result(0).unwrap().into()],
location,
));
block.append_operation(func::r#return(&[], location));
let region = Region::new();
region.append_block(block);
region
},
&[],
location,
));
module.run_pass();
assert!(module.module.as_operation().verify());
}
}