cubecl_spirv/
compiler.rs

1use cubecl_common::ExecutionMode;
2use cubecl_core::{
3    Metadata, WgpuCompilationOptions,
4    ir::{self as core, InstructionModes},
5    post_processing::{
6        checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
7        unroll::UnrollProcessor,
8    },
9    prelude::FastMath,
10};
11use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
12use cubecl_runtime::{
13    EnumSet,
14    config::{GlobalConfig, compilation::CompilationLogLevel},
15};
16use std::{
17    collections::HashSet,
18    fmt::Debug,
19    mem::take,
20    ops::{Deref, DerefMut},
21    rc::Rc,
22};
23
24use cubecl_core::{Compiler, compute::KernelDefinition};
25use rspirv::{
26    dr::{Builder, InsertPoint, Instruction, Module, Operand},
27    spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
28};
29
30use crate::{
31    SpirvKernel,
32    debug::DebugInfo,
33    item::Item,
34    lookups::LookupTables,
35    target::{GLCompute, SpirvTarget},
36    transformers::{BitwiseTransform, ErfTransform},
37};
38
39pub const MAX_VECTORIZATION: u32 = 4;
40
41pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
42    pub target: Target,
43    pub(crate) builder: Builder,
44
45    pub mode: ExecutionMode,
46    pub debug_symbols: bool,
47    global_invocation_id: Word,
48    num_workgroups: Word,
49    pub setup_block: usize,
50    pub opt: Rc<Optimizer>,
51    pub uniformity: Rc<Uniformity>,
52    pub shared_liveness: Rc<SharedLiveness>,
53    pub current_block: Option<NodeIndex>,
54    pub visited: HashSet<NodeIndex>,
55
56    pub capabilities: HashSet<Capability>,
57    pub state: LookupTables,
58    pub ext_meta_pos: Vec<u32>,
59    pub metadata: Metadata,
60    pub debug_info: Option<DebugInfo>,
61    pub compilation_options: WgpuCompilationOptions,
62}
63
64unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
65unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
66
67impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
68    fn clone(&self) -> Self {
69        Self {
70            target: self.target.clone(),
71            builder: Builder::new_from_module(self.module_ref().clone()),
72            mode: self.mode,
73            global_invocation_id: self.global_invocation_id,
74            num_workgroups: self.num_workgroups,
75            setup_block: self.setup_block,
76            opt: self.opt.clone(),
77            uniformity: self.uniformity.clone(),
78            shared_liveness: self.shared_liveness.clone(),
79            current_block: self.current_block,
80            capabilities: self.capabilities.clone(),
81            state: self.state.clone(),
82            debug_symbols: self.debug_symbols,
83            visited: self.visited.clone(),
84            metadata: self.metadata.clone(),
85            debug_info: self.debug_info.clone(),
86            ext_meta_pos: self.ext_meta_pos.clone(),
87            compilation_options: self.compilation_options.clone(),
88        }
89    }
90}
91
92fn debug_symbols_activated() -> bool {
93    matches!(
94        GlobalConfig::get().compilation.logger.level,
95        CompilationLogLevel::Full
96    )
97}
98
99impl<T: SpirvTarget> Default for SpirvCompiler<T> {
100    fn default() -> Self {
101        Self {
102            target: Default::default(),
103            builder: Builder::new(),
104            mode: Default::default(),
105            global_invocation_id: Default::default(),
106            num_workgroups: Default::default(),
107            capabilities: Default::default(),
108            state: Default::default(),
109            setup_block: Default::default(),
110            opt: Default::default(),
111            uniformity: Default::default(),
112            shared_liveness: Default::default(),
113            current_block: Default::default(),
114            debug_symbols: debug_symbols_activated(),
115            visited: Default::default(),
116            metadata: Default::default(),
117            debug_info: Default::default(),
118            ext_meta_pos: Default::default(),
119            compilation_options: Default::default(),
120        }
121    }
122}
123
124impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
125    type Target = Builder;
126
127    fn deref(&self) -> &Self::Target {
128        &self.builder
129    }
130}
131
132impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
133    fn deref_mut(&mut self) -> &mut Self::Target {
134        &mut self.builder
135    }
136}
137
138impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
139    type Representation = SpirvKernel;
140    type CompilationOptions = WgpuCompilationOptions;
141
142    fn compile(
143        &mut self,
144        value: KernelDefinition,
145        compilation_options: &Self::CompilationOptions,
146        mode: ExecutionMode,
147    ) -> Self::Representation {
148        let bindings = value.buffers.clone();
149        let scalars = value
150            .scalars
151            .iter()
152            .map(|s| (self.compile_storage_type(s.ty), s.count))
153            .collect();
154        let mut ext_meta_pos = Vec::new();
155        let mut num_ext = 0;
156
157        let mut all_meta: Vec<_> = value
158            .buffers
159            .iter()
160            .chain(value.tensor_maps.iter())
161            .map(|buf| (buf.id, buf.has_extended_meta))
162            .collect();
163        all_meta.sort_by_key(|(id, _)| *id);
164
165        let num_meta = all_meta.len();
166
167        for (_, has_extended_meta) in all_meta.iter() {
168            ext_meta_pos.push(num_ext);
169            if *has_extended_meta {
170                num_ext += 1;
171            }
172        }
173
174        self.mode = mode;
175        self.metadata = Metadata::new(num_meta as u32, num_ext);
176        self.compilation_options = compilation_options.clone();
177        self.ext_meta_pos = ext_meta_pos;
178
179        let (module, optimizer) = self.compile_kernel(value);
180        SpirvKernel {
181            module,
182            optimizer,
183            bindings,
184            scalars,
185            has_metadata: self.metadata.static_len() > 0,
186        }
187    }
188
189    fn elem_size(&self, elem: core::ElemType) -> usize {
190        elem.size()
191    }
192
193    fn extension(&self) -> &'static str {
194        "spv"
195    }
196}
197
198impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
199    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200        write!(f, "spirv<{:?}>", self.target)
201    }
202}
203
204impl<Target: SpirvTarget> SpirvCompiler<Target> {
205    pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
206        let options = kernel.options.clone();
207
208        self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
209
210        self.set_version(1, 6);
211
212        let mut target = self.target.clone();
213
214        let mut opt = OptimizerBuilder::default()
215            .with_transformer(ErfTransform)
216            .with_transformer(BitwiseTransform)
217            .with_processor(CheckedIoProcessor::new(self.mode))
218            .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
219            .with_processor(SaturatingArithmeticProcessor::new(true))
220            .optimize(kernel.body.clone(), kernel.cube_dim);
221
222        self.uniformity = opt.analysis::<Uniformity>();
223        self.shared_liveness = opt.analysis::<SharedLiveness>();
224        self.opt = Rc::new(opt);
225
226        self.init_state(kernel.clone());
227        self.init_debug();
228
229        let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
230
231        target.set_kernel_name(options.kernel_name.clone());
232
233        let (main, debug_setup) = self.declare_main(&options.kernel_name);
234
235        let setup = self.id();
236        self.debug_name(setup, "setup");
237
238        let entry = self.opt.entry();
239        let body = self.label(entry);
240        let setup_block = self.setup(setup, debug_setup);
241        self.setup_block = setup_block;
242        self.compile_block(entry);
243
244        let ret = self.opt.ret;
245        self.compile_block(ret);
246
247        if self.selected_block().is_some() {
248            let label = self.label(ret);
249            self.branch(label).unwrap();
250        }
251
252        self.select_block(Some(setup_block)).unwrap();
253        self.branch(body).unwrap();
254
255        self.end_function().unwrap();
256
257        self.declare_shared_memories();
258
259        let builtins = self
260            .state
261            .used_builtins
262            .clone()
263            .into_iter()
264            .map(|(builtin, (id, item))| {
265                let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
266                self.variable(ty, Some(id), StorageClass::Input, None);
267                self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
268                id
269            })
270            .collect::<Vec<_>>();
271
272        target.set_modes(self, main, builtins, cube_dims);
273
274        let module = take(&mut self.builder).module();
275        (module, self.opt.as_ref().clone())
276    }
277
278    fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
279        self.begin_block(Some(label)).unwrap();
280
281        let opt = self.opt.clone();
282        for const_arr in opt.const_arrays() {
283            self.register_const_array(const_arr);
284        }
285
286        debug_setup(self);
287
288        let setup_block = self.selected_block().unwrap();
289        self.select_block(None).unwrap();
290        setup_block
291    }
292
293    #[track_caller]
294    pub fn current_block(&self) -> BasicBlock {
295        self.opt.block(self.current_block.unwrap()).clone()
296    }
297
298    pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
299        if let Some(existing) = self.state.used_builtins.get(&builtin) {
300            existing.0
301        } else {
302            let id = self.id();
303            self.state.used_builtins.insert(builtin, (id, item));
304            id
305        }
306    }
307
308    pub fn compile_block(&mut self, block: NodeIndex) {
309        if self.visited.contains(&block) {
310            return;
311        }
312        self.visited.insert(block);
313        self.current_block = Some(block);
314
315        let label = self.label(block);
316        self.begin_block(Some(label)).unwrap();
317        let block_id = self.selected_block().unwrap();
318
319        self.debug_start_block();
320
321        let operations = self.current_block().ops.borrow().clone();
322        for (_, operation) in operations {
323            self.compile_operation(operation);
324        }
325
326        let control_flow = self.current_block().control_flow.borrow().clone();
327        self.compile_control_flow(control_flow);
328
329        let current = self.selected_block();
330        self.select_block(Some(block_id)).unwrap();
331        let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
332        for phi in phi {
333            let out = self.compile_variable(phi.out);
334            let ty = out.item().id(self);
335            let out_id = self.write_id(&out);
336            let entries: Vec<_> = phi
337                .entries
338                .into_iter()
339                .map(|it| {
340                    let label = self.end_label(it.block);
341                    let value = self.compile_variable(it.value);
342                    let value = self.read(&value);
343                    (value, label)
344                })
345                .collect();
346            self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
347                .unwrap();
348        }
349        self.select_block(current).unwrap();
350    }
351
352    // Declare variable in the first block of the function
353    pub fn declare_function_variable(&mut self, ty: Word) -> Word {
354        let setup = self.setup_block;
355        let id = self.id();
356        let var = Instruction::new(
357            Op::Variable,
358            Some(ty),
359            Some(id),
360            vec![Operand::StorageClass(StorageClass::Function)],
361        );
362        let current_block = self.selected_block();
363        self.select_block(Some(setup)).unwrap();
364        self.insert_into_block(InsertPoint::Begin, var).unwrap();
365        self.select_block(current_block).unwrap();
366        id
367    }
368
369    fn declare_shared_memories(&mut self) {
370        if self.compilation_options.supports_explicit_smem {
371            self.declare_shared_memories_explicit();
372        } else {
373            self.declare_shared_memories_implicit();
374        }
375    }
376
377    /// When using `VK_KHR_workgroup_memory_explicit_layout`, all shared memory is declared as a
378    /// `Block`. This means they are all pointers into the same chunk of memory, with different
379    /// offsets and sizes. Unlike C++, this shared block is declared implicitly, not explicitly.
380    /// Alignment and total size is calculated by the driver.
381    fn declare_shared_memories_explicit(&mut self) {
382        let shared_memories = self.state.shared_memories.clone();
383        if shared_memories.is_empty() {
384            return;
385        }
386
387        self.capabilities
388            .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
389
390        for (index, memory) in shared_memories {
391            let item_size = memory.item.size();
392
393            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
394            // explicit layout as well.
395            match item_size {
396                1 => {
397                    self.capabilities
398                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
399                }
400                2 => {
401                    self.capabilities
402                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
403                }
404                _ => {}
405            }
406
407            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
408            let arr_id = arr_ty.id(self);
409
410            if !self.state.decorated_types.contains(&arr_id) {
411                self.decorate(
412                    arr_id,
413                    Decoration::ArrayStride,
414                    [Operand::LiteralBit32(item_size)],
415                );
416                self.state.decorated_types.insert(arr_id);
417            }
418
419            let block_ty = Item::Struct(vec![arr_ty]);
420            let block_id = block_ty.id(self);
421
422            self.decorate(block_id, Decoration::Block, []);
423            self.member_decorate(
424                block_id,
425                0,
426                Decoration::Offset,
427                [Operand::LiteralBit32(memory.offset)],
428            );
429
430            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
431
432            self.debug_shared(memory.id, index);
433            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
434            self.decorate(memory.id, Decoration::Aliased, []);
435        }
436    }
437
438    fn declare_shared_memories_implicit(&mut self) {
439        let shared_memories = self.state.shared_memories.clone();
440        for (index, memory) in shared_memories {
441            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
442            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
443
444            self.debug_shared(memory.id, index);
445            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
446        }
447    }
448
449    pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
450        if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
451            return;
452        }
453        let mode = convert_math_mode(modes.fp_math_mode);
454        self.capabilities.insert(Capability::FloatControls2);
455        self.decorate(
456            out_id,
457            Decoration::FPFastMathMode,
458            [Operand::FPFastMathMode(mode)],
459        );
460    }
461
462    pub fn is_uniform_block(&self) -> bool {
463        self.uniformity
464            .is_block_uniform(self.current_block.unwrap())
465    }
466}
467
468pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
469    let mut flags = FPFastMathMode::NONE;
470
471    for mode in math_mode.iter() {
472        match mode {
473            FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
474            FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
475            FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
476            FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
477            FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
478            FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
479            FastMath::AllowTransform => {
480                flags |= FPFastMathMode::ALLOW_CONTRACT
481                    | FPFastMathMode::ALLOW_REASSOC
482                    | FPFastMathMode::ALLOW_TRANSFORM
483            }
484            _ => {}
485        }
486    }
487
488    flags
489}