cubecl_spirv/
compiler.rs

1use cubecl_common::ExecutionMode;
2use cubecl_core::{
3    Metadata, WgpuCompilationOptions, ir as core,
4    post_processing::{
5        checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
6        unroll::UnrollProcessor,
7    },
8    prelude::FastMath,
9};
10use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
11use cubecl_runtime::{
12    EnumSet,
13    config::{GlobalConfig, compilation::CompilationLogLevel},
14};
15use std::{
16    collections::HashSet,
17    fmt::Debug,
18    mem::take,
19    ops::{Deref, DerefMut},
20    rc::Rc,
21};
22
23use cubecl_core::{Compiler, compute::KernelDefinition};
24use rspirv::{
25    dr::{Builder, InsertPoint, Instruction, Module, Operand},
26    spirv::{self, BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
27};
28
29use crate::{
30    SpirvKernel,
31    debug::DebugInfo,
32    item::Item,
33    lookups::LookupTables,
34    target::{GLCompute, SpirvTarget},
35    transformers::{BitwiseTransform, ErfTransform},
36};
37
38pub const MAX_VECTORIZATION: u32 = 4;
39
40pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
41    pub target: Target,
42    pub(crate) builder: Builder,
43
44    pub mode: ExecutionMode,
45    pub debug_symbols: bool,
46    pub fp_math_mode: FPFastMathMode,
47    global_invocation_id: Word,
48    num_workgroups: Word,
49    pub setup_block: usize,
50    pub opt: Rc<Optimizer>,
51    pub uniformity: Rc<Uniformity>,
52    pub shared_liveness: Rc<SharedLiveness>,
53    pub current_block: Option<NodeIndex>,
54    pub visited: HashSet<NodeIndex>,
55
56    pub capabilities: HashSet<Capability>,
57    pub float_controls: bool,
58    pub state: LookupTables,
59    pub ext_meta_pos: Vec<u32>,
60    pub metadata: Metadata,
61    pub debug_info: Option<DebugInfo>,
62    pub compilation_options: WgpuCompilationOptions,
63}
64
65unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
66unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
67
68impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
69    fn clone(&self) -> Self {
70        Self {
71            target: self.target.clone(),
72            builder: Builder::new_from_module(self.module_ref().clone()),
73            mode: self.mode,
74            global_invocation_id: self.global_invocation_id,
75            num_workgroups: self.num_workgroups,
76            setup_block: self.setup_block,
77            opt: self.opt.clone(),
78            uniformity: self.uniformity.clone(),
79            shared_liveness: self.shared_liveness.clone(),
80            current_block: self.current_block,
81
82            capabilities: self.capabilities.clone(),
83            float_controls: self.float_controls,
84            state: self.state.clone(),
85            debug_symbols: self.debug_symbols,
86            fp_math_mode: self.fp_math_mode,
87            visited: self.visited.clone(),
88            metadata: self.metadata.clone(),
89            debug_info: self.debug_info.clone(),
90            ext_meta_pos: self.ext_meta_pos.clone(),
91            compilation_options: self.compilation_options.clone(),
92        }
93    }
94}
95
96fn debug_symbols_activated() -> bool {
97    matches!(
98        GlobalConfig::get().compilation.logger.level,
99        CompilationLogLevel::Full
100    )
101}
102
103impl<T: SpirvTarget> Default for SpirvCompiler<T> {
104    fn default() -> Self {
105        Self {
106            target: Default::default(),
107            builder: Builder::new(),
108            mode: Default::default(),
109            global_invocation_id: Default::default(),
110            num_workgroups: Default::default(),
111            capabilities: Default::default(),
112            float_controls: Default::default(),
113            state: Default::default(),
114            setup_block: Default::default(),
115            opt: Default::default(),
116            uniformity: Default::default(),
117            shared_liveness: Default::default(),
118            current_block: Default::default(),
119            debug_symbols: debug_symbols_activated(),
120            fp_math_mode: FPFastMathMode::NONE,
121            visited: Default::default(),
122            metadata: Default::default(),
123            debug_info: Default::default(),
124            ext_meta_pos: Default::default(),
125            compilation_options: Default::default(),
126        }
127    }
128}
129
130impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
131    type Target = Builder;
132
133    fn deref(&self) -> &Self::Target {
134        &self.builder
135    }
136}
137
138impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
139    fn deref_mut(&mut self) -> &mut Self::Target {
140        &mut self.builder
141    }
142}
143
144impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
145    type Representation = SpirvKernel;
146    type CompilationOptions = WgpuCompilationOptions;
147
148    fn compile(
149        &mut self,
150        value: KernelDefinition,
151        compilation_options: &Self::CompilationOptions,
152        mode: ExecutionMode,
153    ) -> Self::Representation {
154        let bindings = value.buffers.clone();
155        let scalars = value
156            .scalars
157            .iter()
158            .map(|s| (self.compile_storage_type(s.ty), s.count))
159            .collect();
160        let mut ext_meta_pos = Vec::new();
161        let mut num_ext = 0;
162
163        let mut all_meta: Vec<_> = value
164            .buffers
165            .iter()
166            .chain(value.tensor_maps.iter())
167            .map(|buf| (buf.id, buf.has_extended_meta))
168            .collect();
169        all_meta.sort_by_key(|(id, _)| *id);
170
171        let num_meta = all_meta.len();
172
173        for (_, has_extended_meta) in all_meta.iter() {
174            ext_meta_pos.push(num_ext);
175            if *has_extended_meta {
176                num_ext += 1;
177            }
178        }
179
180        self.mode = mode;
181        self.metadata = Metadata::new(num_meta as u32, num_ext);
182        self.compilation_options = compilation_options.clone();
183        self.ext_meta_pos = ext_meta_pos;
184
185        let (module, optimizer) = self.compile_kernel(value);
186        SpirvKernel {
187            module,
188            optimizer,
189            bindings,
190            scalars,
191            has_metadata: self.metadata.static_len() > 0,
192        }
193    }
194
195    fn elem_size(&self, elem: core::ElemType) -> usize {
196        elem.size()
197    }
198
199    fn extension(&self) -> &'static str {
200        "spv"
201    }
202}
203
204impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        write!(f, "spirv<{:?}>", self.target)
207    }
208}
209
210impl<Target: SpirvTarget> SpirvCompiler<Target> {
211    pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
212        let options = kernel.options.clone();
213
214        self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
215        self.fp_math_mode = match self.compilation_options.supports_fp_fast_math {
216            true => convert_math_mode(options.fp_math_mode),
217            false => FPFastMathMode::NONE,
218        };
219        self.float_controls = self.fp_math_mode != FPFastMathMode::NONE;
220
221        if self.float_controls {
222            self.capabilities.insert(Capability::FloatControls2);
223        }
224
225        self.set_version(1, 6);
226
227        let mut target = self.target.clone();
228
229        let mut opt = OptimizerBuilder::default()
230            .with_transformer(ErfTransform)
231            .with_transformer(BitwiseTransform)
232            .with_processor(CheckedIoProcessor::new(self.mode))
233            .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
234            .with_processor(SaturatingArithmeticProcessor::new(true))
235            .optimize(kernel.body.clone(), kernel.cube_dim);
236
237        self.uniformity = opt.analysis::<Uniformity>();
238        self.shared_liveness = opt.analysis::<SharedLiveness>();
239        self.opt = Rc::new(opt);
240
241        self.init_state(kernel.clone());
242        self.init_debug();
243
244        let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
245
246        target.set_kernel_name(options.kernel_name.clone());
247
248        let (main, debug_setup) = self.declare_main(&options.kernel_name);
249
250        let setup = self.id();
251        self.debug_name(setup, "setup");
252
253        let entry = self.opt.entry();
254        let body = self.label(entry);
255        let setup_block = self.setup(setup, debug_setup);
256        self.setup_block = setup_block;
257        self.compile_block(entry);
258
259        let ret = self.opt.ret;
260        self.compile_block(ret);
261
262        if self.selected_block().is_some() {
263            let label = self.label(ret);
264            self.branch(label).unwrap();
265        }
266
267        self.select_block(Some(setup_block)).unwrap();
268        self.branch(body).unwrap();
269
270        self.end_function().unwrap();
271
272        self.declare_shared_memories();
273
274        let builtins = self
275            .state
276            .used_builtins
277            .clone()
278            .into_iter()
279            .map(|(builtin, (id, item))| {
280                let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
281                self.variable(ty, Some(id), StorageClass::Input, None);
282                self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
283                id
284            })
285            .collect::<Vec<_>>();
286
287        target.set_modes(self, main, builtins, cube_dims);
288
289        let module = take(&mut self.builder).module();
290        (module, self.opt.as_ref().clone())
291    }
292
293    fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
294        self.begin_block(Some(label)).unwrap();
295
296        let opt = self.opt.clone();
297        for const_arr in opt.const_arrays() {
298            self.register_const_array(const_arr);
299        }
300
301        debug_setup(self);
302
303        let setup_block = self.selected_block().unwrap();
304        self.select_block(None).unwrap();
305        setup_block
306    }
307
308    #[track_caller]
309    pub fn current_block(&self) -> BasicBlock {
310        self.opt.block(self.current_block.unwrap()).clone()
311    }
312
313    pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
314        if let Some(existing) = self.state.used_builtins.get(&builtin) {
315            existing.0
316        } else {
317            let id = self.id();
318            self.state.used_builtins.insert(builtin, (id, item));
319            id
320        }
321    }
322
323    pub fn compile_block(&mut self, block: NodeIndex) {
324        if self.visited.contains(&block) {
325            return;
326        }
327        self.visited.insert(block);
328        self.current_block = Some(block);
329
330        let label = self.label(block);
331        self.begin_block(Some(label)).unwrap();
332        let block_id = self.selected_block().unwrap();
333
334        self.debug_start_block();
335
336        let operations = self.current_block().ops.borrow().clone();
337        for (_, operation) in operations {
338            self.compile_operation(operation);
339        }
340
341        let control_flow = self.current_block().control_flow.borrow().clone();
342        self.compile_control_flow(control_flow);
343
344        let current = self.selected_block();
345        self.select_block(Some(block_id)).unwrap();
346        let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
347        for phi in phi {
348            let out = self.compile_variable(phi.out);
349            let ty = out.item().id(self);
350            let out_id = self.write_id(&out);
351            let entries: Vec<_> = phi
352                .entries
353                .into_iter()
354                .map(|it| {
355                    let label = self.end_label(it.block);
356                    let value = self.compile_variable(it.value);
357                    let value = self.read(&value);
358                    (value, label)
359                })
360                .collect();
361            self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
362                .unwrap();
363        }
364        self.select_block(current).unwrap();
365    }
366
367    // Declare variable in the first block of the function
368    pub fn declare_function_variable(&mut self, ty: Word) -> Word {
369        let setup = self.setup_block;
370        let id = self.id();
371        let var = Instruction::new(
372            Op::Variable,
373            Some(ty),
374            Some(id),
375            vec![Operand::StorageClass(StorageClass::Function)],
376        );
377        let current_block = self.selected_block();
378        self.select_block(Some(setup)).unwrap();
379        self.insert_into_block(InsertPoint::Begin, var).unwrap();
380        self.select_block(current_block).unwrap();
381        id
382    }
383
384    fn declare_shared_memories(&mut self) {
385        if self.compilation_options.supports_explicit_smem {
386            self.declare_shared_memories_explicit();
387        } else {
388            self.declare_shared_memories_implicit();
389        }
390    }
391
392    /// When using `VK_KHR_workgroup_memory_explicit_layout`, all shared memory is declared as a
393    /// `Block`. This means they are all pointers into the same chunk of memory, with different
394    /// offsets and sizes. Unlike C++, this shared block is declared implicitly, not explicitly.
395    /// Alignment and total size is calculated by the driver.
396    fn declare_shared_memories_explicit(&mut self) {
397        let shared_memories = self.state.shared_memories.clone();
398        if shared_memories.is_empty() {
399            return;
400        }
401
402        self.capabilities
403            .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
404
405        for (index, memory) in shared_memories {
406            let item_size = memory.item.size();
407
408            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
409            // explicit layout as well.
410            match item_size {
411                1 => {
412                    self.capabilities
413                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
414                }
415                2 => {
416                    self.capabilities
417                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
418                }
419                _ => {}
420            }
421
422            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
423            let arr_id = arr_ty.id(self);
424
425            if !self.state.decorated_types.contains(&arr_id) {
426                self.decorate(
427                    arr_id,
428                    Decoration::ArrayStride,
429                    [Operand::LiteralBit32(item_size)],
430                );
431                self.state.decorated_types.insert(arr_id);
432            }
433
434            let block_ty = Item::Struct(vec![arr_ty]);
435            let block_id = block_ty.id(self);
436
437            self.decorate(block_id, Decoration::Block, []);
438            self.member_decorate(
439                block_id,
440                0,
441                Decoration::Offset,
442                [Operand::LiteralBit32(memory.offset)],
443            );
444
445            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
446
447            self.debug_shared(memory.id, index);
448            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
449            self.decorate(memory.id, Decoration::Aliased, []);
450        }
451    }
452
453    fn declare_shared_memories_implicit(&mut self) {
454        let shared_memories = self.state.shared_memories.clone();
455        for (index, memory) in shared_memories {
456            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
457            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
458
459            self.debug_shared(memory.id, index);
460            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
461        }
462    }
463
464    pub fn declare_float_execution_modes(&mut self, main: Word) {
465        let mode = self.const_u32(self.fp_math_mode.bits());
466
467        let types = self.builder.module_ref().types_global_values.clone();
468        let scalars = types
469            .iter()
470            .filter(|inst| inst.class.opcode == Op::TypeFloat)
471            .map(|it| it.result_id.expect("OpTypeFloat always has result ID"))
472            .collect::<Vec<_>>();
473        for ty in scalars {
474            self.execution_mode(main, spirv::ExecutionMode::FPFastMathDefault, [ty, mode]);
475        }
476    }
477
478    pub fn is_uniform_block(&self) -> bool {
479        self.uniformity
480            .is_block_uniform(self.current_block.unwrap())
481    }
482}
483
484fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
485    let mut flags = FPFastMathMode::NONE;
486
487    for mode in math_mode.iter() {
488        match mode {
489            FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
490            FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
491            FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
492            FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
493            FastMath::AllowContraction => flags |= FPFastMathMode::from_bits_retain(0x10000),
494            FastMath::AllowReassociation => flags |= FPFastMathMode::from_bits_retain(0x20000),
495            FastMath::AllowTransform => {
496                flags |= FPFastMathMode::from_bits_retain(0x10000)
497                    | FPFastMathMode::from_bits_retain(0x20000)
498                    | FPFastMathMode::from_bits_retain(0x40000)
499            }
500            _ => {}
501        }
502    }
503
504    flags
505}