Skip to main content

cubecl_spirv/
compiler.rs

1use crate::{
2    SpirvKernel,
3    debug::DebugInfo,
4    item::Item,
5    lookups::LookupTables,
6    target::{GLCompute, SpirvTarget},
7    transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform},
8};
9use cubecl_common::backtrace::BackTrace;
10use cubecl_core::{
11    Compiler, CubeDim, Info, Metadata, WgpuCompilationOptions,
12    ir::{self as core, ElemType, InstructionModes, StorageType, UIntKind, features::EnumSet},
13    post_processing::{
14        checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
15        unroll::UnrollProcessor,
16    },
17    prelude::{FastMath, KernelDefinition},
18    server::ExecutionMode,
19};
20use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
21use cubecl_runtime::{
22    compiler::CompilationError,
23    config::{GlobalConfig, compilation::CompilationLogLevel},
24};
25use rspirv::{
26    binary::Assemble,
27    dr::{Builder, InsertPoint, Instruction, Module, Operand},
28    spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
29};
30use std::{
31    collections::HashSet,
32    fmt::Debug,
33    mem::take,
34    ops::{Deref, DerefMut},
35    rc::Rc,
36    sync::Arc,
37};
38
39pub const MAX_VECTORIZATION: usize = 4;
40
41pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
42    pub target: Target,
43    pub(crate) builder: Builder,
44
45    pub cube_dim: CubeDim,
46    pub mode: ExecutionMode,
47    pub addr_type: StorageType,
48    pub debug_symbols: bool,
49    global_invocation_id: Word,
50    num_workgroups: Word,
51    pub setup_block: usize,
52    pub opt: Rc<Optimizer>,
53    pub uniformity: Rc<Uniformity>,
54    pub shared_liveness: Rc<SharedLiveness>,
55    pub current_block: Option<NodeIndex>,
56    pub visited: HashSet<NodeIndex>,
57
58    pub capabilities: HashSet<Capability>,
59    pub state: LookupTables,
60    pub ext_meta_pos: Vec<u32>,
61    pub info: Info,
62    pub debug_info: Option<DebugInfo>,
63    pub compilation_options: WgpuCompilationOptions,
64}
65
66unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
67unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
68
69impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
70    fn clone(&self) -> Self {
71        Self {
72            target: self.target.clone(),
73            builder: Builder::new_from_module(self.module_ref().clone()),
74            cube_dim: self.cube_dim,
75            mode: self.mode,
76            addr_type: self.addr_type,
77            global_invocation_id: self.global_invocation_id,
78            num_workgroups: self.num_workgroups,
79            setup_block: self.setup_block,
80            opt: self.opt.clone(),
81            uniformity: self.uniformity.clone(),
82            shared_liveness: self.shared_liveness.clone(),
83            current_block: self.current_block,
84            capabilities: self.capabilities.clone(),
85            state: self.state.clone(),
86            debug_symbols: self.debug_symbols,
87            visited: self.visited.clone(),
88            info: self.info.clone(),
89            debug_info: self.debug_info.clone(),
90            ext_meta_pos: self.ext_meta_pos.clone(),
91            compilation_options: self.compilation_options,
92        }
93    }
94}
95
96fn debug_symbols_activated() -> bool {
97    matches!(
98        GlobalConfig::get().compilation.logger.level,
99        CompilationLogLevel::Full
100    )
101}
102
103impl<T: SpirvTarget> Default for SpirvCompiler<T> {
104    fn default() -> Self {
105        Self {
106            target: Default::default(),
107            builder: Builder::new(),
108            cube_dim: CubeDim::new_single(),
109            mode: Default::default(),
110            addr_type: ElemType::UInt(UIntKind::U32).into(),
111            global_invocation_id: Default::default(),
112            num_workgroups: Default::default(),
113            capabilities: Default::default(),
114            state: Default::default(),
115            setup_block: Default::default(),
116            opt: Default::default(),
117            uniformity: Default::default(),
118            shared_liveness: Default::default(),
119            current_block: Default::default(),
120            debug_symbols: debug_symbols_activated(),
121            visited: Default::default(),
122            info: Default::default(),
123            debug_info: Default::default(),
124            ext_meta_pos: Default::default(),
125            compilation_options: Default::default(),
126        }
127    }
128}
129
130impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
131    type Target = Builder;
132
133    fn deref(&self) -> &Self::Target {
134        &self.builder
135    }
136}
137
138impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
139    fn deref_mut(&mut self) -> &mut Self::Target {
140        &mut self.builder
141    }
142}
143
144impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
145    type Representation = SpirvKernel;
146    type CompilationOptions = WgpuCompilationOptions;
147
148    fn compile(
149        &mut self,
150        mut value: KernelDefinition,
151        compilation_options: &Self::CompilationOptions,
152        mode: ExecutionMode,
153        addr_type: StorageType,
154    ) -> Result<Self::Representation, CompilationError> {
155        let errors = value.body.pop_errors();
156        if !errors.is_empty() {
157            let mut reason = "Can't compile spirv kernel".to_string();
158            for error in errors {
159                reason += error.as_str();
160                reason += "\n";
161            }
162
163            return Err(CompilationError::Validation {
164                reason,
165                backtrace: BackTrace::capture(),
166            });
167        }
168
169        let bindings = value.buffers.clone();
170        let mut ext_meta_pos = Vec::new();
171        let mut num_ext = 0;
172
173        let mut all_meta: Vec<_> = value
174            .buffers
175            .iter()
176            .chain(value.tensor_maps.iter())
177            .map(|buf| (buf.id, buf.has_extended_meta))
178            .collect();
179        all_meta.sort_by_key(|(id, _)| *id);
180
181        let num_meta = all_meta.len();
182
183        for (_, has_extended_meta) in all_meta.iter() {
184            ext_meta_pos.push(num_ext);
185            if *has_extended_meta {
186                num_ext += 1;
187            }
188        }
189
190        let metadata = Metadata::new(num_meta as u32, num_ext);
191
192        self.cube_dim = value.cube_dim;
193        self.mode = mode;
194        self.addr_type = addr_type;
195        self.info = Info::new(&value.scalars, metadata, addr_type);
196        self.compilation_options = *compilation_options;
197        self.ext_meta_pos = ext_meta_pos;
198
199        let (module, optimizer, shared_size) = self.compile_kernel(value);
200        let uniform_info = matches!(T::info_storage_class(self), StorageClass::Uniform);
201
202        Ok(SpirvKernel {
203            assembled_module: module.assemble(),
204            module: Some(Arc::new(module)),
205            optimizer: Some(Arc::new(optimizer)),
206            bindings: bindings.iter().map(|it| it.visibility).collect(),
207            shared_size,
208            uniform_info,
209        })
210    }
211
212    fn elem_size(&self, elem: core::ElemType) -> usize {
213        elem.size()
214    }
215
216    fn extension(&self) -> &'static str {
217        "spv"
218    }
219}
220
221impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        write!(f, "spirv<{:?}>", self.target)
224    }
225}
226
227impl<Target: SpirvTarget> SpirvCompiler<Target> {
228    pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer, usize) {
229        let options = kernel.options.clone();
230
231        self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
232
233        let version = self.compilation_options.vulkan.max_spirv_version;
234        self.set_version(version.0, version.1);
235
236        let mut target = self.target.clone();
237
238        let mut opt = OptimizerBuilder::default()
239            .with_transformer(ErfTransform)
240            .with_transformer(BitwiseTransform::new(
241                self.compilation_options.vulkan.supports_arbitrary_bitwise,
242            ))
243            .with_transformer(HypotTransform)
244            .with_transformer(RhypotTransform)
245            .with_processor(CheckedIoProcessor::new(
246                self.mode,
247                kernel.options.kernel_name.clone(),
248            ))
249            .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
250            .with_processor(SaturatingArithmeticProcessor::new(true))
251            .optimize(kernel.body.clone(), kernel.cube_dim);
252
253        self.uniformity = opt.analysis::<Uniformity>();
254        self.shared_liveness = opt.analysis::<SharedLiveness>();
255        self.opt = Rc::new(opt);
256
257        self.init_state(kernel.clone());
258        self.init_debug();
259
260        let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
261
262        target.set_kernel_name(options.kernel_name.clone());
263
264        let (main, debug_setup) = self.declare_main(&options.kernel_name);
265
266        let setup = self.id();
267        self.debug_name(setup, "setup");
268
269        let entry = self.opt.entry();
270        let body = self.label(entry);
271        let setup_block = self.setup(setup, debug_setup);
272        self.setup_block = setup_block;
273        self.compile_block(entry);
274
275        let ret = self.opt.ret;
276        self.compile_block(ret);
277
278        if self.selected_block().is_some() {
279            let label = self.label(ret);
280            self.branch(label).unwrap();
281        }
282
283        self.select_block(Some(setup_block)).unwrap();
284        self.branch(body).unwrap();
285
286        self.end_function().unwrap();
287
288        let shared_size = self.declare_shared_memories();
289
290        let builtins = self
291            .state
292            .used_builtins
293            .clone()
294            .into_iter()
295            .map(|(builtin, (id, item))| {
296                let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
297                self.variable(ty, Some(id), StorageClass::Input, None);
298                self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
299                id
300            })
301            .collect::<Vec<_>>();
302
303        target.set_modes(self, main, builtins, cube_dims);
304
305        let module = take(&mut self.builder).module();
306        (module, self.opt.as_ref().clone(), shared_size)
307    }
308
309    fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
310        self.begin_block(Some(label)).unwrap();
311
312        let opt = self.opt.clone();
313        for const_arr in opt.const_arrays() {
314            self.register_const_array(const_arr);
315        }
316
317        debug_setup(self);
318
319        let setup_block = self.selected_block().unwrap();
320        self.select_block(None).unwrap();
321        setup_block
322    }
323
324    #[track_caller]
325    pub fn current_block(&self) -> BasicBlock {
326        self.opt.block(self.current_block.unwrap()).clone()
327    }
328
329    pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
330        if let Some(existing) = self.state.used_builtins.get(&builtin) {
331            existing.0
332        } else {
333            let id = self.id();
334            self.state.used_builtins.insert(builtin, (id, item));
335            id
336        }
337    }
338
339    pub fn compile_block(&mut self, block: NodeIndex) {
340        if self.visited.contains(&block) {
341            return;
342        }
343        self.visited.insert(block);
344        self.current_block = Some(block);
345
346        let label = self.label(block);
347        self.begin_block(Some(label)).unwrap();
348        let block_id = self.selected_block().unwrap();
349
350        self.debug_start_block();
351
352        let operations = self.current_block().ops.borrow().clone();
353        for (_, operation) in operations {
354            self.compile_operation(operation);
355        }
356
357        let control_flow = self.current_block().control_flow.borrow().clone();
358        self.compile_control_flow(control_flow);
359
360        let current = self.selected_block();
361        self.select_block(Some(block_id)).unwrap();
362        let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
363        for phi in phi {
364            let out = self.compile_variable(phi.out);
365            let ty = out.item().id(self);
366            let out_id = self.write_id(&out);
367            let entries: Vec<_> = phi
368                .entries
369                .into_iter()
370                .map(|it| {
371                    let label = self.end_label(it.block);
372                    let value = self.compile_variable(it.value);
373                    let value = self.read(&value);
374                    (value, label)
375                })
376                .collect();
377            self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
378                .unwrap();
379        }
380        self.select_block(current).unwrap();
381    }
382
383    // Declare variable in the first block of the function
384    pub fn declare_function_variable(&mut self, ty: Word) -> Word {
385        let setup = self.setup_block;
386        let id = self.id();
387        let var = Instruction::new(
388            Op::Variable,
389            Some(ty),
390            Some(id),
391            vec![Operand::StorageClass(StorageClass::Function)],
392        );
393        let current_block = self.selected_block();
394        self.select_block(Some(setup)).unwrap();
395        self.insert_into_block(InsertPoint::Begin, var).unwrap();
396        self.select_block(current_block).unwrap();
397        id
398    }
399
400    fn declare_shared_memories(&mut self) -> usize {
401        if self.compilation_options.vulkan.supports_explicit_smem {
402            self.declare_shared_memories_explicit() as usize
403        } else {
404            self.declare_shared_memories_implicit() as usize
405        }
406    }
407
408    /// When using `VK_KHR_workgroup_memory_explicit_layout`, all shared memory is declared as a
409    /// `Block`. This means they are all pointers into the same chunk of memory, with different
410    /// offsets and sizes. Unlike C++, this shared block is declared implicitly, not explicitly.
411    /// Alignment and total size is calculated by the driver.
412    fn declare_shared_memories_explicit(&mut self) -> u32 {
413        let mut shared_size = 0;
414
415        let shared_arrays = self.state.shared_arrays.clone();
416        let shared = self.state.shared.clone();
417        if shared_arrays.is_empty() && shared.is_empty() {
418            return shared_size;
419        }
420
421        self.capabilities
422            .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
423
424        for (index, memory) in shared_arrays {
425            let item_size = memory.item.size();
426            shared_size = shared_size.max(memory.offset + memory.len * item_size);
427
428            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
429            // explicit layout as well.
430            match item_size {
431                1 => {
432                    self.capabilities
433                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
434                }
435                2 => {
436                    self.capabilities
437                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
438                }
439                _ => {}
440            }
441
442            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
443            let arr_id = arr_ty.id(self);
444
445            if !self.state.decorated_types.contains(&arr_id) {
446                self.decorate(
447                    arr_id,
448                    Decoration::ArrayStride,
449                    [Operand::LiteralBit32(item_size)],
450                );
451                self.state.decorated_types.insert(arr_id);
452            }
453
454            let block_ty = Item::Struct(vec![arr_ty]);
455            let block_id = block_ty.id(self);
456
457            self.decorate(block_id, Decoration::Block, []);
458            self.member_decorate(
459                block_id,
460                0,
461                Decoration::Offset,
462                [Operand::LiteralBit32(memory.offset)],
463            );
464
465            let ptr_ty = self.type_pointer(None, StorageClass::Workgroup, block_id);
466
467            self.debug_shared(memory.id, index);
468            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
469            self.decorate(memory.id, Decoration::Aliased, []);
470        }
471
472        for (index, memory) in shared {
473            let item_size = memory.item.size();
474            shared_size = shared_size.max(memory.offset + item_size);
475
476            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
477            // explicit layout as well.
478            match item_size {
479                1 => {
480                    self.capabilities
481                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
482                }
483                2 => {
484                    self.capabilities
485                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
486                }
487                _ => {}
488            }
489
490            let block_ty = Item::Struct(vec![memory.item]);
491            let block_id = block_ty.id(self);
492
493            self.decorate(block_id, Decoration::Block, []);
494            self.member_decorate(
495                block_id,
496                0,
497                Decoration::Offset,
498                [Operand::LiteralBit32(memory.offset)],
499            );
500
501            let ptr_ty = self.type_pointer(None, StorageClass::Workgroup, block_id);
502
503            self.debug_shared(memory.id, index);
504            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
505            self.decorate(memory.id, Decoration::Aliased, []);
506        }
507
508        shared_size
509    }
510
511    fn declare_shared_memories_implicit(&mut self) -> u32 {
512        let mut shared_size = 0;
513        let shared_memories = self.state.shared_arrays.clone();
514        for (index, memory) in shared_memories {
515            shared_size += memory.len * memory.item.size();
516
517            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
518            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
519
520            self.debug_shared(memory.id, index);
521            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
522        }
523        let shared = self.state.shared.clone();
524        for (index, memory) in shared {
525            shared_size += memory.item.size();
526
527            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
528
529            self.debug_shared(memory.id, index);
530            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
531        }
532        shared_size
533    }
534
535    pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
536        if !self.compilation_options.vulkan.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
537            return;
538        }
539        let mode = convert_math_mode(modes.fp_math_mode);
540        self.capabilities.insert(Capability::FloatControls2);
541        self.decorate(
542            out_id,
543            Decoration::FPFastMathMode,
544            [Operand::FPFastMathMode(mode)],
545        );
546    }
547
548    pub fn is_uniform_block(&self) -> bool {
549        self.uniformity
550            .is_block_uniform(self.current_block.unwrap())
551    }
552}
553
554pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
555    let mut flags = FPFastMathMode::NONE;
556
557    for mode in math_mode.iter() {
558        match mode {
559            FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
560            FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
561            FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
562            FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
563            FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
564            FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
565            FastMath::AllowTransform => {
566                flags |= FPFastMathMode::ALLOW_CONTRACT
567                    | FPFastMathMode::ALLOW_REASSOC
568                    | FPFastMathMode::ALLOW_TRANSFORM
569            }
570            _ => {}
571        }
572    }
573
574    flags
575}