1use cubecl_common::ExecutionMode;
2use cubecl_core::{
3 Metadata, WgpuCompilationOptions, ir as core,
4 post_processing::{
5 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
6 unroll::UnrollProcessor,
7 },
8 prelude::FastMath,
9};
10use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
11use cubecl_runtime::{
12 EnumSet,
13 config::{GlobalConfig, compilation::CompilationLogLevel},
14};
15use std::{
16 collections::HashSet,
17 fmt::Debug,
18 mem::take,
19 ops::{Deref, DerefMut},
20 rc::Rc,
21};
22
23use cubecl_core::{Compiler, compute::KernelDefinition};
24use rspirv::{
25 dr::{Builder, InsertPoint, Instruction, Module, Operand},
26 spirv::{self, BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
27};
28
29use crate::{
30 SpirvKernel,
31 debug::DebugInfo,
32 item::Item,
33 lookups::LookupTables,
34 target::{GLCompute, SpirvTarget},
35 transformers::{BitwiseTransform, ErfTransform},
36};
37
38pub const MAX_VECTORIZATION: u32 = 4;
39
40pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
41 pub target: Target,
42 pub(crate) builder: Builder,
43
44 pub mode: ExecutionMode,
45 pub debug_symbols: bool,
46 pub fp_math_mode: FPFastMathMode,
47 global_invocation_id: Word,
48 num_workgroups: Word,
49 pub setup_block: usize,
50 pub opt: Rc<Optimizer>,
51 pub uniformity: Rc<Uniformity>,
52 pub shared_liveness: Rc<SharedLiveness>,
53 pub current_block: Option<NodeIndex>,
54 pub visited: HashSet<NodeIndex>,
55
56 pub capabilities: HashSet<Capability>,
57 pub float_controls: bool,
58 pub state: LookupTables,
59 pub ext_meta_pos: Vec<u32>,
60 pub metadata: Metadata,
61 pub debug_info: Option<DebugInfo>,
62 pub compilation_options: WgpuCompilationOptions,
63}
64
65unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
66unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
67
68impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
69 fn clone(&self) -> Self {
70 Self {
71 target: self.target.clone(),
72 builder: Builder::new_from_module(self.module_ref().clone()),
73 mode: self.mode,
74 global_invocation_id: self.global_invocation_id,
75 num_workgroups: self.num_workgroups,
76 setup_block: self.setup_block,
77 opt: self.opt.clone(),
78 uniformity: self.uniformity.clone(),
79 shared_liveness: self.shared_liveness.clone(),
80 current_block: self.current_block,
81
82 capabilities: self.capabilities.clone(),
83 float_controls: self.float_controls,
84 state: self.state.clone(),
85 debug_symbols: self.debug_symbols,
86 fp_math_mode: self.fp_math_mode,
87 visited: self.visited.clone(),
88 metadata: self.metadata.clone(),
89 debug_info: self.debug_info.clone(),
90 ext_meta_pos: self.ext_meta_pos.clone(),
91 compilation_options: self.compilation_options.clone(),
92 }
93 }
94}
95
96fn debug_symbols_activated() -> bool {
97 matches!(
98 GlobalConfig::get().compilation.logger.level,
99 CompilationLogLevel::Full
100 )
101}
102
103impl<T: SpirvTarget> Default for SpirvCompiler<T> {
104 fn default() -> Self {
105 Self {
106 target: Default::default(),
107 builder: Builder::new(),
108 mode: Default::default(),
109 global_invocation_id: Default::default(),
110 num_workgroups: Default::default(),
111 capabilities: Default::default(),
112 float_controls: Default::default(),
113 state: Default::default(),
114 setup_block: Default::default(),
115 opt: Default::default(),
116 uniformity: Default::default(),
117 shared_liveness: Default::default(),
118 current_block: Default::default(),
119 debug_symbols: debug_symbols_activated(),
120 fp_math_mode: FPFastMathMode::NONE,
121 visited: Default::default(),
122 metadata: Default::default(),
123 debug_info: Default::default(),
124 ext_meta_pos: Default::default(),
125 compilation_options: Default::default(),
126 }
127 }
128}
129
130impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
131 type Target = Builder;
132
133 fn deref(&self) -> &Self::Target {
134 &self.builder
135 }
136}
137
138impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
139 fn deref_mut(&mut self) -> &mut Self::Target {
140 &mut self.builder
141 }
142}
143
144impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
145 type Representation = SpirvKernel;
146 type CompilationOptions = WgpuCompilationOptions;
147
148 fn compile(
149 &mut self,
150 value: KernelDefinition,
151 compilation_options: &Self::CompilationOptions,
152 mode: ExecutionMode,
153 ) -> Self::Representation {
154 let bindings = value.buffers.clone();
155 let scalars = value
156 .scalars
157 .iter()
158 .map(|s| (self.compile_storage_type(s.ty), s.count))
159 .collect();
160 let mut ext_meta_pos = Vec::new();
161 let mut num_ext = 0;
162
163 let mut all_meta: Vec<_> = value
164 .buffers
165 .iter()
166 .chain(value.tensor_maps.iter())
167 .map(|buf| (buf.id, buf.has_extended_meta))
168 .collect();
169 all_meta.sort_by_key(|(id, _)| *id);
170
171 let num_meta = all_meta.len();
172
173 for (_, has_extended_meta) in all_meta.iter() {
174 ext_meta_pos.push(num_ext);
175 if *has_extended_meta {
176 num_ext += 1;
177 }
178 }
179
180 self.mode = mode;
181 self.metadata = Metadata::new(num_meta as u32, num_ext);
182 self.compilation_options = compilation_options.clone();
183 self.ext_meta_pos = ext_meta_pos;
184
185 let (module, optimizer) = self.compile_kernel(value);
186 SpirvKernel {
187 module,
188 optimizer,
189 bindings,
190 scalars,
191 has_metadata: self.metadata.static_len() > 0,
192 }
193 }
194
195 fn elem_size(&self, elem: core::ElemType) -> usize {
196 elem.size()
197 }
198
199 fn extension(&self) -> &'static str {
200 "spv"
201 }
202}
203
204impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 write!(f, "spirv<{:?}>", self.target)
207 }
208}
209
210impl<Target: SpirvTarget> SpirvCompiler<Target> {
211 pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
212 let options = kernel.options.clone();
213
214 self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
215 self.fp_math_mode = match self.compilation_options.supports_fp_fast_math {
216 true => convert_math_mode(options.fp_math_mode),
217 false => FPFastMathMode::NONE,
218 };
219 self.float_controls = self.fp_math_mode != FPFastMathMode::NONE;
220
221 if self.float_controls {
222 self.capabilities.insert(Capability::FloatControls2);
223 }
224
225 self.set_version(1, 6);
226
227 let mut target = self.target.clone();
228
229 let mut opt = OptimizerBuilder::default()
230 .with_transformer(ErfTransform)
231 .with_transformer(BitwiseTransform)
232 .with_processor(CheckedIoProcessor::new(self.mode))
233 .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
234 .with_processor(SaturatingArithmeticProcessor::new(true))
235 .optimize(kernel.body.clone(), kernel.cube_dim);
236
237 self.uniformity = opt.analysis::<Uniformity>();
238 self.shared_liveness = opt.analysis::<SharedLiveness>();
239 self.opt = Rc::new(opt);
240
241 self.init_state(kernel.clone());
242 self.init_debug();
243
244 let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
245
246 target.set_kernel_name(options.kernel_name.clone());
247
248 let (main, debug_setup) = self.declare_main(&options.kernel_name);
249
250 let setup = self.id();
251 self.debug_name(setup, "setup");
252
253 let entry = self.opt.entry();
254 let body = self.label(entry);
255 let setup_block = self.setup(setup, debug_setup);
256 self.setup_block = setup_block;
257 self.compile_block(entry);
258
259 let ret = self.opt.ret;
260 self.compile_block(ret);
261
262 if self.selected_block().is_some() {
263 let label = self.label(ret);
264 self.branch(label).unwrap();
265 }
266
267 self.select_block(Some(setup_block)).unwrap();
268 self.branch(body).unwrap();
269
270 self.end_function().unwrap();
271
272 self.declare_shared_memories();
273
274 let builtins = self
275 .state
276 .used_builtins
277 .clone()
278 .into_iter()
279 .map(|(builtin, (id, item))| {
280 let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
281 self.variable(ty, Some(id), StorageClass::Input, None);
282 self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
283 id
284 })
285 .collect::<Vec<_>>();
286
287 target.set_modes(self, main, builtins, cube_dims);
288
289 let module = take(&mut self.builder).module();
290 (module, self.opt.as_ref().clone())
291 }
292
293 fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
294 self.begin_block(Some(label)).unwrap();
295
296 let opt = self.opt.clone();
297 for const_arr in opt.const_arrays() {
298 self.register_const_array(const_arr);
299 }
300
301 debug_setup(self);
302
303 let setup_block = self.selected_block().unwrap();
304 self.select_block(None).unwrap();
305 setup_block
306 }
307
308 #[track_caller]
309 pub fn current_block(&self) -> BasicBlock {
310 self.opt.block(self.current_block.unwrap()).clone()
311 }
312
313 pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
314 if let Some(existing) = self.state.used_builtins.get(&builtin) {
315 existing.0
316 } else {
317 let id = self.id();
318 self.state.used_builtins.insert(builtin, (id, item));
319 id
320 }
321 }
322
323 pub fn compile_block(&mut self, block: NodeIndex) {
324 if self.visited.contains(&block) {
325 return;
326 }
327 self.visited.insert(block);
328 self.current_block = Some(block);
329
330 let label = self.label(block);
331 self.begin_block(Some(label)).unwrap();
332 let block_id = self.selected_block().unwrap();
333
334 self.debug_start_block();
335
336 let operations = self.current_block().ops.borrow().clone();
337 for (_, operation) in operations {
338 self.compile_operation(operation);
339 }
340
341 let control_flow = self.current_block().control_flow.borrow().clone();
342 self.compile_control_flow(control_flow);
343
344 let current = self.selected_block();
345 self.select_block(Some(block_id)).unwrap();
346 let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
347 for phi in phi {
348 let out = self.compile_variable(phi.out);
349 let ty = out.item().id(self);
350 let out_id = self.write_id(&out);
351 let entries: Vec<_> = phi
352 .entries
353 .into_iter()
354 .map(|it| {
355 let label = self.end_label(it.block);
356 let value = self.compile_variable(it.value);
357 let value = self.read(&value);
358 (value, label)
359 })
360 .collect();
361 self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
362 .unwrap();
363 }
364 self.select_block(current).unwrap();
365 }
366
367 pub fn declare_function_variable(&mut self, ty: Word) -> Word {
369 let setup = self.setup_block;
370 let id = self.id();
371 let var = Instruction::new(
372 Op::Variable,
373 Some(ty),
374 Some(id),
375 vec![Operand::StorageClass(StorageClass::Function)],
376 );
377 let current_block = self.selected_block();
378 self.select_block(Some(setup)).unwrap();
379 self.insert_into_block(InsertPoint::Begin, var).unwrap();
380 self.select_block(current_block).unwrap();
381 id
382 }
383
384 fn declare_shared_memories(&mut self) {
385 if self.compilation_options.supports_explicit_smem {
386 self.declare_shared_memories_explicit();
387 } else {
388 self.declare_shared_memories_implicit();
389 }
390 }
391
392 fn declare_shared_memories_explicit(&mut self) {
397 let shared_memories = self.state.shared_memories.clone();
398 if shared_memories.is_empty() {
399 return;
400 }
401
402 self.capabilities
403 .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
404
405 for (index, memory) in shared_memories {
406 let item_size = memory.item.size();
407
408 match item_size {
411 1 => {
412 self.capabilities
413 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
414 }
415 2 => {
416 self.capabilities
417 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
418 }
419 _ => {}
420 }
421
422 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
423 let arr_id = arr_ty.id(self);
424
425 if !self.state.decorated_types.contains(&arr_id) {
426 self.decorate(
427 arr_id,
428 Decoration::ArrayStride,
429 [Operand::LiteralBit32(item_size)],
430 );
431 self.state.decorated_types.insert(arr_id);
432 }
433
434 let block_ty = Item::Struct(vec![arr_ty]);
435 let block_id = block_ty.id(self);
436
437 self.decorate(block_id, Decoration::Block, []);
438 self.member_decorate(
439 block_id,
440 0,
441 Decoration::Offset,
442 [Operand::LiteralBit32(memory.offset)],
443 );
444
445 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
446
447 self.debug_shared(memory.id, index);
448 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
449 self.decorate(memory.id, Decoration::Aliased, []);
450 }
451 }
452
453 fn declare_shared_memories_implicit(&mut self) {
454 let shared_memories = self.state.shared_memories.clone();
455 for (index, memory) in shared_memories {
456 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
457 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
458
459 self.debug_shared(memory.id, index);
460 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
461 }
462 }
463
464 pub fn declare_float_execution_modes(&mut self, main: Word) {
465 let mode = self.const_u32(self.fp_math_mode.bits());
466
467 let types = self.builder.module_ref().types_global_values.clone();
468 let scalars = types
469 .iter()
470 .filter(|inst| inst.class.opcode == Op::TypeFloat)
471 .map(|it| it.result_id.expect("OpTypeFloat always has result ID"))
472 .collect::<Vec<_>>();
473 for ty in scalars {
474 self.execution_mode(main, spirv::ExecutionMode::FPFastMathDefault, [ty, mode]);
475 }
476 }
477
478 pub fn is_uniform_block(&self) -> bool {
479 self.uniformity
480 .is_block_uniform(self.current_block.unwrap())
481 }
482}
483
484fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
485 let mut flags = FPFastMathMode::NONE;
486
487 for mode in math_mode.iter() {
488 match mode {
489 FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
490 FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
491 FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
492 FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
493 FastMath::AllowContraction => flags |= FPFastMathMode::from_bits_retain(0x10000),
494 FastMath::AllowReassociation => flags |= FPFastMathMode::from_bits_retain(0x20000),
495 FastMath::AllowTransform => {
496 flags |= FPFastMathMode::from_bits_retain(0x10000)
497 | FPFastMathMode::from_bits_retain(0x20000)
498 | FPFastMathMode::from_bits_retain(0x40000)
499 }
500 _ => {}
501 }
502 }
503
504 flags
505}