1use cubecl_core::{
2 self as cubecl,
3 ir::{ElemType, Operator},
4};
5use cubecl_core::{comptime, ir as core, prelude::*};
6use cubecl_core::{cube, ir::Bitwise};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Elem};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_bitwise(&mut self, op: Bitwise, out: Option<core::Variable>, uniform: bool) {
12 if let Some(op) = bool_op(&op) {
13 self.compile_operator(op, out, uniform);
14 return;
15 }
16
17 let out = out.unwrap();
18 match op {
19 Bitwise::BitwiseAnd(op) => {
20 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
21 b.bitwise_and(ty, Some(out), lhs, rhs).unwrap();
22 })
23 }
24 Bitwise::BitwiseOr(op) => {
25 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
26 b.bitwise_or(ty, Some(out), lhs, rhs).unwrap();
27 })
28 }
29 Bitwise::BitwiseXor(op) => {
30 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
31 b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap();
32 })
33 }
34 Bitwise::BitwiseNot(op) => {
35 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
36 b.not(ty, Some(out), input).unwrap();
37 });
38 }
39 Bitwise::ShiftLeft(op) => {
40 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
41 b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap();
42 })
43 }
44 Bitwise::ShiftRight(op) => {
45 self.compile_binary_op(op, out, uniform, |b, item, ty, lhs, rhs, out| {
46 match item.elem() {
47 Elem::Int(_, true) => {
49 b.shift_right_arithmetic(ty, Some(out), lhs, rhs).unwrap()
50 }
51 _ => b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(),
52 };
53 })
54 }
55
56 Bitwise::CountOnes(op) => {
57 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
58 b.bit_count(ty, Some(out), input).unwrap();
59 });
60 }
61 Bitwise::ReverseBits(op) => {
62 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
63 b.bit_reverse(ty, Some(out), input).unwrap();
64 });
65 }
66 Bitwise::LeadingZeros(op) => {
67 let width = op.input.ty.storage_type().size() as u32 * 8;
68 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
69 let width = out_ty.const_u32(b, width - 1);
71 let msb = b.id();
72 T::find_msb(b, ty, input, msb);
73 b.mark_uniformity(msb, uniform);
74 b.i_sub(ty, Some(out), width, msb).unwrap();
75 });
76 }
77 Bitwise::FindFirstSet(op) => {
78 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
79 let one = out_ty.const_u32(b, 1);
80 let lsb = b.id();
81 T::find_lsb(b, ty, input, lsb);
82 b.mark_uniformity(lsb, uniform);
83 b.i_add(ty, Some(out), lsb, one).unwrap();
85 });
86 }
87 Bitwise::TrailingZeros(op) => {
88 let width = op.input.ty.storage_type().size() as u32 * 8;
89 self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
90 let width_const = out_ty.const_u32(b, width);
93 let zero = out_ty.const_u32(b, 0);
94 let lsb = b.id();
95 T::find_lsb(b, ty, input, lsb);
96 b.mark_uniformity(lsb, uniform);
97 let bool_ty = out_ty.same_vectorization(Elem::Bool).id(b);
99 let is_zero = b.id();
100 b.i_equal(bool_ty, Some(is_zero), input, zero).unwrap();
101 b.mark_uniformity(is_zero, uniform);
102 b.select(ty, Some(out), is_zero, width_const, lsb).unwrap();
104 });
105 }
106 }
107 }
108}
109
110fn bool_op(bitwise: &Bitwise) -> Option<Operator> {
121 match bitwise {
122 Bitwise::BitwiseAnd(op)
123 if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
124 {
125 Some(Operator::And(op.clone()))
126 }
127 Bitwise::BitwiseOr(op)
128 if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
129 {
130 Some(Operator::Or(op.clone()))
131 }
132 Bitwise::BitwiseNot(op) if op.input.elem_type() == ElemType::Bool => {
133 Some(Operator::Not(op.clone()))
134 }
135 _ => None,
136 }
137}
138
139#[cube]
140pub(crate) fn small_int_reverse<I: Int, N: Size>(
141 x: Vector<I, N>,
142 #[comptime] width: u32,
143) -> Vector<I, N> {
144 let shift = comptime!(32 - width);
145
146 let reversed = Vector::reverse_bits(Vector::<u32, N>::cast_from(x));
147 Vector::cast_from(reversed >> Vector::new(shift))
148}
149
150#[cube]
151pub(crate) fn u64_reverse<I: Int, N: Size>(x: Vector<I, N>) -> Vector<I, N> {
152 let shift = Vector::new(I::new(32));
153
154 let low = Vector::<u32, N>::cast_from(x);
155 let high = Vector::<u32, N>::cast_from(x >> shift);
156
157 let low_rev = Vector::reverse_bits(low);
158 let high_rev = Vector::reverse_bits(high);
159 let high = Vector::cast_from(low_rev) << shift;
161 high | Vector::cast_from(high_rev)
162}
163
164#[cube]
165pub(crate) fn u64_count_bits<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
166 let shift = Vector::new(I::new(32));
167
168 let low = Vector::<u32, N>::cast_from(x);
169 let high = Vector::<u32, N>::cast_from(x >> shift);
170
171 let low_cnt = Vector::<u32, N>::cast_from(Vector::count_ones(low));
172 let high_cnt = Vector::<u32, N>::cast_from(Vector::count_ones(high));
173 low_cnt + high_cnt
174}
175
176#[cube]
177pub(crate) fn u64_leading_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
178 let shift = Vector::new(I::new(32));
179
180 let low = Vector::<u32, N>::cast_from(x);
181 let high = Vector::<u32, N>::cast_from(x >> shift);
182 let low_zeros = Vector::leading_zeros(low);
183 let high_zeros = Vector::leading_zeros(high);
184
185 select_many(
186 high_zeros.equal(Vector::new(32)),
187 low_zeros + high_zeros,
188 high_zeros,
189 )
190}
191
192#[cube]
197pub(crate) fn u64_ffs<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
198 let shift = Vector::new(I::new(32));
199
200 let low = Vector::<u32, N>::cast_from(x);
201 let high = Vector::<u32, N>::cast_from(x >> shift);
202 let low_ffs = Vector::find_first_set(low);
203 let high_ffs = Vector::find_first_set(high);
204
205 let high_ffs = select_many(
206 high_ffs.equal(Vector::new(0)),
207 high_ffs,
208 high_ffs + Vector::new(32),
209 );
210 select_many(low_ffs.equal(Vector::new(0)), high_ffs, low_ffs)
211}
212
213#[cube]
215pub(crate) fn u16_u8_leading_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
216 let width = I::type_size_bits().comptime() as u32;
217 let over_width = Vector::new(32 - width);
218
219 let x = Vector::<u32, N>::cast_from(x);
220 let lz = x.leading_zeros();
221 lz - over_width
222}
223
224#[cube]
229pub(crate) fn u64_trailing_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
230 let shift = Vector::new(I::new(32));
231
232 let low = Vector::<u32, N>::cast_from(x);
233 let high = Vector::<u32, N>::cast_from(x >> shift);
234 let low_tz = Vector::trailing_zeros(low);
235 let high_tz = Vector::trailing_zeros(high);
236
237 let high_tz = select_many(
238 high_tz.equal(Vector::new(32)),
239 Vector::new(64),
240 high_tz + Vector::new(32),
241 );
242 select_many(low_tz.equal(Vector::new(32)), high_tz, low_tz)
243}
244
245#[cube]
247pub(crate) fn u16_u8_trailing_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
248 let width = Vector::new(I::type_size_bits().comptime() as u32);
249
250 let x = Vector::<u32, N>::cast_from(x);
251 let lz = x.trailing_zeros();
252 lz.min(width)
253}