cubecl_spirv/
compiler.rs

1use cubecl_common::backtrace::BackTrace;
2use cubecl_core::{
3    Metadata, WgpuCompilationOptions,
4    ir::{self as core, InstructionModes},
5    post_processing::{
6        checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
7        unroll::UnrollProcessor,
8    },
9    prelude::{FastMath, KernelDefinition},
10    server::ExecutionMode,
11};
12use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
13use cubecl_runtime::{
14    EnumSet,
15    compiler::CompilationError,
16    config::{GlobalConfig, compilation::CompilationLogLevel},
17};
18use std::{
19    collections::HashSet,
20    fmt::Debug,
21    mem::take,
22    ops::{Deref, DerefMut},
23    rc::Rc,
24};
25
26use cubecl_core::Compiler;
27use rspirv::{
28    dr::{Builder, InsertPoint, Instruction, Module, Operand},
29    spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
30};
31
32use crate::{
33    SpirvKernel,
34    debug::DebugInfo,
35    item::Item,
36    lookups::LookupTables,
37    target::{GLCompute, SpirvTarget},
38    transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform},
39};
40
41pub const MAX_VECTORIZATION: u32 = 4;
42
43pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
44    pub target: Target,
45    pub(crate) builder: Builder,
46
47    pub mode: ExecutionMode,
48    pub debug_symbols: bool,
49    global_invocation_id: Word,
50    num_workgroups: Word,
51    pub setup_block: usize,
52    pub opt: Rc<Optimizer>,
53    pub uniformity: Rc<Uniformity>,
54    pub shared_liveness: Rc<SharedLiveness>,
55    pub current_block: Option<NodeIndex>,
56    pub visited: HashSet<NodeIndex>,
57
58    pub capabilities: HashSet<Capability>,
59    pub state: LookupTables,
60    pub ext_meta_pos: Vec<u32>,
61    pub metadata: Metadata,
62    pub debug_info: Option<DebugInfo>,
63    pub compilation_options: WgpuCompilationOptions,
64}
65
66unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
67unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
68
69impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
70    fn clone(&self) -> Self {
71        Self {
72            target: self.target.clone(),
73            builder: Builder::new_from_module(self.module_ref().clone()),
74            mode: self.mode,
75            global_invocation_id: self.global_invocation_id,
76            num_workgroups: self.num_workgroups,
77            setup_block: self.setup_block,
78            opt: self.opt.clone(),
79            uniformity: self.uniformity.clone(),
80            shared_liveness: self.shared_liveness.clone(),
81            current_block: self.current_block,
82            capabilities: self.capabilities.clone(),
83            state: self.state.clone(),
84            debug_symbols: self.debug_symbols,
85            visited: self.visited.clone(),
86            metadata: self.metadata.clone(),
87            debug_info: self.debug_info.clone(),
88            ext_meta_pos: self.ext_meta_pos.clone(),
89            compilation_options: self.compilation_options.clone(),
90        }
91    }
92}
93
94fn debug_symbols_activated() -> bool {
95    matches!(
96        GlobalConfig::get().compilation.logger.level,
97        CompilationLogLevel::Full
98    )
99}
100
101impl<T: SpirvTarget> Default for SpirvCompiler<T> {
102    fn default() -> Self {
103        Self {
104            target: Default::default(),
105            builder: Builder::new(),
106            mode: Default::default(),
107            global_invocation_id: Default::default(),
108            num_workgroups: Default::default(),
109            capabilities: Default::default(),
110            state: Default::default(),
111            setup_block: Default::default(),
112            opt: Default::default(),
113            uniformity: Default::default(),
114            shared_liveness: Default::default(),
115            current_block: Default::default(),
116            debug_symbols: debug_symbols_activated(),
117            visited: Default::default(),
118            metadata: Default::default(),
119            debug_info: Default::default(),
120            ext_meta_pos: Default::default(),
121            compilation_options: Default::default(),
122        }
123    }
124}
125
126impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
127    type Target = Builder;
128
129    fn deref(&self) -> &Self::Target {
130        &self.builder
131    }
132}
133
134impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
135    fn deref_mut(&mut self) -> &mut Self::Target {
136        &mut self.builder
137    }
138}
139
140impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
141    type Representation = SpirvKernel;
142    type CompilationOptions = WgpuCompilationOptions;
143
144    fn compile(
145        &mut self,
146        mut value: KernelDefinition,
147        compilation_options: &Self::CompilationOptions,
148        mode: ExecutionMode,
149    ) -> Result<Self::Representation, CompilationError> {
150        let errors = value.body.pop_errors();
151        if !errors.is_empty() {
152            let mut reason = "Can't compile spirv kernel".to_string();
153            for error in errors {
154                reason += error.as_str();
155                reason += "\n";
156            }
157
158            return Err(CompilationError::Validation {
159                reason,
160                backtrace: BackTrace::capture(),
161            });
162        }
163
164        let bindings = value.buffers.clone();
165        let scalars = value
166            .scalars
167            .iter()
168            .map(|s| (self.compile_storage_type(s.ty), s.count))
169            .collect();
170        let mut ext_meta_pos = Vec::new();
171        let mut num_ext = 0;
172
173        let mut all_meta: Vec<_> = value
174            .buffers
175            .iter()
176            .chain(value.tensor_maps.iter())
177            .map(|buf| (buf.id, buf.has_extended_meta))
178            .collect();
179        all_meta.sort_by_key(|(id, _)| *id);
180
181        let num_meta = all_meta.len();
182
183        for (_, has_extended_meta) in all_meta.iter() {
184            ext_meta_pos.push(num_ext);
185            if *has_extended_meta {
186                num_ext += 1;
187            }
188        }
189
190        self.mode = mode;
191        self.metadata = Metadata::new(num_meta as u32, num_ext);
192        self.compilation_options = compilation_options.clone();
193        self.ext_meta_pos = ext_meta_pos;
194
195        let (module, optimizer) = self.compile_kernel(value);
196        Ok(SpirvKernel {
197            module,
198            optimizer,
199            bindings,
200            scalars,
201            has_metadata: self.metadata.static_len() > 0,
202        })
203    }
204
205    fn elem_size(&self, elem: core::ElemType) -> usize {
206        elem.size()
207    }
208
209    fn extension(&self) -> &'static str {
210        "spv"
211    }
212}
213
214impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
215    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
216        write!(f, "spirv<{:?}>", self.target)
217    }
218}
219
220impl<Target: SpirvTarget> SpirvCompiler<Target> {
221    pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
222        let options = kernel.options.clone();
223
224        self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
225
226        self.set_version(1, 6);
227
228        let mut target = self.target.clone();
229
230        let mut opt = OptimizerBuilder::default()
231            .with_transformer(ErfTransform)
232            .with_transformer(BitwiseTransform)
233            .with_transformer(HypotTransform)
234            .with_transformer(RhypotTransform)
235            .with_processor(CheckedIoProcessor::new(self.mode))
236            .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
237            .with_processor(SaturatingArithmeticProcessor::new(true))
238            .optimize(kernel.body.clone(), kernel.cube_dim);
239
240        self.uniformity = opt.analysis::<Uniformity>();
241        self.shared_liveness = opt.analysis::<SharedLiveness>();
242        self.opt = Rc::new(opt);
243
244        self.init_state(kernel.clone());
245        self.init_debug();
246
247        let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
248
249        target.set_kernel_name(options.kernel_name.clone());
250
251        let (main, debug_setup) = self.declare_main(&options.kernel_name);
252
253        let setup = self.id();
254        self.debug_name(setup, "setup");
255
256        let entry = self.opt.entry();
257        let body = self.label(entry);
258        let setup_block = self.setup(setup, debug_setup);
259        self.setup_block = setup_block;
260        self.compile_block(entry);
261
262        let ret = self.opt.ret;
263        self.compile_block(ret);
264
265        if self.selected_block().is_some() {
266            let label = self.label(ret);
267            self.branch(label).unwrap();
268        }
269
270        self.select_block(Some(setup_block)).unwrap();
271        self.branch(body).unwrap();
272
273        self.end_function().unwrap();
274
275        self.declare_shared_memories();
276
277        let builtins = self
278            .state
279            .used_builtins
280            .clone()
281            .into_iter()
282            .map(|(builtin, (id, item))| {
283                let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
284                self.variable(ty, Some(id), StorageClass::Input, None);
285                self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
286                id
287            })
288            .collect::<Vec<_>>();
289
290        target.set_modes(self, main, builtins, cube_dims);
291
292        let module = take(&mut self.builder).module();
293        (module, self.opt.as_ref().clone())
294    }
295
296    fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
297        self.begin_block(Some(label)).unwrap();
298
299        let opt = self.opt.clone();
300        for const_arr in opt.const_arrays() {
301            self.register_const_array(const_arr);
302        }
303
304        debug_setup(self);
305
306        let setup_block = self.selected_block().unwrap();
307        self.select_block(None).unwrap();
308        setup_block
309    }
310
311    #[track_caller]
312    pub fn current_block(&self) -> BasicBlock {
313        self.opt.block(self.current_block.unwrap()).clone()
314    }
315
316    pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
317        if let Some(existing) = self.state.used_builtins.get(&builtin) {
318            existing.0
319        } else {
320            let id = self.id();
321            self.state.used_builtins.insert(builtin, (id, item));
322            id
323        }
324    }
325
326    pub fn compile_block(&mut self, block: NodeIndex) {
327        if self.visited.contains(&block) {
328            return;
329        }
330        self.visited.insert(block);
331        self.current_block = Some(block);
332
333        let label = self.label(block);
334        self.begin_block(Some(label)).unwrap();
335        let block_id = self.selected_block().unwrap();
336
337        self.debug_start_block();
338
339        let operations = self.current_block().ops.borrow().clone();
340        for (_, operation) in operations {
341            self.compile_operation(operation);
342        }
343
344        let control_flow = self.current_block().control_flow.borrow().clone();
345        self.compile_control_flow(control_flow);
346
347        let current = self.selected_block();
348        self.select_block(Some(block_id)).unwrap();
349        let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
350        for phi in phi {
351            let out = self.compile_variable(phi.out);
352            let ty = out.item().id(self);
353            let out_id = self.write_id(&out);
354            let entries: Vec<_> = phi
355                .entries
356                .into_iter()
357                .map(|it| {
358                    let label = self.end_label(it.block);
359                    let value = self.compile_variable(it.value);
360                    let value = self.read(&value);
361                    (value, label)
362                })
363                .collect();
364            self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
365                .unwrap();
366        }
367        self.select_block(current).unwrap();
368    }
369
370    // Declare variable in the first block of the function
371    pub fn declare_function_variable(&mut self, ty: Word) -> Word {
372        let setup = self.setup_block;
373        let id = self.id();
374        let var = Instruction::new(
375            Op::Variable,
376            Some(ty),
377            Some(id),
378            vec![Operand::StorageClass(StorageClass::Function)],
379        );
380        let current_block = self.selected_block();
381        self.select_block(Some(setup)).unwrap();
382        self.insert_into_block(InsertPoint::Begin, var).unwrap();
383        self.select_block(current_block).unwrap();
384        id
385    }
386
387    fn declare_shared_memories(&mut self) {
388        if self.compilation_options.supports_explicit_smem {
389            self.declare_shared_memories_explicit();
390        } else {
391            self.declare_shared_memories_implicit();
392        }
393    }
394
395    /// When using `VK_KHR_workgroup_memory_explicit_layout`, all shared memory is declared as a
396    /// `Block`. This means they are all pointers into the same chunk of memory, with different
397    /// offsets and sizes. Unlike C++, this shared block is declared implicitly, not explicitly.
398    /// Alignment and total size is calculated by the driver.
399    fn declare_shared_memories_explicit(&mut self) {
400        let shared_arrays = self.state.shared_arrays.clone();
401        let shared = self.state.shared.clone();
402        if shared_arrays.is_empty() && shared.is_empty() {
403            return;
404        }
405
406        self.capabilities
407            .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
408
409        for (index, memory) in shared_arrays {
410            let item_size = memory.item.size();
411
412            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
413            // explicit layout as well.
414            match item_size {
415                1 => {
416                    self.capabilities
417                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
418                }
419                2 => {
420                    self.capabilities
421                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
422                }
423                _ => {}
424            }
425
426            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
427            let arr_id = arr_ty.id(self);
428
429            if !self.state.decorated_types.contains(&arr_id) {
430                self.decorate(
431                    arr_id,
432                    Decoration::ArrayStride,
433                    [Operand::LiteralBit32(item_size)],
434                );
435                self.state.decorated_types.insert(arr_id);
436            }
437
438            let block_ty = Item::Struct(vec![arr_ty]);
439            let block_id = block_ty.id(self);
440
441            self.decorate(block_id, Decoration::Block, []);
442            self.member_decorate(
443                block_id,
444                0,
445                Decoration::Offset,
446                [Operand::LiteralBit32(memory.offset)],
447            );
448
449            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
450
451            self.debug_shared(memory.id, index);
452            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
453            self.decorate(memory.id, Decoration::Aliased, []);
454        }
455
456        for (index, memory) in shared {
457            let item_size = memory.item.size();
458
459            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
460            // explicit layout as well.
461            match item_size {
462                1 => {
463                    self.capabilities
464                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
465                }
466                2 => {
467                    self.capabilities
468                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
469                }
470                _ => {}
471            }
472
473            let block_ty = Item::Struct(vec![memory.item]);
474            let block_id = block_ty.id(self);
475
476            self.decorate(block_id, Decoration::Block, []);
477            self.member_decorate(
478                block_id,
479                0,
480                Decoration::Offset,
481                [Operand::LiteralBit32(memory.offset)],
482            );
483
484            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
485
486            self.debug_shared(memory.id, index);
487            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
488            self.decorate(memory.id, Decoration::Aliased, []);
489        }
490    }
491
492    fn declare_shared_memories_implicit(&mut self) {
493        let shared_memories = self.state.shared_arrays.clone();
494        for (index, memory) in shared_memories {
495            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
496            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
497
498            self.debug_shared(memory.id, index);
499            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
500        }
501        let shared = self.state.shared.clone();
502        for (index, memory) in shared {
503            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
504
505            self.debug_shared(memory.id, index);
506            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
507        }
508    }
509
510    pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
511        if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
512            return;
513        }
514        let mode = convert_math_mode(modes.fp_math_mode);
515        self.capabilities.insert(Capability::FloatControls2);
516        self.decorate(
517            out_id,
518            Decoration::FPFastMathMode,
519            [Operand::FPFastMathMode(mode)],
520        );
521    }
522
523    pub fn is_uniform_block(&self) -> bool {
524        self.uniformity
525            .is_block_uniform(self.current_block.unwrap())
526    }
527}
528
529pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
530    let mut flags = FPFastMathMode::NONE;
531
532    for mode in math_mode.iter() {
533        match mode {
534            FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
535            FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
536            FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
537            FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
538            FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
539            FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
540            FastMath::AllowTransform => {
541                flags |= FPFastMathMode::ALLOW_CONTRACT
542                    | FPFastMathMode::ALLOW_REASSOC
543                    | FPFastMathMode::ALLOW_TRANSFORM
544            }
545            _ => {}
546        }
547    }
548
549    flags
550}