1use crate::{
2 SpirvKernel,
3 debug::DebugInfo,
4 item::Item,
5 lookups::LookupTables,
6 target::{GLCompute, SpirvTarget},
7 transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform},
8};
9use cubecl_common::backtrace::BackTrace;
10use cubecl_core::{
11 Compiler, CubeDim, Info, Metadata, WgpuCompilationOptions,
12 ir::{self as core, ElemType, InstructionModes, StorageType, UIntKind, features::EnumSet},
13 post_processing::{
14 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
15 unroll::UnrollProcessor,
16 },
17 prelude::{FastMath, KernelDefinition},
18 server::ExecutionMode,
19};
20use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
21use cubecl_runtime::{
22 compiler::CompilationError,
23 config::{GlobalConfig, compilation::CompilationLogLevel},
24};
25use rspirv::{
26 binary::Assemble,
27 dr::{Builder, InsertPoint, Instruction, Module, Operand},
28 spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
29};
30use std::{
31 collections::HashSet,
32 fmt::Debug,
33 mem::take,
34 ops::{Deref, DerefMut},
35 rc::Rc,
36 sync::Arc,
37};
38
39pub const MAX_VECTORIZATION: usize = 4;
40
41pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
42 pub target: Target,
43 pub(crate) builder: Builder,
44
45 pub cube_dim: CubeDim,
46 pub mode: ExecutionMode,
47 pub addr_type: StorageType,
48 pub debug_symbols: bool,
49 global_invocation_id: Word,
50 num_workgroups: Word,
51 pub setup_block: usize,
52 pub opt: Rc<Optimizer>,
53 pub uniformity: Rc<Uniformity>,
54 pub shared_liveness: Rc<SharedLiveness>,
55 pub current_block: Option<NodeIndex>,
56 pub visited: HashSet<NodeIndex>,
57
58 pub capabilities: HashSet<Capability>,
59 pub state: LookupTables,
60 pub ext_meta_pos: Vec<u32>,
61 pub info: Info,
62 pub debug_info: Option<DebugInfo>,
63 pub compilation_options: WgpuCompilationOptions,
64}
65
66unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
67unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
68
69impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
70 fn clone(&self) -> Self {
71 Self {
72 target: self.target.clone(),
73 builder: Builder::new_from_module(self.module_ref().clone()),
74 cube_dim: self.cube_dim,
75 mode: self.mode,
76 addr_type: self.addr_type,
77 global_invocation_id: self.global_invocation_id,
78 num_workgroups: self.num_workgroups,
79 setup_block: self.setup_block,
80 opt: self.opt.clone(),
81 uniformity: self.uniformity.clone(),
82 shared_liveness: self.shared_liveness.clone(),
83 current_block: self.current_block,
84 capabilities: self.capabilities.clone(),
85 state: self.state.clone(),
86 debug_symbols: self.debug_symbols,
87 visited: self.visited.clone(),
88 info: self.info.clone(),
89 debug_info: self.debug_info.clone(),
90 ext_meta_pos: self.ext_meta_pos.clone(),
91 compilation_options: self.compilation_options,
92 }
93 }
94}
95
96fn debug_symbols_activated() -> bool {
97 matches!(
98 GlobalConfig::get().compilation.logger.level,
99 CompilationLogLevel::Full
100 )
101}
102
103impl<T: SpirvTarget> Default for SpirvCompiler<T> {
104 fn default() -> Self {
105 Self {
106 target: Default::default(),
107 builder: Builder::new(),
108 cube_dim: CubeDim::new_single(),
109 mode: Default::default(),
110 addr_type: ElemType::UInt(UIntKind::U32).into(),
111 global_invocation_id: Default::default(),
112 num_workgroups: Default::default(),
113 capabilities: Default::default(),
114 state: Default::default(),
115 setup_block: Default::default(),
116 opt: Default::default(),
117 uniformity: Default::default(),
118 shared_liveness: Default::default(),
119 current_block: Default::default(),
120 debug_symbols: debug_symbols_activated(),
121 visited: Default::default(),
122 info: Default::default(),
123 debug_info: Default::default(),
124 ext_meta_pos: Default::default(),
125 compilation_options: Default::default(),
126 }
127 }
128}
129
130impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
131 type Target = Builder;
132
133 fn deref(&self) -> &Self::Target {
134 &self.builder
135 }
136}
137
138impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
139 fn deref_mut(&mut self) -> &mut Self::Target {
140 &mut self.builder
141 }
142}
143
144impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
145 type Representation = SpirvKernel;
146 type CompilationOptions = WgpuCompilationOptions;
147
148 fn compile(
149 &mut self,
150 mut value: KernelDefinition,
151 compilation_options: &Self::CompilationOptions,
152 mode: ExecutionMode,
153 addr_type: StorageType,
154 ) -> Result<Self::Representation, CompilationError> {
155 let errors = value.body.pop_errors();
156 if !errors.is_empty() {
157 let mut reason = "Can't compile spirv kernel".to_string();
158 for error in errors {
159 reason += error.as_str();
160 reason += "\n";
161 }
162
163 return Err(CompilationError::Validation {
164 reason,
165 backtrace: BackTrace::capture(),
166 });
167 }
168
169 let bindings = value.buffers.clone();
170 let mut ext_meta_pos = Vec::new();
171 let mut num_ext = 0;
172
173 let mut all_meta: Vec<_> = value
174 .buffers
175 .iter()
176 .chain(value.tensor_maps.iter())
177 .map(|buf| (buf.id, buf.has_extended_meta))
178 .collect();
179 all_meta.sort_by_key(|(id, _)| *id);
180
181 let num_meta = all_meta.len();
182
183 for (_, has_extended_meta) in all_meta.iter() {
184 ext_meta_pos.push(num_ext);
185 if *has_extended_meta {
186 num_ext += 1;
187 }
188 }
189
190 let metadata = Metadata::new(num_meta as u32, num_ext);
191
192 self.cube_dim = value.cube_dim;
193 self.mode = mode;
194 self.addr_type = addr_type;
195 self.info = Info::new(&value.scalars, metadata, addr_type);
196 self.compilation_options = *compilation_options;
197 self.ext_meta_pos = ext_meta_pos;
198
199 let (module, optimizer, shared_size) = self.compile_kernel(value);
200 let uniform_info = matches!(T::info_storage_class(self), StorageClass::Uniform);
201
202 Ok(SpirvKernel {
203 assembled_module: module.assemble(),
204 module: Some(Arc::new(module)),
205 optimizer: Some(Arc::new(optimizer)),
206 bindings: bindings.iter().map(|it| it.visibility).collect(),
207 shared_size,
208 uniform_info,
209 })
210 }
211
212 fn elem_size(&self, elem: core::ElemType) -> usize {
213 elem.size()
214 }
215
216 fn extension(&self) -> &'static str {
217 "spv"
218 }
219}
220
221impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 write!(f, "spirv<{:?}>", self.target)
224 }
225}
226
227impl<Target: SpirvTarget> SpirvCompiler<Target> {
228 pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer, usize) {
229 let options = kernel.options.clone();
230
231 self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
232
233 let version = self.compilation_options.vulkan.max_spirv_version;
234 self.set_version(version.0, version.1);
235
236 let mut target = self.target.clone();
237
238 let mut opt = OptimizerBuilder::default()
239 .with_transformer(ErfTransform)
240 .with_transformer(BitwiseTransform::new(
241 self.compilation_options.vulkan.supports_arbitrary_bitwise,
242 ))
243 .with_transformer(HypotTransform)
244 .with_transformer(RhypotTransform)
245 .with_processor(CheckedIoProcessor::new(
246 self.mode,
247 kernel.options.kernel_name.clone(),
248 ))
249 .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
250 .with_processor(SaturatingArithmeticProcessor::new(true))
251 .optimize(kernel.body.clone(), kernel.cube_dim);
252
253 self.uniformity = opt.analysis::<Uniformity>();
254 self.shared_liveness = opt.analysis::<SharedLiveness>();
255 self.opt = Rc::new(opt);
256
257 self.init_state(kernel.clone());
258 self.init_debug();
259
260 let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
261
262 target.set_kernel_name(options.kernel_name.clone());
263
264 let (main, debug_setup) = self.declare_main(&options.kernel_name);
265
266 let setup = self.id();
267 self.debug_name(setup, "setup");
268
269 let entry = self.opt.entry();
270 let body = self.label(entry);
271 let setup_block = self.setup(setup, debug_setup);
272 self.setup_block = setup_block;
273 self.compile_block(entry);
274
275 let ret = self.opt.ret;
276 self.compile_block(ret);
277
278 if self.selected_block().is_some() {
279 let label = self.label(ret);
280 self.branch(label).unwrap();
281 }
282
283 self.select_block(Some(setup_block)).unwrap();
284 self.branch(body).unwrap();
285
286 self.end_function().unwrap();
287
288 let shared_size = self.declare_shared_memories();
289
290 let builtins = self
291 .state
292 .used_builtins
293 .clone()
294 .into_iter()
295 .map(|(builtin, (id, item))| {
296 let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
297 self.variable(ty, Some(id), StorageClass::Input, None);
298 self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
299 id
300 })
301 .collect::<Vec<_>>();
302
303 target.set_modes(self, main, builtins, cube_dims);
304
305 let module = take(&mut self.builder).module();
306 (module, self.opt.as_ref().clone(), shared_size)
307 }
308
309 fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
310 self.begin_block(Some(label)).unwrap();
311
312 let opt = self.opt.clone();
313 for const_arr in opt.const_arrays() {
314 self.register_const_array(const_arr);
315 }
316
317 debug_setup(self);
318
319 let setup_block = self.selected_block().unwrap();
320 self.select_block(None).unwrap();
321 setup_block
322 }
323
324 #[track_caller]
325 pub fn current_block(&self) -> BasicBlock {
326 self.opt.block(self.current_block.unwrap()).clone()
327 }
328
329 pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
330 if let Some(existing) = self.state.used_builtins.get(&builtin) {
331 existing.0
332 } else {
333 let id = self.id();
334 self.state.used_builtins.insert(builtin, (id, item));
335 id
336 }
337 }
338
339 pub fn compile_block(&mut self, block: NodeIndex) {
340 if self.visited.contains(&block) {
341 return;
342 }
343 self.visited.insert(block);
344 self.current_block = Some(block);
345
346 let label = self.label(block);
347 self.begin_block(Some(label)).unwrap();
348 let block_id = self.selected_block().unwrap();
349
350 self.debug_start_block();
351
352 let operations = self.current_block().ops.borrow().clone();
353 for (_, operation) in operations {
354 self.compile_operation(operation);
355 }
356
357 let control_flow = self.current_block().control_flow.borrow().clone();
358 self.compile_control_flow(control_flow);
359
360 let current = self.selected_block();
361 self.select_block(Some(block_id)).unwrap();
362 let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
363 for phi in phi {
364 let out = self.compile_variable(phi.out);
365 let ty = out.item().id(self);
366 let out_id = self.write_id(&out);
367 let entries: Vec<_> = phi
368 .entries
369 .into_iter()
370 .map(|it| {
371 let label = self.end_label(it.block);
372 let value = self.compile_variable(it.value);
373 let value = self.read(&value);
374 (value, label)
375 })
376 .collect();
377 self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
378 .unwrap();
379 }
380 self.select_block(current).unwrap();
381 }
382
383 pub fn declare_function_variable(&mut self, ty: Word) -> Word {
385 let setup = self.setup_block;
386 let id = self.id();
387 let var = Instruction::new(
388 Op::Variable,
389 Some(ty),
390 Some(id),
391 vec![Operand::StorageClass(StorageClass::Function)],
392 );
393 let current_block = self.selected_block();
394 self.select_block(Some(setup)).unwrap();
395 self.insert_into_block(InsertPoint::Begin, var).unwrap();
396 self.select_block(current_block).unwrap();
397 id
398 }
399
400 fn declare_shared_memories(&mut self) -> usize {
401 if self.compilation_options.vulkan.supports_explicit_smem {
402 self.declare_shared_memories_explicit() as usize
403 } else {
404 self.declare_shared_memories_implicit() as usize
405 }
406 }
407
408 fn declare_shared_memories_explicit(&mut self) -> u32 {
413 let mut shared_size = 0;
414
415 let shared_arrays = self.state.shared_arrays.clone();
416 let shared = self.state.shared.clone();
417 if shared_arrays.is_empty() && shared.is_empty() {
418 return shared_size;
419 }
420
421 self.capabilities
422 .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
423
424 for (index, memory) in shared_arrays {
425 let item_size = memory.item.size();
426 shared_size = shared_size.max(memory.offset + memory.len * item_size);
427
428 match item_size {
431 1 => {
432 self.capabilities
433 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
434 }
435 2 => {
436 self.capabilities
437 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
438 }
439 _ => {}
440 }
441
442 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
443 let arr_id = arr_ty.id(self);
444
445 if !self.state.decorated_types.contains(&arr_id) {
446 self.decorate(
447 arr_id,
448 Decoration::ArrayStride,
449 [Operand::LiteralBit32(item_size)],
450 );
451 self.state.decorated_types.insert(arr_id);
452 }
453
454 let block_ty = Item::Struct(vec![arr_ty]);
455 let block_id = block_ty.id(self);
456
457 self.decorate(block_id, Decoration::Block, []);
458 self.member_decorate(
459 block_id,
460 0,
461 Decoration::Offset,
462 [Operand::LiteralBit32(memory.offset)],
463 );
464
465 let ptr_ty = self.type_pointer(None, StorageClass::Workgroup, block_id);
466
467 self.debug_shared(memory.id, index);
468 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
469 self.decorate(memory.id, Decoration::Aliased, []);
470 }
471
472 for (index, memory) in shared {
473 let item_size = memory.item.size();
474 shared_size = shared_size.max(memory.offset + item_size);
475
476 match item_size {
479 1 => {
480 self.capabilities
481 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
482 }
483 2 => {
484 self.capabilities
485 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
486 }
487 _ => {}
488 }
489
490 let block_ty = Item::Struct(vec![memory.item]);
491 let block_id = block_ty.id(self);
492
493 self.decorate(block_id, Decoration::Block, []);
494 self.member_decorate(
495 block_id,
496 0,
497 Decoration::Offset,
498 [Operand::LiteralBit32(memory.offset)],
499 );
500
501 let ptr_ty = self.type_pointer(None, StorageClass::Workgroup, block_id);
502
503 self.debug_shared(memory.id, index);
504 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
505 self.decorate(memory.id, Decoration::Aliased, []);
506 }
507
508 shared_size
509 }
510
511 fn declare_shared_memories_implicit(&mut self) -> u32 {
512 let mut shared_size = 0;
513 let shared_memories = self.state.shared_arrays.clone();
514 for (index, memory) in shared_memories {
515 shared_size += memory.len * memory.item.size();
516
517 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
518 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
519
520 self.debug_shared(memory.id, index);
521 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
522 }
523 let shared = self.state.shared.clone();
524 for (index, memory) in shared {
525 shared_size += memory.item.size();
526
527 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
528
529 self.debug_shared(memory.id, index);
530 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
531 }
532 shared_size
533 }
534
535 pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
536 if !self.compilation_options.vulkan.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
537 return;
538 }
539 let mode = convert_math_mode(modes.fp_math_mode);
540 self.capabilities.insert(Capability::FloatControls2);
541 self.decorate(
542 out_id,
543 Decoration::FPFastMathMode,
544 [Operand::FPFastMathMode(mode)],
545 );
546 }
547
548 pub fn is_uniform_block(&self) -> bool {
549 self.uniformity
550 .is_block_uniform(self.current_block.unwrap())
551 }
552}
553
554pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
555 let mut flags = FPFastMathMode::NONE;
556
557 for mode in math_mode.iter() {
558 match mode {
559 FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
560 FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
561 FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
562 FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
563 FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
564 FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
565 FastMath::AllowTransform => {
566 flags |= FPFastMathMode::ALLOW_CONTRACT
567 | FPFastMathMode::ALLOW_REASSOC
568 | FPFastMathMode::ALLOW_TRANSFORM
569 }
570 _ => {}
571 }
572 }
573
574 flags
575}