cubecl-cpu 0.10.0-pre.3

CPU runtime for CubeCL
pub mod builtin;
pub(super) mod external_function;
pub(super) mod memref;
pub mod mlir_data;
pub mod mlir_engine;
pub mod module;
pub mod passes;
pub(super) mod visitor;

use cubecl_common::backtrace::BackTrace;
use cubecl_runtime::compiler::CompilationError;
use passes::shared_memories::SharedMemories;
pub use visitor::elem::register_supported_types;

use cubecl_core::{
    Compiler,
    ir::{self, StorageType},
    post_processing::{
        checked_io::CheckedIoProcessor, predicate::PredicateProcessor,
        saturating::SaturatingArithmeticProcessor,
    },
    prelude::KernelDefinition,
    server::ExecutionMode,
};
use cubecl_opt::OptimizerBuilder;
use mlir_engine::MlirEngine;

use crate::compiler::passes::{
    erf_transform::ErfTransform,
    trigonometries_transform::{HypotTransform, RhypotTransform},
};

#[derive(Clone, Debug, Default)]
pub struct MlirCompiler {}

#[derive(Default, Debug)]
pub struct MlirCompilerOptions {}

impl Compiler for MlirCompiler {
    type Representation = MlirEngine;

    type CompilationOptions = MlirCompilerOptions;

    fn compile(
        &mut self,
        mut kernel: KernelDefinition,
        _compilation_options: &Self::CompilationOptions, // TODO pass this through the visitor, though it doesn't need anything for the moment
        mode: ExecutionMode, // TODO support this by adding array bound checking
        addr_type: StorageType,
    ) -> Result<Self::Representation, CompilationError> {
        let errors = kernel.body.pop_errors();
        if !errors.is_empty() {
            let mut reason = "Can't compile mlir kernel".to_string();
            for error in errors {
                reason += error.as_str();
                reason += "\n";
            }

            return Err(CompilationError::Validation {
                reason,
                backtrace: BackTrace::capture(),
            });
        }

        #[cfg(feature = "mlir-dump")]
        dump_scope(&kernel.body, &kernel.options.kernel_name);
        let opt = OptimizerBuilder::default()
            .with_transformer(ErfTransform)
            .with_transformer(HypotTransform)
            .with_transformer(RhypotTransform)
            .with_processor(CheckedIoProcessor::new(
                mode,
                kernel.options.kernel_name.clone(),
            ))
            .with_processor(SaturatingArithmeticProcessor::new(true))
            .with_processor(PredicateProcessor)
            .optimize(kernel.body.clone(), kernel.cube_dim);

        let mut shared_memories = SharedMemories::default();
        shared_memories.visit(&opt);

        #[cfg(feature = "mlir-dump")]
        dump_opt(&opt, &kernel.options.kernel_name);
        Ok(MlirEngine::from_cubecl_ir(
            kernel,
            &opt,
            shared_memories,
            addr_type,
        ))
    }

    fn elem_size(&self, elem: ir::ElemType) -> usize {
        elem.size()
    }

    fn extension(&self) -> &'static str {
        "mlir"
    }
}

#[cfg(feature = "mlir-dump")]
fn dump_scope(scope: &cubecl_core::prelude::Scope, name: &str) {
    use std::fs;

    if let Ok(dir) = std::env::var("CUBECL_DEBUG_MLIR") {
        let path = format!("{dir}/{name}");
        let _ = fs::create_dir(&path);
        fs::write(format!("{path}/cubecl.ir.txt"), format!("{}", scope)).unwrap();
    }
}

#[cfg(feature = "mlir-dump")]
fn dump_opt(opt: &cubecl_opt::Optimizer, name: &str) {
    if let Ok(dir) = std::env::var("CUBECL_DEBUG_MLIR") {
        use std::fs;
        let path = format!("{dir}/{name}");
        let _ = fs::create_dir(&path);
        fs::write(format!("{path}/cubecl-opt.ir.txt"), format!("{}", opt)).unwrap();
        fs::write(
            format!("{path}/cubecl-opt.ir.dot"),
            format!("{}", opt.dot_viz()),
        )
        .unwrap();
    }
}