Skip to main content

cubecl_spirv/
bitwise.rs

1use cubecl_core::{
2    self as cubecl,
3    ir::{ElemType, Operator},
4};
5use cubecl_core::{comptime, ir as core, prelude::*};
6use cubecl_core::{cube, ir::Bitwise};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Elem};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11    pub fn compile_bitwise(&mut self, op: Bitwise, out: Option<core::Variable>, uniform: bool) {
12        if let Some(op) = bool_op(&op) {
13            self.compile_operator(op, out, uniform);
14            return;
15        }
16
17        let out = out.unwrap();
18        match op {
19            Bitwise::BitwiseAnd(op) => {
20                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
21                    b.bitwise_and(ty, Some(out), lhs, rhs).unwrap();
22                })
23            }
24            Bitwise::BitwiseOr(op) => {
25                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
26                    b.bitwise_or(ty, Some(out), lhs, rhs).unwrap();
27                })
28            }
29            Bitwise::BitwiseXor(op) => {
30                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
31                    b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap();
32                })
33            }
34            Bitwise::BitwiseNot(op) => {
35                self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
36                    b.not(ty, Some(out), input).unwrap();
37                });
38            }
39            Bitwise::ShiftLeft(op) => {
40                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
41                    b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap();
42                })
43            }
44            Bitwise::ShiftRight(op) => {
45                self.compile_binary_op(op, out, uniform, |b, item, ty, lhs, rhs, out| {
46                    match item.elem() {
47                        // Match behaviour on most compilers
48                        Elem::Int(_, true) => {
49                            b.shift_right_arithmetic(ty, Some(out), lhs, rhs).unwrap()
50                        }
51                        _ => b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(),
52                    };
53                })
54            }
55
56            Bitwise::CountOnes(op) => {
57                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
58                    b.bit_count(ty, Some(out), input).unwrap();
59                });
60            }
61            Bitwise::ReverseBits(op) => {
62                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
63                    b.bit_reverse(ty, Some(out), input).unwrap();
64                });
65            }
66            Bitwise::LeadingZeros(op) => {
67                let width = op.input.ty.storage_type().size() as u32 * 8;
68                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
69                    // Indices are zero based, so subtract 1
70                    let width = out_ty.const_u32(b, width - 1);
71                    let msb = b.id();
72                    T::find_msb(b, ty, input, msb);
73                    b.mark_uniformity(msb, uniform);
74                    b.i_sub(ty, Some(out), width, msb).unwrap();
75                });
76            }
77            Bitwise::FindFirstSet(op) => {
78                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
79                    let one = out_ty.const_u32(b, 1);
80                    let lsb = b.id();
81                    T::find_lsb(b, ty, input, lsb);
82                    b.mark_uniformity(lsb, uniform);
83                    // Normalize to CUDA/POSIX convention of 1 based index, with 0 meaning not found
84                    b.i_add(ty, Some(out), lsb, one).unwrap();
85                });
86            }
87            Bitwise::TrailingZeros(op) => {
88                let width = op.input.ty.storage_type().size() as u32 * 8;
89                self.compile_unary_op(op, out, uniform, |b, out_ty, ty, input, out| {
90                    // find_lsb returns -1 (0xFFFFFFFF) for zero input
91                    // trailing_zeros should return bit_width for zero input
92                    let width_const = out_ty.const_u32(b, width);
93                    let zero = out_ty.const_u32(b, 0);
94                    let lsb = b.id();
95                    T::find_lsb(b, ty, input, lsb);
96                    b.mark_uniformity(lsb, uniform);
97                    // Check if input is zero
98                    let bool_ty = out_ty.same_vectorization(Elem::Bool).id(b);
99                    let is_zero = b.id();
100                    b.i_equal(bool_ty, Some(is_zero), input, zero).unwrap();
101                    b.mark_uniformity(is_zero, uniform);
102                    // Select width if zero, otherwise lsb
103                    b.select(ty, Some(out), is_zero, width_const, lsb).unwrap();
104                });
105            }
106        }
107    }
108}
109
110/// Map bitwise on boolean to logical, since bitwise ops aren't allowed in Vulkan. This fixes the
111/// case of
112/// ```ignore
113/// let a = true;
114/// for shape in 0..dims {
115///     a |= shape < width;
116/// }
117/// ```
118///
119/// Rust maps this to logical and/or internally, but the macro only sees the bitwise op.
120fn bool_op(bitwise: &Bitwise) -> Option<Operator> {
121    match bitwise {
122        Bitwise::BitwiseAnd(op)
123            if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
124        {
125            Some(Operator::And(op.clone()))
126        }
127        Bitwise::BitwiseOr(op)
128            if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
129        {
130            Some(Operator::Or(op.clone()))
131        }
132        Bitwise::BitwiseNot(op) if op.input.elem_type() == ElemType::Bool => {
133            Some(Operator::Not(op.clone()))
134        }
135        _ => None,
136    }
137}
138
139#[cube]
140pub(crate) fn small_int_reverse<I: Int, N: Size>(
141    x: Vector<I, N>,
142    #[comptime] width: u32,
143) -> Vector<I, N> {
144    let shift = comptime!(32 - width);
145
146    let reversed = Vector::reverse_bits(Vector::<u32, N>::cast_from(x));
147    Vector::cast_from(reversed >> Vector::new(shift))
148}
149
150#[cube]
151pub(crate) fn u64_reverse<I: Int, N: Size>(x: Vector<I, N>) -> Vector<I, N> {
152    let shift = Vector::new(I::new(32));
153
154    let low = Vector::<u32, N>::cast_from(x);
155    let high = Vector::<u32, N>::cast_from(x >> shift);
156
157    let low_rev = Vector::reverse_bits(low);
158    let high_rev = Vector::reverse_bits(high);
159    // Swap low and high values
160    let high = Vector::cast_from(low_rev) << shift;
161    high | Vector::cast_from(high_rev)
162}
163
164#[cube]
165pub(crate) fn u64_count_bits<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
166    let shift = Vector::new(I::new(32));
167
168    let low = Vector::<u32, N>::cast_from(x);
169    let high = Vector::<u32, N>::cast_from(x >> shift);
170
171    let low_cnt = Vector::<u32, N>::cast_from(Vector::count_ones(low));
172    let high_cnt = Vector::<u32, N>::cast_from(Vector::count_ones(high));
173    low_cnt + high_cnt
174}
175
176#[cube]
177pub(crate) fn u64_leading_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
178    let shift = Vector::new(I::new(32));
179
180    let low = Vector::<u32, N>::cast_from(x);
181    let high = Vector::<u32, N>::cast_from(x >> shift);
182    let low_zeros = Vector::leading_zeros(low);
183    let high_zeros = Vector::leading_zeros(high);
184
185    select_many(
186        high_zeros.equal(Vector::new(32)),
187        low_zeros + high_zeros,
188        high_zeros,
189    )
190}
191
192/// There are three possible outcomes:
193/// * low has any set -> return low
194/// * low is empty, high has any set -> return high + 32
195/// * low and high are empty -> return 0
196#[cube]
197pub(crate) fn u64_ffs<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
198    let shift = Vector::new(I::new(32));
199
200    let low = Vector::<u32, N>::cast_from(x);
201    let high = Vector::<u32, N>::cast_from(x >> shift);
202    let low_ffs = Vector::find_first_set(low);
203    let high_ffs = Vector::find_first_set(high);
204
205    let high_ffs = select_many(
206        high_ffs.equal(Vector::new(0)),
207        high_ffs,
208        high_ffs + Vector::new(32),
209    );
210    select_many(low_ffs.equal(Vector::new(0)), high_ffs, low_ffs)
211}
212
213/// Subtract extra leading zeros after normalizing
214#[cube]
215pub(crate) fn u16_u8_leading_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
216    let width = I::type_size_bits().comptime() as u32;
217    let over_width = Vector::new(32 - width);
218
219    let x = Vector::<u32, N>::cast_from(x);
220    let lz = x.leading_zeros();
221    lz - over_width
222}
223
224/// There are three possible outcomes:
225/// * low has any set -> return low
226/// * low is empty, high has any set -> return high + 32
227/// * low and high are empty -> return 0
228#[cube]
229pub(crate) fn u64_trailing_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
230    let shift = Vector::new(I::new(32));
231
232    let low = Vector::<u32, N>::cast_from(x);
233    let high = Vector::<u32, N>::cast_from(x >> shift);
234    let low_tz = Vector::trailing_zeros(low);
235    let high_tz = Vector::trailing_zeros(high);
236
237    let high_tz = select_many(
238        high_tz.equal(Vector::new(32)),
239        Vector::new(64),
240        high_tz + Vector::new(32),
241    );
242    select_many(low_tz.equal(Vector::new(32)), high_tz, low_tz)
243}
244
245/// Clamp to width
246#[cube]
247pub(crate) fn u16_u8_trailing_zeros<I: Int, N: Size>(x: Vector<I, N>) -> Vector<u32, N> {
248    let width = Vector::new(I::type_size_bits().comptime() as u32);
249
250    let x = Vector::<u32, N>::cast_from(x);
251    let lz = x.trailing_zeros();
252    lz.min(width)
253}