1use cubecl_core::{
2 self as cubecl,
3 ir::{ElemType, Operator},
4};
5use cubecl_core::{comptime, ir as core, prelude::*};
6use cubecl_core::{cube, ir::Bitwise};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Elem};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_bitwise(&mut self, op: Bitwise, out: Option<core::Variable>, uniform: bool) {
12 if let Some(op) = bool_op(&op) {
13 self.compile_operator(op, out, uniform);
14 return;
15 }
16
17 let out = out.unwrap();
18 match op {
19 Bitwise::BitwiseAnd(op) => {
20 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
21 b.bitwise_and(ty, Some(out), lhs, rhs).unwrap();
22 })
23 }
24 Bitwise::BitwiseOr(op) => {
25 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
26 b.bitwise_or(ty, Some(out), lhs, rhs).unwrap();
27 })
28 }
29 Bitwise::BitwiseXor(op) => {
30 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
31 b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap();
32 })
33 }
34 Bitwise::BitwiseNot(op) => {
35 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
36 b.not(ty, Some(out), input).unwrap();
37 });
38 }
39 Bitwise::ShiftLeft(op) => {
40 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
41 b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap();
42 })
43 }
44 Bitwise::ShiftRight(op) => {
45 self.compile_binary_op(op, out, uniform, |b, item, ty, lhs, rhs, out| {
46 match item.elem() {
47 Elem::Int(_, true) => {
49 b.shift_right_arithmetic(ty, Some(out), lhs, rhs).unwrap()
50 }
51 _ => b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(),
52 };
53 })
54 }
55
56 Bitwise::CountOnes(op) => {
57 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
58 b.bit_count(ty, Some(out), input).unwrap();
59 });
60 }
61 Bitwise::ReverseBits(op) => {
62 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
63 b.bit_reverse(ty, Some(out), input).unwrap();
64 });
65 }
66 Bitwise::LeadingZeros(op) => {
67 let width = op.input.ty.storage_type().size() as u32 * 8;
68 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
69 let width = out_ty.const_u32(b, width - 1);
71 let msb = b.id();
72 T::find_msb(b, ty, input, msb);
73 b.mark_uniformity(msb, uniform);
74 b.i_sub(ty, Some(out), width, msb).unwrap();
75 });
76 }
77 Bitwise::FindFirstSet(op) => {
78 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
79 let one = out_ty.const_u32(b, 1);
80 let lsb = b.id();
81 T::find_lsb(b, ty, input, lsb);
82 b.mark_uniformity(lsb, uniform);
83 b.i_add(ty, Some(out), lsb, one).unwrap();
85 });
86 }
87 }
88 }
89}
90
91fn bool_op(bitwise: &Bitwise) -> Option<Operator> {
102 match bitwise {
103 Bitwise::BitwiseAnd(op)
104 if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
105 {
106 Some(Operator::And(op.clone()))
107 }
108 Bitwise::BitwiseOr(op)
109 if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
110 {
111 Some(Operator::Or(op.clone()))
112 }
113 Bitwise::BitwiseNot(op) if op.input.elem_type() == ElemType::Bool => {
114 Some(Operator::Not(op.clone()))
115 }
116 _ => None,
117 }
118}
119
120#[cube]
121pub(crate) fn small_int_reverse<I: Int>(x: Line<I>, #[comptime] width: u32) -> Line<I> {
122 let shift = comptime!(32 - width);
123
124 let reversed = Line::reverse_bits(Line::<u32>::cast_from(x));
125 Line::cast_from(reversed >> Line::new(shift))
126}
127
128#[cube]
129pub(crate) fn u64_reverse<I: Int>(x: Line<I>) -> Line<I> {
130 let shift = Line::new(I::new(32));
131
132 let low = Line::<u32>::cast_from(x);
133 let high = Line::<u32>::cast_from(x >> shift);
134
135 let low_rev = Line::reverse_bits(low);
136 let high_rev = Line::reverse_bits(high);
137 let high = Line::cast_from(low_rev) << shift;
139 high | Line::cast_from(high_rev)
140}
141
142#[cube]
143pub(crate) fn u64_count_bits<I: Int>(x: Line<I>) -> Line<u32> {
144 let shift = Line::new(I::new(32));
145
146 let low = Line::<u32>::cast_from(x);
147 let high = Line::<u32>::cast_from(x >> shift);
148
149 let low_cnt = Line::<u32>::cast_from(Line::count_ones(low));
150 let high_cnt = Line::<u32>::cast_from(Line::count_ones(high));
151 low_cnt + high_cnt
152}
153
154#[cube]
155pub(crate) fn u64_leading_zeros<I: Int>(x: Line<I>) -> Line<u32> {
156 let shift = Line::new(I::new(32));
157
158 let low = Line::<u32>::cast_from(x);
159 let high = Line::<u32>::cast_from(x >> shift);
160 let low_zeros = Line::leading_zeros(low);
161 let high_zeros = Line::leading_zeros(high);
162
163 select_many(
164 high_zeros.equal(Line::new(32)),
165 low_zeros + high_zeros,
166 high_zeros,
167 )
168}
169
170#[cube]
175pub(crate) fn u64_ffs<I: Int>(x: Line<I>) -> Line<u32> {
176 let shift = Line::new(I::new(32));
177
178 let low = Line::<u32>::cast_from(x);
179 let high = Line::<u32>::cast_from(x >> shift);
180 let low_ffs = Line::find_first_set(low);
181 let high_ffs = Line::find_first_set(high);
182
183 let high_ffs = select_many(
184 high_ffs.equal(Line::new(0)),
185 high_ffs,
186 high_ffs + Line::new(32),
187 );
188 select_many(low_ffs.equal(Line::new(0)), high_ffs, low_ffs)
189}