use cubecl_core::ir::{IndexAssignOperator, IndexOperator, Operator, StorageType, VariableKind};
use tracel_llvm::mlir_rs::{
dialect::{
arith, index, memref,
ods::{self, llvm, vector},
},
ir::{Operation, r#type::IntegerType},
};
use crate::compiler::visitor::prelude::*;
impl<'a> Visitor<'a> {
pub fn visit_operator_with_out(&mut self, operator: &Operator, out: Variable) {
match operator {
Operator::And(and) => {
let lhs = self.get_variable(and.lhs);
let rhs = self.get_variable(and.rhs);
let value = self.append_operation_with_result(arith::andi(lhs, rhs, self.location));
self.insert_variable(out, value);
}
Operator::Cast(cast) => {
self.visit_cast(cast.input, out);
}
Operator::CopyMemory(copy_memory) => {
let memref = self.get_memory(copy_memory.input);
let in_index = self.get_index(
copy_memory.in_index,
copy_memory.input.ty,
copy_memory.input.ty.is_vectorized(),
);
let out_memref = self.get_memory(out);
let out_index =
self.get_index(copy_memory.out_index, out.ty, out.ty.is_vectorized());
if out.ty.is_vectorized() {
let result = out.ty.to_type(self.context);
let value = self.append_operation_with_result(vector::load(
self.context,
result,
memref,
&[in_index],
self.location,
));
self.block.append_operation(vector::store(
self.context,
value,
out_memref,
&[out_index],
self.location,
));
} else {
let value = self.append_operation_with_result(memref::load(
memref,
&[in_index],
self.location,
));
self.block.append_operation(memref::store(
value,
out_memref,
&[out_index],
self.location,
));
}
}
Operator::CopyMemoryBulk(_copy_memory_bulk) => {
todo!("copy_memory_bulk is not implemented {}", operator)
}
Operator::Index(index) | Operator::UncheckedIndex(index) => {
let load_ssa = self.visit_index(index, out);
self.insert_variable(out, load_ssa);
}
Operator::IndexAssign(index_assign) | Operator::UncheckedIndexAssign(index_assign) => {
self.visit_index_assign(index_assign, out)
}
Operator::InitVector(init_vector) => {
let inputs: Vec<_> = init_vector
.inputs
.iter()
.map(|input| self.get_variable(*input))
.collect();
let result = out.ty.to_type(self.context);
let init_vector = self.append_operation_with_result(vector::from_elements(
self.context,
result,
&inputs,
self.location,
));
self.insert_variable(out, init_vector);
}
Operator::Not(not) => {
let lhs = self.get_variable(not.input);
let mask = self.create_int_constant_from_item(not.input.ty, -1);
let value =
self.append_operation_with_result(arith::xori(lhs, mask, self.location));
self.insert_variable(out, value);
}
Operator::Or(or) => {
let lhs = self.get_variable(or.lhs);
let rhs = self.get_variable(or.rhs);
let value = self.append_operation_with_result(arith::ori(lhs, rhs, self.location));
self.insert_variable(out, value);
}
Operator::Reinterpret(reinterpret) => {
let target_type = out.ty.to_type(self.context);
let input = self.get_variable(reinterpret.input);
let value = self.append_operation_with_result(arith::bitcast(
input,
target_type,
self.location,
));
self.insert_variable(out, value);
}
Operator::Select(select) => {
let condition = self.get_variable(select.cond);
let condition = self.cast_to_bool(condition, select.cond.ty);
let mut then = self.get_variable(select.then);
let mut or_else = self.get_variable(select.or_else);
if out.ty.is_vectorized() && !select.then.ty.is_vectorized() {
let vector = Type::vector(
&[out.vector_size() as u64],
select.then.storage_type().to_type(self.context),
);
then = self.append_operation_with_result(vector::splat(
self.context,
vector,
then,
self.location,
));
}
if out.ty.is_vectorized() && !select.or_else.ty.is_vectorized() {
let vector = Type::vector(
&[out.vector_size() as u64],
select.or_else.storage_type().to_type(self.context),
);
or_else = self.append_operation_with_result(vector::splat(
self.context,
vector,
or_else,
self.location,
));
}
let value = self.append_operation_with_result(arith::select(
condition,
then,
or_else,
self.location,
));
self.insert_variable(out, value);
}
}
}
fn visit_index(&mut self, index: &IndexOperator, out: Variable) -> Value<'a, 'a> {
assert!(index.vector_size == 0);
let mut index_value = self.get_index(index.index, out.ty, index.list.ty.is_vectorized());
if !self.is_memory(index.list) {
let to_extract = self.get_variable(index.list);
if !to_extract.r#type().is_vector() {
return to_extract;
}
let res = index.list.storage_type().to_type(self.context);
if index_value.r#type().is_index() {
let u32_int = IntegerType::new(self.context, 32).into();
index_value = self.append_operation_with_result(index::casts(
index_value,
u32_int,
self.location,
));
}
let vector_extract =
llvm::extractelement(self.context, res, to_extract, index_value, self.location);
self.append_operation_with_result(vector_extract)
} else if out.ty.is_vectorized() {
let vector_type = Type::vector(
&[out.vector_size() as u64],
index.list.storage_type().to_type(self.context),
);
let memref = self.get_memory(index.list);
self.append_operation_with_result(vector::load(
self.context,
vector_type,
memref,
&[index_value],
self.location,
))
} else {
let memref = self.get_memory(index.list);
self.append_operation_with_result(memref::load(memref, &[index_value], self.location))
}
}
fn visit_index_assign(&mut self, index_assign: &IndexAssignOperator, out: Variable) {
assert!(index_assign.vector_size == 0);
let value = self.get_variable(index_assign.value);
let memref = self.get_memory(out);
if matches!(
out.kind,
VariableKind::LocalMut { .. } | VariableKind::LocalConst { .. }
) {
let indices = self.get_index(
index_assign.index,
index_assign.value.ty,
out.ty.is_vectorized(),
);
let operation = if index_assign.value.ty.is_vectorized() {
vector::store(self.context, value, memref, &[indices], self.location).into()
} else {
memref::store(value, memref, &[indices], self.location)
};
self.block.append_operation(operation);
return;
}
let operation = if index_assign.value.ty.is_vectorized() {
let indices = self.get_index(
index_assign.index,
index_assign.value.ty,
out.ty.is_vectorized(),
);
vector::store(self.context, value, memref, &[indices], self.location)
} else {
let vector_type = Type::vector(
&[out.vector_size() as u64],
index_assign.value.storage_type().to_type(self.context),
);
let indices = self.get_index(index_assign.index, out.ty, out.ty.is_vectorized());
let splat = self.append_operation_with_result(vector::splat(
self.context,
vector_type,
value,
self.location,
));
vector::store(self.context, splat, memref, &[indices], self.location)
};
self.block.append_operation(operation);
}
pub(crate) fn visit_cast(&mut self, to_cast: Variable, out: Variable) {
let mut value = self.get_variable(to_cast);
let target = out.ty.to_type(self.context);
if !to_cast.ty.is_vectorized() && out.ty.is_vectorized() {
let r#type = to_cast.storage_type().to_type(self.context);
let vector_type = Type::vector(&[out.vector_size() as u64], r#type);
value = self.append_operation_with_result(vector::splat(
self.context,
vector_type,
value,
self.location,
));
};
let value = if to_cast.storage_type().is_int() == out.storage_type().is_int() {
self.get_cast_same_type_category(
to_cast.storage_type(),
out.storage_type(),
target,
value,
)
} else {
self.get_cast_different_type_category(
to_cast.storage_type(),
out.storage_type(),
target,
value,
)
};
self.insert_variable(out, value);
}
pub(crate) fn get_cast_different_type_category(
&self,
to_cast: StorageType,
out: StorageType,
target: Type<'a>,
value: Value<'a, 'a>,
) -> Value<'a, 'a> {
if to_cast.is_int() {
self.append_operation_with_result(self.cast_int_to_float(to_cast, target, value))
} else {
self.append_operation_with_result(self.cast_float_to_int(out, target, value))
}
}
fn cast_float_to_int(
&self,
out: StorageType,
target: Type<'a>,
value: Value<'a, 'a>,
) -> Operation<'a> {
if out.is_signed_int() {
arith::fptosi(value, target, self.location)
} else {
arith::fptoui(value, target, self.location)
}
}
fn cast_int_to_float(
&self,
to_cast: StorageType,
target: Type<'a>,
value: Value<'a, 'a>,
) -> Operation<'a> {
if to_cast.is_signed_int() {
arith::sitofp(value, target, self.location)
} else {
arith::uitofp(value, target, self.location)
}
}
fn get_cast_same_type_category(
&self,
to_cast: StorageType,
out: StorageType,
target: Type<'a>,
value: Value<'a, 'a>,
) -> Value<'a, 'a> {
if to_cast.size() > out.size() {
self.append_operation_with_result(self.get_trunc(to_cast, target, value))
} else if to_cast.size() < out.size() {
self.append_operation_with_result(self.get_ext(to_cast, target, value))
} else {
value
}
}
fn get_trunc(
&self,
to_cast: StorageType,
target: Type<'a>,
value: Value<'a, 'a>,
) -> Operation<'a> {
if to_cast.is_int() {
arith::trunci(value, target, self.location)
} else {
ods::arith::truncf(self.context, target, value, self.location).into()
}
}
fn get_ext(
&self,
to_cast: StorageType,
target: Type<'a>,
value: Value<'a, 'a>,
) -> Operation<'a> {
if to_cast.is_signed_int() {
arith::extsi(value, target, self.location)
} else if to_cast.is_unsigned_int() {
arith::extui(value, target, self.location)
} else {
arith::extf(value, target, self.location)
}
}
}