1use cubecl_common::{ExecutionMode, backtrace::BackTrace};
2use cubecl_core::{
3 Metadata, WgpuCompilationOptions,
4 ir::{self as core, InstructionModes},
5 post_processing::{
6 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
7 unroll::UnrollProcessor,
8 },
9 prelude::{FastMath, KernelDefinition},
10};
11use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
12use cubecl_runtime::{
13 EnumSet,
14 compiler::CompilationError,
15 config::{GlobalConfig, compilation::CompilationLogLevel},
16};
17use std::{
18 collections::HashSet,
19 fmt::Debug,
20 mem::take,
21 ops::{Deref, DerefMut},
22 rc::Rc,
23};
24
25use cubecl_core::Compiler;
26use rspirv::{
27 dr::{Builder, InsertPoint, Instruction, Module, Operand},
28 spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
29};
30
31use crate::{
32 SpirvKernel,
33 debug::DebugInfo,
34 item::Item,
35 lookups::LookupTables,
36 target::{GLCompute, SpirvTarget},
37 transformers::{BitwiseTransform, ErfTransform},
38};
39
40pub const MAX_VECTORIZATION: u32 = 4;
41
42pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
43 pub target: Target,
44 pub(crate) builder: Builder,
45
46 pub mode: ExecutionMode,
47 pub debug_symbols: bool,
48 global_invocation_id: Word,
49 num_workgroups: Word,
50 pub setup_block: usize,
51 pub opt: Rc<Optimizer>,
52 pub uniformity: Rc<Uniformity>,
53 pub shared_liveness: Rc<SharedLiveness>,
54 pub current_block: Option<NodeIndex>,
55 pub visited: HashSet<NodeIndex>,
56
57 pub capabilities: HashSet<Capability>,
58 pub state: LookupTables,
59 pub ext_meta_pos: Vec<u32>,
60 pub metadata: Metadata,
61 pub debug_info: Option<DebugInfo>,
62 pub compilation_options: WgpuCompilationOptions,
63}
64
65unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
66unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
67
68impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
69 fn clone(&self) -> Self {
70 Self {
71 target: self.target.clone(),
72 builder: Builder::new_from_module(self.module_ref().clone()),
73 mode: self.mode,
74 global_invocation_id: self.global_invocation_id,
75 num_workgroups: self.num_workgroups,
76 setup_block: self.setup_block,
77 opt: self.opt.clone(),
78 uniformity: self.uniformity.clone(),
79 shared_liveness: self.shared_liveness.clone(),
80 current_block: self.current_block,
81 capabilities: self.capabilities.clone(),
82 state: self.state.clone(),
83 debug_symbols: self.debug_symbols,
84 visited: self.visited.clone(),
85 metadata: self.metadata.clone(),
86 debug_info: self.debug_info.clone(),
87 ext_meta_pos: self.ext_meta_pos.clone(),
88 compilation_options: self.compilation_options.clone(),
89 }
90 }
91}
92
93fn debug_symbols_activated() -> bool {
94 matches!(
95 GlobalConfig::get().compilation.logger.level,
96 CompilationLogLevel::Full
97 )
98}
99
100impl<T: SpirvTarget> Default for SpirvCompiler<T> {
101 fn default() -> Self {
102 Self {
103 target: Default::default(),
104 builder: Builder::new(),
105 mode: Default::default(),
106 global_invocation_id: Default::default(),
107 num_workgroups: Default::default(),
108 capabilities: Default::default(),
109 state: Default::default(),
110 setup_block: Default::default(),
111 opt: Default::default(),
112 uniformity: Default::default(),
113 shared_liveness: Default::default(),
114 current_block: Default::default(),
115 debug_symbols: debug_symbols_activated(),
116 visited: Default::default(),
117 metadata: Default::default(),
118 debug_info: Default::default(),
119 ext_meta_pos: Default::default(),
120 compilation_options: Default::default(),
121 }
122 }
123}
124
125impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
126 type Target = Builder;
127
128 fn deref(&self) -> &Self::Target {
129 &self.builder
130 }
131}
132
133impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
134 fn deref_mut(&mut self) -> &mut Self::Target {
135 &mut self.builder
136 }
137}
138
139impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
140 type Representation = SpirvKernel;
141 type CompilationOptions = WgpuCompilationOptions;
142
143 fn compile(
144 &mut self,
145 mut value: KernelDefinition,
146 compilation_options: &Self::CompilationOptions,
147 mode: ExecutionMode,
148 ) -> Result<Self::Representation, CompilationError> {
149 let errors = value.body.pop_errors();
150 if !errors.is_empty() {
151 let mut reason = "Can't compile spirv kernel".to_string();
152 for error in errors {
153 reason += error.as_str();
154 reason += "\n";
155 }
156
157 return Err(CompilationError::Validation {
158 reason,
159 backtrace: BackTrace::capture(),
160 });
161 }
162
163 let bindings = value.buffers.clone();
164 let scalars = value
165 .scalars
166 .iter()
167 .map(|s| (self.compile_storage_type(s.ty), s.count))
168 .collect();
169 let mut ext_meta_pos = Vec::new();
170 let mut num_ext = 0;
171
172 let mut all_meta: Vec<_> = value
173 .buffers
174 .iter()
175 .chain(value.tensor_maps.iter())
176 .map(|buf| (buf.id, buf.has_extended_meta))
177 .collect();
178 all_meta.sort_by_key(|(id, _)| *id);
179
180 let num_meta = all_meta.len();
181
182 for (_, has_extended_meta) in all_meta.iter() {
183 ext_meta_pos.push(num_ext);
184 if *has_extended_meta {
185 num_ext += 1;
186 }
187 }
188
189 self.mode = mode;
190 self.metadata = Metadata::new(num_meta as u32, num_ext);
191 self.compilation_options = compilation_options.clone();
192 self.ext_meta_pos = ext_meta_pos;
193
194 let (module, optimizer) = self.compile_kernel(value);
195 Ok(SpirvKernel {
196 module,
197 optimizer,
198 bindings,
199 scalars,
200 has_metadata: self.metadata.static_len() > 0,
201 })
202 }
203
204 fn elem_size(&self, elem: core::ElemType) -> usize {
205 elem.size()
206 }
207
208 fn extension(&self) -> &'static str {
209 "spv"
210 }
211}
212
213impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
214 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215 write!(f, "spirv<{:?}>", self.target)
216 }
217}
218
219impl<Target: SpirvTarget> SpirvCompiler<Target> {
220 pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
221 let options = kernel.options.clone();
222
223 self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
224
225 self.set_version(1, 6);
226
227 let mut target = self.target.clone();
228
229 let mut opt = OptimizerBuilder::default()
230 .with_transformer(ErfTransform)
231 .with_transformer(BitwiseTransform)
232 .with_processor(CheckedIoProcessor::new(self.mode))
233 .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
234 .with_processor(SaturatingArithmeticProcessor::new(true))
235 .optimize(kernel.body.clone(), kernel.cube_dim);
236
237 self.uniformity = opt.analysis::<Uniformity>();
238 self.shared_liveness = opt.analysis::<SharedLiveness>();
239 self.opt = Rc::new(opt);
240
241 self.init_state(kernel.clone());
242 self.init_debug();
243
244 let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
245
246 target.set_kernel_name(options.kernel_name.clone());
247
248 let (main, debug_setup) = self.declare_main(&options.kernel_name);
249
250 let setup = self.id();
251 self.debug_name(setup, "setup");
252
253 let entry = self.opt.entry();
254 let body = self.label(entry);
255 let setup_block = self.setup(setup, debug_setup);
256 self.setup_block = setup_block;
257 self.compile_block(entry);
258
259 let ret = self.opt.ret;
260 self.compile_block(ret);
261
262 if self.selected_block().is_some() {
263 let label = self.label(ret);
264 self.branch(label).unwrap();
265 }
266
267 self.select_block(Some(setup_block)).unwrap();
268 self.branch(body).unwrap();
269
270 self.end_function().unwrap();
271
272 self.declare_shared_memories();
273
274 let builtins = self
275 .state
276 .used_builtins
277 .clone()
278 .into_iter()
279 .map(|(builtin, (id, item))| {
280 let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
281 self.variable(ty, Some(id), StorageClass::Input, None);
282 self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
283 id
284 })
285 .collect::<Vec<_>>();
286
287 target.set_modes(self, main, builtins, cube_dims);
288
289 let module = take(&mut self.builder).module();
290 (module, self.opt.as_ref().clone())
291 }
292
293 fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
294 self.begin_block(Some(label)).unwrap();
295
296 let opt = self.opt.clone();
297 for const_arr in opt.const_arrays() {
298 self.register_const_array(const_arr);
299 }
300
301 debug_setup(self);
302
303 let setup_block = self.selected_block().unwrap();
304 self.select_block(None).unwrap();
305 setup_block
306 }
307
308 #[track_caller]
309 pub fn current_block(&self) -> BasicBlock {
310 self.opt.block(self.current_block.unwrap()).clone()
311 }
312
313 pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
314 if let Some(existing) = self.state.used_builtins.get(&builtin) {
315 existing.0
316 } else {
317 let id = self.id();
318 self.state.used_builtins.insert(builtin, (id, item));
319 id
320 }
321 }
322
323 pub fn compile_block(&mut self, block: NodeIndex) {
324 if self.visited.contains(&block) {
325 return;
326 }
327 self.visited.insert(block);
328 self.current_block = Some(block);
329
330 let label = self.label(block);
331 self.begin_block(Some(label)).unwrap();
332 let block_id = self.selected_block().unwrap();
333
334 self.debug_start_block();
335
336 let operations = self.current_block().ops.borrow().clone();
337 for (_, operation) in operations {
338 self.compile_operation(operation);
339 }
340
341 let control_flow = self.current_block().control_flow.borrow().clone();
342 self.compile_control_flow(control_flow);
343
344 let current = self.selected_block();
345 self.select_block(Some(block_id)).unwrap();
346 let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
347 for phi in phi {
348 let out = self.compile_variable(phi.out);
349 let ty = out.item().id(self);
350 let out_id = self.write_id(&out);
351 let entries: Vec<_> = phi
352 .entries
353 .into_iter()
354 .map(|it| {
355 let label = self.end_label(it.block);
356 let value = self.compile_variable(it.value);
357 let value = self.read(&value);
358 (value, label)
359 })
360 .collect();
361 self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
362 .unwrap();
363 }
364 self.select_block(current).unwrap();
365 }
366
367 pub fn declare_function_variable(&mut self, ty: Word) -> Word {
369 let setup = self.setup_block;
370 let id = self.id();
371 let var = Instruction::new(
372 Op::Variable,
373 Some(ty),
374 Some(id),
375 vec![Operand::StorageClass(StorageClass::Function)],
376 );
377 let current_block = self.selected_block();
378 self.select_block(Some(setup)).unwrap();
379 self.insert_into_block(InsertPoint::Begin, var).unwrap();
380 self.select_block(current_block).unwrap();
381 id
382 }
383
384 fn declare_shared_memories(&mut self) {
385 if self.compilation_options.supports_explicit_smem {
386 self.declare_shared_memories_explicit();
387 } else {
388 self.declare_shared_memories_implicit();
389 }
390 }
391
392 fn declare_shared_memories_explicit(&mut self) {
397 let shared_arrays = self.state.shared_arrays.clone();
398 let shared = self.state.shared.clone();
399 if shared_arrays.is_empty() && shared.is_empty() {
400 return;
401 }
402
403 self.capabilities
404 .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
405
406 for (index, memory) in shared_arrays {
407 let item_size = memory.item.size();
408
409 match item_size {
412 1 => {
413 self.capabilities
414 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
415 }
416 2 => {
417 self.capabilities
418 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
419 }
420 _ => {}
421 }
422
423 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
424 let arr_id = arr_ty.id(self);
425
426 if !self.state.decorated_types.contains(&arr_id) {
427 self.decorate(
428 arr_id,
429 Decoration::ArrayStride,
430 [Operand::LiteralBit32(item_size)],
431 );
432 self.state.decorated_types.insert(arr_id);
433 }
434
435 let block_ty = Item::Struct(vec![arr_ty]);
436 let block_id = block_ty.id(self);
437
438 self.decorate(block_id, Decoration::Block, []);
439 self.member_decorate(
440 block_id,
441 0,
442 Decoration::Offset,
443 [Operand::LiteralBit32(memory.offset)],
444 );
445
446 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
447
448 self.debug_shared(memory.id, index);
449 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
450 self.decorate(memory.id, Decoration::Aliased, []);
451 }
452
453 for (index, memory) in shared {
454 let item_size = memory.item.size();
455
456 match item_size {
459 1 => {
460 self.capabilities
461 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
462 }
463 2 => {
464 self.capabilities
465 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
466 }
467 _ => {}
468 }
469
470 let block_ty = Item::Struct(vec![memory.item]);
471 let block_id = block_ty.id(self);
472
473 self.decorate(block_id, Decoration::Block, []);
474 self.member_decorate(
475 block_id,
476 0,
477 Decoration::Offset,
478 [Operand::LiteralBit32(memory.offset)],
479 );
480
481 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
482
483 self.debug_shared(memory.id, index);
484 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
485 self.decorate(memory.id, Decoration::Aliased, []);
486 }
487 }
488
489 fn declare_shared_memories_implicit(&mut self) {
490 let shared_memories = self.state.shared_arrays.clone();
491 for (index, memory) in shared_memories {
492 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
493 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
494
495 self.debug_shared(memory.id, index);
496 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
497 }
498 let shared = self.state.shared.clone();
499 for (index, memory) in shared {
500 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
501
502 self.debug_shared(memory.id, index);
503 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
504 }
505 }
506
507 pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
508 if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
509 return;
510 }
511 let mode = convert_math_mode(modes.fp_math_mode);
512 self.capabilities.insert(Capability::FloatControls2);
513 self.decorate(
514 out_id,
515 Decoration::FPFastMathMode,
516 [Operand::FPFastMathMode(mode)],
517 );
518 }
519
520 pub fn is_uniform_block(&self) -> bool {
521 self.uniformity
522 .is_block_uniform(self.current_block.unwrap())
523 }
524}
525
526pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
527 let mut flags = FPFastMathMode::NONE;
528
529 for mode in math_mode.iter() {
530 match mode {
531 FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
532 FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
533 FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
534 FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
535 FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
536 FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
537 FastMath::AllowTransform => {
538 flags |= FPFastMathMode::ALLOW_CONTRACT
539 | FPFastMathMode::ALLOW_REASSOC
540 | FPFastMathMode::ALLOW_TRANSFORM
541 }
542 _ => {}
543 }
544 }
545
546 flags
547}