use cubecl_core::{
self as cubecl,
ir::{ElemType, Operator},
};
use cubecl_core::{comptime, ir as core, prelude::*};
use cubecl_core::{cube, ir::Bitwise};
use crate::{SpirvCompiler, SpirvTarget, item::Elem};
impl<T: SpirvTarget> SpirvCompiler<T> {
pub fn compile_bitwise(&mut self, op: Bitwise, out: Option<core::Variable>, uniform: bool) {
if let Some(op) = bool_op(&op) {
self.compile_operator(op, out, uniform);
return;
}
let out = out.unwrap();
match op {
Bitwise::BitwiseAnd(op) => {
self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.bitwise_and(ty, Some(out), lhs, rhs).unwrap();
})
}
Bitwise::BitwiseOr(op) => {
self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.bitwise_or(ty, Some(out), lhs, rhs).unwrap();
})
}
Bitwise::BitwiseXor(op) => {
self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap();
})
}
Bitwise::BitwiseNot(op) => {
self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
b.not(ty, Some(out), input).unwrap();
});
}
Bitwise::ShiftLeft(op) => {
self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap();
})
}
Bitwise::ShiftRight(op) => {
self.compile_binary_op(op, out, uniform, |b, item, ty, lhs, rhs, out| {
match item.elem() {
Elem::Int(_, true) => {
b.shift_right_arithmetic(ty, Some(out), lhs, rhs).unwrap()
}
_ => b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(),
};
})
}
Bitwise::CountOnes(op) => {
self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
b.bit_count(ty, Some(out), input).unwrap();
});
}
Bitwise::ReverseBits(op) => {
self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
b.bit_reverse(ty, Some(out), input).unwrap();
});
}
Bitwise::LeadingZeros(op) => {
let width = op.input.ty.storage_type().size() as u32 * 8;
self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
let width = out_ty.const_u32(b, width - 1);
let msb = b.id();
T::find_msb(b, ty, input, msb);
b.mark_uniformity(msb, uniform);
b.i_sub(ty, Some(out), width, msb).unwrap();
});
}
Bitwise::FindFirstSet(op) => {
self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
let one = out_ty.const_u32(b, 1);
let lsb = b.id();
T::find_lsb(b, ty, input, lsb);
b.mark_uniformity(lsb, uniform);
b.i_add(ty, Some(out), lsb, one).unwrap();
});
}
Bitwise::TrailingZeros(op) => {
let width = op.input.ty.storage_type().size() as u32 * 8;
self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
let width_const = out_ty.const_u32(b, width);
let zero = out_ty.const_u32(b, 0);
let lsb = b.id();
T::find_lsb(b, ty, input, lsb);
b.mark_uniformity(lsb, uniform);
let bool_ty = out_ty.same_vectorization(Elem::Bool).id(b);
let is_zero = b.id();
b.i_equal(bool_ty, Some(is_zero), input, zero).unwrap();
b.mark_uniformity(is_zero, uniform);
b.select(ty, Some(out), is_zero, width_const, lsb).unwrap();
});
}
}
}
}
fn bool_op(bitwise: &Bitwise) -> Option<Operator> {
match bitwise {
Bitwise::BitwiseAnd(op)
if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
{
Some(Operator::And(op.clone()))
}
Bitwise::BitwiseOr(op)
if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
{
Some(Operator::Or(op.clone()))
}
Bitwise::BitwiseNot(op) if op.input.elem_type() == ElemType::Bool => {
Some(Operator::Not(op.clone()))
}
_ => None,
}
}
#[cube]
pub(crate) fn small_int_reverse<I: Int, N: Size>(
x: Vector<I, N>,
#[comptime] width: u32,
) -> Vector<I, N> {
let shift = comptime!(32 - width);
let reversed = Vector::reverse_bits(Vector::<u32, N>::cast_from(x));
Vector::cast_from(reversed >> Vector::new(shift))
}
#[cube]
pub(crate) fn u64_reverse<I: Int, N: Size>(x: Vector<I, N>) -> Vector<I, N> {
let shift = Vector::new(I::new(32));
let low = Vector::<u32, N>::cast_from(x);
let high = Vector::<u32, N>::cast_from(x >> shift);
let low_rev = Vector::reverse_bits(low);
let high_rev = Vector::reverse_bits(high);
let high = Vector::cast_from(low_rev) << shift;
high | Vector::cast_from(high_rev)
}
#[cube]
pub(crate) fn u64_count_bits<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
let shift = Vector::new(I::new(32));
let low = Vector::<u32, N>::cast_from(x);
let high = Vector::<u32, N>::cast_from(x >> shift);
let low_cnt = Vector::<u32, N>::cast_from(Vector::count_ones(low));
let high_cnt = Vector::<u32, N>::cast_from(Vector::count_ones(high));
low_cnt + high_cnt
}
#[cube]
pub(crate) fn u64_leading_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
let shift = Vector::new(I::new(32));
let low = Vector::<u32, N>::cast_from(x);
let high = Vector::<u32, N>::cast_from(x >> shift);
let low_zeros = Vector::leading_zeros(low);
let high_zeros = Vector::leading_zeros(high);
select_many(
high_zeros.equal(Vector::new(32)),
low_zeros + high_zeros,
high_zeros,
)
}
#[cube]
pub(crate) fn u64_ffs<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
let shift = Vector::new(I::new(32));
let low = Vector::<u32, N>::cast_from(x);
let high = Vector::<u32, N>::cast_from(x >> shift);
let low_ffs = Vector::find_first_set(low);
let high_ffs = Vector::find_first_set(high);
let high_ffs = select_many(
high_ffs.equal(Vector::new(0)),
high_ffs,
high_ffs + Vector::new(32),
);
select_many(low_ffs.equal(Vector::new(0)), high_ffs, low_ffs)
}
#[cube]
pub(crate) fn u16_u8_leading_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
let width = I::type_size_bits().comptime() as u32;
let over_width = Vector::new(32 - width);
let x = Vector::<u32, N>::cast_from(x);
let lz = x.leading_zeros();
lz - over_width
}
#[cube]
pub(crate) fn u64_trailing_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
let shift = Vector::new(I::new(32));
let low = Vector::<u32, N>::cast_from(x);
let high = Vector::<u32, N>::cast_from(x >> shift);
let low_tz = Vector::trailing_zeros(low);
let high_tz = Vector::trailing_zeros(high);
let high_tz = select_many(
high_tz.equal(Vector::new(32)),
Vector::new(64),
high_tz + Vector::new(32),
);
select_many(low_tz.equal(Vector::new(32)), high_tz, low_tz)
}
#[cube]
pub(crate) fn u16_u8_trailing_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
let width = Vector::new(I::type_size_bits().comptime() as u32);
let x = Vector::<u32, N>::cast_from(x);
let lz = x.trailing_zeros();
lz.min(width)
}