cubecl-cpu 0.10.0-pre.3

CPU runtime for CubeCL
use cubecl_core::ir::{self, ConstantValue, VariableKind};
use tracel_llvm::mlir_rs::{
    Context,
    dialect::{arith, ods::vector},
    ir::{
        Attribute, Type, Value,
        attribute::{FloatAttribute, IntegerAttribute},
        r#type::IntegerType,
    },
};

use super::prelude::*;

impl IntoType for ir::Type {
    fn to_type<'a>(self, context: &'a Context) -> Type<'a> {
        let inner_type = self.storage_type().to_type(context);
        match self.vector_size() {
            size if size > 1 => Type::vector(&[size as u64], inner_type),
            _ => inner_type,
        }
    }
    fn is_vectorized(&self) -> bool {
        self.vector_size() > 1
    }
}

impl<'a> Visitor<'a> {
    pub fn into_attribute(
        context: &'a Context,
        var: Variable,
        item: ir::Type,
    ) -> Option<Attribute<'a>> {
        let r#type = item.storage_type().to_type(context);
        match var.kind {
            VariableKind::Constant(ConstantValue::Float(float)) => {
                if item.is_float() {
                    Some(FloatAttribute::new(context, r#type, float).into())
                } else {
                    Some(IntegerAttribute::new(r#type, float as i64).into())
                }
            }
            VariableKind::Constant(ConstantValue::Bool(bool)) => {
                if item.is_float() {
                    Some(FloatAttribute::new(context, r#type, bool as i64 as f64).into())
                } else {
                    Some(IntegerAttribute::new(r#type, bool as i64).into())
                }
            }
            VariableKind::Constant(ConstantValue::Int(int)) => {
                if item.is_float() {
                    Some(FloatAttribute::new(context, r#type, int as f64).into())
                } else {
                    Some(IntegerAttribute::new(r#type, int).into())
                }
            }
            VariableKind::Constant(ConstantValue::UInt(u_int)) => {
                if item.is_float() {
                    Some(FloatAttribute::new(context, r#type, u_int as f64).into())
                } else {
                    Some(IntegerAttribute::new(r#type, u_int as i64).into())
                }
            }
            _ => None,
        }
    }

    pub fn create_float_constant_from_item(&self, item: ir::Type, constant: f64) -> Value<'a, 'a> {
        let float = item.storage_type().to_type(self.context);
        let constant = FloatAttribute::new(self.context, float, constant);
        let constant = self.append_operation_with_result(arith::constant(
            self.context,
            constant.into(),
            self.location,
        ));
        let result_type = item.to_type(self.context);
        match item.is_vectorized() {
            true => self.append_operation_with_result(vector::splat(
                self.context,
                result_type,
                constant,
                self.location,
            )),
            false => constant,
        }
    }

    pub fn create_int_constant_from_item(&self, item: ir::Type, constant: i64) -> Value<'a, 'a> {
        let integer = item.storage_type().to_type(self.context);
        let constant = IntegerAttribute::new(integer, constant);
        let constant = self.append_operation_with_result(arith::constant(
            self.context,
            constant.into(),
            self.location,
        ));
        let result_type = item.to_type(self.context);
        match item.is_vectorized() {
            true => self.append_operation_with_result(vector::splat(
                self.context,
                result_type,
                constant,
                self.location,
            )),
            false => constant,
        }
    }

    pub fn cast_to_bool(&self, value: Value<'a, 'a>, item: ir::Type) -> Value<'a, 'a> {
        let mut bool = IntegerType::new(self.context, 1).into();
        if item.is_vectorized() {
            bool = Type::vector(&[item.vector_size() as u64], bool);
        }
        self.append_operation_with_result(arith::trunci(value, bool, self.location))
    }

    pub fn cast_to_u8(&self, value: Value<'a, 'a>, item: ir::Type) -> Value<'a, 'a> {
        let mut byte = IntegerType::new(self.context, 8).into();
        if item.is_vectorized() {
            byte = Type::vector(&[item.vector_size() as u64], byte);
        }
        self.append_operation_with_result(arith::extui(value, byte, self.location))
    }
}