pub mod builtin;
pub(super) mod external_function;
pub(super) mod memref;
pub mod mlir_data;
pub mod mlir_engine;
pub mod module;
pub mod passes;
pub(super) mod visitor;
use cubecl_common::backtrace::BackTrace;
use cubecl_runtime::compiler::CompilationError;
use passes::shared_memories::SharedMemories;
pub use visitor::elem::register_supported_types;
use cubecl_core::{
Compiler,
ir::{self, StorageType},
post_processing::{
checked_io::CheckedIoProcessor, predicate::PredicateProcessor,
saturating::SaturatingArithmeticProcessor,
},
prelude::KernelDefinition,
server::ExecutionMode,
};
use cubecl_opt::OptimizerBuilder;
use mlir_engine::MlirEngine;
use crate::compiler::passes::{
erf_transform::ErfTransform,
trigonometries_transform::{HypotTransform, RhypotTransform},
};
#[derive(Clone, Debug, Default)]
pub struct MlirCompiler {}
#[derive(Default, Debug)]
pub struct MlirCompilerOptions {}
impl Compiler for MlirCompiler {
type Representation = MlirEngine;
type CompilationOptions = MlirCompilerOptions;
fn compile(
&mut self,
mut kernel: KernelDefinition,
_compilation_options: &Self::CompilationOptions, mode: ExecutionMode, addr_type: StorageType,
) -> Result<Self::Representation, CompilationError> {
let errors = kernel.body.pop_errors();
if !errors.is_empty() {
let mut reason = "Can't compile mlir kernel".to_string();
for error in errors {
reason += error.as_str();
reason += "\n";
}
return Err(CompilationError::Validation {
reason,
backtrace: BackTrace::capture(),
});
}
#[cfg(feature = "mlir-dump")]
dump_scope(&kernel.body, &kernel.options.kernel_name);
let opt = OptimizerBuilder::default()
.with_transformer(ErfTransform)
.with_transformer(HypotTransform)
.with_transformer(RhypotTransform)
.with_processor(CheckedIoProcessor::new(
mode,
kernel.options.kernel_name.clone(),
))
.with_processor(SaturatingArithmeticProcessor::new(true))
.with_processor(PredicateProcessor)
.optimize(kernel.body.clone(), kernel.cube_dim);
let mut shared_memories = SharedMemories::default();
shared_memories.visit(&opt);
#[cfg(feature = "mlir-dump")]
dump_opt(&opt, &kernel.options.kernel_name);
Ok(MlirEngine::from_cubecl_ir(
kernel,
&opt,
shared_memories,
addr_type,
))
}
fn elem_size(&self, elem: ir::ElemType) -> usize {
elem.size()
}
fn extension(&self) -> &'static str {
"mlir"
}
}
#[cfg(feature = "mlir-dump")]
fn dump_scope(scope: &cubecl_core::prelude::Scope, name: &str) {
use std::fs;
if let Ok(dir) = std::env::var("CUBECL_DEBUG_MLIR") {
let path = format!("{dir}/{name}");
let _ = fs::create_dir(&path);
fs::write(format!("{path}/cubecl.ir.txt"), format!("{}", scope)).unwrap();
}
}
#[cfg(feature = "mlir-dump")]
fn dump_opt(opt: &cubecl_opt::Optimizer, name: &str) {
if let Ok(dir) = std::env::var("CUBECL_DEBUG_MLIR") {
use std::fs;
let path = format!("{dir}/{name}");
let _ = fs::create_dir(&path);
fs::write(format!("{path}/cubecl-opt.ir.txt"), format!("{}", opt)).unwrap();
fs::write(
format!("{path}/cubecl-opt.ir.dot"),
format!("{}", opt.dot_viz()),
)
.unwrap();
}
}