cubecl_spirv/
compiler.rs

1use cubecl_common::{ExecutionMode, backtrace::BackTrace};
2use cubecl_core::{
3    Metadata, WgpuCompilationOptions,
4    ir::{self as core, InstructionModes},
5    post_processing::{
6        checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
7        unroll::UnrollProcessor,
8    },
9    prelude::{FastMath, KernelDefinition},
10};
11use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
12use cubecl_runtime::{
13    EnumSet,
14    compiler::CompilationError,
15    config::{GlobalConfig, compilation::CompilationLogLevel},
16};
17use std::{
18    collections::HashSet,
19    fmt::Debug,
20    mem::take,
21    ops::{Deref, DerefMut},
22    rc::Rc,
23};
24
25use cubecl_core::Compiler;
26use rspirv::{
27    dr::{Builder, InsertPoint, Instruction, Module, Operand},
28    spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
29};
30
31use crate::{
32    SpirvKernel,
33    debug::DebugInfo,
34    item::Item,
35    lookups::LookupTables,
36    target::{GLCompute, SpirvTarget},
37    transformers::{BitwiseTransform, ErfTransform},
38};
39
40pub const MAX_VECTORIZATION: u32 = 4;
41
42pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
43    pub target: Target,
44    pub(crate) builder: Builder,
45
46    pub mode: ExecutionMode,
47    pub debug_symbols: bool,
48    global_invocation_id: Word,
49    num_workgroups: Word,
50    pub setup_block: usize,
51    pub opt: Rc<Optimizer>,
52    pub uniformity: Rc<Uniformity>,
53    pub shared_liveness: Rc<SharedLiveness>,
54    pub current_block: Option<NodeIndex>,
55    pub visited: HashSet<NodeIndex>,
56
57    pub capabilities: HashSet<Capability>,
58    pub state: LookupTables,
59    pub ext_meta_pos: Vec<u32>,
60    pub metadata: Metadata,
61    pub debug_info: Option<DebugInfo>,
62    pub compilation_options: WgpuCompilationOptions,
63}
64
65unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
66unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
67
68impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
69    fn clone(&self) -> Self {
70        Self {
71            target: self.target.clone(),
72            builder: Builder::new_from_module(self.module_ref().clone()),
73            mode: self.mode,
74            global_invocation_id: self.global_invocation_id,
75            num_workgroups: self.num_workgroups,
76            setup_block: self.setup_block,
77            opt: self.opt.clone(),
78            uniformity: self.uniformity.clone(),
79            shared_liveness: self.shared_liveness.clone(),
80            current_block: self.current_block,
81            capabilities: self.capabilities.clone(),
82            state: self.state.clone(),
83            debug_symbols: self.debug_symbols,
84            visited: self.visited.clone(),
85            metadata: self.metadata.clone(),
86            debug_info: self.debug_info.clone(),
87            ext_meta_pos: self.ext_meta_pos.clone(),
88            compilation_options: self.compilation_options.clone(),
89        }
90    }
91}
92
93fn debug_symbols_activated() -> bool {
94    matches!(
95        GlobalConfig::get().compilation.logger.level,
96        CompilationLogLevel::Full
97    )
98}
99
100impl<T: SpirvTarget> Default for SpirvCompiler<T> {
101    fn default() -> Self {
102        Self {
103            target: Default::default(),
104            builder: Builder::new(),
105            mode: Default::default(),
106            global_invocation_id: Default::default(),
107            num_workgroups: Default::default(),
108            capabilities: Default::default(),
109            state: Default::default(),
110            setup_block: Default::default(),
111            opt: Default::default(),
112            uniformity: Default::default(),
113            shared_liveness: Default::default(),
114            current_block: Default::default(),
115            debug_symbols: debug_symbols_activated(),
116            visited: Default::default(),
117            metadata: Default::default(),
118            debug_info: Default::default(),
119            ext_meta_pos: Default::default(),
120            compilation_options: Default::default(),
121        }
122    }
123}
124
125impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
126    type Target = Builder;
127
128    fn deref(&self) -> &Self::Target {
129        &self.builder
130    }
131}
132
133impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
134    fn deref_mut(&mut self) -> &mut Self::Target {
135        &mut self.builder
136    }
137}
138
139impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
140    type Representation = SpirvKernel;
141    type CompilationOptions = WgpuCompilationOptions;
142
143    fn compile(
144        &mut self,
145        mut value: KernelDefinition,
146        compilation_options: &Self::CompilationOptions,
147        mode: ExecutionMode,
148    ) -> Result<Self::Representation, CompilationError> {
149        let errors = value.body.pop_errors();
150        if !errors.is_empty() {
151            let mut reason = "Can't compile spirv kernel".to_string();
152            for error in errors {
153                reason += error.as_str();
154                reason += "\n";
155            }
156
157            return Err(CompilationError::Validation {
158                reason,
159                backtrace: BackTrace::capture(),
160            });
161        }
162
163        let bindings = value.buffers.clone();
164        let scalars = value
165            .scalars
166            .iter()
167            .map(|s| (self.compile_storage_type(s.ty), s.count))
168            .collect();
169        let mut ext_meta_pos = Vec::new();
170        let mut num_ext = 0;
171
172        let mut all_meta: Vec<_> = value
173            .buffers
174            .iter()
175            .chain(value.tensor_maps.iter())
176            .map(|buf| (buf.id, buf.has_extended_meta))
177            .collect();
178        all_meta.sort_by_key(|(id, _)| *id);
179
180        let num_meta = all_meta.len();
181
182        for (_, has_extended_meta) in all_meta.iter() {
183            ext_meta_pos.push(num_ext);
184            if *has_extended_meta {
185                num_ext += 1;
186            }
187        }
188
189        self.mode = mode;
190        self.metadata = Metadata::new(num_meta as u32, num_ext);
191        self.compilation_options = compilation_options.clone();
192        self.ext_meta_pos = ext_meta_pos;
193
194        let (module, optimizer) = self.compile_kernel(value);
195        Ok(SpirvKernel {
196            module,
197            optimizer,
198            bindings,
199            scalars,
200            has_metadata: self.metadata.static_len() > 0,
201        })
202    }
203
204    fn elem_size(&self, elem: core::ElemType) -> usize {
205        elem.size()
206    }
207
208    fn extension(&self) -> &'static str {
209        "spv"
210    }
211}
212
213impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
214    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
215        write!(f, "spirv<{:?}>", self.target)
216    }
217}
218
219impl<Target: SpirvTarget> SpirvCompiler<Target> {
220    pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
221        let options = kernel.options.clone();
222
223        self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
224
225        self.set_version(1, 6);
226
227        let mut target = self.target.clone();
228
229        let mut opt = OptimizerBuilder::default()
230            .with_transformer(ErfTransform)
231            .with_transformer(BitwiseTransform)
232            .with_processor(CheckedIoProcessor::new(self.mode))
233            .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
234            .with_processor(SaturatingArithmeticProcessor::new(true))
235            .optimize(kernel.body.clone(), kernel.cube_dim);
236
237        self.uniformity = opt.analysis::<Uniformity>();
238        self.shared_liveness = opt.analysis::<SharedLiveness>();
239        self.opt = Rc::new(opt);
240
241        self.init_state(kernel.clone());
242        self.init_debug();
243
244        let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
245
246        target.set_kernel_name(options.kernel_name.clone());
247
248        let (main, debug_setup) = self.declare_main(&options.kernel_name);
249
250        let setup = self.id();
251        self.debug_name(setup, "setup");
252
253        let entry = self.opt.entry();
254        let body = self.label(entry);
255        let setup_block = self.setup(setup, debug_setup);
256        self.setup_block = setup_block;
257        self.compile_block(entry);
258
259        let ret = self.opt.ret;
260        self.compile_block(ret);
261
262        if self.selected_block().is_some() {
263            let label = self.label(ret);
264            self.branch(label).unwrap();
265        }
266
267        self.select_block(Some(setup_block)).unwrap();
268        self.branch(body).unwrap();
269
270        self.end_function().unwrap();
271
272        self.declare_shared_memories();
273
274        let builtins = self
275            .state
276            .used_builtins
277            .clone()
278            .into_iter()
279            .map(|(builtin, (id, item))| {
280                let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
281                self.variable(ty, Some(id), StorageClass::Input, None);
282                self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
283                id
284            })
285            .collect::<Vec<_>>();
286
287        target.set_modes(self, main, builtins, cube_dims);
288
289        let module = take(&mut self.builder).module();
290        (module, self.opt.as_ref().clone())
291    }
292
293    fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
294        self.begin_block(Some(label)).unwrap();
295
296        let opt = self.opt.clone();
297        for const_arr in opt.const_arrays() {
298            self.register_const_array(const_arr);
299        }
300
301        debug_setup(self);
302
303        let setup_block = self.selected_block().unwrap();
304        self.select_block(None).unwrap();
305        setup_block
306    }
307
308    #[track_caller]
309    pub fn current_block(&self) -> BasicBlock {
310        self.opt.block(self.current_block.unwrap()).clone()
311    }
312
313    pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
314        if let Some(existing) = self.state.used_builtins.get(&builtin) {
315            existing.0
316        } else {
317            let id = self.id();
318            self.state.used_builtins.insert(builtin, (id, item));
319            id
320        }
321    }
322
323    pub fn compile_block(&mut self, block: NodeIndex) {
324        if self.visited.contains(&block) {
325            return;
326        }
327        self.visited.insert(block);
328        self.current_block = Some(block);
329
330        let label = self.label(block);
331        self.begin_block(Some(label)).unwrap();
332        let block_id = self.selected_block().unwrap();
333
334        self.debug_start_block();
335
336        let operations = self.current_block().ops.borrow().clone();
337        for (_, operation) in operations {
338            self.compile_operation(operation);
339        }
340
341        let control_flow = self.current_block().control_flow.borrow().clone();
342        self.compile_control_flow(control_flow);
343
344        let current = self.selected_block();
345        self.select_block(Some(block_id)).unwrap();
346        let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
347        for phi in phi {
348            let out = self.compile_variable(phi.out);
349            let ty = out.item().id(self);
350            let out_id = self.write_id(&out);
351            let entries: Vec<_> = phi
352                .entries
353                .into_iter()
354                .map(|it| {
355                    let label = self.end_label(it.block);
356                    let value = self.compile_variable(it.value);
357                    let value = self.read(&value);
358                    (value, label)
359                })
360                .collect();
361            self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
362                .unwrap();
363        }
364        self.select_block(current).unwrap();
365    }
366
367    // Declare variable in the first block of the function
368    pub fn declare_function_variable(&mut self, ty: Word) -> Word {
369        let setup = self.setup_block;
370        let id = self.id();
371        let var = Instruction::new(
372            Op::Variable,
373            Some(ty),
374            Some(id),
375            vec![Operand::StorageClass(StorageClass::Function)],
376        );
377        let current_block = self.selected_block();
378        self.select_block(Some(setup)).unwrap();
379        self.insert_into_block(InsertPoint::Begin, var).unwrap();
380        self.select_block(current_block).unwrap();
381        id
382    }
383
384    fn declare_shared_memories(&mut self) {
385        if self.compilation_options.supports_explicit_smem {
386            self.declare_shared_memories_explicit();
387        } else {
388            self.declare_shared_memories_implicit();
389        }
390    }
391
392    /// When using `VK_KHR_workgroup_memory_explicit_layout`, all shared memory is declared as a
393    /// `Block`. This means they are all pointers into the same chunk of memory, with different
394    /// offsets and sizes. Unlike C++, this shared block is declared implicitly, not explicitly.
395    /// Alignment and total size is calculated by the driver.
396    fn declare_shared_memories_explicit(&mut self) {
397        let shared_arrays = self.state.shared_arrays.clone();
398        let shared = self.state.shared.clone();
399        if shared_arrays.is_empty() && shared.is_empty() {
400            return;
401        }
402
403        self.capabilities
404            .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
405
406        for (index, memory) in shared_arrays {
407            let item_size = memory.item.size();
408
409            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
410            // explicit layout as well.
411            match item_size {
412                1 => {
413                    self.capabilities
414                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
415                }
416                2 => {
417                    self.capabilities
418                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
419                }
420                _ => {}
421            }
422
423            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
424            let arr_id = arr_ty.id(self);
425
426            if !self.state.decorated_types.contains(&arr_id) {
427                self.decorate(
428                    arr_id,
429                    Decoration::ArrayStride,
430                    [Operand::LiteralBit32(item_size)],
431                );
432                self.state.decorated_types.insert(arr_id);
433            }
434
435            let block_ty = Item::Struct(vec![arr_ty]);
436            let block_id = block_ty.id(self);
437
438            self.decorate(block_id, Decoration::Block, []);
439            self.member_decorate(
440                block_id,
441                0,
442                Decoration::Offset,
443                [Operand::LiteralBit32(memory.offset)],
444            );
445
446            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
447
448            self.debug_shared(memory.id, index);
449            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
450            self.decorate(memory.id, Decoration::Aliased, []);
451        }
452
453        for (index, memory) in shared {
454            let item_size = memory.item.size();
455
456            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
457            // explicit layout as well.
458            match item_size {
459                1 => {
460                    self.capabilities
461                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
462                }
463                2 => {
464                    self.capabilities
465                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
466                }
467                _ => {}
468            }
469
470            let block_ty = Item::Struct(vec![memory.item]);
471            let block_id = block_ty.id(self);
472
473            self.decorate(block_id, Decoration::Block, []);
474            self.member_decorate(
475                block_id,
476                0,
477                Decoration::Offset,
478                [Operand::LiteralBit32(memory.offset)],
479            );
480
481            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
482
483            self.debug_shared(memory.id, index);
484            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
485            self.decorate(memory.id, Decoration::Aliased, []);
486        }
487    }
488
489    fn declare_shared_memories_implicit(&mut self) {
490        let shared_memories = self.state.shared_arrays.clone();
491        for (index, memory) in shared_memories {
492            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
493            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
494
495            self.debug_shared(memory.id, index);
496            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
497        }
498        let shared = self.state.shared.clone();
499        for (index, memory) in shared {
500            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
501
502            self.debug_shared(memory.id, index);
503            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
504        }
505    }
506
507    pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
508        if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
509            return;
510        }
511        let mode = convert_math_mode(modes.fp_math_mode);
512        self.capabilities.insert(Capability::FloatControls2);
513        self.decorate(
514            out_id,
515            Decoration::FPFastMathMode,
516            [Operand::FPFastMathMode(mode)],
517        );
518    }
519
520    pub fn is_uniform_block(&self) -> bool {
521        self.uniformity
522            .is_block_uniform(self.current_block.unwrap())
523    }
524}
525
526pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
527    let mut flags = FPFastMathMode::NONE;
528
529    for mode in math_mode.iter() {
530        match mode {
531            FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
532            FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
533            FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
534            FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
535            FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
536            FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
537            FastMath::AllowTransform => {
538                flags |= FPFastMathMode::ALLOW_CONTRACT
539                    | FPFastMathMode::ALLOW_REASSOC
540                    | FPFastMathMode::ALLOW_TRANSFORM
541            }
542            _ => {}
543        }
544    }
545
546    flags
547}