1use cubecl_core::ir::{
2 self as core, BinaryOperator, Comparison, Instruction, Operation, Operator, UnaryOperator,
3};
4use rspirv::spirv::{Capability, Decoration, Word};
5
6use crate::{
7 SpirvCompiler, SpirvTarget,
8 item::{Elem, Item},
9 variable::IndexedVariable,
10};
11
12impl<T: SpirvTarget> SpirvCompiler<T> {
13 pub fn compile_operation(&mut self, inst: Instruction) {
14 if !matches!(inst.operation, Operation::NonSemantic(_)) {
16 self.set_source_loc(&inst.source_loc);
17 }
18 let uniform = matches!(inst.out, Some(out) if self.uniformity.is_var_uniform(out));
19 match inst.operation {
20 Operation::Copy(var) => {
21 let input = self.compile_variable(var);
22 let out = self.compile_variable(inst.out());
23 let ty = out.item().id(self);
24 let in_id = self.read(&input);
25 let in_id = input.item().broadcast(self, in_id, None, &out.item());
26 let out_id = self.write_id(&out);
27
28 self.copy_object(ty, Some(out_id), in_id).unwrap();
29 self.mark_uniformity(out_id, uniform);
30 self.write(&out, out_id);
31 }
32 Operation::Arithmetic(operator) => self.compile_arithmetic(operator, inst.out, uniform),
33 Operation::Comparison(operator) => self.compile_cmp(operator, inst.out, uniform),
34 Operation::Bitwise(operator) => self.compile_bitwise(operator, inst.out, uniform),
35 Operation::Operator(operator) => self.compile_operator(operator, inst.out, uniform),
36 Operation::Atomic(atomic) => self.compile_atomic(atomic, inst.out),
37 Operation::Branch(_) => unreachable!("Branches shouldn't exist in optimized IR"),
38 Operation::Metadata(meta) => self.compile_meta(meta, inst.out, uniform),
39 Operation::Plane(plane) => self.compile_plane(plane, inst.out, uniform),
40 Operation::Synchronization(sync) => self.compile_sync(sync),
41 Operation::CoopMma(cmma) => self.compile_cmma(cmma, inst.out),
42 Operation::NonSemantic(debug) => self.compile_debug(debug),
43 Operation::Barrier(_) => panic!("Barrier not supported in SPIR-V"),
44 Operation::Tma(_) => panic!("TMA not supported in SPIR-V"),
45 Operation::Free(_) => {}
46 }
47 }
48
49 pub fn compile_cmp(&mut self, op: Comparison, out: Option<core::Variable>, uniform: bool) {
50 let out = out.unwrap();
51 match op {
52 Comparison::Equal(op) => {
53 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
54 match lhs_ty.elem() {
55 Elem::Bool => b.logical_equal(ty, Some(out), lhs, rhs),
56 Elem::Int(_, _) => b.i_equal(ty, Some(out), lhs, rhs),
57 Elem::Float(..) => b.f_ord_equal(ty, Some(out), lhs, rhs),
58 Elem::Relaxed => {
59 b.decorate(out, Decoration::RelaxedPrecision, []);
60 b.f_ord_equal(ty, Some(out), lhs, rhs)
61 }
62 Elem::Void => unreachable!(),
63 }
64 .unwrap();
65 });
66 }
67 Comparison::NotEqual(op) => {
68 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
69 match lhs_ty.elem() {
70 Elem::Bool => b.logical_not_equal(ty, Some(out), lhs, rhs),
71 Elem::Int(_, _) => b.i_not_equal(ty, Some(out), lhs, rhs),
72 Elem::Float(..) => b.f_ord_not_equal(ty, Some(out), lhs, rhs),
73 Elem::Relaxed => {
74 b.decorate(out, Decoration::RelaxedPrecision, []);
75 b.f_ord_not_equal(ty, Some(out), lhs, rhs)
76 }
77 Elem::Void => unreachable!(),
78 }
79 .unwrap();
80 });
81 }
82 Comparison::Lower(op) => {
83 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
84 match lhs_ty.elem() {
85 Elem::Int(_, false) => b.u_less_than(ty, Some(out), lhs, rhs),
86 Elem::Int(_, true) => b.s_less_than(ty, Some(out), lhs, rhs),
87 Elem::Float(..) => b.f_ord_less_than(ty, Some(out), lhs, rhs),
88 Elem::Relaxed => {
89 b.decorate(out, Decoration::RelaxedPrecision, []);
90 b.f_ord_less_than(ty, Some(out), lhs, rhs)
91 }
92 _ => unreachable!(),
93 }
94 .unwrap();
95 });
96 }
97 Comparison::LowerEqual(op) => {
98 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
99 match lhs_ty.elem() {
100 Elem::Int(_, false) => b.u_less_than_equal(ty, Some(out), lhs, rhs),
101 Elem::Int(_, true) => b.s_less_than_equal(ty, Some(out), lhs, rhs),
102 Elem::Float(..) => b.f_ord_less_than_equal(ty, Some(out), lhs, rhs),
103 Elem::Relaxed => {
104 b.decorate(out, Decoration::RelaxedPrecision, []);
105 b.f_ord_less_than_equal(ty, Some(out), lhs, rhs)
106 }
107 _ => unreachable!(),
108 }
109 .unwrap();
110 });
111 }
112 Comparison::Greater(op) => {
113 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
114 match lhs_ty.elem() {
115 Elem::Int(_, false) => b.u_greater_than(ty, Some(out), lhs, rhs),
116 Elem::Int(_, true) => b.s_greater_than(ty, Some(out), lhs, rhs),
117 Elem::Float(..) => b.f_ord_greater_than(ty, Some(out), lhs, rhs),
118 Elem::Relaxed => {
119 b.decorate(out, Decoration::RelaxedPrecision, []);
120 b.f_ord_greater_than(ty, Some(out), lhs, rhs)
121 }
122 _ => unreachable!(),
123 }
124 .unwrap();
125 });
126 }
127 Comparison::GreaterEqual(op) => {
128 self.compile_binary_op_bool(op, out, uniform, |b, lhs_ty, ty, lhs, rhs, out| {
129 match lhs_ty.elem() {
130 Elem::Int(_, false) => b.u_greater_than_equal(ty, Some(out), lhs, rhs),
131 Elem::Int(_, true) => b.s_greater_than_equal(ty, Some(out), lhs, rhs),
132 Elem::Float(..) => b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs),
133 Elem::Relaxed => {
134 b.decorate(out, Decoration::RelaxedPrecision, []);
135 b.f_ord_greater_than_equal(ty, Some(out), lhs, rhs)
136 }
137 _ => unreachable!(),
138 }
139 .unwrap();
140 });
141 }
142 Comparison::IsNan(op) => {
143 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
144 b.is_nan(ty, Some(out), input).unwrap();
145 });
146 }
147 Comparison::IsInf(op) => {
148 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
149 b.is_inf(ty, Some(out), input).unwrap();
150 });
151 }
152 }
153 }
154
155 pub fn compile_operator(&mut self, op: Operator, out: Option<core::Variable>, uniform: bool) {
156 let out = out.unwrap();
157 match op {
158 Operator::Index(op) | Operator::UncheckedIndex(op) => {
159 let is_atomic = op.list.ty.is_atomic();
160 let value = self.compile_variable(op.list);
161 let index = self.compile_variable(op.index);
162 let out = self.compile_variable(out);
163
164 if is_atomic {
165 let ptr = match self.index(&value, &index, true) {
166 IndexedVariable::Pointer(ptr, _) => ptr,
167 _ => unreachable!("Atomic is always pointer"),
168 };
169 let out_id = out.as_binding().unwrap();
170
171 self.merge_binding(out_id, ptr);
173 } else {
174 let out_id = self.read_indexed(&out, &value, &index);
175 self.mark_uniformity(out_id, uniform);
176 self.write(&out, out_id);
177 }
178 }
179 Operator::IndexAssign(op) | Operator::UncheckedIndexAssign(op) => {
180 let index = self.compile_variable(op.index);
181 let value = self.compile_variable(op.value);
182 let out = self.compile_variable(out);
183 let value_id = self.read_as(&value, &out.indexed_item());
184
185 self.write_indexed(&out, &index, value_id);
186 }
187 Operator::Cast(op) => {
188 let input = self.compile_variable(op.input);
189 let out = self.compile_variable(out);
190 let ty = out.item().id(self);
191 let in_id = self.read(&input);
192 let out_id = self.write_id(&out);
193 self.mark_uniformity(out_id, uniform);
194
195 if let Some(as_const) = input.as_const() {
196 let cast = self.static_cast(as_const, &input.elem(), &out.item());
197 self.copy_object(ty, Some(out_id), cast).unwrap();
198 } else {
199 input.item().cast_to(self, Some(out_id), in_id, &out.item());
200 }
201
202 self.write(&out, out_id);
203 }
204 Operator::And(op) => {
205 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
206 b.logical_and(ty, Some(out), lhs, rhs).unwrap();
207 });
208 }
209 Operator::Or(op) => {
210 self.compile_binary_op(op, out, uniform, |b, _, ty, lhs, rhs, out| {
211 b.logical_or(ty, Some(out), lhs, rhs).unwrap();
212 });
213 }
214 Operator::Not(op) => {
215 self.compile_unary_op_cast(op, out, uniform, |b, _, ty, input, out| {
216 b.logical_not(ty, Some(out), input).unwrap();
217 });
218 }
219 Operator::Reinterpret(op) => {
220 self.compile_unary_op(op, out, uniform, |b, _, ty, input, out| {
221 b.bitcast(ty, Some(out), input).unwrap();
222 })
223 }
224 Operator::InitLine(op) => {
225 let values = op
226 .inputs
227 .into_iter()
228 .map(|input| self.compile_variable(input))
229 .collect::<Vec<_>>()
230 .into_iter()
231 .map(|it| self.read(&it))
232 .collect::<Vec<_>>();
233 let item = self.compile_type(out.ty);
234 let out = self.compile_variable(out);
235 let out_id = self.write_id(&out);
236 self.mark_uniformity(out_id, uniform);
237 let ty = item.id(self);
238 self.composite_construct(ty, Some(out_id), values).unwrap();
239 self.write(&out, out_id);
240 }
241 Operator::CopyMemory(op) => {
242 let input = self.compile_variable(op.input);
243 let in_index = self.compile_variable(op.in_index);
244 let out = self.compile_variable(out);
245 let out_index = self.compile_variable(op.out_index);
246
247 let in_ptr = self.index_ptr(&input, &in_index);
248 let out_ptr = self.index_ptr(&out, &out_index);
249 self.copy_memory(out_ptr, in_ptr, None, None, vec![])
250 .unwrap();
251 }
252 Operator::CopyMemoryBulk(op) => {
253 self.capabilities.insert(Capability::Addresses);
254 let input = self.compile_variable(op.input);
255 let in_index = self.compile_variable(op.in_index);
256 let out = self.compile_variable(out);
257 let out_index = self.compile_variable(op.out_index);
258 let len = op.len;
259
260 let source = self.index_ptr(&input, &in_index);
261 let target = self.index_ptr(&out, &out_index);
262 let size = self.const_u32(len * out.item().size());
263 self.copy_memory_sized(target, source, size, None, None, vec![])
264 .unwrap();
265 }
266 Operator::Select(op) => self.compile_select(op.cond, op.then, op.or_else, out, uniform),
267 }
268 }
269
270 pub fn compile_unary_op_cast(
271 &mut self,
272 op: UnaryOperator,
273 out: core::Variable,
274 uniform: bool,
275 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
276 ) {
277 let input = self.compile_variable(op.input);
278 let out = self.compile_variable(out);
279 let out_ty = out.item();
280
281 let input_id = self.read_as(&input, &out_ty);
282 let out_id = self.write_id(&out);
283 self.mark_uniformity(out_id, uniform);
284
285 let ty = out_ty.id(self);
286
287 exec(self, out_ty, ty, input_id, out_id);
288 self.write(&out, out_id);
289 }
290
291 pub fn compile_unary_op(
292 &mut self,
293 op: UnaryOperator,
294 out: core::Variable,
295 uniform: bool,
296 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
297 ) {
298 let input = self.compile_variable(op.input);
299 let out = self.compile_variable(out);
300 let out_ty = out.item();
301
302 let input_id = self.read(&input);
303 let out_id = self.write_id(&out);
304 self.mark_uniformity(out_id, uniform);
305
306 let ty = out_ty.id(self);
307
308 exec(self, out_ty, ty, input_id, out_id);
309 self.write(&out, out_id);
310 }
311
312 pub fn compile_unary_op_bool(
313 &mut self,
314 op: UnaryOperator,
315 out: core::Variable,
316 uniform: bool,
317 exec: impl FnOnce(&mut Self, Item, Word, Word, Word),
318 ) {
319 let input = self.compile_variable(op.input);
320 let out = self.compile_variable(out);
321 let in_ty = input.item();
322
323 let input_id = self.read(&input);
324 let out_id = self.write_id(&out);
325 self.mark_uniformity(out_id, uniform);
326
327 let ty = out.item().id(self);
328
329 exec(self, in_ty, ty, input_id, out_id);
330 self.write(&out, out_id);
331 }
332
333 pub fn compile_binary_op(
334 &mut self,
335 op: BinaryOperator,
336 out: core::Variable,
337 uniform: bool,
338 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
339 ) {
340 let lhs = self.compile_variable(op.lhs);
341 let rhs = self.compile_variable(op.rhs);
342 let out = self.compile_variable(out);
343 let out_ty = out.item();
344
345 let lhs_id = self.read_as(&lhs, &out_ty);
346 let rhs_id = self.read_as(&rhs, &out_ty);
347 let out_id = self.write_id(&out);
348 self.mark_uniformity(out_id, uniform);
349
350 let ty = out_ty.id(self);
351
352 exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
353 self.write(&out, out_id);
354 }
355
356 pub fn compile_binary_op_no_cast(
357 &mut self,
358 op: BinaryOperator,
359 out: core::Variable,
360 uniform: bool,
361 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
362 ) {
363 let lhs = self.compile_variable(op.lhs);
364 let rhs = self.compile_variable(op.rhs);
365 let out = self.compile_variable(out);
366 let out_ty = out.item();
367
368 let lhs_id = self.read(&lhs);
369 let rhs_id = self.read(&rhs);
370 let out_id = self.write_id(&out);
371 self.mark_uniformity(out_id, uniform);
372
373 let ty = out_ty.id(self);
374
375 exec(self, out_ty, ty, lhs_id, rhs_id, out_id);
376 self.write(&out, out_id);
377 }
378
379 pub fn compile_binary_op_bool(
380 &mut self,
381 op: BinaryOperator,
382 out: core::Variable,
383 uniform: bool,
384 exec: impl FnOnce(&mut Self, Item, Word, Word, Word, Word),
385 ) {
386 let lhs = self.compile_variable(op.lhs);
387 let rhs = self.compile_variable(op.rhs);
388 let out = self.compile_variable(out);
389
390 let in_ty = out.item().same_vectorization(lhs.elem());
391
392 let lhs_id = self.read_as(&lhs, &in_ty);
393 let rhs_id = self.read_as(&rhs, &in_ty);
394 let out_id = self.write_id(&out);
395 self.mark_uniformity(out_id, uniform);
396
397 let ty = out.item().id(self);
398
399 exec(self, in_ty, ty, lhs_id, rhs_id, out_id);
400 self.write(&out, out_id);
401 }
402
403 pub fn compile_select(
404 &mut self,
405 cond: core::Variable,
406 then: core::Variable,
407 or_else: core::Variable,
408 out: core::Variable,
409 uniform: bool,
410 ) {
411 let cond = self.compile_variable(cond);
412 let then = self.compile_variable(then);
413 let or_else = self.compile_variable(or_else);
414 let out = self.compile_variable(out);
415
416 let out_ty = out.item();
417 let ty = out_ty.id(self);
418
419 let cond_id = self.read(&cond);
420 let then = self.read_as(&then, &out_ty);
421 let or_else = self.read_as(&or_else, &out_ty);
422 let out_id = self.write_id(&out);
423 self.mark_uniformity(out_id, uniform);
424
425 self.select(ty, Some(out_id), cond_id, then, or_else)
426 .unwrap();
427 self.write(&out, out_id);
428 }
429
430 pub fn mark_uniformity(&mut self, id: Word, uniform: bool) {
431 if uniform {
432 self.decorate(id, Decoration::Uniform, []);
433 }
434 }
435}