1use std::hash::Hash;
2use std::{collections::HashSet, fmt::Debug, num::NonZero};
3
4use cubecl_core::ir::expand_checked_index_assign;
5use cubecl_core::{
6 ir::{self as gpu},
7 Compiler, Feature,
8};
9use cubecl_runtime::{DeviceProperties, ExecutionMode};
10
11use super::{
12 AtomicKind, BinaryInstruction, Binding, Body, ComputeKernel, ConstArray, Elem, Fragment,
13 FragmentIdent, FragmentLayout, Instruction, Item, LocalArray, SharedMemory, UnaryInstruction,
14 Variable, VariableSettings, WarpInstruction, WmmaCompiler, WmmaInstruction,
15};
16
17pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
18 std::sync::atomic::AtomicU32::new(0);
19
20pub trait Dialect:
21 WmmaCompiler<Self> + Default + Clone + Copy + Debug + Send + Sync + Eq + Hash + 'static
22{
23 fn include_f16(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
25 fn include_bf16(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
26 fn include_runtime(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
27 fn bfloat16_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
29 fn bfloat162_type_name(f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result;
30 fn warp_shuffle(var: &str, source: &str) -> String;
32 fn warp_shuffle_xor(var: &str, offset: &str) -> String;
33 fn warp_shuffle_down(var: &str, offset: &str) -> String;
34 fn warp_all(var: &str) -> String;
35 fn warp_any(var: &str) -> String;
36}
37
38#[derive(Clone, Debug)]
39pub struct CompilationOptions {
40 pub warp_size: u32,
41}
42
43impl Default for CompilationOptions {
44 fn default() -> Self {
45 Self { warp_size: 32 }
46 }
47}
48
49#[allow(clippy::too_many_arguments)]
50#[derive(Clone, Debug, Default)]
51pub struct CppCompiler<D: Dialect> {
52 shared_memories: Vec<SharedMemory<D>>,
53 const_arrays: Vec<ConstArray<D>>,
54 local_arrays: Vec<LocalArray<D>>,
55 metadata: cubecl_core::Metadata,
56 warp_size_checked: bool,
57 wmma: bool,
58 bf16: bool,
59 f16: bool,
60 printf: bool,
61 num_inputs: usize,
62 num_outputs: usize,
63 ext_meta_positions: Vec<u32>,
64 items: HashSet<Item<D>>,
65 strategy: ExecutionMode,
66 settings: VariableSettings,
67 compilation_options: CompilationOptions,
68}
69
70impl<D: Dialect> Compiler for CppCompiler<D> {
71 type Representation = ComputeKernel<D>;
72 type CompilationOptions = CompilationOptions;
73
74 fn compile(
75 kernel: cubecl_core::ir::KernelDefinition,
76 compilation_options: &Self::CompilationOptions,
77 strategy: ExecutionMode,
78 ) -> Self::Representation {
79 let compiler = Self {
80 compilation_options: compilation_options.clone(),
81 strategy,
82 ..Self::default()
83 };
84 let ir = compiler.compile_ir(kernel);
85 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
86 ir
87 }
88
89 fn elem_size(elem: gpu::Elem) -> usize {
90 elem.size()
91 }
92
93 fn max_shared_memory_size() -> usize {
94 49152
95 }
96}
97
98impl<D: Dialect> CppCompiler<D> {
99 fn compile_ir(mut self, mut value: gpu::KernelDefinition) -> ComputeKernel<D> {
100 self.build_metadata(&value);
101
102 let instructions = self.compile_scope(&mut value.body);
103 let inputs = value
104 .inputs
105 .into_iter()
106 .map(|b| self.compile_binding(b))
107 .collect();
108 let outputs = value
109 .outputs
110 .into_iter()
111 .map(|b| self.compile_binding(b))
112 .collect();
113 let named = value
114 .named
115 .into_iter()
116 .map(|(name, binding)| (name, self.compile_binding(binding)))
117 .collect();
118
119 let body = Body {
120 instructions,
121 shared_memories: self.shared_memories,
122 const_arrays: self.const_arrays,
123 local_arrays: self.local_arrays,
124 warp_size_checked: self.warp_size_checked,
125 settings: self.settings,
126 };
127
128 ComputeKernel {
129 inputs,
130 outputs,
131 named,
132 cube_dim: value.cube_dim,
133 body,
134 wmma_activated: self.wmma,
135 bf16: self.bf16,
136 f16: self.f16,
137 items: self.items,
138 kernel_name: value.kernel_name,
139 }
140 }
141
142 fn build_metadata(&mut self, value: &gpu::KernelDefinition) {
143 self.num_inputs = value.inputs.len();
144 self.num_outputs = value.outputs.len();
145
146 let mut num_ext = 0;
147
148 for binding in value.inputs.iter().chain(value.outputs.iter()) {
149 self.ext_meta_positions.push(num_ext);
150 if binding.has_extended_meta {
151 num_ext += 1;
152 }
153 }
154
155 let num_meta = self.num_inputs + self.num_outputs;
156
157 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
158 }
159
160 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
161 let pos = match var.kind {
162 gpu::VariableKind::GlobalInputArray(id) => id as usize,
163 gpu::VariableKind::GlobalOutputArray(id) => self.num_inputs + id as usize,
164 other => panic!("Only global arrays have metadata, got {other:?}"),
165 };
166 self.ext_meta_positions[pos]
167 }
168
169 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
170 let mut instructions = Vec::new();
171
172 let const_arrays = scope
173 .const_arrays
174 .drain(..)
175 .map(|(var, values)| ConstArray {
176 index: var.index().unwrap(),
177 item: self.compile_item(var.item),
178 size: values.len() as u32,
179 values: values
180 .into_iter()
181 .map(|val| self.compile_variable(val))
182 .collect(),
183 })
184 .collect::<Vec<_>>();
185 self.const_arrays.extend(const_arrays);
186
187 let processing = scope.process();
188
189 for var in processing.variables {
190 if let gpu::VariableKind::Slice { .. } = var.kind {
191 continue;
192 }
193 instructions.push(Instruction::DeclareVariable {
194 var: self.compile_variable(var),
195 });
196 }
197
198 processing
199 .operations
200 .into_iter()
201 .for_each(|op| self.compile_operation(&mut instructions, op, scope));
202
203 instructions
204 }
205
206 fn compile_operation(
207 &mut self,
208 instructions: &mut Vec<Instruction<D>>,
209 instruction: gpu::Instruction,
210 scope: &mut gpu::Scope,
211 ) {
212 let out = instruction.out;
213 match instruction.operation {
214 gpu::Operation::Copy(variable) => {
215 instructions.push(Instruction::Assign(UnaryInstruction {
216 input: self.compile_variable(variable),
217 out: self.compile_variable(out.unwrap()),
218 }));
219 }
220 gpu::Operation::Operator(op) => self.compile_instruction(op, out, instructions, scope),
221 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
222 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
223 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
224 gpu::Operation::Synchronization(val) => match val {
225 gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads),
226 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
227 },
228 gpu::Operation::Plane(op) => {
229 self.warp_size_checked = true;
230 let out = self.compile_variable(out.unwrap());
231 match op {
232 gpu::Plane::Sum(op) => {
233 instructions.push(Instruction::Wrap(WarpInstruction::ReduceSum {
234 input: self.compile_variable(op.input),
235 out,
236 }))
237 }
238 gpu::Plane::Prod(op) => {
239 instructions.push(Instruction::Wrap(WarpInstruction::ReduceProd {
240 input: self.compile_variable(op.input),
241 out,
242 }))
243 }
244 gpu::Plane::Max(op) => {
245 instructions.push(Instruction::Wrap(WarpInstruction::ReduceMax {
246 input: self.compile_variable(op.input),
247 out,
248 }))
249 }
250 gpu::Plane::Min(op) => {
251 instructions.push(Instruction::Wrap(WarpInstruction::ReduceMin {
252 input: self.compile_variable(op.input),
253 out,
254 }))
255 }
256 gpu::Plane::Elect => {
257 instructions.push(Instruction::Wrap(WarpInstruction::Elect { out }))
258 }
259 gpu::Plane::All(op) => {
260 instructions.push(Instruction::Wrap(WarpInstruction::All {
261 input: self.compile_variable(op.input),
262 out,
263 }))
264 }
265 gpu::Plane::Any(op) => {
266 instructions.push(Instruction::Wrap(WarpInstruction::Any {
267 input: self.compile_variable(op.input),
268 out,
269 }))
270 }
271 gpu::Plane::Broadcast(op) => {
272 instructions.push(Instruction::Wrap(WarpInstruction::Broadcast {
273 input: self.compile_variable(op.lhs),
274 id: self.compile_variable(op.rhs),
275 out,
276 }))
277 }
278 }
279 }
280 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
281 gpu::Operation::NonSemantic(debug) => match debug {
282 gpu::NonSemantic::BeginCall { .. }
284 | gpu::NonSemantic::EndCall
285 | gpu::NonSemantic::Source { .. }
286 | gpu::NonSemantic::Line { .. } => {}
287 gpu::NonSemantic::Print {
288 format_string,
289 args,
290 } => {
291 self.printf = true;
292 instructions.push(Instruction::Printf {
293 format_string,
294 args: args
295 .into_iter()
296 .map(|arg| self.compile_variable(arg))
297 .collect(),
298 })
299 }
300 gpu::NonSemantic::Comment { content } => {
301 instructions.push(Instruction::Comment { content })
302 }
303 },
304 }
305 }
306
307 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
308 let out = self.compile_variable(out.unwrap());
309 match cmma {
310 gpu::CoopMma::Fill { value } => Instruction::Wmma(WmmaInstruction::Fill {
311 frag: out,
312 value: self.compile_variable(value),
313 }),
314 gpu::CoopMma::Load {
315 value,
316 stride,
317 layout,
318 } => Instruction::Wmma(WmmaInstruction::Load {
319 frag: out,
320 value: self.compile_variable(value),
321 stride: self.compile_variable(stride),
322 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
323 }),
324 gpu::CoopMma::Execute {
325 mat_a,
326 mat_b,
327 mat_c,
328 } => Instruction::Wmma(WmmaInstruction::Execute {
329 frag_a: self.compile_variable(mat_a),
330 frag_b: self.compile_variable(mat_b),
331 frag_c: self.compile_variable(mat_c),
332 frag_d: out,
333 warp_size: self.compilation_options.warp_size,
334 }),
335 gpu::CoopMma::Store {
336 mat,
337 stride,
338 layout,
339 } => Instruction::Wmma(WmmaInstruction::Store {
340 output: out,
341 frag: self.compile_variable(mat),
342 stride: self.compile_variable(stride),
343 layout: self
344 .compile_matrix_layout(layout)
345 .expect("Layout required for store instruction"),
346 }),
347 gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast {
348 input: self.compile_variable(input),
349 output: out,
350 }),
351 }
352 }
353
354 fn compile_metadata(
355 &mut self,
356 metadata: gpu::Metadata,
357 out: Option<gpu::Variable>,
358 ) -> Instruction<D> {
359 let out = out.unwrap();
360 match metadata {
361 gpu::Metadata::Stride { dim, var } => {
362 let position = self.ext_meta_position(var);
363 let offset = self.metadata.stride_offset_index(position);
364 Instruction::ExtendedMetadata {
365 info_offset: self.compile_variable(offset.into()),
366 dim: self.compile_variable(dim),
367 out: self.compile_variable(out),
368 }
369 }
370 gpu::Metadata::Shape { dim, var } => {
371 let position = self.ext_meta_position(var);
372 let offset = self.metadata.shape_offset_index(position);
373 Instruction::ExtendedMetadata {
374 info_offset: self.compile_variable(offset.into()),
375 dim: self.compile_variable(dim),
376 out: self.compile_variable(out),
377 }
378 }
379 gpu::Metadata::Rank { var } => {
380 let out = self.compile_variable(out);
381 let pos = self.ext_meta_position(var);
382 let offset = self.metadata.rank_index(pos);
383 super::Instruction::Metadata {
384 info_offset: self.compile_variable(offset.into()),
385 out,
386 }
387 }
388 gpu::Metadata::Length { var } => {
389 let input = self.compile_variable(var);
390 let out = self.compile_variable(out);
391
392 match input {
393 Variable::Slice { .. } => Instruction::SliceLength { input, out },
394 _ => {
395 let id = match input {
396 Variable::GlobalInputArray(id, _) => id,
397 Variable::GlobalOutputArray(id, _) => self.num_inputs as u32 + id,
398 _ => panic!("Can only get length of global array"),
399 };
400 let offset = self.metadata.len_index(id);
401 Instruction::Metadata {
402 info_offset: self.compile_variable(offset.into()),
403 out,
404 }
405 }
406 }
407 }
408 gpu::Metadata::BufferLength { var } => {
409 let input = self.compile_variable(var);
410 let out = self.compile_variable(out);
411
412 match input {
413 Variable::Slice { .. } => Instruction::SliceLength { input, out },
414 _ => {
415 let id = match input {
416 Variable::GlobalInputArray(id, _) => id,
417 Variable::GlobalOutputArray(id, _) => self.num_inputs as u32 + id,
418 _ => panic!("Can only get buffer length of global array"),
419 };
420 let offset = self.metadata.buffer_len_index(id);
421 Instruction::Metadata {
422 info_offset: self.compile_variable(offset.into()),
423 out,
424 }
425 }
426 }
427 }
428 }
429 }
430
431 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
432 match branch {
433 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
434 cond: self.compile_variable(op.cond),
435 instructions: self.compile_scope(&mut op.scope),
436 }),
437 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
438 cond: self.compile_variable(op.cond),
439 instructions_if: self.compile_scope(&mut op.scope_if),
440 instructions_else: self.compile_scope(&mut op.scope_else),
441 }),
442 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
443 value: self.compile_variable(op.value),
444 instructions_default: self.compile_scope(&mut op.scope_default),
445 instructions_cases: op
446 .cases
447 .into_iter()
448 .map(|(val, mut block)| {
449 (self.compile_variable(val), self.compile_scope(&mut block))
450 })
451 .collect(),
452 }),
453 gpu::Branch::Return => instructions.push(Instruction::Return),
454 gpu::Branch::Break => instructions.push(Instruction::Break),
455 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
456 i: self.compile_variable(range_loop.i),
457 start: self.compile_variable(range_loop.start),
458 end: self.compile_variable(range_loop.end),
459 step: range_loop.step.map(|it| self.compile_variable(it)),
460 inclusive: range_loop.inclusive,
461 instructions: self.compile_scope(&mut range_loop.scope),
462 }),
463 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
464 instructions: self.compile_scope(&mut op.scope),
465 }),
466 };
467 }
468
469 fn compile_atomic(
470 &mut self,
471 value: gpu::AtomicOp,
472 out: Option<gpu::Variable>,
473 instructions: &mut Vec<Instruction<D>>,
474 ) {
475 let out = out.unwrap();
476 match value {
477 gpu::AtomicOp::Load(op) => {
478 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
479 }
480 gpu::AtomicOp::Store(op) => {
481 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
482 }
483 gpu::AtomicOp::Swap(op) => {
484 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
485 }
486 gpu::AtomicOp::Add(op) => {
487 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
488 }
489 gpu::AtomicOp::Sub(op) => {
490 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
491 }
492 gpu::AtomicOp::Max(op) => {
493 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
494 }
495 gpu::AtomicOp::Min(op) => {
496 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
497 }
498 gpu::AtomicOp::And(op) => {
499 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
500 }
501 gpu::AtomicOp::Or(op) => {
502 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
503 }
504 gpu::AtomicOp::Xor(op) => {
505 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
506 }
507 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
508 input: self.compile_variable(op.input),
509 cmp: self.compile_variable(op.cmp),
510 val: self.compile_variable(op.val),
511 out: self.compile_variable(out),
512 }),
513 }
514 }
515
516 fn compile_instruction(
517 &mut self,
518 value: gpu::Operator,
519 out: Option<gpu::Variable>,
520 instructions: &mut Vec<Instruction<D>>,
521 scope: &mut gpu::Scope,
522 ) {
523 let out = out.unwrap();
524 match value {
525 gpu::Operator::Add(op) => {
526 instructions.push(Instruction::Add(self.compile_binary(op, out)))
527 }
528 gpu::Operator::Mul(op) => {
529 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
530 }
531 gpu::Operator::Div(op) => {
532 instructions.push(Instruction::Div(self.compile_binary(op, out)))
533 }
534 gpu::Operator::Sub(op) => {
535 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
536 }
537 gpu::Operator::Slice(op) => {
538 if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() {
539 let input = op.input;
540 let input_len =
541 scope.create_local_mut(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32)));
542 instructions.extend(self.compile_scope(scope));
543
544 let length = match input.has_buffer_length() {
545 true => gpu::Metadata::BufferLength { var: input },
546 false => gpu::Metadata::Length { var: input },
547 };
548
549 instructions.push(self.compile_metadata(length, Some(input_len)));
550 instructions.push(Instruction::CheckedSlice {
551 input: self.compile_variable(op.input),
552 start: self.compile_variable(op.start),
553 end: self.compile_variable(op.end),
554 out: self.compile_variable(out),
555 len: self.compile_variable(input_len),
556 });
557 } else {
558 instructions.push(Instruction::Slice {
559 input: self.compile_variable(op.input),
560 start: self.compile_variable(op.start),
561 end: self.compile_variable(op.end),
562 out: self.compile_variable(out),
563 })
564 }
565 }
566 gpu::Operator::Index(op) => {
567 if matches!(self.strategy, ExecutionMode::Checked) && op.lhs.has_length() {
568 let lhs = op.lhs;
569 let rhs = op.rhs;
570
571 let array_len =
572 scope.create_local(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32)));
573
574 instructions.extend(self.compile_scope(scope));
575
576 let length = match lhs.has_buffer_length() {
577 true => gpu::Metadata::BufferLength { var: lhs },
578 false => gpu::Metadata::Length { var: lhs },
579 };
580 instructions.push(self.compile_metadata(length, Some(array_len)));
581 instructions.push(Instruction::CheckedIndex {
582 len: self.compile_variable(array_len),
583 lhs: self.compile_variable(lhs),
584 rhs: self.compile_variable(rhs),
585 out: self.compile_variable(out),
586 });
587 } else {
588 instructions.push(Instruction::Index(self.compile_binary(op, out)));
589 }
590 }
591 gpu::Operator::UncheckedIndex(op) => {
592 instructions.push(Instruction::Index(self.compile_binary(op, out)))
593 }
594 gpu::Operator::IndexAssign(op) => {
595 if let ExecutionMode::Checked = self.strategy {
596 if out.has_length() {
597 expand_checked_index_assign(scope, op.lhs, op.rhs, out);
598 instructions.extend(self.compile_scope(scope));
599 return;
600 }
601 };
602 instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)));
603 }
604 gpu::Operator::UncheckedIndexAssign(op) => {
605 instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)))
606 }
607 gpu::Operator::Modulo(op) => {
608 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
609 }
610 gpu::Operator::Equal(op) => {
611 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
612 }
613 gpu::Operator::Lower(op) => {
614 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
615 }
616 gpu::Operator::Greater(op) => {
617 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
618 }
619 gpu::Operator::LowerEqual(op) => {
620 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
621 }
622 gpu::Operator::GreaterEqual(op) => {
623 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
624 }
625 gpu::Operator::Abs(op) => {
626 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
627 }
628 gpu::Operator::Exp(op) => {
629 instructions.push(Instruction::Exp(self.compile_unary(op, out)))
630 }
631 gpu::Operator::Log(op) => {
632 instructions.push(Instruction::Log(self.compile_unary(op, out)))
633 }
634 gpu::Operator::Log1p(op) => {
635 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
636 }
637 gpu::Operator::Cos(op) => {
638 instructions.push(Instruction::Cos(self.compile_unary(op, out)))
639 }
640 gpu::Operator::Sin(op) => {
641 instructions.push(Instruction::Sin(self.compile_unary(op, out)))
642 }
643 gpu::Operator::Tanh(op) => {
644 instructions.push(Instruction::Tanh(self.compile_unary(op, out)))
645 }
646 gpu::Operator::Powf(op) => {
647 instructions.push(Instruction::Powf(self.compile_binary(op, out)))
648 }
649 gpu::Operator::Sqrt(op) => {
650 instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
651 }
652 gpu::Operator::Erf(op) => {
653 instructions.push(Instruction::Erf(self.compile_unary(op, out)))
654 }
655 gpu::Operator::And(op) => {
656 instructions.push(Instruction::And(self.compile_binary(op, out)))
657 }
658 gpu::Operator::Or(op) => {
659 instructions.push(Instruction::Or(self.compile_binary(op, out)))
660 }
661 gpu::Operator::Not(op) => {
662 instructions.push(Instruction::Not(self.compile_unary(op, out)))
663 }
664 gpu::Operator::Max(op) => {
665 instructions.push(Instruction::Max(self.compile_binary(op, out)))
666 }
667 gpu::Operator::Min(op) => {
668 instructions.push(Instruction::Min(self.compile_binary(op, out)))
669 }
670 gpu::Operator::NotEqual(op) => {
671 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
672 }
673 gpu::Operator::BitwiseOr(op) => {
674 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
675 }
676 gpu::Operator::BitwiseAnd(op) => {
677 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
678 }
679 gpu::Operator::BitwiseXor(op) => {
680 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
681 }
682 gpu::Operator::CountOnes(op) => {
683 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
684 }
685 gpu::Operator::ReverseBits(op) => {
686 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
687 }
688 gpu::Operator::ShiftLeft(op) => {
689 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
690 }
691 gpu::Operator::ShiftRight(op) => {
692 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
693 }
694 gpu::Operator::Clamp(op) => instructions.push(Instruction::Clamp {
695 input: self.compile_variable(op.input),
696 min_value: self.compile_variable(op.min_value),
697 max_value: self.compile_variable(op.max_value),
698 out: self.compile_variable(out),
699 }),
700 gpu::Operator::Recip(op) => {
701 let elem = op.input.item.elem();
702 let lhs = match elem {
703 gpu::Elem::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
704 gpu::Elem::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
705 gpu::Elem::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
706 gpu::Elem::Bool => gpu::ConstantScalarValue::Bool(true),
707 gpu::Elem::AtomicInt(_)
708 | gpu::Elem::AtomicUInt(_)
709 | gpu::Elem::AtomicFloat(_) => {
710 panic!("Cannot use recip with atomics")
711 }
712 };
713
714 instructions.push(Instruction::Div(BinaryInstruction {
715 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
716 rhs: self.compile_variable(op.input),
717 out: self.compile_variable(out),
718 }))
719 }
720 gpu::Operator::Round(op) => {
721 instructions.push(Instruction::Round(self.compile_unary(op, out)))
722 }
723 gpu::Operator::Floor(op) => {
724 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
725 }
726 gpu::Operator::Ceil(op) => {
727 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
728 }
729 gpu::Operator::Remainder(op) => {
730 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
731 }
732 gpu::Operator::Fma(op) => instructions.push(Instruction::Fma {
733 a: self.compile_variable(op.a),
734 b: self.compile_variable(op.b),
735 c: self.compile_variable(op.c),
736 out: self.compile_variable(out),
737 }),
738 gpu::Operator::Bitcast(op) => {
739 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
740 }
741 gpu::Operator::Neg(op) => {
742 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
743 }
744 gpu::Operator::Normalize(op) => {
745 instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
746 }
747 gpu::Operator::Magnitude(op) => {
748 instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
749 }
750 gpu::Operator::Dot(op) => {
751 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
752 }
753 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
754 inputs: op
755 .inputs
756 .into_iter()
757 .map(|it| self.compile_variable(it))
758 .collect(),
759 out: self.compile_variable(out),
760 }),
761 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
762 input: self.compile_variable(op.input),
763 in_index: self.compile_variable(op.in_index),
764 out: self.compile_variable(out),
765 out_index: self.compile_variable(op.out_index),
766 }),
767 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
768 input: self.compile_variable(op.input),
769 in_index: self.compile_variable(op.in_index),
770 out: self.compile_variable(out),
771 out_index: self.compile_variable(op.out_index),
772 len: op.len,
773 }),
774 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
775 cond: self.compile_variable(op.cond),
776 then: self.compile_variable(op.then),
777 or_else: self.compile_variable(op.or_else),
778 out: self.compile_variable(out),
779 }),
780 gpu::Operator::Cast(op) => {
781 instructions.push(Instruction::Assign(self.compile_unary(op, out)))
782 }
783 };
784 }
785
786 fn compile_binary(
787 &mut self,
788 value: gpu::BinaryOperator,
789 out: gpu::Variable,
790 ) -> BinaryInstruction<D> {
791 BinaryInstruction {
792 lhs: self.compile_variable(value.lhs),
793 rhs: self.compile_variable(value.rhs),
794 out: self.compile_variable(out),
795 }
796 }
797
798 fn compile_unary(
799 &mut self,
800 value: gpu::UnaryOperator,
801 out: gpu::Variable,
802 ) -> UnaryInstruction<D> {
803 UnaryInstruction {
804 input: self.compile_variable(value.input),
805 out: self.compile_variable(out),
806 }
807 }
808
809 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
810 let item = value.item;
811 match value.kind {
812 gpu::VariableKind::GlobalInputArray(id) => {
813 Variable::GlobalInputArray(id, self.compile_item(item))
814 }
815 gpu::VariableKind::GlobalScalar(id) => {
816 Variable::GlobalScalar(id, self.compile_item(item).elem, item.elem)
817 }
818 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
819 id,
820 item: self.compile_item(item),
821 },
822 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
823 id,
824 item: self.compile_item(item),
825 },
826 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
827 id,
828 item: self.compile_item(item),
829 },
830 gpu::VariableKind::Slice { id } => Variable::Slice {
831 id,
832 item: self.compile_item(item),
833 },
834 gpu::VariableKind::GlobalOutputArray(id) => {
835 Variable::GlobalOutputArray(id, self.compile_item(item))
836 }
837 gpu::VariableKind::ConstantScalar(value) => {
838 Variable::ConstantScalar(value, self.compile_elem(value.elem()))
839 }
840 gpu::VariableKind::SharedMemory { id, length } => {
841 let item = self.compile_item(item);
842 if !self.shared_memories.iter().any(|s| s.index == id) {
843 self.shared_memories
844 .push(SharedMemory::new(id, item, length));
845 }
846 Variable::SharedMemory(id, item, length)
847 }
848 gpu::VariableKind::ConstantArray { id, length } => {
849 let item = self.compile_item(item);
850 Variable::ConstantArray(id, item, length)
851 }
852 gpu::VariableKind::Builtin(builtin) => match builtin {
853 gpu::Builtin::AbsolutePos => {
854 self.settings.idx_global = true;
855 Variable::IdxGlobal
856 }
857 gpu::Builtin::UnitPos => {
858 self.settings.thread_idx_global = true;
859 Variable::ThreadIdxGlobal
860 }
861 gpu::Builtin::UnitPosX => Variable::ThreadIdxX,
862 gpu::Builtin::UnitPosY => Variable::ThreadIdxY,
863 gpu::Builtin::UnitPosZ => Variable::ThreadIdxZ,
864 gpu::Builtin::CubePosX => Variable::BlockIdxX,
865 gpu::Builtin::CubePosY => Variable::BlockIdxY,
866 gpu::Builtin::CubePosZ => Variable::BlockIdxZ,
867 gpu::Builtin::AbsolutePosX => {
868 self.settings.absolute_idx.0 = true;
869 Variable::AbsoluteIdxX
870 }
871 gpu::Builtin::AbsolutePosY => {
872 self.settings.absolute_idx.1 = true;
873 Variable::AbsoluteIdxY
874 }
875 gpu::Builtin::AbsolutePosZ => {
876 self.settings.absolute_idx.2 = true;
877 Variable::AbsoluteIdxZ
878 }
879 gpu::Builtin::CubeDimX => Variable::BlockDimX,
880 gpu::Builtin::CubeDimY => Variable::BlockDimY,
881 gpu::Builtin::CubeDimZ => Variable::BlockDimZ,
882 gpu::Builtin::CubeCountX => Variable::GridDimX,
883 gpu::Builtin::CubeCountY => Variable::GridDimY,
884 gpu::Builtin::CubeCountZ => Variable::GridDimZ,
885 gpu::Builtin::CubePos => {
886 self.settings.block_idx_global = true;
887 Variable::BlockIdxGlobal
888 }
889 gpu::Builtin::CubeDim => {
890 self.settings.block_dim_global = true;
891 Variable::BlockDimGlobal
892 }
893 gpu::Builtin::CubeCount => {
894 self.settings.grid_dim_global = true;
895 Variable::GridDimGlobal
896 }
897 gpu::Builtin::PlaneDim => Variable::WarpSize,
898 gpu::Builtin::UnitPosPlane => Variable::ThreadIdxWarp,
899 },
900 gpu::VariableKind::LocalArray { id, length } => {
901 let item = self.compile_item(item);
902 if !self.local_arrays.iter().any(|s| s.index == id) {
903 self.local_arrays.push(LocalArray::new(id, item, length));
904 }
905 Variable::LocalArray(id, item, length)
906 }
907 gpu::VariableKind::Matrix { id, mat } => {
908 self.wmma = true;
909 Variable::WmmaFragment {
910 id,
911 frag: self.compile_matrix(mat),
912 }
913 }
914 }
915 }
916
917 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
918 Fragment {
919 ident: self.compile_matrix_ident(matrix.ident),
920 m: matrix.m,
921 n: matrix.n,
922 k: matrix.k,
923 elem: self.compile_elem(matrix.elem),
924 layout: self.compile_matrix_layout(matrix.layout),
925 }
926 }
927
928 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
929 match ident {
930 gpu::MatrixIdent::A => FragmentIdent::A,
931 gpu::MatrixIdent::B => FragmentIdent::B,
932 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
933 }
934 }
935
936 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
937 match layout {
938 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
939 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
940 gpu::MatrixLayout::Undefined => None,
941 }
942 }
943
944 fn compile_binding(&mut self, binding: gpu::Binding) -> Binding<D> {
945 Binding {
946 item: self.compile_item(binding.item),
947 size: binding.size,
948 vis: binding.visibility,
949 }
950 }
951
952 fn compile_item(&mut self, item: gpu::Item) -> Item<D> {
953 let item = Item::new(
954 self.compile_elem(item.elem),
955 item.vectorization.map(NonZero::get).unwrap_or(1).into(),
956 );
957 if item.elem != super::Elem::TF32 {
958 self.items.insert(item);
959 self.items.insert(item.optimized());
960 } else {
961 let mut item = item;
963 item.elem = super::Elem::F32;
964 self.items.insert(item);
965 }
966
967 item
968 }
969
970 fn compile_elem(&mut self, value: gpu::Elem) -> Elem<D> {
971 match value {
972 gpu::Elem::Float(kind) => match kind {
973 gpu::FloatKind::F16 => {
974 self.f16 = true;
975 Elem::F16
976 }
977 gpu::FloatKind::BF16 => {
978 self.bf16 = true;
979 Elem::BF16
980 }
981 gpu::FloatKind::TF32 => Elem::TF32,
982 gpu::FloatKind::Flex32 => Elem::F32,
983 gpu::FloatKind::F32 => Elem::F32,
984 gpu::FloatKind::F64 => Elem::F64,
985 },
986 gpu::Elem::AtomicFloat(kind) => match kind {
987 gpu::FloatKind::F16 => Elem::Atomic(AtomicKind::F16),
988 gpu::FloatKind::BF16 => Elem::Atomic(AtomicKind::BF16),
989 gpu::FloatKind::F32 => Elem::Atomic(AtomicKind::F32),
990 gpu::FloatKind::F64 => Elem::Atomic(AtomicKind::F64),
991 kind => unimplemented!("atomic<{kind:?}> not yet supported"),
992 },
993 gpu::Elem::Int(kind) => match kind {
994 gpu::IntKind::I8 => Elem::I8,
995 gpu::IntKind::I16 => Elem::I16,
996 gpu::IntKind::I32 => Elem::I32,
997 gpu::IntKind::I64 => Elem::I64,
998 },
999 gpu::Elem::AtomicInt(kind) => match kind {
1000 gpu::IntKind::I32 => Elem::Atomic(AtomicKind::I32),
1001 gpu::IntKind::I64 => Elem::Atomic(AtomicKind::I64),
1002 kind => panic!("atomic<{kind:?}> isn't supported yet"),
1003 },
1004 gpu::Elem::UInt(kind) => match kind {
1005 gpu::UIntKind::U8 => Elem::U8,
1006 gpu::UIntKind::U16 => Elem::U16,
1007 gpu::UIntKind::U32 => Elem::U32,
1008 gpu::UIntKind::U64 => Elem::U64,
1009 },
1010 gpu::Elem::AtomicUInt(kind) => match kind {
1011 gpu::UIntKind::U32 => Elem::Atomic(AtomicKind::U32),
1012 gpu::UIntKind::U64 => Elem::Atomic(AtomicKind::U64),
1013 kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1014 },
1015 gpu::Elem::Bool => Elem::Bool,
1016 }
1017 }
1018}
1019
1020pub fn register_supported_types(props: &mut DeviceProperties<Feature>) {
1021 let supported_types = [
1022 gpu::Elem::UInt(gpu::UIntKind::U8),
1023 gpu::Elem::UInt(gpu::UIntKind::U16),
1024 gpu::Elem::UInt(gpu::UIntKind::U32),
1025 gpu::Elem::UInt(gpu::UIntKind::U64),
1026 gpu::Elem::Int(gpu::IntKind::I8),
1027 gpu::Elem::Int(gpu::IntKind::I16),
1028 gpu::Elem::Int(gpu::IntKind::I32),
1029 gpu::Elem::Int(gpu::IntKind::I64),
1030 gpu::Elem::AtomicInt(gpu::IntKind::I32),
1031 gpu::Elem::AtomicInt(gpu::IntKind::I64),
1032 gpu::Elem::AtomicUInt(gpu::UIntKind::U32),
1033 gpu::Elem::AtomicUInt(gpu::UIntKind::U64),
1034 gpu::Elem::Float(gpu::FloatKind::BF16),
1035 gpu::Elem::Float(gpu::FloatKind::F16),
1036 gpu::Elem::Float(gpu::FloatKind::F32),
1037 gpu::Elem::Float(gpu::FloatKind::Flex32),
1038 gpu::Elem::AtomicFloat(gpu::FloatKind::F32),
1039 gpu::Elem::Bool,
1042 ];
1043
1044 for ty in supported_types {
1045 props.register_feature(Feature::Type(ty));
1046 }
1047}