cubecl_spirv/
bitwise.rs

1use cubecl_core::{
2    self as cubecl,
3    ir::{ElemType, Operator},
4};
5use cubecl_core::{comptime, ir as core, prelude::*};
6use cubecl_core::{cube, ir::Bitwise};
7
8use crate::{SpirvCompiler, SpirvTarget, item::Elem};
9
10impl<T: SpirvTarget> SpirvCompiler<T> {
11    pub fn compile_bitwise(&mut self, op: Bitwise, out: Option<core::Variable>, uniform: bool) {
12        if let Some(op) = bool_op(&op) {
13            self.compile_operator(op, out, uniform);
14            return;
15        }
16
17        let out = out.unwrap();
18        match op {
19            Bitwise::BitwiseAnd(op) => {
20                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
21                    b.bitwise_and(ty, Some(out), lhs, rhs).unwrap();
22                })
23            }
24            Bitwise::BitwiseOr(op) => {
25                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
26                    b.bitwise_or(ty, Some(out), lhs, rhs).unwrap();
27                })
28            }
29            Bitwise::BitwiseXor(op) => {
30                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
31                    b.bitwise_xor(ty, Some(out), lhs, rhs).unwrap();
32                })
33            }
34            Bitwise::BitwiseNot(op) => {
35                self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
36                    b.not(ty, Some(out), input).unwrap();
37                });
38            }
39            Bitwise::ShiftLeft(op) => {
40                self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
41                    b.shift_left_logical(ty, Some(out), lhs, rhs).unwrap();
42                })
43            }
44            Bitwise::ShiftRight(op) => {
45                self.compile_binary_op(op, out, uniform, |b, item, ty, lhs, rhs, out| {
46                    match item.elem() {
47                        // Match behaviour on most compilers
48                        Elem::Int(_, true) => {
49                            b.shift_right_arithmetic(ty, Some(out), lhs, rhs).unwrap()
50                        }
51                        _ => b.shift_right_logical(ty, Some(out), lhs, rhs).unwrap(),
52                    };
53                })
54            }
55
56            Bitwise::CountOnes(op) => {
57                self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
58                    b.bit_count(ty, Some(out), input).unwrap();
59                });
60            }
61            Bitwise::ReverseBits(op) => {
62                self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
63                    b.bit_reverse(ty, Some(out), input).unwrap();
64                });
65            }
66            Bitwise::LeadingZeros(op) => {
67                let width = op.input.ty.storage_type().size() as u32 * 8;
68                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
69                    // Indices are zero based, so subtract 1
70                    let width = out_ty.const_u32(b, width - 1);
71                    let msb = b.id();
72                    T::find_msb(b, ty, input, msb);
73                    b.mark_uniformity(msb, uniform);
74                    b.i_sub(ty, Some(out), width, msb).unwrap();
75                });
76            }
77            Bitwise::FindFirstSet(op) => {
78                self.compile_unary_op_cast(op, out, uniform, |b, out_ty, ty, input, out| {
79                    let one = out_ty.const_u32(b, 1);
80                    let lsb = b.id();
81                    T::find_lsb(b, ty, input, lsb);
82                    b.mark_uniformity(lsb, uniform);
83                    // Normalize to CUDA/POSIX convention of 1 based index, with 0 meaning not found
84                    b.i_add(ty, Some(out), lsb, one).unwrap();
85                });
86            }
87        }
88    }
89}
90
91/// Map bitwise on boolean to logical, since bitwise ops aren't allowed in Vulkan. This fixes the
92/// case of
93/// ```ignore
94/// let a = true;
95/// for shape in 0..dims {
96///     a |= shape < width;
97/// }
98/// ```
99///
100/// Rust maps this to logical and/or internally, but the macro only sees the bitwise op.
101fn bool_op(bitwise: &Bitwise) -> Option<Operator> {
102    match bitwise {
103        Bitwise::BitwiseAnd(op)
104            if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
105        {
106            Some(Operator::And(op.clone()))
107        }
108        Bitwise::BitwiseOr(op)
109            if op.lhs.elem_type() == ElemType::Bool || op.rhs.elem_type() == ElemType::Bool =>
110        {
111            Some(Operator::Or(op.clone()))
112        }
113        Bitwise::BitwiseNot(op) if op.input.elem_type() == ElemType::Bool => {
114            Some(Operator::Not(op.clone()))
115        }
116        _ => None,
117    }
118}
119
120#[cube]
121pub(crate) fn small_int_reverse<I: Int>(x: Line<I>, #[comptime] width: u32) -> Line<I> {
122    let shift = comptime!(32 - width);
123
124    let reversed = Line::reverse_bits(Line::<u32>::cast_from(x));
125    Line::cast_from(reversed >> Line::new(shift))
126}
127
128#[cube]
129pub(crate) fn u64_reverse<I: Int>(x: Line<I>) -> Line<I> {
130    let shift = Line::new(I::new(32));
131
132    let low = Line::<u32>::cast_from(x);
133    let high = Line::<u32>::cast_from(x >> shift);
134
135    let low_rev = Line::reverse_bits(low);
136    let high_rev = Line::reverse_bits(high);
137    // Swap low and high values
138    let high = Line::cast_from(low_rev) << shift;
139    high | Line::cast_from(high_rev)
140}
141
142#[cube]
143pub(crate) fn u64_count_bits<I: Int>(x: Line<I>) -> Line<u32> {
144    let shift = Line::new(I::new(32));
145
146    let low = Line::<u32>::cast_from(x);
147    let high = Line::<u32>::cast_from(x >> shift);
148
149    let low_cnt = Line::<u32>::cast_from(Line::count_ones(low));
150    let high_cnt = Line::<u32>::cast_from(Line::count_ones(high));
151    low_cnt + high_cnt
152}
153
154#[cube]
155pub(crate) fn u64_leading_zeros<I: Int>(x: Line<I>) -> Line<u32> {
156    let shift = Line::new(I::new(32));
157
158    let low = Line::<u32>::cast_from(x);
159    let high = Line::<u32>::cast_from(x >> shift);
160    let low_zeros = Line::leading_zeros(low);
161    let high_zeros = Line::leading_zeros(high);
162
163    select_many(
164        high_zeros.equal(Line::new(32)),
165        low_zeros + high_zeros,
166        high_zeros,
167    )
168}
169
170/// There are three possible outcomes:
171/// * low has any set -> return low
172/// * low is empty, high has any set -> return high + 32
173/// * low and high are empty -> return 0
174#[cube]
175pub(crate) fn u64_ffs<I: Int>(x: Line<I>) -> Line<u32> {
176    let shift = Line::new(I::new(32));
177
178    let low = Line::<u32>::cast_from(x);
179    let high = Line::<u32>::cast_from(x >> shift);
180    let low_ffs = Line::find_first_set(low);
181    let high_ffs = Line::find_first_set(high);
182
183    let high_ffs = select_many(
184        high_ffs.equal(Line::new(0)),
185        high_ffs,
186        high_ffs + Line::new(32),
187    );
188    select_many(low_ffs.equal(Line::new(0)), high_ffs, low_ffs)
189}