1use cubecl_common::ExecutionMode;
2use cubecl_core::{
3 Metadata, WgpuCompilationOptions,
4 ir::{self as core, InstructionModes},
5 post_processing::{
6 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
7 unroll::UnrollProcessor,
8 },
9 prelude::FastMath,
10};
11use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
12use cubecl_runtime::{
13 EnumSet,
14 config::{GlobalConfig, compilation::CompilationLogLevel},
15};
16use std::{
17 collections::HashSet,
18 fmt::Debug,
19 mem::take,
20 ops::{Deref, DerefMut},
21 rc::Rc,
22};
23
24use cubecl_core::{Compiler, compute::KernelDefinition};
25use rspirv::{
26 dr::{Builder, InsertPoint, Instruction, Module, Operand},
27 spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
28};
29
30use crate::{
31 SpirvKernel,
32 debug::DebugInfo,
33 item::Item,
34 lookups::LookupTables,
35 target::{GLCompute, SpirvTarget},
36 transformers::{BitwiseTransform, ErfTransform},
37};
38
39pub const MAX_VECTORIZATION: u32 = 4;
40
41pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
42 pub target: Target,
43 pub(crate) builder: Builder,
44
45 pub mode: ExecutionMode,
46 pub debug_symbols: bool,
47 global_invocation_id: Word,
48 num_workgroups: Word,
49 pub setup_block: usize,
50 pub opt: Rc<Optimizer>,
51 pub uniformity: Rc<Uniformity>,
52 pub shared_liveness: Rc<SharedLiveness>,
53 pub current_block: Option<NodeIndex>,
54 pub visited: HashSet<NodeIndex>,
55
56 pub capabilities: HashSet<Capability>,
57 pub state: LookupTables,
58 pub ext_meta_pos: Vec<u32>,
59 pub metadata: Metadata,
60 pub debug_info: Option<DebugInfo>,
61 pub compilation_options: WgpuCompilationOptions,
62}
63
64unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
65unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
66
67impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
68 fn clone(&self) -> Self {
69 Self {
70 target: self.target.clone(),
71 builder: Builder::new_from_module(self.module_ref().clone()),
72 mode: self.mode,
73 global_invocation_id: self.global_invocation_id,
74 num_workgroups: self.num_workgroups,
75 setup_block: self.setup_block,
76 opt: self.opt.clone(),
77 uniformity: self.uniformity.clone(),
78 shared_liveness: self.shared_liveness.clone(),
79 current_block: self.current_block,
80 capabilities: self.capabilities.clone(),
81 state: self.state.clone(),
82 debug_symbols: self.debug_symbols,
83 visited: self.visited.clone(),
84 metadata: self.metadata.clone(),
85 debug_info: self.debug_info.clone(),
86 ext_meta_pos: self.ext_meta_pos.clone(),
87 compilation_options: self.compilation_options.clone(),
88 }
89 }
90}
91
92fn debug_symbols_activated() -> bool {
93 matches!(
94 GlobalConfig::get().compilation.logger.level,
95 CompilationLogLevel::Full
96 )
97}
98
99impl<T: SpirvTarget> Default for SpirvCompiler<T> {
100 fn default() -> Self {
101 Self {
102 target: Default::default(),
103 builder: Builder::new(),
104 mode: Default::default(),
105 global_invocation_id: Default::default(),
106 num_workgroups: Default::default(),
107 capabilities: Default::default(),
108 state: Default::default(),
109 setup_block: Default::default(),
110 opt: Default::default(),
111 uniformity: Default::default(),
112 shared_liveness: Default::default(),
113 current_block: Default::default(),
114 debug_symbols: debug_symbols_activated(),
115 visited: Default::default(),
116 metadata: Default::default(),
117 debug_info: Default::default(),
118 ext_meta_pos: Default::default(),
119 compilation_options: Default::default(),
120 }
121 }
122}
123
124impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
125 type Target = Builder;
126
127 fn deref(&self) -> &Self::Target {
128 &self.builder
129 }
130}
131
132impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
133 fn deref_mut(&mut self) -> &mut Self::Target {
134 &mut self.builder
135 }
136}
137
138impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
139 type Representation = SpirvKernel;
140 type CompilationOptions = WgpuCompilationOptions;
141
142 fn compile(
143 &mut self,
144 value: KernelDefinition,
145 compilation_options: &Self::CompilationOptions,
146 mode: ExecutionMode,
147 ) -> Self::Representation {
148 let bindings = value.buffers.clone();
149 let scalars = value
150 .scalars
151 .iter()
152 .map(|s| (self.compile_storage_type(s.ty), s.count))
153 .collect();
154 let mut ext_meta_pos = Vec::new();
155 let mut num_ext = 0;
156
157 let mut all_meta: Vec<_> = value
158 .buffers
159 .iter()
160 .chain(value.tensor_maps.iter())
161 .map(|buf| (buf.id, buf.has_extended_meta))
162 .collect();
163 all_meta.sort_by_key(|(id, _)| *id);
164
165 let num_meta = all_meta.len();
166
167 for (_, has_extended_meta) in all_meta.iter() {
168 ext_meta_pos.push(num_ext);
169 if *has_extended_meta {
170 num_ext += 1;
171 }
172 }
173
174 self.mode = mode;
175 self.metadata = Metadata::new(num_meta as u32, num_ext);
176 self.compilation_options = compilation_options.clone();
177 self.ext_meta_pos = ext_meta_pos;
178
179 let (module, optimizer) = self.compile_kernel(value);
180 SpirvKernel {
181 module,
182 optimizer,
183 bindings,
184 scalars,
185 has_metadata: self.metadata.static_len() > 0,
186 }
187 }
188
189 fn elem_size(&self, elem: core::ElemType) -> usize {
190 elem.size()
191 }
192
193 fn extension(&self) -> &'static str {
194 "spv"
195 }
196}
197
198impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 write!(f, "spirv<{:?}>", self.target)
201 }
202}
203
204impl<Target: SpirvTarget> SpirvCompiler<Target> {
205 pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
206 let options = kernel.options.clone();
207
208 self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
209
210 self.set_version(1, 6);
211
212 let mut target = self.target.clone();
213
214 let mut opt = OptimizerBuilder::default()
215 .with_transformer(ErfTransform)
216 .with_transformer(BitwiseTransform)
217 .with_processor(CheckedIoProcessor::new(self.mode))
218 .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
219 .with_processor(SaturatingArithmeticProcessor::new(true))
220 .optimize(kernel.body.clone(), kernel.cube_dim);
221
222 self.uniformity = opt.analysis::<Uniformity>();
223 self.shared_liveness = opt.analysis::<SharedLiveness>();
224 self.opt = Rc::new(opt);
225
226 self.init_state(kernel.clone());
227 self.init_debug();
228
229 let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
230
231 target.set_kernel_name(options.kernel_name.clone());
232
233 let (main, debug_setup) = self.declare_main(&options.kernel_name);
234
235 let setup = self.id();
236 self.debug_name(setup, "setup");
237
238 let entry = self.opt.entry();
239 let body = self.label(entry);
240 let setup_block = self.setup(setup, debug_setup);
241 self.setup_block = setup_block;
242 self.compile_block(entry);
243
244 let ret = self.opt.ret;
245 self.compile_block(ret);
246
247 if self.selected_block().is_some() {
248 let label = self.label(ret);
249 self.branch(label).unwrap();
250 }
251
252 self.select_block(Some(setup_block)).unwrap();
253 self.branch(body).unwrap();
254
255 self.end_function().unwrap();
256
257 self.declare_shared_memories();
258
259 let builtins = self
260 .state
261 .used_builtins
262 .clone()
263 .into_iter()
264 .map(|(builtin, (id, item))| {
265 let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
266 self.variable(ty, Some(id), StorageClass::Input, None);
267 self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
268 id
269 })
270 .collect::<Vec<_>>();
271
272 target.set_modes(self, main, builtins, cube_dims);
273
274 let module = take(&mut self.builder).module();
275 (module, self.opt.as_ref().clone())
276 }
277
278 fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
279 self.begin_block(Some(label)).unwrap();
280
281 let opt = self.opt.clone();
282 for const_arr in opt.const_arrays() {
283 self.register_const_array(const_arr);
284 }
285
286 debug_setup(self);
287
288 let setup_block = self.selected_block().unwrap();
289 self.select_block(None).unwrap();
290 setup_block
291 }
292
293 #[track_caller]
294 pub fn current_block(&self) -> BasicBlock {
295 self.opt.block(self.current_block.unwrap()).clone()
296 }
297
298 pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
299 if let Some(existing) = self.state.used_builtins.get(&builtin) {
300 existing.0
301 } else {
302 let id = self.id();
303 self.state.used_builtins.insert(builtin, (id, item));
304 id
305 }
306 }
307
308 pub fn compile_block(&mut self, block: NodeIndex) {
309 if self.visited.contains(&block) {
310 return;
311 }
312 self.visited.insert(block);
313 self.current_block = Some(block);
314
315 let label = self.label(block);
316 self.begin_block(Some(label)).unwrap();
317 let block_id = self.selected_block().unwrap();
318
319 self.debug_start_block();
320
321 let operations = self.current_block().ops.borrow().clone();
322 for (_, operation) in operations {
323 self.compile_operation(operation);
324 }
325
326 let control_flow = self.current_block().control_flow.borrow().clone();
327 self.compile_control_flow(control_flow);
328
329 let current = self.selected_block();
330 self.select_block(Some(block_id)).unwrap();
331 let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
332 for phi in phi {
333 let out = self.compile_variable(phi.out);
334 let ty = out.item().id(self);
335 let out_id = self.write_id(&out);
336 let entries: Vec<_> = phi
337 .entries
338 .into_iter()
339 .map(|it| {
340 let label = self.end_label(it.block);
341 let value = self.compile_variable(it.value);
342 let value = self.read(&value);
343 (value, label)
344 })
345 .collect();
346 self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
347 .unwrap();
348 }
349 self.select_block(current).unwrap();
350 }
351
352 pub fn declare_function_variable(&mut self, ty: Word) -> Word {
354 let setup = self.setup_block;
355 let id = self.id();
356 let var = Instruction::new(
357 Op::Variable,
358 Some(ty),
359 Some(id),
360 vec![Operand::StorageClass(StorageClass::Function)],
361 );
362 let current_block = self.selected_block();
363 self.select_block(Some(setup)).unwrap();
364 self.insert_into_block(InsertPoint::Begin, var).unwrap();
365 self.select_block(current_block).unwrap();
366 id
367 }
368
369 fn declare_shared_memories(&mut self) {
370 if self.compilation_options.supports_explicit_smem {
371 self.declare_shared_memories_explicit();
372 } else {
373 self.declare_shared_memories_implicit();
374 }
375 }
376
377 fn declare_shared_memories_explicit(&mut self) {
382 let shared_memories = self.state.shared_memories.clone();
383 if shared_memories.is_empty() {
384 return;
385 }
386
387 self.capabilities
388 .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
389
390 for (index, memory) in shared_memories {
391 let item_size = memory.item.size();
392
393 match item_size {
396 1 => {
397 self.capabilities
398 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
399 }
400 2 => {
401 self.capabilities
402 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
403 }
404 _ => {}
405 }
406
407 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
408 let arr_id = arr_ty.id(self);
409
410 if !self.state.decorated_types.contains(&arr_id) {
411 self.decorate(
412 arr_id,
413 Decoration::ArrayStride,
414 [Operand::LiteralBit32(item_size)],
415 );
416 self.state.decorated_types.insert(arr_id);
417 }
418
419 let block_ty = Item::Struct(vec![arr_ty]);
420 let block_id = block_ty.id(self);
421
422 self.decorate(block_id, Decoration::Block, []);
423 self.member_decorate(
424 block_id,
425 0,
426 Decoration::Offset,
427 [Operand::LiteralBit32(memory.offset)],
428 );
429
430 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
431
432 self.debug_shared(memory.id, index);
433 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
434 self.decorate(memory.id, Decoration::Aliased, []);
435 }
436 }
437
438 fn declare_shared_memories_implicit(&mut self) {
439 let shared_memories = self.state.shared_memories.clone();
440 for (index, memory) in shared_memories {
441 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
442 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
443
444 self.debug_shared(memory.id, index);
445 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
446 }
447 }
448
449 pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
450 if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
451 return;
452 }
453 let mode = convert_math_mode(modes.fp_math_mode);
454 self.capabilities.insert(Capability::FloatControls2);
455 self.decorate(
456 out_id,
457 Decoration::FPFastMathMode,
458 [Operand::FPFastMathMode(mode)],
459 );
460 }
461
462 pub fn is_uniform_block(&self) -> bool {
463 self.uniformity
464 .is_block_uniform(self.current_block.unwrap())
465 }
466}
467
468pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
469 let mut flags = FPFastMathMode::NONE;
470
471 for mode in math_mode.iter() {
472 match mode {
473 FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
474 FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
475 FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
476 FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
477 FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
478 FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
479 FastMath::AllowTransform => {
480 flags |= FPFastMathMode::ALLOW_CONTRACT
481 | FPFastMathMode::ALLOW_REASSOC
482 | FPFastMathMode::ALLOW_TRANSFORM
483 }
484 _ => {}
485 }
486 }
487
488 flags
489}