1use crate::{
2 SpirvKernel,
3 debug::DebugInfo,
4 item::Item,
5 lookups::LookupTables,
6 target::{GLCompute, SpirvTarget},
7 transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform},
8};
9use cubecl_common::backtrace::BackTrace;
10use cubecl_core::{
11 Compiler, CubeDim, Metadata, WgpuCompilationOptions,
12 ir::{self as core, ElemType, InstructionModes, StorageType, UIntKind, features::EnumSet},
13 post_processing::{
14 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
15 unroll::UnrollProcessor,
16 },
17 prelude::{FastMath, KernelDefinition},
18 server::ExecutionMode,
19};
20use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
21use cubecl_runtime::{
22 compiler::CompilationError,
23 config::{GlobalConfig, compilation::CompilationLogLevel},
24};
25use rspirv::{
26 dr::{Builder, InsertPoint, Instruction, Module, Operand},
27 spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
28};
29use std::{
30 collections::HashSet,
31 fmt::Debug,
32 mem::take,
33 ops::{Deref, DerefMut},
34 rc::Rc,
35};
36
37pub const MAX_VECTORIZATION: usize = 4;
38
39pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
40 pub target: Target,
41 pub(crate) builder: Builder,
42
43 pub cube_dim: CubeDim,
44 pub mode: ExecutionMode,
45 pub addr_type: StorageType,
46 pub debug_symbols: bool,
47 global_invocation_id: Word,
48 num_workgroups: Word,
49 pub setup_block: usize,
50 pub opt: Rc<Optimizer>,
51 pub uniformity: Rc<Uniformity>,
52 pub shared_liveness: Rc<SharedLiveness>,
53 pub current_block: Option<NodeIndex>,
54 pub visited: HashSet<NodeIndex>,
55
56 pub capabilities: HashSet<Capability>,
57 pub state: LookupTables,
58 pub ext_meta_pos: Vec<u32>,
59 pub metadata: Metadata,
60 pub debug_info: Option<DebugInfo>,
61 pub compilation_options: WgpuCompilationOptions,
62}
63
64unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
65unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
66
67impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
68 fn clone(&self) -> Self {
69 Self {
70 target: self.target.clone(),
71 builder: Builder::new_from_module(self.module_ref().clone()),
72 cube_dim: self.cube_dim,
73 mode: self.mode,
74 addr_type: self.addr_type,
75 global_invocation_id: self.global_invocation_id,
76 num_workgroups: self.num_workgroups,
77 setup_block: self.setup_block,
78 opt: self.opt.clone(),
79 uniformity: self.uniformity.clone(),
80 shared_liveness: self.shared_liveness.clone(),
81 current_block: self.current_block,
82 capabilities: self.capabilities.clone(),
83 state: self.state.clone(),
84 debug_symbols: self.debug_symbols,
85 visited: self.visited.clone(),
86 metadata: self.metadata.clone(),
87 debug_info: self.debug_info.clone(),
88 ext_meta_pos: self.ext_meta_pos.clone(),
89 compilation_options: self.compilation_options.clone(),
90 }
91 }
92}
93
94fn debug_symbols_activated() -> bool {
95 matches!(
96 GlobalConfig::get().compilation.logger.level,
97 CompilationLogLevel::Full
98 )
99}
100
101impl<T: SpirvTarget> Default for SpirvCompiler<T> {
102 fn default() -> Self {
103 Self {
104 target: Default::default(),
105 builder: Builder::new(),
106 cube_dim: CubeDim::new_single(),
107 mode: Default::default(),
108 addr_type: ElemType::UInt(UIntKind::U32).into(),
109 global_invocation_id: Default::default(),
110 num_workgroups: Default::default(),
111 capabilities: Default::default(),
112 state: Default::default(),
113 setup_block: Default::default(),
114 opt: Default::default(),
115 uniformity: Default::default(),
116 shared_liveness: Default::default(),
117 current_block: Default::default(),
118 debug_symbols: debug_symbols_activated(),
119 visited: Default::default(),
120 metadata: Default::default(),
121 debug_info: Default::default(),
122 ext_meta_pos: Default::default(),
123 compilation_options: Default::default(),
124 }
125 }
126}
127
128impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
129 type Target = Builder;
130
131 fn deref(&self) -> &Self::Target {
132 &self.builder
133 }
134}
135
136impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
137 fn deref_mut(&mut self) -> &mut Self::Target {
138 &mut self.builder
139 }
140}
141
142impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
143 type Representation = SpirvKernel;
144 type CompilationOptions = WgpuCompilationOptions;
145
146 fn compile(
147 &mut self,
148 mut value: KernelDefinition,
149 compilation_options: &Self::CompilationOptions,
150 mode: ExecutionMode,
151 addr_type: StorageType,
152 ) -> Result<Self::Representation, CompilationError> {
153 let errors = value.body.pop_errors();
154 if !errors.is_empty() {
155 let mut reason = "Can't compile spirv kernel".to_string();
156 for error in errors {
157 reason += error.as_str();
158 reason += "\n";
159 }
160
161 return Err(CompilationError::Validation {
162 reason,
163 backtrace: BackTrace::capture(),
164 });
165 }
166
167 let bindings = value.buffers.clone();
168 let scalars = value
169 .scalars
170 .iter()
171 .map(|s| (self.compile_storage_type(s.ty), s.count))
172 .collect();
173 let mut ext_meta_pos = Vec::new();
174 let mut num_ext = 0;
175
176 let mut all_meta: Vec<_> = value
177 .buffers
178 .iter()
179 .chain(value.tensor_maps.iter())
180 .map(|buf| (buf.id, buf.has_extended_meta))
181 .collect();
182 all_meta.sort_by_key(|(id, _)| *id);
183
184 let num_meta = all_meta.len();
185
186 for (_, has_extended_meta) in all_meta.iter() {
187 ext_meta_pos.push(num_ext);
188 if *has_extended_meta {
189 num_ext += 1;
190 }
191 }
192
193 self.cube_dim = value.cube_dim;
194 self.mode = mode;
195 self.addr_type = addr_type;
196 self.metadata = Metadata::new(num_meta as u32, num_ext);
197 self.compilation_options = compilation_options.clone();
198 self.ext_meta_pos = ext_meta_pos;
199
200 let (module, optimizer) = self.compile_kernel(value);
201 Ok(SpirvKernel {
202 module,
203 optimizer,
204 bindings,
205 scalars,
206 has_metadata: self.metadata.static_len() > 0,
207 })
208 }
209
210 fn elem_size(&self, elem: core::ElemType) -> usize {
211 elem.size()
212 }
213
214 fn extension(&self) -> &'static str {
215 "spv"
216 }
217}
218
219impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
220 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221 write!(f, "spirv<{:?}>", self.target)
222 }
223}
224
225impl<Target: SpirvTarget> SpirvCompiler<Target> {
226 pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
227 let options = kernel.options.clone();
228
229 self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
230
231 self.set_version(1, 6);
232
233 let mut target = self.target.clone();
234
235 let mut opt = OptimizerBuilder::default()
236 .with_transformer(ErfTransform)
237 .with_transformer(BitwiseTransform)
238 .with_transformer(HypotTransform)
239 .with_transformer(RhypotTransform)
240 .with_processor(CheckedIoProcessor::new(self.mode))
241 .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
242 .with_processor(SaturatingArithmeticProcessor::new(true))
243 .optimize(kernel.body.clone(), kernel.cube_dim);
244
245 self.uniformity = opt.analysis::<Uniformity>();
246 self.shared_liveness = opt.analysis::<SharedLiveness>();
247 self.opt = Rc::new(opt);
248
249 self.init_state(kernel.clone());
250 self.init_debug();
251
252 let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
253
254 target.set_kernel_name(options.kernel_name.clone());
255
256 let (main, debug_setup) = self.declare_main(&options.kernel_name);
257
258 let setup = self.id();
259 self.debug_name(setup, "setup");
260
261 let entry = self.opt.entry();
262 let body = self.label(entry);
263 let setup_block = self.setup(setup, debug_setup);
264 self.setup_block = setup_block;
265 self.compile_block(entry);
266
267 let ret = self.opt.ret;
268 self.compile_block(ret);
269
270 if self.selected_block().is_some() {
271 let label = self.label(ret);
272 self.branch(label).unwrap();
273 }
274
275 self.select_block(Some(setup_block)).unwrap();
276 self.branch(body).unwrap();
277
278 self.end_function().unwrap();
279
280 self.declare_shared_memories();
281
282 let builtins = self
283 .state
284 .used_builtins
285 .clone()
286 .into_iter()
287 .map(|(builtin, (id, item))| {
288 let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
289 self.variable(ty, Some(id), StorageClass::Input, None);
290 self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
291 id
292 })
293 .collect::<Vec<_>>();
294
295 target.set_modes(self, main, builtins, cube_dims);
296
297 let module = take(&mut self.builder).module();
298 (module, self.opt.as_ref().clone())
299 }
300
301 fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
302 self.begin_block(Some(label)).unwrap();
303
304 let opt = self.opt.clone();
305 for const_arr in opt.const_arrays() {
306 self.register_const_array(const_arr);
307 }
308
309 debug_setup(self);
310
311 let setup_block = self.selected_block().unwrap();
312 self.select_block(None).unwrap();
313 setup_block
314 }
315
316 #[track_caller]
317 pub fn current_block(&self) -> BasicBlock {
318 self.opt.block(self.current_block.unwrap()).clone()
319 }
320
321 pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
322 if let Some(existing) = self.state.used_builtins.get(&builtin) {
323 existing.0
324 } else {
325 let id = self.id();
326 self.state.used_builtins.insert(builtin, (id, item));
327 id
328 }
329 }
330
331 pub fn compile_block(&mut self, block: NodeIndex) {
332 if self.visited.contains(&block) {
333 return;
334 }
335 self.visited.insert(block);
336 self.current_block = Some(block);
337
338 let label = self.label(block);
339 self.begin_block(Some(label)).unwrap();
340 let block_id = self.selected_block().unwrap();
341
342 self.debug_start_block();
343
344 let operations = self.current_block().ops.borrow().clone();
345 for (_, operation) in operations {
346 self.compile_operation(operation);
347 }
348
349 let control_flow = self.current_block().control_flow.borrow().clone();
350 self.compile_control_flow(control_flow);
351
352 let current = self.selected_block();
353 self.select_block(Some(block_id)).unwrap();
354 let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
355 for phi in phi {
356 let out = self.compile_variable(phi.out);
357 let ty = out.item().id(self);
358 let out_id = self.write_id(&out);
359 let entries: Vec<_> = phi
360 .entries
361 .into_iter()
362 .map(|it| {
363 let label = self.end_label(it.block);
364 let value = self.compile_variable(it.value);
365 let value = self.read(&value);
366 (value, label)
367 })
368 .collect();
369 self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
370 .unwrap();
371 }
372 self.select_block(current).unwrap();
373 }
374
375 pub fn declare_function_variable(&mut self, ty: Word) -> Word {
377 let setup = self.setup_block;
378 let id = self.id();
379 let var = Instruction::new(
380 Op::Variable,
381 Some(ty),
382 Some(id),
383 vec![Operand::StorageClass(StorageClass::Function)],
384 );
385 let current_block = self.selected_block();
386 self.select_block(Some(setup)).unwrap();
387 self.insert_into_block(InsertPoint::Begin, var).unwrap();
388 self.select_block(current_block).unwrap();
389 id
390 }
391
392 fn declare_shared_memories(&mut self) {
393 if self.compilation_options.supports_explicit_smem {
394 self.declare_shared_memories_explicit();
395 } else {
396 self.declare_shared_memories_implicit();
397 }
398 }
399
400 fn declare_shared_memories_explicit(&mut self) {
405 let shared_arrays = self.state.shared_arrays.clone();
406 let shared = self.state.shared.clone();
407 if shared_arrays.is_empty() && shared.is_empty() {
408 return;
409 }
410
411 self.capabilities
412 .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
413
414 for (index, memory) in shared_arrays {
415 let item_size = memory.item.size();
416
417 match item_size {
420 1 => {
421 self.capabilities
422 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
423 }
424 2 => {
425 self.capabilities
426 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
427 }
428 _ => {}
429 }
430
431 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
432 let arr_id = arr_ty.id(self);
433
434 if !self.state.decorated_types.contains(&arr_id) {
435 self.decorate(
436 arr_id,
437 Decoration::ArrayStride,
438 [Operand::LiteralBit32(item_size)],
439 );
440 self.state.decorated_types.insert(arr_id);
441 }
442
443 let block_ty = Item::Struct(vec![arr_ty]);
444 let block_id = block_ty.id(self);
445
446 self.decorate(block_id, Decoration::Block, []);
447 self.member_decorate(
448 block_id,
449 0,
450 Decoration::Offset,
451 [Operand::LiteralBit32(memory.offset)],
452 );
453
454 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
455
456 self.debug_shared(memory.id, index);
457 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
458 self.decorate(memory.id, Decoration::Aliased, []);
459 }
460
461 for (index, memory) in shared {
462 let item_size = memory.item.size();
463
464 match item_size {
467 1 => {
468 self.capabilities
469 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
470 }
471 2 => {
472 self.capabilities
473 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
474 }
475 _ => {}
476 }
477
478 let block_ty = Item::Struct(vec![memory.item]);
479 let block_id = block_ty.id(self);
480
481 self.decorate(block_id, Decoration::Block, []);
482 self.member_decorate(
483 block_id,
484 0,
485 Decoration::Offset,
486 [Operand::LiteralBit32(memory.offset)],
487 );
488
489 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
490
491 self.debug_shared(memory.id, index);
492 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
493 self.decorate(memory.id, Decoration::Aliased, []);
494 }
495 }
496
497 fn declare_shared_memories_implicit(&mut self) {
498 let shared_memories = self.state.shared_arrays.clone();
499 for (index, memory) in shared_memories {
500 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
501 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
502
503 self.debug_shared(memory.id, index);
504 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
505 }
506 let shared = self.state.shared.clone();
507 for (index, memory) in shared {
508 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
509
510 self.debug_shared(memory.id, index);
511 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
512 }
513 }
514
515 pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
516 if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
517 return;
518 }
519 let mode = convert_math_mode(modes.fp_math_mode);
520 self.capabilities.insert(Capability::FloatControls2);
521 self.decorate(
522 out_id,
523 Decoration::FPFastMathMode,
524 [Operand::FPFastMathMode(mode)],
525 );
526 }
527
528 pub fn is_uniform_block(&self) -> bool {
529 self.uniformity
530 .is_block_uniform(self.current_block.unwrap())
531 }
532}
533
534pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
535 let mut flags = FPFastMathMode::NONE;
536
537 for mode in math_mode.iter() {
538 match mode {
539 FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
540 FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
541 FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
542 FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
543 FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
544 FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
545 FastMath::AllowTransform => {
546 flags |= FPFastMathMode::ALLOW_CONTRACT
547 | FPFastMathMode::ALLOW_REASSOC
548 | FPFastMathMode::ALLOW_TRANSFORM
549 }
550 _ => {}
551 }
552 }
553
554 flags
555}