1use cubecl_common::ExecutionMode;
2use cubecl_core::{
3 Metadata, WgpuCompilationOptions,
4 ir::{self as core, InstructionModes},
5 post_processing::{
6 checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
7 unroll::UnrollProcessor,
8 },
9 prelude::{FastMath, KernelDefinition},
10};
11use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
12use cubecl_runtime::{
13 EnumSet,
14 compiler::CompilationError,
15 config::{GlobalConfig, compilation::CompilationLogLevel},
16};
17use std::{
18 collections::HashSet,
19 fmt::Debug,
20 mem::take,
21 ops::{Deref, DerefMut},
22 rc::Rc,
23};
24
25use cubecl_core::Compiler;
26use rspirv::{
27 dr::{Builder, InsertPoint, Instruction, Module, Operand},
28 spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
29};
30
31use crate::{
32 SpirvKernel,
33 debug::DebugInfo,
34 item::Item,
35 lookups::LookupTables,
36 target::{GLCompute, SpirvTarget},
37 transformers::{BitwiseTransform, ErfTransform},
38};
39
40pub const MAX_VECTORIZATION: u32 = 4;
41
42pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
43 pub target: Target,
44 pub(crate) builder: Builder,
45
46 pub mode: ExecutionMode,
47 pub debug_symbols: bool,
48 global_invocation_id: Word,
49 num_workgroups: Word,
50 pub setup_block: usize,
51 pub opt: Rc<Optimizer>,
52 pub uniformity: Rc<Uniformity>,
53 pub shared_liveness: Rc<SharedLiveness>,
54 pub current_block: Option<NodeIndex>,
55 pub visited: HashSet<NodeIndex>,
56
57 pub capabilities: HashSet<Capability>,
58 pub state: LookupTables,
59 pub ext_meta_pos: Vec<u32>,
60 pub metadata: Metadata,
61 pub debug_info: Option<DebugInfo>,
62 pub compilation_options: WgpuCompilationOptions,
63}
64
65unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
66unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
67
68impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
69 fn clone(&self) -> Self {
70 Self {
71 target: self.target.clone(),
72 builder: Builder::new_from_module(self.module_ref().clone()),
73 mode: self.mode,
74 global_invocation_id: self.global_invocation_id,
75 num_workgroups: self.num_workgroups,
76 setup_block: self.setup_block,
77 opt: self.opt.clone(),
78 uniformity: self.uniformity.clone(),
79 shared_liveness: self.shared_liveness.clone(),
80 current_block: self.current_block,
81 capabilities: self.capabilities.clone(),
82 state: self.state.clone(),
83 debug_symbols: self.debug_symbols,
84 visited: self.visited.clone(),
85 metadata: self.metadata.clone(),
86 debug_info: self.debug_info.clone(),
87 ext_meta_pos: self.ext_meta_pos.clone(),
88 compilation_options: self.compilation_options.clone(),
89 }
90 }
91}
92
93fn debug_symbols_activated() -> bool {
94 matches!(
95 GlobalConfig::get().compilation.logger.level,
96 CompilationLogLevel::Full
97 )
98}
99
100impl<T: SpirvTarget> Default for SpirvCompiler<T> {
101 fn default() -> Self {
102 Self {
103 target: Default::default(),
104 builder: Builder::new(),
105 mode: Default::default(),
106 global_invocation_id: Default::default(),
107 num_workgroups: Default::default(),
108 capabilities: Default::default(),
109 state: Default::default(),
110 setup_block: Default::default(),
111 opt: Default::default(),
112 uniformity: Default::default(),
113 shared_liveness: Default::default(),
114 current_block: Default::default(),
115 debug_symbols: debug_symbols_activated(),
116 visited: Default::default(),
117 metadata: Default::default(),
118 debug_info: Default::default(),
119 ext_meta_pos: Default::default(),
120 compilation_options: Default::default(),
121 }
122 }
123}
124
125impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
126 type Target = Builder;
127
128 fn deref(&self) -> &Self::Target {
129 &self.builder
130 }
131}
132
133impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
134 fn deref_mut(&mut self) -> &mut Self::Target {
135 &mut self.builder
136 }
137}
138
139impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
140 type Representation = SpirvKernel;
141 type CompilationOptions = WgpuCompilationOptions;
142
143 fn compile(
144 &mut self,
145 value: KernelDefinition,
146 compilation_options: &Self::CompilationOptions,
147 mode: ExecutionMode,
148 ) -> Result<Self::Representation, CompilationError> {
149 let bindings = value.buffers.clone();
150 let scalars = value
151 .scalars
152 .iter()
153 .map(|s| (self.compile_storage_type(s.ty), s.count))
154 .collect();
155 let mut ext_meta_pos = Vec::new();
156 let mut num_ext = 0;
157
158 let mut all_meta: Vec<_> = value
159 .buffers
160 .iter()
161 .chain(value.tensor_maps.iter())
162 .map(|buf| (buf.id, buf.has_extended_meta))
163 .collect();
164 all_meta.sort_by_key(|(id, _)| *id);
165
166 let num_meta = all_meta.len();
167
168 for (_, has_extended_meta) in all_meta.iter() {
169 ext_meta_pos.push(num_ext);
170 if *has_extended_meta {
171 num_ext += 1;
172 }
173 }
174
175 self.mode = mode;
176 self.metadata = Metadata::new(num_meta as u32, num_ext);
177 self.compilation_options = compilation_options.clone();
178 self.ext_meta_pos = ext_meta_pos;
179
180 let (module, optimizer) = self.compile_kernel(value);
181 Ok(SpirvKernel {
182 module,
183 optimizer,
184 bindings,
185 scalars,
186 has_metadata: self.metadata.static_len() > 0,
187 })
188 }
189
190 fn elem_size(&self, elem: core::ElemType) -> usize {
191 elem.size()
192 }
193
194 fn extension(&self) -> &'static str {
195 "spv"
196 }
197}
198
199impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 write!(f, "spirv<{:?}>", self.target)
202 }
203}
204
205impl<Target: SpirvTarget> SpirvCompiler<Target> {
206 pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
207 let options = kernel.options.clone();
208
209 self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
210
211 self.set_version(1, 6);
212
213 let mut target = self.target.clone();
214
215 let mut opt = OptimizerBuilder::default()
216 .with_transformer(ErfTransform)
217 .with_transformer(BitwiseTransform)
218 .with_processor(CheckedIoProcessor::new(self.mode))
219 .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
220 .with_processor(SaturatingArithmeticProcessor::new(true))
221 .optimize(kernel.body.clone(), kernel.cube_dim);
222
223 self.uniformity = opt.analysis::<Uniformity>();
224 self.shared_liveness = opt.analysis::<SharedLiveness>();
225 self.opt = Rc::new(opt);
226
227 self.init_state(kernel.clone());
228 self.init_debug();
229
230 let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
231
232 target.set_kernel_name(options.kernel_name.clone());
233
234 let (main, debug_setup) = self.declare_main(&options.kernel_name);
235
236 let setup = self.id();
237 self.debug_name(setup, "setup");
238
239 let entry = self.opt.entry();
240 let body = self.label(entry);
241 let setup_block = self.setup(setup, debug_setup);
242 self.setup_block = setup_block;
243 self.compile_block(entry);
244
245 let ret = self.opt.ret;
246 self.compile_block(ret);
247
248 if self.selected_block().is_some() {
249 let label = self.label(ret);
250 self.branch(label).unwrap();
251 }
252
253 self.select_block(Some(setup_block)).unwrap();
254 self.branch(body).unwrap();
255
256 self.end_function().unwrap();
257
258 self.declare_shared_memories();
259
260 let builtins = self
261 .state
262 .used_builtins
263 .clone()
264 .into_iter()
265 .map(|(builtin, (id, item))| {
266 let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
267 self.variable(ty, Some(id), StorageClass::Input, None);
268 self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
269 id
270 })
271 .collect::<Vec<_>>();
272
273 target.set_modes(self, main, builtins, cube_dims);
274
275 let module = take(&mut self.builder).module();
276 (module, self.opt.as_ref().clone())
277 }
278
279 fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
280 self.begin_block(Some(label)).unwrap();
281
282 let opt = self.opt.clone();
283 for const_arr in opt.const_arrays() {
284 self.register_const_array(const_arr);
285 }
286
287 debug_setup(self);
288
289 let setup_block = self.selected_block().unwrap();
290 self.select_block(None).unwrap();
291 setup_block
292 }
293
294 #[track_caller]
295 pub fn current_block(&self) -> BasicBlock {
296 self.opt.block(self.current_block.unwrap()).clone()
297 }
298
299 pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
300 if let Some(existing) = self.state.used_builtins.get(&builtin) {
301 existing.0
302 } else {
303 let id = self.id();
304 self.state.used_builtins.insert(builtin, (id, item));
305 id
306 }
307 }
308
309 pub fn compile_block(&mut self, block: NodeIndex) {
310 if self.visited.contains(&block) {
311 return;
312 }
313 self.visited.insert(block);
314 self.current_block = Some(block);
315
316 let label = self.label(block);
317 self.begin_block(Some(label)).unwrap();
318 let block_id = self.selected_block().unwrap();
319
320 self.debug_start_block();
321
322 let operations = self.current_block().ops.borrow().clone();
323 for (_, operation) in operations {
324 self.compile_operation(operation);
325 }
326
327 let control_flow = self.current_block().control_flow.borrow().clone();
328 self.compile_control_flow(control_flow);
329
330 let current = self.selected_block();
331 self.select_block(Some(block_id)).unwrap();
332 let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
333 for phi in phi {
334 let out = self.compile_variable(phi.out);
335 let ty = out.item().id(self);
336 let out_id = self.write_id(&out);
337 let entries: Vec<_> = phi
338 .entries
339 .into_iter()
340 .map(|it| {
341 let label = self.end_label(it.block);
342 let value = self.compile_variable(it.value);
343 let value = self.read(&value);
344 (value, label)
345 })
346 .collect();
347 self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
348 .unwrap();
349 }
350 self.select_block(current).unwrap();
351 }
352
353 pub fn declare_function_variable(&mut self, ty: Word) -> Word {
355 let setup = self.setup_block;
356 let id = self.id();
357 let var = Instruction::new(
358 Op::Variable,
359 Some(ty),
360 Some(id),
361 vec![Operand::StorageClass(StorageClass::Function)],
362 );
363 let current_block = self.selected_block();
364 self.select_block(Some(setup)).unwrap();
365 self.insert_into_block(InsertPoint::Begin, var).unwrap();
366 self.select_block(current_block).unwrap();
367 id
368 }
369
370 fn declare_shared_memories(&mut self) {
371 if self.compilation_options.supports_explicit_smem {
372 self.declare_shared_memories_explicit();
373 } else {
374 self.declare_shared_memories_implicit();
375 }
376 }
377
378 fn declare_shared_memories_explicit(&mut self) {
383 let shared_memories = self.state.shared_memories.clone();
384 if shared_memories.is_empty() {
385 return;
386 }
387
388 self.capabilities
389 .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
390
391 for (index, memory) in shared_memories {
392 let item_size = memory.item.size();
393
394 match item_size {
397 1 => {
398 self.capabilities
399 .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
400 }
401 2 => {
402 self.capabilities
403 .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
404 }
405 _ => {}
406 }
407
408 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
409 let arr_id = arr_ty.id(self);
410
411 if !self.state.decorated_types.contains(&arr_id) {
412 self.decorate(
413 arr_id,
414 Decoration::ArrayStride,
415 [Operand::LiteralBit32(item_size)],
416 );
417 self.state.decorated_types.insert(arr_id);
418 }
419
420 let block_ty = Item::Struct(vec![arr_ty]);
421 let block_id = block_ty.id(self);
422
423 self.decorate(block_id, Decoration::Block, []);
424 self.member_decorate(
425 block_id,
426 0,
427 Decoration::Offset,
428 [Operand::LiteralBit32(memory.offset)],
429 );
430
431 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
432
433 self.debug_shared(memory.id, index);
434 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
435 self.decorate(memory.id, Decoration::Aliased, []);
436 }
437 }
438
439 fn declare_shared_memories_implicit(&mut self) {
440 let shared_memories = self.state.shared_memories.clone();
441 for (index, memory) in shared_memories {
442 let arr_ty = Item::Array(Box::new(memory.item), memory.len);
443 let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
444
445 self.debug_shared(memory.id, index);
446 self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
447 }
448 }
449
450 pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
451 if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
452 return;
453 }
454 let mode = convert_math_mode(modes.fp_math_mode);
455 self.capabilities.insert(Capability::FloatControls2);
456 self.decorate(
457 out_id,
458 Decoration::FPFastMathMode,
459 [Operand::FPFastMathMode(mode)],
460 );
461 }
462
463 pub fn is_uniform_block(&self) -> bool {
464 self.uniformity
465 .is_block_uniform(self.current_block.unwrap())
466 }
467}
468
469pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
470 let mut flags = FPFastMathMode::NONE;
471
472 for mode in math_mode.iter() {
473 match mode {
474 FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
475 FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
476 FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
477 FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
478 FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
479 FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
480 FastMath::AllowTransform => {
481 flags |= FPFastMathMode::ALLOW_CONTRACT
482 | FPFastMathMode::ALLOW_REASSOC
483 | FPFastMathMode::ALLOW_TRANSFORM
484 }
485 _ => {}
486 }
487 }
488
489 flags
490}