1use std::{collections::HashSet, fmt::Debug, num::NonZero};
2
3use cubecl_common::ExecutionMode;
4use cubecl_core::ir::{ExpandElement, UIntKind, VariableKind};
5use cubecl_core::prelude::{FloatExpand, Line};
6use cubecl_core::{
7 Compiler, Feature,
8 ir::{self as gpu},
9};
10use cubecl_core::{CubeDim, io::read_tensor_checked};
11use cubecl_core::{
12 ir::{Operation, SourceLoc},
13 prelude::{FastMath, KernelDefinition, expand_checked_index_assign},
14};
15use cubecl_runtime::DeviceProperties;
16
17use super::barrier::BarrierOps;
18use super::pipeline::PipelineOps;
19use super::{
20 AtomicKind, BinaryInstruction, Binding, Body, ComputeKernel, ConstArray, Dialect, Elem,
21 Fragment, FragmentIdent, FragmentLayout, Instruction, Item, LocalArray, SharedMemory,
22 UnaryInstruction, Variable, WarpInstruction, WmmaInstruction,
23};
24
25pub(super) static COUNTER_TMP_VAR: std::sync::atomic::AtomicU32 =
26 std::sync::atomic::AtomicU32::new(0);
27
28#[derive(Clone, Debug)]
29pub struct CompilationOptions {
30 pub warp_size: u32,
31 pub grid_constants: bool,
32 pub supports_clusters: bool,
33}
34
35impl Default for CompilationOptions {
36 fn default() -> Self {
37 Self {
38 warp_size: 32,
39 grid_constants: false,
40 supports_clusters: false,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Default)]
48pub struct CubeIndexFlags {
49 pub absolute_pos: bool,
50 pub absolute_pos_tuple: bool,
51 pub cube_count: bool,
52 pub cube_count_tuple: bool,
53 pub cube_dim: bool,
54 pub cube_dim_tuple: bool,
55 pub cube_pos: bool,
56 pub cube_pos_tuple: bool,
57 pub plane_dim: bool,
58 pub plane_dim_checked: bool,
59 pub plane_index: bool,
60 pub unit_pos: bool,
61 pub unit_pos_tuple: bool,
62 pub unit_pos_plane: bool,
63 pub cluster_pos: bool,
64}
65
66#[derive(Debug, Clone, Default)]
68pub struct Flags {
69 pub elem_bf16: bool,
70 pub elem_f16: bool,
71 pub indexes: CubeIndexFlags,
72 pub op_barrier: bool,
73 pub op_pipeline: bool,
74 pub inst_fast_math: bool,
75 pub inst_tma: bool,
76 pub inst_tma_im2col: bool,
77 pub inst_wmma: bool,
78 pub use_grid_constants: bool,
79 pub static_meta_length: usize,
80 pub has_dynamic_meta: bool,
81 pub cluster_dim: Option<CubeDim>,
82}
83
84#[allow(clippy::too_many_arguments)]
85#[derive(Clone, Debug, Default)]
86pub struct CppCompiler<D: Dialect> {
87 barriers: Vec<BarrierOps<D>>,
88 compilation_options: CompilationOptions,
89 const_arrays: Vec<ConstArray<D>>,
90 ext_meta_positions: Vec<u32>,
91 cluster_dim: CubeDim,
92 extensions: Vec<D::Extension>,
93 flags: Flags,
94 items: HashSet<Item<D>>,
95 local_arrays: Vec<LocalArray<D>>,
96 metadata: cubecl_core::Metadata,
97 pipelines: Vec<PipelineOps<D>>,
98 shared_memories: Vec<SharedMemory<D>>,
99 source_loc: Option<SourceLoc>,
100 strategy: ExecutionMode,
101}
102
103impl<D: Dialect> Compiler for CppCompiler<D> {
104 type Representation = ComputeKernel<D>;
105 type CompilationOptions = CompilationOptions;
106
107 fn compile(
108 &mut self,
109 mut kernel: KernelDefinition,
110 compilation_options: &Self::CompilationOptions,
111 strategy: ExecutionMode,
112 ) -> Self::Representation {
113 self.compilation_options = compilation_options.clone();
114 self.strategy = strategy;
115
116 if !self.compilation_options.supports_clusters {
117 kernel.options.cluster_dim = None;
118 }
119 self.cluster_dim = kernel.options.cluster_dim.unwrap_or(CubeDim::new_single());
120
121 let ir = self.clone().compile_ir(kernel);
122 COUNTER_TMP_VAR.store(0, std::sync::atomic::Ordering::Relaxed);
123 ir
124 }
125
126 fn elem_size(&self, elem: gpu::Elem) -> usize {
127 elem.size()
128 }
129
130 fn extension(&self) -> &'static str {
131 "cpp"
132 }
133}
134
135impl<D: Dialect> CppCompiler<D> {
136 fn compile_ir(mut self, mut value: KernelDefinition) -> ComputeKernel<D> {
137 self.build_metadata(&value);
138
139 let instructions = self.compile_scope(&mut value.body);
140 let buffers = value
141 .buffers
142 .into_iter()
143 .map(|b| self.compile_binding(b))
144 .collect();
145 let scalars = value
146 .scalars
147 .into_iter()
148 .map(|binding| (self.compile_elem(binding.elem), binding.count))
149 .collect();
150
151 let flags = Flags {
153 indexes: D::builtin_rules(&self.flags.indexes),
154 inst_wmma: self.flags.inst_wmma,
155 op_pipeline: self.flags.op_pipeline,
156 op_barrier: self.flags.op_barrier,
157 elem_bf16: self.flags.elem_bf16,
158 elem_f16: self.flags.elem_f16,
159 inst_fast_math: value
160 .options
161 .fp_math_mode
162 .contains(FastMath::ReducedPrecision),
163 inst_tma: self.flags.inst_tma,
164 inst_tma_im2col: self.flags.inst_tma_im2col,
165 use_grid_constants: self.compilation_options.grid_constants,
166 has_dynamic_meta: self.metadata.static_len() > 0,
169 static_meta_length: self.metadata.static_len() as usize,
170 cluster_dim: value.options.cluster_dim,
171 };
172
173 let body = Body {
174 instructions,
175 shared_memories: self.shared_memories,
176 pipelines: self.pipelines,
177 barriers: self.barriers,
178 const_arrays: self.const_arrays,
179 local_arrays: self.local_arrays,
180 };
181
182 let mut cluster_dim = value.options.cluster_dim;
183 if !self.compilation_options.supports_clusters {
184 cluster_dim = None;
185 }
186
187 ComputeKernel {
188 tensor_maps: value.tensor_maps,
189 buffers,
190 scalars,
191 meta_static_len: self.metadata.static_len() as usize,
192 cube_dim: value.cube_dim,
193 body,
194 extensions: self.extensions,
195 flags,
196 items: self.items,
197 kernel_name: value.options.kernel_name,
198 cluster_dim,
199 }
200 }
201
202 fn build_metadata(&mut self, value: &KernelDefinition) {
203 let mut num_ext = 0;
204
205 let mut all_meta: Vec<_> = value
206 .buffers
207 .iter()
208 .map(|buf| (buf.id, buf.has_extended_meta))
209 .chain(value.tensor_maps.iter().map(|i| (*i, true)))
210 .collect();
211
212 all_meta.sort_by_key(|(id, _)| *id);
213
214 for (_, has_extended_meta) in &all_meta {
215 self.ext_meta_positions.push(num_ext);
216 if *has_extended_meta {
217 num_ext += 1;
218 }
219 }
220
221 let num_meta = all_meta.len();
222
223 self.metadata = cubecl_core::Metadata::new(num_meta as u32, num_ext);
224 }
225
226 pub(crate) fn ext_meta_position(&self, var: gpu::Variable) -> u32 {
227 let id = var.index().expect("Variable should have index");
228 self.ext_meta_positions[id as usize]
229 }
230
231 fn compile_scope(&mut self, scope: &mut gpu::Scope) -> Vec<Instruction<D>> {
232 let mut instructions = Vec::new();
233
234 let const_arrays = scope
235 .const_arrays
236 .drain(..)
237 .map(|(var, values)| ConstArray {
238 index: var.index().unwrap(),
239 item: self.compile_item(var.item),
240 size: values.len() as u32,
241 values: values
242 .into_iter()
243 .map(|val| self.compile_variable(val))
244 .collect(),
245 })
246 .collect::<Vec<_>>();
247 self.const_arrays.extend(const_arrays);
248
249 let processing = scope.process();
250
251 for var in processing.variables {
252 if let gpu::VariableKind::Slice { .. } = var.kind {
253 continue;
254 }
255 instructions.push(Instruction::DeclareVariable {
256 var: self.compile_variable(var),
257 });
258 }
259
260 processing
261 .instructions
262 .into_iter()
263 .for_each(|op| self.compile_instruction(&mut instructions, op, scope));
264
265 instructions
266 }
267
268 fn compile_instruction(
269 &mut self,
270 instructions: &mut Vec<Instruction<D>>,
271 instruction: gpu::Instruction,
272 scope: &mut gpu::Scope,
273 ) {
274 self.update_debug_loc(instructions, &instruction);
275 let out = instruction.out;
276 match instruction.operation {
277 gpu::Operation::Copy(variable) => {
278 instructions.push(Instruction::Assign(UnaryInstruction {
279 input: self.compile_variable(variable),
280 out: self.compile_variable(out.unwrap()),
281 }));
282 }
283 gpu::Operation::Arithmetic(op) => self.compile_arithmetic(op, out, instructions),
284 gpu::Operation::Comparison(op) => self.compile_comparison(op, out, instructions),
285 gpu::Operation::Bitwise(op) => self.compile_bitwise(op, out, instructions),
286 gpu::Operation::Operator(op) => self.compile_operator(op, out, instructions, scope),
287 gpu::Operation::Atomic(op) => self.compile_atomic(op, out, instructions),
288 gpu::Operation::Metadata(op) => instructions.push(self.compile_metadata(op, out)),
289 gpu::Operation::Branch(val) => self.compile_branch(instructions, val),
290 gpu::Operation::Synchronization(val) => match val {
291 gpu::Synchronization::SyncUnits => instructions.push(Instruction::SyncThreads),
292 gpu::Synchronization::SyncStorage => instructions.push(Instruction::SyncThreads),
293 gpu::Synchronization::SyncProxyShared => {
294 self.flags.inst_tma = true;
295 instructions.push(Instruction::ProxySharedFence)
296 }
297 },
298 gpu::Operation::Plane(op) => {
299 self.flags.indexes.plane_dim_checked = true;
300 let out = self.compile_variable(out.unwrap());
301 match op {
302 gpu::Plane::Sum(op) => {
303 let instruction = WarpInstruction::ReduceSum {
304 input: self.compile_variable(op.input),
305 out,
306 };
307 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
308 instructions.push(Instruction::Warp(instruction));
309 }
310 gpu::Plane::InclusiveSum(op) => {
311 self.flags.indexes.unit_pos_plane = true;
312 instructions.push(Instruction::Warp(WarpInstruction::InclusiveSum {
313 input: self.compile_variable(op.input),
314 out,
315 }))
316 }
317 gpu::Plane::InclusiveProd(op) => {
318 self.flags.indexes.unit_pos_plane = true;
319 instructions.push(Instruction::Warp(WarpInstruction::InclusiveProd {
320 input: self.compile_variable(op.input),
321 out,
322 }))
323 }
324 gpu::Plane::ExclusiveSum(op) => {
325 self.flags.indexes.unit_pos_plane = true;
326 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveSum {
327 input: self.compile_variable(op.input),
328 out,
329 }))
330 }
331 gpu::Plane::ExclusiveProd(op) => {
332 self.flags.indexes.unit_pos_plane = true;
333 instructions.push(Instruction::Warp(WarpInstruction::ExclusiveProd {
334 input: self.compile_variable(op.input),
335 out,
336 }))
337 }
338 gpu::Plane::Prod(op) => {
339 let instruction = WarpInstruction::ReduceProd {
340 input: self.compile_variable(op.input),
341 out,
342 };
343 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
344 instructions.push(Instruction::Warp(instruction))
345 }
346 gpu::Plane::Max(op) => {
347 let instruction = WarpInstruction::ReduceMax {
348 input: self.compile_variable(op.input),
349 out,
350 };
351 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
352 instructions.push(Instruction::Warp(instruction))
353 }
354 gpu::Plane::Min(op) => {
355 let instruction = WarpInstruction::ReduceMin {
356 input: self.compile_variable(op.input),
357 out,
358 };
359 D::register_warp_instruction_extension(&mut self.extensions, &instruction);
360 instructions.push(Instruction::Warp(instruction))
361 }
362 gpu::Plane::Elect => {
363 instructions.push(Instruction::Warp(WarpInstruction::Elect { out }))
364 }
365 gpu::Plane::All(op) => {
366 instructions.push(Instruction::Warp(WarpInstruction::All {
367 input: self.compile_variable(op.input),
368 out,
369 }))
370 }
371 gpu::Plane::Any(op) => {
372 instructions.push(Instruction::Warp(WarpInstruction::Any {
373 input: self.compile_variable(op.input),
374 out,
375 }))
376 }
377 gpu::Plane::Ballot(op) => {
378 instructions.push(Instruction::Warp(WarpInstruction::Ballot {
379 input: self.compile_variable(op.input),
380 out,
381 }))
382 }
383 gpu::Plane::Broadcast(op) => {
384 instructions.push(Instruction::Warp(WarpInstruction::Broadcast {
385 input: self.compile_variable(op.lhs),
386 id: self.compile_variable(op.rhs),
387 out,
388 }))
389 }
390 }
391 }
392 gpu::Operation::CoopMma(cmma) => instructions.push(self.compile_cmma(cmma, out)),
393 gpu::Operation::NonSemantic(debug) => match debug {
394 gpu::NonSemantic::Print {
395 format_string,
396 args,
397 } => instructions.push(Instruction::Printf {
398 format_string,
399 args: args
400 .into_iter()
401 .map(|arg| self.compile_variable(arg))
402 .collect(),
403 }),
404 gpu::NonSemantic::Comment { content } => {
405 instructions.push(Instruction::Comment { content })
406 }
407 _ => {}
409 },
410 gpu::Operation::Pipeline(pipeline_ops) => match pipeline_ops {
411 gpu::PipelineOps::MemCopyAsync {
412 pipeline,
413 source,
414 destination,
415 } => {
416 instructions.push(Instruction::Pipeline(
417 super::pipeline::PipelineOps::MemCopyAsync {
418 pipeline: self.compile_variable(pipeline),
419 source: self.compile_variable(source),
420 destination: self.compile_variable(destination),
421 },
422 ));
423 }
424 gpu::PipelineOps::ProducerAcquire { pipeline } => instructions.push(
425 Instruction::Pipeline(super::pipeline::PipelineOps::ProducerAcquire {
426 pipeline: self.compile_variable(pipeline),
427 }),
428 ),
429 gpu::PipelineOps::ProducerCommit { pipeline } => instructions.push(
430 Instruction::Pipeline(super::pipeline::PipelineOps::ProducerCommit {
431 pipeline: self.compile_variable(pipeline),
432 }),
433 ),
434
435 gpu::PipelineOps::ConsumerWait { pipeline } => instructions.push(
436 Instruction::Pipeline(super::pipeline::PipelineOps::ConsumerWait {
437 pipeline: self.compile_variable(pipeline),
438 }),
439 ),
440
441 gpu::PipelineOps::ConsumerRelease { pipeline } => instructions.push(
442 Instruction::Pipeline(super::pipeline::PipelineOps::ConsumerRelease {
443 pipeline: self.compile_variable(pipeline),
444 }),
445 ),
446 },
447 gpu::Operation::Barrier(barrier_ops) => match barrier_ops {
448 gpu::BarrierOps::Init {
449 barrier,
450 with_cta_fence,
451 } => {
452 let VariableKind::Barrier { level, .. } = barrier.kind else {
453 unreachable!()
454 };
455 let barrier = self.compile_variable(barrier);
456 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Init {
457 barrier,
458 level,
459 with_cta_fence,
460 }));
461 }
462 gpu::BarrierOps::MemCopyAsync { barrier, source } => {
463 let VariableKind::Barrier { level, .. } = barrier.kind else {
464 unreachable!()
465 };
466 instructions.push(Instruction::Barrier(
467 super::barrier::BarrierOps::MemCopyAsync {
468 barrier: self.compile_variable(barrier),
469 source: self.compile_variable(source),
470 destination: self.compile_variable(out.unwrap()),
471 level,
472 },
473 ));
474 }
475 gpu::BarrierOps::TmaLoad {
476 barrier,
477 tensor_map,
478 indices,
479 } => {
480 instructions.push(Instruction::Barrier(
481 super::barrier::BarrierOps::MemCopyAsyncTensorGlobalToShared {
482 barrier: self.compile_variable(barrier),
483 smem_buffer: self.compile_variable(out.unwrap()),
484 tensor_map: self.compile_variable(tensor_map),
485 indices: indices
486 .into_iter()
487 .map(|it| self.compile_variable(it))
488 .collect(),
489 },
490 ));
491 }
492 gpu::BarrierOps::TmaLoadIm2col {
493 barrier,
494 tensor_map,
495 indices,
496 offsets,
497 } => {
498 self.flags.inst_tma_im2col = true;
499 instructions.push(Instruction::Barrier(
500 super::barrier::BarrierOps::TmaLoadIm2col {
501 barrier: self.compile_variable(barrier),
502 smem_buffer: self.compile_variable(out.unwrap()),
503 tensor_map: self.compile_variable(tensor_map),
504 indices: indices
505 .into_iter()
506 .map(|it| self.compile_variable(it))
507 .collect(),
508 offsets: offsets
509 .into_iter()
510 .map(|it| self.compile_variable(it))
511 .collect(),
512 },
513 ));
514 }
515 gpu::BarrierOps::Arrive { barrier } => {
516 let VariableKind::Barrier { level, .. } = barrier.kind else {
517 unreachable!()
518 };
519 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Arrive {
520 barrier: self.compile_variable(barrier),
521 level,
522 }))
523 }
524 gpu::BarrierOps::ArriveTx {
525 barrier,
526 arrive_count_update,
527 transaction_count_update,
528 } => {
529 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ArriveTx {
530 barrier: self.compile_variable(barrier),
531 arrive_count_update: self.compile_variable(arrive_count_update),
532 transaction_count_update: self.compile_variable(transaction_count_update),
533 }))
534 }
535 gpu::BarrierOps::ExpectTx {
536 barrier,
537 transaction_count_update,
538 } => {
539 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::ExpectTx {
540 barrier: self.compile_variable(barrier),
541 transaction_count_update: self.compile_variable(transaction_count_update),
542 }))
543 }
544 gpu::BarrierOps::Wait { barrier } => {
545 let VariableKind::Barrier { level, .. } = barrier.kind else {
546 unreachable!()
547 };
548 instructions.push(Instruction::Barrier(super::barrier::BarrierOps::Wait {
549 barrier: self.compile_variable(barrier),
550 level,
551 }))
552 }
553 gpu::BarrierOps::ArriveAndWait { barrier } => {
554 let VariableKind::Barrier { level, .. } = barrier.kind else {
555 unreachable!()
556 };
557 instructions.push(Instruction::Barrier(
558 super::barrier::BarrierOps::ArriveAndWait {
559 barrier: self.compile_variable(barrier),
560 level,
561 },
562 ))
563 }
564 },
565 gpu::Operation::Tma(tma_ops) => {
566 self.flags.inst_tma = true;
567 match tma_ops {
568 gpu::TmaOps::TmaStore {
569 source,
570 coordinates,
571 } => {
572 instructions.push(Instruction::MemCopyAsyncTensorSharedToGlobal {
573 smem_buffer: self.compile_variable(source),
574 tensor_map: self.compile_variable(out.unwrap()),
575 indices: coordinates
576 .into_iter()
577 .map(|it| self.compile_variable(it))
578 .collect(),
579 });
580 }
581 gpu::TmaOps::CommitGroup => {
582 instructions.push(Instruction::BulkCommitGroup);
583 }
584 gpu::TmaOps::WaitGroup { max_pending } => {
585 instructions.push(Instruction::BulkWaitGroup { max_pending });
586 }
587 gpu::TmaOps::WaitGroupRead { max_pending } => {
588 instructions.push(Instruction::BulkWaitGroupRead { max_pending });
589 }
590 }
591 }
592 }
593 }
594
595 fn update_debug_loc(
596 &mut self,
597 instructions: &mut Vec<Instruction<D>>,
598 inst: &gpu::Instruction,
599 ) {
600 if !matches!(inst.operation, Operation::NonSemantic(_)) {
601 match &inst.source_loc {
602 Some(loc) if Some(loc) != self.source_loc.as_ref() => {
603 self.source_loc = Some(loc.clone());
604 instructions.push(Instruction::Line {
605 file: loc.source.file.clone(),
606 line: loc.line,
607 });
608 }
609 _ => {}
610 }
611 }
612 }
613
614 fn compile_cmma(&mut self, cmma: gpu::CoopMma, out: Option<gpu::Variable>) -> Instruction<D> {
615 let out = self.compile_variable(out.unwrap());
616 match cmma {
617 gpu::CoopMma::Fill { value } => Instruction::Wmma(WmmaInstruction::Fill {
618 frag: out,
619 value: self.compile_variable(value),
620 }),
621 gpu::CoopMma::Load {
622 value,
623 stride,
624 layout,
625 } => Instruction::Wmma(WmmaInstruction::Load {
626 frag: out,
627 value: self.compile_variable(value),
628 stride: self.compile_variable(stride),
629 layout: layout.and_then(|l| self.compile_matrix_layout(l)),
630 }),
631 gpu::CoopMma::Execute {
632 mat_a,
633 mat_b,
634 mat_c,
635 } => Instruction::Wmma(WmmaInstruction::Execute {
636 frag_a: self.compile_variable(mat_a),
637 frag_b: self.compile_variable(mat_b),
638 frag_c: self.compile_variable(mat_c),
639 frag_d: out,
640 warp_size: self.compilation_options.warp_size,
641 }),
642 gpu::CoopMma::Store {
643 mat,
644 stride,
645 layout,
646 } => {
647 self.flags.indexes.unit_pos = true;
648 self.flags.indexes.plane_index = true;
649 Instruction::Wmma(WmmaInstruction::Store {
650 output: out,
651 frag: self.compile_variable(mat),
652 stride: self.compile_variable(stride),
653 layout: self
654 .compile_matrix_layout(layout)
655 .expect("Layout required for store instruction"),
656 })
657 }
658 gpu::CoopMma::Cast { input } => Instruction::Wmma(WmmaInstruction::Cast {
659 input: self.compile_variable(input),
660 output: out,
661 }),
662 }
663 }
664
665 fn compile_metadata(
666 &mut self,
667 metadata: gpu::Metadata,
668 out: Option<gpu::Variable>,
669 ) -> Instruction<D> {
670 let out = out.unwrap();
671 match metadata {
672 gpu::Metadata::Stride { dim, var } => {
673 let position = self.ext_meta_position(var);
674 let offset = self.metadata.stride_offset_index(position);
675 Instruction::ExtendedMetadata {
676 info_offset: self.compile_variable(offset.into()),
677 dim: self.compile_variable(dim),
678 split_meta: self.compilation_options.grid_constants,
679 static_offset: self.metadata.static_len(),
680 out: self.compile_variable(out),
681 }
682 }
683 gpu::Metadata::Shape { dim, var } => {
684 let position = self.ext_meta_position(var);
685 let offset = self.metadata.shape_offset_index(position);
686 Instruction::ExtendedMetadata {
687 info_offset: self.compile_variable(offset.into()),
688 dim: self.compile_variable(dim),
689 split_meta: self.compilation_options.grid_constants,
690 static_offset: self.metadata.static_len(),
691 out: self.compile_variable(out),
692 }
693 }
694 gpu::Metadata::Rank { var } => {
695 let out = self.compile_variable(out);
696 let pos = self.ext_meta_position(var);
697 let offset = self.metadata.rank_index(pos);
698 super::Instruction::Metadata {
699 info_offset: self.compile_variable(offset.into()),
700 split_meta: self.compilation_options.grid_constants,
701 out,
702 }
703 }
704 gpu::Metadata::Length { var } => {
705 let input = self.compile_variable(var);
706 let out = self.compile_variable(out);
707
708 match input {
709 Variable::Slice { .. } => Instruction::SliceLength { input, out },
710 Variable::SharedMemory(_id, _item, length) => {
711 Instruction::ConstLength { length, out }
712 }
713 _ => {
714 let id = input.id().expect("Variable should have id");
715 let offset = self.metadata.len_index(id);
716 Instruction::Metadata {
717 info_offset: self.compile_variable(offset.into()),
718 split_meta: self.compilation_options.grid_constants,
719 out,
720 }
721 }
722 }
723 }
724 gpu::Metadata::BufferLength { var } => {
725 let input = self.compile_variable(var);
726 let out = self.compile_variable(out);
727
728 match input {
729 Variable::Slice { .. } => Instruction::SliceLength { input, out },
730 _ => {
731 let id = input.id().expect("Variable should have id");
732 let offset = self.metadata.buffer_len_index(id);
733 Instruction::Metadata {
734 info_offset: self.compile_variable(offset.into()),
735 split_meta: self.compilation_options.grid_constants,
736 out,
737 }
738 }
739 }
740 }
741 }
742 }
743
744 fn compile_branch(&mut self, instructions: &mut Vec<Instruction<D>>, branch: gpu::Branch) {
745 match branch {
746 gpu::Branch::If(mut op) => instructions.push(Instruction::If {
747 cond: self.compile_variable(op.cond),
748 instructions: self.compile_scope(&mut op.scope),
749 }),
750 gpu::Branch::IfElse(mut op) => instructions.push(Instruction::IfElse {
751 cond: self.compile_variable(op.cond),
752 instructions_if: self.compile_scope(&mut op.scope_if),
753 instructions_else: self.compile_scope(&mut op.scope_else),
754 }),
755 gpu::Branch::Switch(mut op) => instructions.push(Instruction::Switch {
756 value: self.compile_variable(op.value),
757 instructions_default: self.compile_scope(&mut op.scope_default),
758 instructions_cases: op
759 .cases
760 .into_iter()
761 .map(|(val, mut block)| {
762 (self.compile_variable(val), self.compile_scope(&mut block))
763 })
764 .collect(),
765 }),
766 gpu::Branch::Return => instructions.push(Instruction::Return),
767 gpu::Branch::Break => instructions.push(Instruction::Break),
768 gpu::Branch::RangeLoop(mut range_loop) => instructions.push(Instruction::RangeLoop {
769 i: self.compile_variable(range_loop.i),
770 start: self.compile_variable(range_loop.start),
771 end: self.compile_variable(range_loop.end),
772 step: range_loop.step.map(|it| self.compile_variable(it)),
773 inclusive: range_loop.inclusive,
774 instructions: self.compile_scope(&mut range_loop.scope),
775 }),
776 gpu::Branch::Loop(mut op) => instructions.push(Instruction::Loop {
777 instructions: self.compile_scope(&mut op.scope),
778 }),
779 };
780 }
781
782 fn compile_atomic(
783 &mut self,
784 value: gpu::AtomicOp,
785 out: Option<gpu::Variable>,
786 instructions: &mut Vec<Instruction<D>>,
787 ) {
788 let out = out.unwrap();
789 match value {
790 gpu::AtomicOp::Load(op) => {
791 instructions.push(Instruction::AtomicLoad(self.compile_unary(op, out)))
792 }
793 gpu::AtomicOp::Store(op) => {
794 instructions.push(Instruction::AtomicStore(self.compile_unary(op, out)))
795 }
796 gpu::AtomicOp::Swap(op) => {
797 instructions.push(Instruction::AtomicSwap(self.compile_binary(op, out)))
798 }
799 gpu::AtomicOp::Add(op) => {
800 instructions.push(Instruction::AtomicAdd(self.compile_binary(op, out)))
801 }
802 gpu::AtomicOp::Sub(op) => {
803 instructions.push(Instruction::AtomicSub(self.compile_binary(op, out)))
804 }
805 gpu::AtomicOp::Max(op) => {
806 instructions.push(Instruction::AtomicMax(self.compile_binary(op, out)))
807 }
808 gpu::AtomicOp::Min(op) => {
809 instructions.push(Instruction::AtomicMin(self.compile_binary(op, out)))
810 }
811 gpu::AtomicOp::And(op) => {
812 instructions.push(Instruction::AtomicAnd(self.compile_binary(op, out)))
813 }
814 gpu::AtomicOp::Or(op) => {
815 instructions.push(Instruction::AtomicOr(self.compile_binary(op, out)))
816 }
817 gpu::AtomicOp::Xor(op) => {
818 instructions.push(Instruction::AtomicXor(self.compile_binary(op, out)))
819 }
820 gpu::AtomicOp::CompareAndSwap(op) => instructions.push(Instruction::AtomicCAS {
821 input: self.compile_variable(op.input),
822 cmp: self.compile_variable(op.cmp),
823 val: self.compile_variable(op.val),
824 out: self.compile_variable(out),
825 }),
826 }
827 }
828
829 fn compile_arithmetic(
830 &mut self,
831 value: gpu::Arithmetic,
832 out: Option<gpu::Variable>,
833 instructions: &mut Vec<Instruction<D>>,
834 ) {
835 let out = out.unwrap();
836 match value {
837 gpu::Arithmetic::Add(op) => {
838 instructions.push(Instruction::Add(self.compile_binary(op, out)))
839 }
840 gpu::Arithmetic::Mul(op) => {
841 instructions.push(Instruction::Mul(self.compile_binary(op, out)))
842 }
843 gpu::Arithmetic::Div(op) => {
844 instructions.push(Instruction::Div(self.compile_binary(op, out)))
845 }
846 gpu::Arithmetic::Sub(op) => {
847 instructions.push(Instruction::Sub(self.compile_binary(op, out)))
848 }
849 gpu::Arithmetic::MulHi(op) => {
850 let instruction = Instruction::HiMul(self.compile_binary(op, out));
851 D::register_instruction_extension(&mut self.extensions, &instruction);
852 instructions.push(instruction)
853 }
854 gpu::Arithmetic::Modulo(op) => {
855 instructions.push(Instruction::Modulo(self.compile_binary(op, out)))
856 }
857 gpu::Arithmetic::Abs(op) => {
858 instructions.push(Instruction::Abs(self.compile_unary(op, out)))
859 }
860 gpu::Arithmetic::Exp(op) => {
861 instructions.push(Instruction::Exp(self.compile_unary(op, out)))
862 }
863 gpu::Arithmetic::Log(op) => {
864 instructions.push(Instruction::Log(self.compile_unary(op, out)))
865 }
866 gpu::Arithmetic::Log1p(op) => {
867 instructions.push(Instruction::Log1p(self.compile_unary(op, out)))
868 }
869 gpu::Arithmetic::Cos(op) => {
870 instructions.push(Instruction::Cos(self.compile_unary(op, out)))
871 }
872 gpu::Arithmetic::Sin(op) => {
873 instructions.push(Instruction::Sin(self.compile_unary(op, out)))
874 }
875 gpu::Arithmetic::Tanh(op) => {
876 let instruction = Instruction::Tanh(self.compile_unary(op, out));
877 D::register_instruction_extension(&mut self.extensions, &instruction);
878 instructions.push(instruction)
879 }
880 gpu::Arithmetic::Powf(op) => {
881 instructions.push(Instruction::Powf(self.compile_binary(op, out)))
882 }
883 gpu::Arithmetic::Sqrt(op) => {
884 instructions.push(Instruction::Sqrt(self.compile_unary(op, out)))
885 }
886 gpu::Arithmetic::Erf(op) => {
887 let instruction = Instruction::Erf(self.compile_unary(op, out));
888 D::register_instruction_extension(&mut self.extensions, &instruction);
889 instructions.push(instruction)
890 }
891 gpu::Arithmetic::Max(op) => {
892 let instruction = Instruction::Max(self.compile_binary(op, out));
893 D::register_instruction_extension(&mut self.extensions, &instruction);
894 instructions.push(instruction)
895 }
896 gpu::Arithmetic::Min(op) => {
897 let instruction = Instruction::Min(self.compile_binary(op, out));
898 D::register_instruction_extension(&mut self.extensions, &instruction);
899 instructions.push(instruction)
900 }
901 gpu::Arithmetic::Clamp(op) => instructions.push(Instruction::Clamp {
902 input: self.compile_variable(op.input),
903 min_value: self.compile_variable(op.min_value),
904 max_value: self.compile_variable(op.max_value),
905 out: self.compile_variable(out),
906 }),
907 gpu::Arithmetic::Recip(op) => {
908 let elem = op.input.item.elem();
909 let lhs = match elem {
910 gpu::Elem::Float(kind) => gpu::ConstantScalarValue::Float(1.0, kind),
911 gpu::Elem::Int(kind) => gpu::ConstantScalarValue::Int(1, kind),
912 gpu::Elem::UInt(kind) => gpu::ConstantScalarValue::UInt(1, kind),
913 gpu::Elem::Bool => gpu::ConstantScalarValue::Bool(true),
914 gpu::Elem::AtomicInt(_)
915 | gpu::Elem::AtomicUInt(_)
916 | gpu::Elem::AtomicFloat(_) => {
917 panic!("Cannot use recip with atomics")
918 }
919 };
920
921 instructions.push(Instruction::Div(BinaryInstruction {
922 lhs: Variable::ConstantScalar(lhs, self.compile_elem(elem)),
923 rhs: self.compile_variable(op.input),
924 out: self.compile_variable(out),
925 }))
926 }
927 gpu::Arithmetic::Round(op) => {
928 instructions.push(Instruction::Round(self.compile_unary(op, out)))
929 }
930 gpu::Arithmetic::Floor(op) => {
931 instructions.push(Instruction::Floor(self.compile_unary(op, out)))
932 }
933 gpu::Arithmetic::Ceil(op) => {
934 instructions.push(Instruction::Ceil(self.compile_unary(op, out)))
935 }
936 gpu::Arithmetic::Remainder(op) => {
937 instructions.push(Instruction::Remainder(self.compile_binary(op, out)))
938 }
939 gpu::Arithmetic::Fma(op) => instructions.push(Instruction::Fma {
940 a: self.compile_variable(op.a),
941 b: self.compile_variable(op.b),
942 c: self.compile_variable(op.c),
943 out: self.compile_variable(out),
944 }),
945 gpu::Arithmetic::Neg(op) => {
946 instructions.push(Instruction::Neg(self.compile_unary(op, out)))
947 }
948 gpu::Arithmetic::Normalize(op) => {
949 instructions.push(Instruction::Normalize(self.compile_unary(op, out)))
950 }
951 gpu::Arithmetic::Magnitude(op) => {
952 instructions.push(Instruction::Magnitude(self.compile_unary(op, out)))
953 }
954 gpu::Arithmetic::Dot(op) => {
955 instructions.push(Instruction::Dot(self.compile_binary(op, out)))
956 }
957 };
958 }
959
960 fn compile_comparison(
961 &mut self,
962 value: gpu::Comparison,
963 out: Option<gpu::Variable>,
964 instructions: &mut Vec<Instruction<D>>,
965 ) {
966 let out = out.unwrap();
967 match value {
968 gpu::Comparison::Equal(op) => {
969 instructions.push(Instruction::Equal(self.compile_binary(op, out)))
970 }
971 gpu::Comparison::Lower(op) => {
972 instructions.push(Instruction::Lower(self.compile_binary(op, out)))
973 }
974 gpu::Comparison::Greater(op) => {
975 instructions.push(Instruction::Greater(self.compile_binary(op, out)))
976 }
977 gpu::Comparison::LowerEqual(op) => {
978 instructions.push(Instruction::LowerEqual(self.compile_binary(op, out)))
979 }
980 gpu::Comparison::GreaterEqual(op) => {
981 instructions.push(Instruction::GreaterEqual(self.compile_binary(op, out)))
982 }
983 gpu::Comparison::NotEqual(op) => {
984 instructions.push(Instruction::NotEqual(self.compile_binary(op, out)))
985 }
986 };
987 }
988
989 fn compile_bitwise(
990 &mut self,
991 value: gpu::Bitwise,
992 out: Option<gpu::Variable>,
993 instructions: &mut Vec<Instruction<D>>,
994 ) {
995 let out = out.unwrap();
996 match value {
997 gpu::Bitwise::BitwiseOr(op) => {
998 instructions.push(Instruction::BitwiseOr(self.compile_binary(op, out)))
999 }
1000 gpu::Bitwise::BitwiseAnd(op) => {
1001 instructions.push(Instruction::BitwiseAnd(self.compile_binary(op, out)))
1002 }
1003 gpu::Bitwise::BitwiseXor(op) => {
1004 instructions.push(Instruction::BitwiseXor(self.compile_binary(op, out)))
1005 }
1006 gpu::Bitwise::CountOnes(op) => {
1007 instructions.push(Instruction::CountBits(self.compile_unary(op, out)))
1008 }
1009 gpu::Bitwise::ReverseBits(op) => {
1010 instructions.push(Instruction::ReverseBits(self.compile_unary(op, out)))
1011 }
1012 gpu::Bitwise::ShiftLeft(op) => {
1013 instructions.push(Instruction::ShiftLeft(self.compile_binary(op, out)))
1014 }
1015 gpu::Bitwise::ShiftRight(op) => {
1016 instructions.push(Instruction::ShiftRight(self.compile_binary(op, out)))
1017 }
1018 gpu::Bitwise::BitwiseNot(op) => {
1019 instructions.push(Instruction::BitwiseNot(self.compile_unary(op, out)))
1020 }
1021 gpu::Bitwise::LeadingZeros(op) => {
1022 instructions.push(Instruction::LeadingZeros(self.compile_unary(op, out)))
1023 }
1024 gpu::Bitwise::FindFirstSet(op) => {
1025 let instruction = Instruction::FindFirstSet(self.compile_unary(op, out));
1026 D::register_instruction_extension(&mut self.extensions, &instruction);
1027 instructions.push(instruction)
1028 }
1029 };
1030 }
1031
1032 fn compile_operator(
1033 &mut self,
1034 value: gpu::Operator,
1035 out: Option<gpu::Variable>,
1036 instructions: &mut Vec<Instruction<D>>,
1037 scope: &mut gpu::Scope,
1038 ) {
1039 let out = out.unwrap();
1040 match value {
1041 gpu::Operator::Slice(op) => {
1042 if matches!(self.strategy, ExecutionMode::Checked) && op.input.has_length() {
1043 let input = op.input;
1044 let input_len = *scope
1045 .create_local_mut(gpu::Item::new(gpu::Elem::UInt(gpu::UIntKind::U32)));
1046 instructions.extend(self.compile_scope(scope));
1047
1048 let length = match input.has_buffer_length() {
1049 true => gpu::Metadata::BufferLength { var: input },
1050 false => gpu::Metadata::Length { var: input },
1051 };
1052
1053 instructions.push(self.compile_metadata(length, Some(input_len)));
1054 instructions.push(Instruction::CheckedSlice {
1055 input: self.compile_variable(op.input),
1056 start: self.compile_variable(op.start),
1057 end: self.compile_variable(op.end),
1058 out: self.compile_variable(out),
1059 len: self.compile_variable(input_len),
1060 });
1061 } else {
1062 instructions.push(Instruction::Slice {
1063 input: self.compile_variable(op.input),
1064 start: self.compile_variable(op.start),
1065 end: self.compile_variable(op.end),
1066 out: self.compile_variable(out),
1067 })
1068 }
1069 }
1070 gpu::Operator::ReinterpretSlice(op) => {
1071 instructions.push(Instruction::ReinterpretSlice {
1073 input: self.compile_variable(op.input),
1074 line_size: op.line_size,
1075 out: self.compile_variable(out),
1076 })
1077 }
1078 gpu::Operator::Index(op) => {
1079 if matches!(self.strategy, ExecutionMode::Checked)
1080 && op.lhs.has_length()
1081 && !out.elem().is_atomic()
1082 {
1083 let list = ExpandElement::Plain(op.lhs);
1084 let index = ExpandElement::Plain(op.rhs);
1085 scope.register_elem::<FloatExpand<0>>(op.lhs.elem());
1086
1087 let mut child_scope = scope.child();
1088 let input = read_tensor_checked::expand::<Line<FloatExpand<0>>>(
1089 &mut child_scope,
1090 list.into(),
1091 index.into(),
1092 );
1093
1094 for inst in self.compile_scope(&mut child_scope) {
1095 instructions.push(inst);
1096 }
1097
1098 instructions.push(Instruction::Assign(UnaryInstruction {
1099 input: self.compile_variable(input.into_variable()),
1100 out: self.compile_variable(out),
1101 }))
1102 } else {
1103 instructions.push(Instruction::Index(self.compile_binary(op, out)));
1104 }
1105 }
1106 gpu::Operator::UncheckedIndex(op) => {
1107 instructions.push(Instruction::Index(self.compile_binary(op, out)))
1108 }
1109 gpu::Operator::IndexAssign(op) => {
1110 if let ExecutionMode::Checked = self.strategy {
1111 if out.has_length() {
1112 expand_checked_index_assign(scope, op.lhs, op.rhs, out);
1113 instructions.extend(self.compile_scope(scope));
1114 return;
1115 }
1116 };
1117 instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)));
1118 }
1119 gpu::Operator::UncheckedIndexAssign(op) => {
1120 instructions.push(Instruction::IndexAssign(self.compile_binary(op, out)))
1121 }
1122 gpu::Operator::And(op) => {
1123 instructions.push(Instruction::And(self.compile_binary(op, out)))
1124 }
1125 gpu::Operator::Or(op) => {
1126 instructions.push(Instruction::Or(self.compile_binary(op, out)))
1127 }
1128 gpu::Operator::Not(op) => {
1129 instructions.push(Instruction::Not(self.compile_unary(op, out)))
1130 }
1131 gpu::Operator::InitLine(op) => instructions.push(Instruction::VecInit {
1132 inputs: op
1133 .inputs
1134 .into_iter()
1135 .map(|it| self.compile_variable(it))
1136 .collect(),
1137 out: self.compile_variable(out),
1138 }),
1139 gpu::Operator::CopyMemory(op) => instructions.push(Instruction::Copy {
1140 input: self.compile_variable(op.input),
1141 in_index: self.compile_variable(op.in_index),
1142 out: self.compile_variable(out),
1143 out_index: self.compile_variable(op.out_index),
1144 }),
1145 gpu::Operator::CopyMemoryBulk(op) => instructions.push(Instruction::CopyBulk {
1146 input: self.compile_variable(op.input),
1147 in_index: self.compile_variable(op.in_index),
1148 out: self.compile_variable(out),
1149 out_index: self.compile_variable(op.out_index),
1150 len: op.len.as_const().unwrap().as_u32(),
1151 }),
1152 gpu::Operator::Select(op) => instructions.push(Instruction::Select {
1153 cond: self.compile_variable(op.cond),
1154 then: self.compile_variable(op.then),
1155 or_else: self.compile_variable(op.or_else),
1156 out: self.compile_variable(out),
1157 }),
1158 gpu::Operator::Cast(op) => {
1159 instructions.push(Instruction::Assign(self.compile_unary(op, out)))
1160 }
1161 gpu::Operator::Reinterpret(op) => {
1162 instructions.push(Instruction::Bitcast(self.compile_unary(op, out)))
1163 }
1164 };
1165 }
1166
1167 fn compile_binary(
1168 &mut self,
1169 value: gpu::BinaryOperator,
1170 out: gpu::Variable,
1171 ) -> BinaryInstruction<D> {
1172 BinaryInstruction {
1173 lhs: self.compile_variable(value.lhs),
1174 rhs: self.compile_variable(value.rhs),
1175 out: self.compile_variable(out),
1176 }
1177 }
1178
1179 fn compile_unary(
1180 &mut self,
1181 value: gpu::UnaryOperator,
1182 out: gpu::Variable,
1183 ) -> UnaryInstruction<D> {
1184 UnaryInstruction {
1185 input: self.compile_variable(value.input),
1186 out: self.compile_variable(out),
1187 }
1188 }
1189
1190 fn compile_variable(&mut self, value: gpu::Variable) -> Variable<D> {
1191 let item = value.item;
1192 match value.kind {
1193 gpu::VariableKind::GlobalInputArray(id) => {
1194 Variable::GlobalInputArray(id, self.compile_item(item))
1195 }
1196 gpu::VariableKind::GlobalScalar(id) => Variable::GlobalScalar {
1197 id,
1198 elem: self.compile_elem(item.elem),
1199 in_struct: self.compilation_options.grid_constants,
1200 },
1201 gpu::VariableKind::TensorMap(id) => {
1202 self.flags.inst_tma = true;
1203 Variable::TensorMap(id)
1204 }
1205 gpu::VariableKind::LocalMut { id } => Variable::LocalMut {
1206 id,
1207 item: self.compile_item(item),
1208 },
1209 gpu::VariableKind::Versioned { id, .. } => Variable::LocalMut {
1210 id,
1211 item: self.compile_item(item),
1212 },
1213 gpu::VariableKind::LocalConst { id } => Variable::LocalConst {
1214 id,
1215 item: self.compile_item(item),
1216 },
1217 gpu::VariableKind::Slice { id } => Variable::Slice {
1218 id,
1219 item: self.compile_item(item),
1220 },
1221 gpu::VariableKind::GlobalOutputArray(id) => {
1222 Variable::GlobalOutputArray(id, self.compile_item(item))
1223 }
1224 gpu::VariableKind::ConstantScalar(value) => {
1225 Variable::ConstantScalar(value, self.compile_elem(value.elem()))
1226 }
1227 gpu::VariableKind::SharedMemory {
1228 id,
1229 length,
1230 alignment,
1231 } => {
1232 let item = self.compile_item(item);
1233 if !self.shared_memories.iter().any(|s| s.index == id) {
1234 self.shared_memories
1235 .push(SharedMemory::new(id, item, length, alignment));
1236 }
1237 Variable::SharedMemory(id, item, length)
1238 }
1239 gpu::VariableKind::ConstantArray { id, length } => {
1240 let item = self.compile_item(item);
1241 Variable::ConstantArray(id, item, length)
1242 }
1243 gpu::VariableKind::Builtin(builtin) => match builtin {
1244 gpu::Builtin::AbsolutePos => {
1245 self.flags.indexes.absolute_pos = true;
1246 Variable::AbsolutePos
1247 }
1248 gpu::Builtin::CubePosCluster if self.compilation_options.supports_clusters => {
1249 self.flags.indexes.cluster_pos = true;
1250 Variable::ClusterRank
1251 }
1252 gpu::Builtin::CubePosClusterX if self.compilation_options.supports_clusters => {
1253 self.flags.indexes.cluster_pos = true;
1254 Variable::ClusterIndexX
1255 }
1256 gpu::Builtin::CubePosClusterY if self.compilation_options.supports_clusters => {
1257 self.flags.indexes.cluster_pos = true;
1258 Variable::ClusterIndexY
1259 }
1260 gpu::Builtin::CubePosClusterZ if self.compilation_options.supports_clusters => {
1261 self.flags.indexes.cluster_pos = true;
1262 Variable::ClusterIndexZ
1263 }
1264 gpu::Builtin::CubePosCluster
1267 | gpu::Builtin::CubePosClusterX
1268 | gpu::Builtin::CubePosClusterY
1269 | gpu::Builtin::CubePosClusterZ => const_u32(0),
1270 gpu::Builtin::AbsolutePosX => {
1271 self.flags.indexes.absolute_pos_tuple = true;
1272 Variable::AbsolutePosX
1273 }
1274 gpu::Builtin::AbsolutePosY => {
1275 self.flags.indexes.absolute_pos_tuple = true;
1276 Variable::AbsolutePosY
1277 }
1278 gpu::Builtin::AbsolutePosZ => {
1279 self.flags.indexes.absolute_pos_tuple = true;
1280 Variable::AbsolutePosZ
1281 }
1282 gpu::Builtin::CubeDim => {
1283 self.flags.indexes.cube_dim = true;
1284 Variable::CubeDim
1285 }
1286 gpu::Builtin::CubeDimX => {
1287 self.flags.indexes.cube_dim_tuple = true;
1288 Variable::CubeDimX
1289 }
1290 gpu::Builtin::CubeDimY => {
1291 self.flags.indexes.cube_dim_tuple = true;
1292 Variable::CubeDimY
1293 }
1294 gpu::Builtin::CubeDimZ => {
1295 self.flags.indexes.cube_dim_tuple = true;
1296 Variable::CubeDimZ
1297 }
1298 gpu::Builtin::CubeClusterDim => const_u32(self.cluster_dim.num_elems()),
1299 gpu::Builtin::CubeClusterDimX => const_u32(self.cluster_dim.x),
1300 gpu::Builtin::CubeClusterDimY => const_u32(self.cluster_dim.y),
1301 gpu::Builtin::CubeClusterDimZ => const_u32(self.cluster_dim.z),
1302 gpu::Builtin::CubePos => {
1303 self.flags.indexes.cube_pos = true;
1304 Variable::CubePos
1305 }
1306 gpu::Builtin::CubePosX => {
1307 self.flags.indexes.cube_pos_tuple = true;
1308 Variable::CubePosX
1309 }
1310 gpu::Builtin::CubePosY => {
1311 self.flags.indexes.cube_pos_tuple = true;
1312 Variable::CubePosY
1313 }
1314 gpu::Builtin::CubePosZ => {
1315 self.flags.indexes.cube_pos_tuple = true;
1316 Variable::CubePosZ
1317 }
1318 gpu::Builtin::CubeCount => {
1319 self.flags.indexes.cube_count = true;
1320 Variable::CubeCount
1321 }
1322 gpu::Builtin::CubeCountX => {
1323 self.flags.indexes.cube_count_tuple = true;
1324 Variable::CubeCountX
1325 }
1326 gpu::Builtin::CubeCountY => {
1327 self.flags.indexes.cube_count_tuple = true;
1328 Variable::CubeCountY
1329 }
1330 gpu::Builtin::CubeCountZ => {
1331 self.flags.indexes.cube_count_tuple = true;
1332 Variable::CubeCountZ
1333 }
1334 gpu::Builtin::UnitPos => {
1335 self.flags.indexes.unit_pos = true;
1336 Variable::UnitPos
1337 }
1338 gpu::Builtin::UnitPosX => {
1339 self.flags.indexes.unit_pos_tuple = true;
1340 Variable::UnitPosX
1341 }
1342 gpu::Builtin::UnitPosY => {
1343 self.flags.indexes.unit_pos_tuple = true;
1344 Variable::UnitPosY
1345 }
1346 gpu::Builtin::UnitPosZ => {
1347 self.flags.indexes.unit_pos_tuple = true;
1348 Variable::UnitPosZ
1349 }
1350 gpu::Builtin::PlaneDim => {
1351 self.flags.indexes.plane_dim = true;
1352 Variable::PlaneDim
1353 }
1354 gpu::Builtin::UnitPosPlane => {
1355 self.flags.indexes.unit_pos_plane = true;
1356 Variable::UnitPosPlane
1357 }
1358 },
1359 gpu::VariableKind::LocalArray { id, length } => {
1360 let item = self.compile_item(item);
1361 if !self.local_arrays.iter().any(|s| s.index == id) {
1362 self.local_arrays.push(LocalArray::new(id, item, length));
1363 }
1364 Variable::LocalArray(id, item, length)
1365 }
1366 gpu::VariableKind::Matrix { id, mat } => {
1367 self.flags.inst_wmma = true;
1368 Variable::WmmaFragment {
1369 id,
1370 frag: self.compile_matrix(mat),
1371 }
1372 }
1373 gpu::VariableKind::Pipeline {
1374 id,
1375 item,
1376 num_stages,
1377 } => {
1378 self.flags.op_pipeline = true;
1379 let pipeline = Variable::Pipeline {
1380 id,
1381 item: self.compile_item(item),
1382 };
1383 if !self.pipelines.iter().any(|s| s.pipeline_id() == id) {
1384 self.pipelines.push(PipelineOps::Init {
1385 pipeline,
1386 num_stages,
1387 });
1388 }
1389 pipeline
1390 }
1391 gpu::VariableKind::Barrier { id, item, level } => {
1392 self.flags.op_barrier = true;
1393 match level {
1394 gpu::BarrierLevel::CubeCoop(_) | gpu::BarrierLevel::CubeManual(_) => {
1395 self.flags.indexes.cube_dim = true;
1396 self.flags.indexes.unit_pos = true;
1397 }
1398 _ => {}
1399 }
1400 Variable::Barrier {
1401 id,
1402 item: self.compile_item(item),
1403 level,
1404 }
1405 }
1406 }
1407 }
1408
1409 fn compile_matrix(&mut self, matrix: gpu::Matrix) -> Fragment<D> {
1410 Fragment {
1411 ident: self.compile_matrix_ident(matrix.ident),
1412 m: matrix.m,
1413 n: matrix.n,
1414 k: matrix.k,
1415 elem: self.compile_elem(matrix.elem),
1416 layout: self.compile_matrix_layout(matrix.layout),
1417 }
1418 }
1419
1420 fn compile_matrix_ident(&mut self, ident: gpu::MatrixIdent) -> FragmentIdent<D> {
1421 match ident {
1422 gpu::MatrixIdent::A => FragmentIdent::A,
1423 gpu::MatrixIdent::B => FragmentIdent::B,
1424 gpu::MatrixIdent::Accumulator => FragmentIdent::Accumulator,
1425 }
1426 }
1427
1428 fn compile_matrix_layout(&mut self, layout: gpu::MatrixLayout) -> Option<FragmentLayout<D>> {
1429 match layout {
1430 gpu::MatrixLayout::ColMajor => Some(FragmentLayout::ColMajor),
1431 gpu::MatrixLayout::RowMajor => Some(FragmentLayout::RowMajor),
1432 gpu::MatrixLayout::Undefined => None,
1433 }
1434 }
1435
1436 fn compile_binding(&mut self, binding: cubecl_core::compute::Binding) -> Binding<D> {
1437 Binding {
1438 id: binding.id,
1439 item: self.compile_item(binding.item),
1440 location: binding.location,
1441 size: binding.size,
1442 vis: binding.visibility,
1443 }
1444 }
1445
1446 fn compile_item(&mut self, item: gpu::Item) -> Item<D> {
1447 let item = Item::new(
1448 self.compile_elem(item.elem),
1449 item.vectorization.map(NonZero::get).unwrap_or(1).into(),
1450 false,
1451 );
1452 if item.elem != super::Elem::TF32 {
1453 self.items.insert(item);
1454 self.items.insert(item.optimized());
1455 } else {
1456 let mut item = item;
1458 item.elem = super::Elem::F32;
1459 self.items.insert(item);
1460 }
1461
1462 item
1463 }
1464
1465 fn compile_elem(&mut self, value: gpu::Elem) -> Elem<D> {
1466 match value {
1467 gpu::Elem::Float(kind) => match kind {
1468 gpu::FloatKind::F16 => {
1469 self.flags.elem_f16 = true;
1470 Elem::F16
1471 }
1472 gpu::FloatKind::BF16 => {
1473 self.flags.elem_bf16 = true;
1474 Elem::BF16
1475 }
1476 gpu::FloatKind::TF32 => Elem::TF32,
1477 gpu::FloatKind::Flex32 => Elem::F32,
1478 gpu::FloatKind::F32 => Elem::F32,
1479 gpu::FloatKind::F64 => Elem::F64,
1480 },
1481 gpu::Elem::AtomicFloat(kind) => match kind {
1482 gpu::FloatKind::F16 => Elem::Atomic(AtomicKind::F16),
1483 gpu::FloatKind::BF16 => Elem::Atomic(AtomicKind::BF16),
1484 gpu::FloatKind::F32 => Elem::Atomic(AtomicKind::F32),
1485 gpu::FloatKind::F64 => Elem::Atomic(AtomicKind::F64),
1486 kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1487 },
1488 gpu::Elem::Int(kind) => match kind {
1489 gpu::IntKind::I8 => Elem::I8,
1490 gpu::IntKind::I16 => Elem::I16,
1491 gpu::IntKind::I32 => Elem::I32,
1492 gpu::IntKind::I64 => Elem::I64,
1493 },
1494 gpu::Elem::AtomicInt(kind) => match kind {
1495 gpu::IntKind::I32 => Elem::Atomic(AtomicKind::I32),
1496 gpu::IntKind::I64 => Elem::Atomic(AtomicKind::I64),
1497 kind => panic!("atomic<{kind:?}> isn't supported yet"),
1498 },
1499 gpu::Elem::UInt(kind) => match kind {
1500 gpu::UIntKind::U8 => Elem::U8,
1501 gpu::UIntKind::U16 => Elem::U16,
1502 gpu::UIntKind::U32 => Elem::U32,
1503 gpu::UIntKind::U64 => Elem::U64,
1504 },
1505 gpu::Elem::AtomicUInt(kind) => match kind {
1506 gpu::UIntKind::U32 => Elem::Atomic(AtomicKind::U32),
1507 gpu::UIntKind::U64 => Elem::Atomic(AtomicKind::U64),
1508 kind => unimplemented!("atomic<{kind:?}> not yet supported"),
1509 },
1510 gpu::Elem::Bool => Elem::Bool,
1511 }
1512 }
1513}
1514
1515fn const_u32<D: Dialect>(value: u32) -> Variable<D> {
1516 Variable::ConstantScalar(
1517 gpu::ConstantScalarValue::UInt(value as u64, UIntKind::U32),
1518 Elem::U32,
1519 )
1520}
1521
1522pub fn register_supported_types(props: &mut DeviceProperties<Feature>) {
1523 let supported_types = [
1524 gpu::Elem::UInt(gpu::UIntKind::U8),
1525 gpu::Elem::UInt(gpu::UIntKind::U16),
1526 gpu::Elem::UInt(gpu::UIntKind::U32),
1527 gpu::Elem::UInt(gpu::UIntKind::U64),
1528 gpu::Elem::Int(gpu::IntKind::I8),
1529 gpu::Elem::Int(gpu::IntKind::I16),
1530 gpu::Elem::Int(gpu::IntKind::I32),
1531 gpu::Elem::Int(gpu::IntKind::I64),
1532 gpu::Elem::AtomicInt(gpu::IntKind::I32),
1533 gpu::Elem::AtomicInt(gpu::IntKind::I64),
1534 gpu::Elem::AtomicUInt(gpu::UIntKind::U32),
1535 gpu::Elem::AtomicUInt(gpu::UIntKind::U64),
1536 gpu::Elem::Float(gpu::FloatKind::BF16),
1537 gpu::Elem::Float(gpu::FloatKind::F16),
1538 gpu::Elem::Float(gpu::FloatKind::F32),
1539 gpu::Elem::Float(gpu::FloatKind::Flex32),
1540 gpu::Elem::AtomicFloat(gpu::FloatKind::F32),
1541 gpu::Elem::Bool,
1544 ];
1545
1546 for ty in supported_types {
1547 props.register_feature(Feature::Type(ty));
1548 }
1549}