use cubecl_core::ir::{Bitwise, ElemType, IntKind, UIntKind};
use tracel_llvm::mlir_rs::{
dialect::arith::{self, CmpiPredicate},
dialect::llvm,
ir::r#type::IntegerType,
};
use crate::compiler::visitor::prelude::*;
impl<'a> Visitor<'a> {
fn convert_bit_count_to_u32(&mut self, value: Value<'a, 'a>, input: Variable) -> Value<'a, 'a> {
match input.elem_type() {
ElemType::Int(IntKind::I8)
| ElemType::UInt(UIntKind::U8)
| ElemType::Int(IntKind::I16)
| ElemType::UInt(UIntKind::U16)
| ElemType::Int(IntKind::I64)
| ElemType::UInt(UIntKind::U64) => {
let mut r#type = IntegerType::new(self.context, 32).into();
if input.ty.is_vectorized() {
r#type = Type::vector(&[input.vector_size() as u64], r#type);
}
match input.elem_type() {
ElemType::Int(IntKind::I64) | ElemType::UInt(UIntKind::U64) => self
.append_operation_with_result(arith::trunci(value, r#type, self.location)),
_ => self.append_operation_with_result(arith::extui(
value,
r#type,
self.location,
)),
}
}
ElemType::Int(IntKind::I32) | ElemType::UInt(UIntKind::U32) => value,
_ => panic!("Unsupported type for bit counting operation"),
}
}
fn count_zeros_with_clamp(
&mut self,
value: Value<'a, 'a>,
input: Variable,
out: Variable,
) -> Value<'a, 'a> {
match input.elem_type() {
ElemType::Int(IntKind::I8)
| ElemType::UInt(UIntKind::U8)
| ElemType::Int(IntKind::I16)
| ElemType::UInt(UIntKind::U16) => {
let mut r#type = IntegerType::new(self.context, 32).into();
if input.ty.is_vectorized() {
r#type = Type::vector(&[input.vector_size() as u64], r#type);
}
let value =
self.append_operation_with_result(arith::extui(value, r#type, self.location));
let max = self.create_int_constant_from_item(
out.ty,
input.ty.storage_type().size_bits() as i64,
);
self.append_operation_with_result(arith::minui(value, max, self.location))
}
ElemType::Int(IntKind::I32) | ElemType::UInt(UIntKind::U32) => value,
ElemType::Int(IntKind::I64) | ElemType::UInt(UIntKind::U64) => {
let mut r#type = IntegerType::new(self.context, 32).into();
if input.ty.is_vectorized() {
r#type = Type::vector(&[input.vector_size() as u64], r#type);
}
self.append_operation_with_result(arith::trunci(value, r#type, self.location))
}
_ => panic!("Unsupported type for leading/trailing zeros"),
}
}
pub fn visit_bitwise(&mut self, bitwise: &Bitwise, out: Variable) {
let value = match bitwise {
Bitwise::BitwiseAnd(bin_op) => {
let (lhs, rhs) = self.get_binary_op_variable(bin_op.lhs, bin_op.rhs);
self.append_operation_with_result(arith::andi(lhs, rhs, self.location))
}
Bitwise::BitwiseOr(bin_op) => {
let (lhs, rhs) = self.get_binary_op_variable(bin_op.lhs, bin_op.rhs);
self.append_operation_with_result(arith::ori(lhs, rhs, self.location))
}
Bitwise::BitwiseXor(bin_op) => {
let (lhs, rhs) = self.get_binary_op_variable(bin_op.lhs, bin_op.rhs);
self.append_operation_with_result(arith::xori(lhs, rhs, self.location))
}
Bitwise::ShiftLeft(bin_op) => {
let (lhs, rhs) = self.get_binary_op_variable(bin_op.lhs, bin_op.rhs);
self.append_operation_with_result(arith::shli(lhs, rhs, self.location))
}
Bitwise::ShiftRight(bin_op) => {
let (lhs, rhs) = self.get_binary_op_variable(bin_op.lhs, bin_op.rhs);
let operation = if bin_op.lhs.storage_type().is_signed_int() {
arith::shrsi(lhs, rhs, self.location)
} else {
arith::shrui(lhs, rhs, self.location)
};
self.append_operation_with_result(operation)
}
Bitwise::CountOnes(unary_op) => {
let value = self.get_variable(unary_op.input);
let result_type = unary_op.input.ty.to_type(self.context);
let value = self.append_operation_with_result(llvm::intr_ctpop(
value,
result_type,
self.location,
));
self.convert_bit_count_to_u32(value, unary_op.input)
}
Bitwise::ReverseBits(unary_op) => {
let value = self.get_variable(unary_op.input);
let result_type = unary_op.input.ty.to_type(self.context);
self.append_operation_with_result(llvm::intr_bitreverse(
value,
result_type,
self.location,
))
}
Bitwise::BitwiseNot(unary_op) => {
let value = self.get_variable(unary_op.input);
let mask = self.create_int_constant_from_item(unary_op.input.ty, -1);
self.append_operation_with_result(arith::xori(value, mask, self.location))
}
Bitwise::LeadingZeros(unary_op) => {
let value = self.get_variable(unary_op.input);
let result_type = unary_op.input.ty.to_type(self.context);
let value = self.append_operation_with_result(llvm::intr_ctlz(
self.context,
value,
true,
result_type,
self.location,
));
self.count_zeros_with_clamp(value, unary_op.input, out)
}
Bitwise::TrailingZeros(unary_op) => {
let value = self.get_variable(unary_op.input);
let result_type = unary_op.input.ty.to_type(self.context);
let value = self.append_operation_with_result(llvm::intr_cttz(
self.context,
value,
true,
result_type,
self.location,
));
self.count_zeros_with_clamp(value, unary_op.input, out)
}
Bitwise::FindFirstSet(unary_op) => {
let value = self.get_variable(unary_op.input);
let result_type = unary_op.input.ty.to_type(self.context);
let value = self.append_operation_with_result(llvm::intr_cttz(
self.context,
value,
false,
result_type,
self.location,
));
let one = self.create_int_constant_from_item(unary_op.input.ty, 1);
let value =
self.append_operation_with_result(arith::addi(value, one, self.location));
let max = self.create_int_constant_from_item(
unary_op.input.ty,
unary_op.input.ty.storage_type().size_bits() as i64 + 1,
);
let cond = self.append_operation_with_result(arith::cmpi(
self.context,
CmpiPredicate::Uge,
value,
max,
self.location,
));
let zero = self.create_int_constant_from_item(unary_op.input.ty, 0);
let value = self.append_operation_with_result(arith::select(
cond,
zero,
value,
self.location,
));
self.convert_bit_count_to_u32(value, unary_op.input)
}
};
self.insert_variable(out, value);
}
}