1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::{
3 Metadata, WgpuCompilationOptions,
4 ir::{self as core, InstructionModes},
5 post_processing::{
6 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
7 unroll::UnrollProcessor,
8 },
9 prelude::{FastMath, KernelDefinition},
10 server::ExecutionMode,
11};
12use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
13use cubecl_runtime::{
14 EnumSet,
15 compiler::CompilationError,
16 config::{GlobalConfig, compilation::CompilationLogLevel},
17};
18use std::{
19 collections::HashSet,
20 fmt::Debug,
21 mem::take,
22 ops::{Deref, DerefMut},
23 rc::Rc,
24};
25
26use cubecl_core::Compiler;
27use rspirv::{
28 dr::{Builder, InsertPoint, Instruction, Module, Operand},
29 spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
30};
31
32use crate::{
33 SpirvKernel,
34 debug::DebugInfo,
35 item::Item,
36 lookups::LookupTables,
37 target::{GLCompute, SpirvTarget},
38 transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform},
39};
40
41pub const MAX_VECTORIZATION: u32 = 4;
42
43pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
44 pub target: Target,
45 pub(crate) builder: Builder,
46
47 pub mode: ExecutionMode,
48 pub debug_symbols: bool,
49 global_invocation_id: Word,
50 num_workgroups: Word,
51 pub setup_block: usize,
52 pub opt: Rc<Optimizer>,
53 pub uniformity: Rc<Uniformity>,
54 pub shared_liveness: Rc<SharedLiveness>,
55 pub current_block: Option<NodeIndex>,
56 pub visited: HashSet<NodeIndex>,
57
58 pub capabilities: HashSet<Capability>,
59 pub state: LookupTables,
60 pub ext_meta_pos: Vec<u32>,
61 pub metadata: Metadata,
62 pub debug_info: Option<DebugInfo>,
63 pub compilation_options: WgpuCompilationOptions,
64}
65
66unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
67unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
68
69impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
70 fn clone(&self) -> Self {
71 Self {
72 target: self.target.clone(),
73 builder: Builder::new_from_module(self.module_ref().clone()),
74 mode: self.mode,
75 global_invocation_id: self.global_invocation_id,
76 num_workgroups: self.num_workgroups,
77 setup_block: self.setup_block,
78 opt: self.opt.clone(),
79 uniformity: self.uniformity.clone(),
80 shared_liveness: self.shared_liveness.clone(),
81 current_block: self.current_block,
82 capabilities: self.capabilities.clone(),
83 state: self.state.clone(),
84 debug_symbols: self.debug_symbols,
85 visited: self.visited.clone(),
86 metadata: self.metadata.clone(),
87 debug_info: self.debug_info.clone(),
88 ext_meta_pos: self.ext_meta_pos.clone(),
89 compilation_options: self.compilation_options.clone(),
90 }
91 }
92}
93
94fn debug_symbols_activated() -> bool {
95 matches!(
96 GlobalConfig::get().compilation.logger.level,
97 CompilationLogLevel::Full
98 )
99}
100
101impl<T: SpirvTarget> Default for SpirvCompiler<T> {
102 fn default() -> Self {
103 Self {
104 target: Default::default(),
105 builder: Builder::new(),
106 mode: Default::default(),
107 global_invocation_id: Default::default(),
108 num_workgroups: Default::default(),
109 capabilities: Default::default(),
110 state: Default::default(),
111 setup_block: Default::default(),
112 opt: Default::default(),
113 uniformity: Default::default(),
114 shared_liveness: Default::default(),
115 current_block: Default::default(),
116 debug_symbols: debug_symbols_activated(),
117 visited: Default::default(),
118 metadata: Default::default(),
119 debug_info: Default::default(),
120 ext_meta_pos: Default::default(),
121 compilation_options: Default::default(),
122 }
123 }
124}
125
126impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
127 type Target = Builder;
128
129 fn deref(&self) -> &Self::Target {
130 &self.builder
131 }
132}
133
134impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
135 fn deref_mut(&mut self) -> &mut Self::Target {
136 &mut self.builder
137 }
138}
139
140impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
141 type Representation = SpirvKernel;
142 type CompilationOptions = WgpuCompilationOptions;
143
144 fn compile(
145 &mut self,
146 mut value: KernelDefinition,
147 compilation_options: &Self::CompilationOptions,
148 mode: ExecutionMode,
149 ) -> Result<Self::Representation, CompilationError> {
150 let errors = value.body.pop_errors();
151 if !errors.is_empty() {
152 let mut reason = "Can't compile spirv kernel".to_string();
153 for error in errors {
154 reason += error.as_str();
155 reason += "\n";
156 }
157
158 return Err(CompilationError::Validation {
159 reason,
160 backtrace: BackTrace::capture(),
161 });
162 }
163
164 let bindings = value.buffers.clone();
165 let scalars = value
166 .scalars
167 .iter()
168 .map(|s| (self.compile_storage_type(s.ty), s.count))
169 .collect();
170 let mut ext_meta_pos = Vec::new();
171 let mut num_ext = 0;
172
173 let mut all_meta: Vec<_> = value
174 .buffers
175 .iter()
176 .chain(value.tensor_maps.iter())
177 .map(|buf| (buf.id, buf.has_extended_meta))
178 .collect();
179 all_meta.sort_by_key(|(id, _)| *id);
180
181 let num_meta = all_meta.len();
182
183 for (_, has_extended_meta) in all_meta.iter() {
184 ext_meta_pos.push(num_ext);
185 if *has_extended_meta {
186 num_ext += 1;
187 }
188 }
189
190 self.mode = mode;
191 self.metadata = Metadata::new(num_meta as u32, num_ext);
192 self.compilation_options = compilation_options.clone();
193 self.ext_meta_pos = ext_meta_pos;
194
195 let (module, optimizer) = self.compile_kernel(value);
196 Ok(SpirvKernel {
197 module,
198 optimizer,
199 bindings,
200 scalars,
201 has_metadata: self.metadata.static_len() > 0,
202 })
203 }
204
205 fn elem_size(&self, elem: core::ElemType) -> usize {
206 elem.size()
207 }
208
209 fn extension(&self) -> &'static str {
210 "spv"
211 }
212}
213
214impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
215 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216 write!(f, "spirv<{:?}>", self.target)
217 }
218}
219
220impl<Target: SpirvTarget> SpirvCompiler<Target> {
221 pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
222 let options = kernel.options.clone();
223
224 self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
225
226 self.set_version(1, 6);
227
228 let mut target = self.target.clone();
229
230 let mut opt = OptimizerBuilder::default()
231 .with_transformer(ErfTransform)
232 .with_transformer(BitwiseTransform)
233 .with_transformer(HypotTransform)
234 .with_transformer(RhypotTransform)
235 .with_processor(CheckedIoProcessor::new(self.mode))
236 .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
237 .with_processor(SaturatingArithmeticProcessor::new(true))
238 .optimize(kernel.body.clone(), kernel.cube_dim);
239
240 self.uniformity = opt.analysis::<Uniformity>();
241 self.shared_liveness = opt.analysis::<SharedLiveness>();
242 self.opt = Rc::new(opt);
243
244 self.init_state(kernel.clone());
245 self.init_debug();
246
247 let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
248
249 target.set_kernel_name(options.kernel_name.clone());
250
251 let (main, debug_setup) = self.declare_main(&options.kernel_name);
252
253 let setup = self.id();
254 self.debug_name(setup, "setup");
255
256 let entry = self.opt.entry();
257 let body = self.label(entry);
258 let setup_block = self.setup(setup, debug_setup);
259 self.setup_block = setup_block;
260 self.compile_block(entry);
261
262 let ret = self.opt.ret;
263 self.compile_block(ret);
264
265 if self.selected_block().is_some() {
266 let label = self.label(ret);
267 self.branch(label).unwrap();
268 }
269
270 self.select_block(Some(setup_block)).unwrap();
271 self.branch(body).unwrap();
272
273 self.end_function().unwrap();
274
275 self.declare_shared_memories();
276
277 let builtins = self
278 .state
279 .used_builtins
280 .clone()
281 .into_iter()
282 .map(|(builtin, (id, item))| {
283 let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
284 self.variable(ty, Some(id), StorageClass::Input, None);
285 self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
286 id
287 })
288 .collect::<Vec<_>>();
289
290 target.set_modes(self, main, builtins, cube_dims);
291
292 let module = take(&mut self.builder).module();
293 (module, self.opt.as_ref().clone())
294 }
295
296 fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
297 self.begin_block(Some(label)).unwrap();
298
299 let opt = self.opt.clone();
300 for const_arr in opt.const_arrays() {
301 self.register_const_array(const_arr);
302 }
303
304 debug_setup(self);
305
306 let setup_block = self.selected_block().unwrap();
307 self.select_block(None).unwrap();
308 setup_block
309 }
310
311 #[track_caller]
312 pub fn current_block(&self) -> BasicBlock {
313 self.opt.block(self.current_block.unwrap()).clone()
314 }
315
316 pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
317 if let Some(existing) = self.state.used_builtins.get(&builtin) {
318 existing.0
319 } else {
320 let id = self.id();
321 self.state.used_builtins.insert(builtin, (id, item));
322 id
323 }
324 }
325
326 pub fn compile_block(&mut self, block: NodeIndex) {
327 if self.visited.contains(&block) {
328 return;
329 }
330 self.visited.insert(block);
331 self.current_block = Some(block);
332
333 let label = self.label(block);
334 self.begin_block(Some(label)).unwrap();
335 let block_id = self.selected_block().unwrap();
336
337 self.debug_start_block();
338
339 let operations = self.current_block().ops.borrow().clone();
340 for (_, operation) in operations {
341 self.compile_operation(operation);
342 }
343
344 let control_flow = self.current_block().control_flow.borrow().clone();
345 self.compile_control_flow(control_flow);
346
347 let current = self.selected_block();
348 self.select_block(Some(block_id)).unwrap();
349 let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
350 for phi in phi {
351 let out = self.compile_variable(phi.out);
352 let ty = out.item().id(self);
353 let out_id = self.write_id(&out);
354 let entries: Vec<_> = phi
355 .entries
356 .into_iter()
357 .map(|it| {
358 let label = self.end_label(it.block);
359 let value = self.compile_variable(it.value);
360 let value = self.read(&value);
361 (value, label)
362 })
363 .collect();
364 self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
365 .unwrap();
366 }
367 self.select_block(current).unwrap();
368 }
369
370 pub fn declare_function_variable(&mut self, ty: Word) -> Word {
372 let setup = self.setup_block;
373 let id = self.id();
374 let var = Instruction::new(
375 Op::Variable,
376 Some(ty),
377 Some(id),
378 vec![Operand::StorageClass(StorageClass::Function)],
379 );
380 let current_block = self.selected_block();
381 self.select_block(Some(setup)).unwrap();
382 self.insert_into_block(InsertPoint::Begin, var).unwrap();
383 self.select_block(current_block).unwrap();
384 id
385 }
386
387 fn declare_shared_memories(&mut self) {
388 if self.compilation_options.supports_explicit_smem {
389 self.declare_shared_memories_explicit();
390 } else {
391 self.declare_shared_memories_implicit();
392 }
393 }
394
395 fn declare_shared_memories_explicit(&mut self) {
400 let shared_arrays = self.state.shared_arrays.clone();
401 let shared = self.state.shared.clone();
402 if shared_arrays.is_empty() && shared.is_empty() {
403 return;
404 }
405
406 self.capabilities
407 .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
408
409 for (index, memory) in shared_arrays {
410 let item_size = memory.item.size();
411
412 match item_size {
415 1 => {
416 self.capabilities
417 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
418 }
419 2 => {
420 self.capabilities
421 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
422 }
423 _ => {}
424 }
425
426 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
427 let arr_id = arr_ty.id(self);
428
429 if !self.state.decorated_types.contains(&arr_id) {
430 self.decorate(
431 arr_id,
432 Decoration::ArrayStride,
433 [Operand::LiteralBit32(item_size)],
434 );
435 self.state.decorated_types.insert(arr_id);
436 }
437
438 let block_ty = Item::Struct(vec![arr_ty]);
439 let block_id = block_ty.id(self);
440
441 self.decorate(block_id, Decoration::Block, []);
442 self.member_decorate(
443 block_id,
444 0,
445 Decoration::Offset,
446 [Operand::LiteralBit32(memory.offset)],
447 );
448
449 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
450
451 self.debug_shared(memory.id, index);
452 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
453 self.decorate(memory.id, Decoration::Aliased, []);
454 }
455
456 for (index, memory) in shared {
457 let item_size = memory.item.size();
458
459 match item_size {
462 1 => {
463 self.capabilities
464 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
465 }
466 2 => {
467 self.capabilities
468 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
469 }
470 _ => {}
471 }
472
473 let block_ty = Item::Struct(vec![memory.item]);
474 let block_id = block_ty.id(self);
475
476 self.decorate(block_id, Decoration::Block, []);
477 self.member_decorate(
478 block_id,
479 0,
480 Decoration::Offset,
481 [Operand::LiteralBit32(memory.offset)],
482 );
483
484 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
485
486 self.debug_shared(memory.id, index);
487 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
488 self.decorate(memory.id, Decoration::Aliased, []);
489 }
490 }
491
492 fn declare_shared_memories_implicit(&mut self) {
493 let shared_memories = self.state.shared_arrays.clone();
494 for (index, memory) in shared_memories {
495 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
496 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
497
498 self.debug_shared(memory.id, index);
499 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
500 }
501 let shared = self.state.shared.clone();
502 for (index, memory) in shared {
503 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
504
505 self.debug_shared(memory.id, index);
506 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
507 }
508 }
509
510 pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
511 if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
512 return;
513 }
514 let mode = convert_math_mode(modes.fp_math_mode);
515 self.capabilities.insert(Capability::FloatControls2);
516 self.decorate(
517 out_id,
518 Decoration::FPFastMathMode,
519 [Operand::FPFastMathMode(mode)],
520 );
521 }
522
523 pub fn is_uniform_block(&self) -> bool {
524 self.uniformity
525 .is_block_uniform(self.current_block.unwrap())
526 }
527}
528
529pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
530 let mut flags = FPFastMathMode::NONE;
531
532 for mode in math_mode.iter() {
533 match mode {
534 FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
535 FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
536 FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
537 FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
538 FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
539 FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
540 FastMath::AllowTransform => {
541 flags |= FPFastMathMode::ALLOW_CONTRACT
542 | FPFastMathMode::ALLOW_REASSOC
543 | FPFastMathMode::ALLOW_TRANSFORM
544 }
545 _ => {}
546 }
547 }
548
549 flags
550}