1use cubecl_core::{
2 self as cubecl,
3 ir::{ElemType, Operator},
4};
5use cubecl_core::{comptime, ir as core, prelude::*};
6use cubecl_core::{cube, ir::Bitwise};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Elem};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11 pub fn compile_bitwise(&mut self, op: Bitwise, out: Option<core::Variable>, uniform: bool) {
12 if let Some(op) = bool_op(&op) {
13 self.compile_operator(op, out, uniform);
14 return;
15 }
16
17 let out = out.unwrap();
18 match op {
19 Bitwise::BitwiseAnd(op) => {
20 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
21 b.bitwise_and(ty, Some(out), lhs, rhs).unwrap();
22 })
23 }
24 Bitwise::BitwiseOr(op) => {
25 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
26 b.bitwise_or(ty, Some(out), lhs, rhs).unwrap();
27 })
28 }
29 Bitwise::BitwiseXor(op) => {
30 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
31 b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap();
32 })
33 }
34 Bitwise::BitwiseNot(op) => {
35 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
36 b.not(ty, Some(out), input).unwrap();
37 });
38 }
39 Bitwise::ShiftLeft(op) => {
40 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
41 b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap();
42 })
43 }
44 Bitwise::ShiftRight(op) => {
45 self.compile_binary_op(op, out, uniform, |b, item, ty, lhs, rhs, out| {
46 match item.elem() {
47 Elem::Int(_, true) => {
49 b.shift_right_arithmetic(ty, Some(out), lhs, rhs).unwrap()
50 }
51 _ => b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(),
52 };
53 })
54 }
55
56 Bitwise::CountOnes(op) => {
57 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
58 b.bit_count(ty, Some(out), input).unwrap();
59 });
60 }
61 Bitwise::ReverseBits(op) => {
62 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
63 b.bit_reverse(ty, Some(out), input).unwrap();
64 });
65 }
66 Bitwise::LeadingZeros(op) => {
67 let width = op.input.ty.storage_type().size() as u32 * 8;
68 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
69 let width = out_ty.const_u32(b, width - 1);
71 let msb = b.id();
72 T::find_msb(b, ty, input, msb);
73 b.mark_uniformity(msb, uniform);
74 b.i_sub(ty, Some(out), width, msb).unwrap();
75 });
76 }
77 Bitwise::FindFirstSet(op) => {
78 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
79 let one = out_ty.const_u32(b, 1);
80 let lsb = b.id();
81 T::find_lsb(b, ty, input, lsb);
82 b.mark_uniformity(lsb, uniform);
83 b.i_add(ty, Some(out), lsb, one).unwrap();
85 });
86 }
87 Bitwise::TrailingZeros(op) => {
88 let width = op.input.ty.storage_type().size() as u32 * 8;
89 self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
90 let width_const = out_ty.const_u32(b, width);
93 let zero = out_ty.const_u32(b, 0);
94 let lsb = b.id();
95 T::find_lsb(b, ty, input, lsb);
96 b.mark_uniformity(lsb, uniform);
97 let bool_ty = b.type_bool();
99 let is_zero = b.id();
100 b.i_equal(bool_ty, Some(is_zero), input, zero).unwrap();
101 b.mark_uniformity(is_zero, uniform);
102 b.select(ty, Some(out), is_zero, width_const, lsb).unwrap();
104 });
105 }
106 }
107 }
108}
109
110fn bool_op(bitwise: &Bitwise) -> Option<Operator> {
121 match bitwise {
122 Bitwise::BitwiseAnd(op)
123 if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
124 {
125 Some(Operator::And(op.clone()))
126 }
127 Bitwise::BitwiseOr(op)
128 if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
129 {
130 Some(Operator::Or(op.clone()))
131 }
132 Bitwise::BitwiseNot(op) if op.input.elem_type() == ElemType::Bool => {
133 Some(Operator::Not(op.clone()))
134 }
135 _ => None,
136 }
137}
138
139#[cube]
140pub(crate) fn small_int_reverse<I: Int>(x: Line<I>, #[comptime] width: u32) -> Line<I> {
141 let shift = comptime!(32 - width);
142
143 let reversed = Line::reverse_bits(Line::<u32>::cast_from(x));
144 Line::cast_from(reversed >> Line::new(shift))
145}
146
147#[cube]
148pub(crate) fn u64_reverse<I: Int>(x: Line<I>) -> Line<I> {
149 let shift = Line::new(I::new(32));
150
151 let low = Line::<u32>::cast_from(x);
152 let high = Line::<u32>::cast_from(x >> shift);
153
154 let low_rev = Line::reverse_bits(low);
155 let high_rev = Line::reverse_bits(high);
156 let high = Line::cast_from(low_rev) << shift;
158 high | Line::cast_from(high_rev)
159}
160
161#[cube]
162pub(crate) fn u64_count_bits<I: Int>(x: Line<I>) -> Line<u32> {
163 let shift = Line::new(I::new(32));
164
165 let low = Line::<u32>::cast_from(x);
166 let high = Line::<u32>::cast_from(x >> shift);
167
168 let low_cnt = Line::<u32>::cast_from(Line::count_ones(low));
169 let high_cnt = Line::<u32>::cast_from(Line::count_ones(high));
170 low_cnt + high_cnt
171}
172
173#[cube]
174pub(crate) fn u64_leading_zeros<I: Int>(x: Line<I>) -> Line<u32> {
175 let shift = Line::new(I::new(32));
176
177 let low = Line::<u32>::cast_from(x);
178 let high = Line::<u32>::cast_from(x >> shift);
179 let low_zeros = Line::leading_zeros(low);
180 let high_zeros = Line::leading_zeros(high);
181
182 select_many(
183 high_zeros.equal(Line::new(32)),
184 low_zeros + high_zeros,
185 high_zeros,
186 )
187}
188
189#[cube]
194pub(crate) fn u64_ffs<I: Int>(x: Line<I>) -> Line<u32> {
195 let shift = Line::new(I::new(32));
196
197 let low = Line::<u32>::cast_from(x);
198 let high = Line::<u32>::cast_from(x >> shift);
199 let low_ffs = Line::find_first_set(low);
200 let high_ffs = Line::find_first_set(high);
201
202 let high_ffs = select_many(
203 high_ffs.equal(Line::new(0)),
204 high_ffs,
205 high_ffs + Line::new(32),
206 );
207 select_many(low_ffs.equal(Line::new(0)), high_ffs, low_ffs)
208}