cubecl-cpu 0.10.0-pre.4

CPU runtime for CubeCL
use cubecl_core::ir::{IndexAssignOperator, IndexOperator, Operator, StorageType, VariableKind};
use tracel_llvm::mlir_rs::{
    dialect::{
        arith, index, memref,
        ods::{self, llvm, vector},
    },
    ir::{Operation, r#type::IntegerType},
};

use crate::compiler::visitor::prelude::*;

impl<'a> Visitor<'a> {
    pub fn visit_operator_with_out(&mut self, operator: &Operator, out: Variable) {
        match operator {
            Operator::And(and) => {
                let lhs = self.get_variable(and.lhs);
                let rhs = self.get_variable(and.rhs);
                let value = self.append_operation_with_result(arith::andi(lhs, rhs, self.location));
                self.insert_variable(out, value);
            }
            Operator::Cast(cast) => {
                self.visit_cast(cast.input, out);
            }
            Operator::CopyMemory(copy_memory) => {
                let memref = self.get_memory(copy_memory.input);
                let in_index = self.get_index(
                    copy_memory.in_index,
                    copy_memory.input.ty,
                    copy_memory.input.ty.is_vectorized(),
                );
                let out_memref = self.get_memory(out);
                let out_index =
                    self.get_index(copy_memory.out_index, out.ty, out.ty.is_vectorized());
                if out.ty.is_vectorized() {
                    let result = out.ty.to_type(self.context);
                    let value = self.append_operation_with_result(vector::load(
                        self.context,
                        result,
                        memref,
                        &[in_index],
                        self.location,
                    ));
                    self.block.append_operation(vector::store(
                        self.context,
                        value,
                        out_memref,
                        &[out_index],
                        self.location,
                    ));
                } else {
                    let value = self.append_operation_with_result(memref::load(
                        memref,
                        &[in_index],
                        self.location,
                    ));
                    self.block.append_operation(memref::store(
                        value,
                        out_memref,
                        &[out_index],
                        self.location,
                    ));
                }
            }
            Operator::CopyMemoryBulk(_copy_memory_bulk) => {
                todo!("copy_memory_bulk is not implemented {}", operator)
            }
            Operator::Index(index) | Operator::UncheckedIndex(index) => {
                let load_ssa = self.visit_index(index, out);
                self.insert_variable(out, load_ssa);
            }
            Operator::IndexAssign(index_assign) | Operator::UncheckedIndexAssign(index_assign) => {
                self.visit_index_assign(index_assign, out)
            }
            Operator::InitVector(init_vector) => {
                let inputs: Vec<_> = init_vector
                    .inputs
                    .iter()
                    .map(|input| self.get_variable(*input))
                    .collect();
                let result = out.ty.to_type(self.context);
                let init_vector = self.append_operation_with_result(vector::from_elements(
                    self.context,
                    result,
                    &inputs,
                    self.location,
                ));
                self.insert_variable(out, init_vector);
            }
            Operator::Not(not) => {
                let lhs = self.get_variable(not.input);
                let mask = self.create_int_constant_from_item(not.input.ty, -1);
                let value =
                    self.append_operation_with_result(arith::xori(lhs, mask, self.location));
                self.insert_variable(out, value);
            }
            Operator::Or(or) => {
                let lhs = self.get_variable(or.lhs);
                let rhs = self.get_variable(or.rhs);
                let value = self.append_operation_with_result(arith::ori(lhs, rhs, self.location));
                self.insert_variable(out, value);
            }
            Operator::Reinterpret(reinterpret) => {
                let target_type = out.ty.to_type(self.context);
                let input = self.get_variable(reinterpret.input);
                let value = self.append_operation_with_result(arith::bitcast(
                    input,
                    target_type,
                    self.location,
                ));
                self.insert_variable(out, value);
            }
            Operator::Select(select) => {
                let condition = self.get_variable(select.cond);
                let condition = self.cast_to_bool(condition, select.cond.ty);
                let mut then = self.get_variable(select.then);
                let mut or_else = self.get_variable(select.or_else);
                if out.ty.is_vectorized() && !select.then.ty.is_vectorized() {
                    let vector = Type::vector(
                        &[out.vector_size() as u64],
                        select.then.storage_type().to_type(self.context),
                    );
                    then = self.append_operation_with_result(vector::splat(
                        self.context,
                        vector,
                        then,
                        self.location,
                    ));
                }
                if out.ty.is_vectorized() && !select.or_else.ty.is_vectorized() {
                    let vector = Type::vector(
                        &[out.vector_size() as u64],
                        select.or_else.storage_type().to_type(self.context),
                    );
                    or_else = self.append_operation_with_result(vector::splat(
                        self.context,
                        vector,
                        or_else,
                        self.location,
                    ));
                }
                let value = self.append_operation_with_result(arith::select(
                    condition,
                    then,
                    or_else,
                    self.location,
                ));
                self.insert_variable(out, value);
            }
        }
    }

    fn visit_index(&mut self, index: &IndexOperator, out: Variable) -> Value<'a, 'a> {
        assert!(index.vector_size == 0);
        let mut index_value = self.get_index(index.index, out.ty, index.list.ty.is_vectorized());
        if !self.is_memory(index.list) {
            let to_extract = self.get_variable(index.list);
            // Item of size 1
            if !to_extract.r#type().is_vector() {
                return to_extract;
            }
            let res = index.list.storage_type().to_type(self.context);
            if index_value.r#type().is_index() {
                let u32_int = IntegerType::new(self.context, 32).into();
                index_value = self.append_operation_with_result(index::casts(
                    index_value,
                    u32_int,
                    self.location,
                ));
            }
            let vector_extract =
                llvm::extractelement(self.context, res, to_extract, index_value, self.location);
            self.append_operation_with_result(vector_extract)
        } else if out.ty.is_vectorized() {
            let vector_type = Type::vector(
                &[out.vector_size() as u64],
                index.list.storage_type().to_type(self.context),
            );
            let memref = self.get_memory(index.list);
            self.append_operation_with_result(vector::load(
                self.context,
                vector_type,
                memref,
                &[index_value],
                self.location,
            ))
        } else {
            let memref = self.get_memory(index.list);
            self.append_operation_with_result(memref::load(memref, &[index_value], self.location))
        }
    }

    fn visit_index_assign(&mut self, index_assign: &IndexAssignOperator, out: Variable) {
        assert!(index_assign.vector_size == 0);
        let value = self.get_variable(index_assign.value);
        let memref = self.get_memory(out);
        if matches!(
            out.kind,
            VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. }
        ) {
            let indices = self.get_index(
                index_assign.index,
                index_assign.value.ty,
                out.ty.is_vectorized(),
            );
            let operation = if index_assign.value.ty.is_vectorized() {
                vector::store(self.context, value, memref, &[indices], self.location).into()
            } else {
                memref::store(value, memref, &[indices], self.location)
            };
            self.block.append_operation(operation);
            return;
        }
        let operation = if index_assign.value.ty.is_vectorized() {
            let indices = self.get_index(
                index_assign.index,
                index_assign.value.ty,
                out.ty.is_vectorized(),
            );
            vector::store(self.context, value, memref, &[indices], self.location)
        } else {
            let vector_type = Type::vector(
                &[out.vector_size() as u64],
                index_assign.value.storage_type().to_type(self.context),
            );
            let indices = self.get_index(index_assign.index, out.ty, out.ty.is_vectorized());
            let splat = self.append_operation_with_result(vector::splat(
                self.context,
                vector_type,
                value,
                self.location,
            ));
            vector::store(self.context, splat, memref, &[indices], self.location)
        };
        self.block.append_operation(operation);
    }

    pub(crate) fn visit_cast(&mut self, to_cast: Variable, out: Variable) {
        let mut value = self.get_variable(to_cast);
        let target = out.ty.to_type(self.context);

        if !to_cast.ty.is_vectorized() && out.ty.is_vectorized() {
            let r#type = to_cast.storage_type().to_type(self.context);
            let vector_type = Type::vector(&[out.vector_size() as u64], r#type);
            value = self.append_operation_with_result(vector::splat(
                self.context,
                vector_type,
                value,
                self.location,
            ));
        };

        let value = if to_cast.storage_type().is_int() == out.storage_type().is_int() {
            self.get_cast_same_type_category(
                to_cast.storage_type(),
                out.storage_type(),
                target,
                value,
            )
        } else {
            self.get_cast_different_type_category(
                to_cast.storage_type(),
                out.storage_type(),
                target,
                value,
            )
        };
        self.insert_variable(out, value);
    }

    pub(crate) fn get_cast_different_type_category(
        &self,
        to_cast: StorageType,
        out: StorageType,
        target: Type<'a>,
        value: Value<'a, 'a>,
    ) -> Value<'a, 'a> {
        if to_cast.is_int() {
            self.append_operation_with_result(self.cast_int_to_float(to_cast, target, value))
        } else {
            self.append_operation_with_result(self.cast_float_to_int(out, target, value))
        }
    }

    fn cast_float_to_int(
        &self,
        out: StorageType,
        target: Type<'a>,
        value: Value<'a, 'a>,
    ) -> Operation<'a> {
        if out.is_signed_int() {
            arith::fptosi(value, target, self.location)
        } else {
            arith::fptoui(value, target, self.location)
        }
    }

    fn cast_int_to_float(
        &self,
        to_cast: StorageType,
        target: Type<'a>,
        value: Value<'a, 'a>,
    ) -> Operation<'a> {
        if to_cast.is_signed_int() {
            arith::sitofp(value, target, self.location)
        } else {
            arith::uitofp(value, target, self.location)
        }
    }

    fn get_cast_same_type_category(
        &self,
        to_cast: StorageType,
        out: StorageType,
        target: Type<'a>,
        value: Value<'a, 'a>,
    ) -> Value<'a, 'a> {
        if to_cast.size() > out.size() {
            self.append_operation_with_result(self.get_trunc(to_cast, target, value))
        } else if to_cast.size() < out.size() {
            self.append_operation_with_result(self.get_ext(to_cast, target, value))
        } else {
            value
        }
    }

    fn get_trunc(
        &self,
        to_cast: StorageType,
        target: Type<'a>,
        value: Value<'a, 'a>,
    ) -> Operation<'a> {
        if to_cast.is_int() {
            arith::trunci(value, target, self.location)
        } else {
            ods::arith::truncf(self.context, target, value, self.location).into()
        }
    }

    fn get_ext(
        &self,
        to_cast: StorageType,
        target: Type<'a>,
        value: Value<'a, 'a>,
    ) -> Operation<'a> {
        if to_cast.is_signed_int() {
            arith::extsi(value, target, self.location)
        } else if to_cast.is_unsigned_int() {
            arith::extui(value, target, self.location)
        } else {
            arith::extf(value, target, self.location)
        }
    }
}