cubecl_spirv/
compiler.rs

1use crate::{
2    SpirvKernel,
3    debug::DebugInfo,
4    item::Item,
5    lookups::LookupTables,
6    target::{GLCompute, SpirvTarget},
7    transformers::{BitwiseTransform, ErfTransform, HypotTransform, RhypotTransform},
8};
9use cubecl_common::backtrace::BackTrace;
10use cubecl_core::{
11    Compiler, CubeDim, Metadata, WgpuCompilationOptions,
12    ir::{self as core, ElemType, InstructionModes, StorageType, UIntKind, features::EnumSet},
13    post_processing::{
14        checked_io::CheckedIoProcessor, saturating::SaturatingArithmeticProcessor,
15        unroll::UnrollProcessor,
16    },
17    prelude::{FastMath, KernelDefinition},
18    server::ExecutionMode,
19};
20use cubecl_opt::{BasicBlock, NodeIndex, Optimizer, OptimizerBuilder, SharedLiveness, Uniformity};
21use cubecl_runtime::{
22    compiler::CompilationError,
23    config::{GlobalConfig, compilation::CompilationLogLevel},
24};
25use rspirv::{
26    dr::{Builder, InsertPoint, Instruction, Module, Operand},
27    spirv::{BuiltIn, Capability, Decoration, FPFastMathMode, Op, StorageClass, Word},
28};
29use std::{
30    collections::HashSet,
31    fmt::Debug,
32    mem::take,
33    ops::{Deref, DerefMut},
34    rc::Rc,
35};
36
37pub const MAX_VECTORIZATION: usize = 4;
38
39pub struct SpirvCompiler<Target: SpirvTarget = GLCompute> {
40    pub target: Target,
41    pub(crate) builder: Builder,
42
43    pub cube_dim: CubeDim,
44    pub mode: ExecutionMode,
45    pub addr_type: StorageType,
46    pub debug_symbols: bool,
47    global_invocation_id: Word,
48    num_workgroups: Word,
49    pub setup_block: usize,
50    pub opt: Rc<Optimizer>,
51    pub uniformity: Rc<Uniformity>,
52    pub shared_liveness: Rc<SharedLiveness>,
53    pub current_block: Option<NodeIndex>,
54    pub visited: HashSet<NodeIndex>,
55
56    pub capabilities: HashSet<Capability>,
57    pub state: LookupTables,
58    pub ext_meta_pos: Vec<u32>,
59    pub metadata: Metadata,
60    pub debug_info: Option<DebugInfo>,
61    pub compilation_options: WgpuCompilationOptions,
62}
63
64unsafe impl<T: SpirvTarget> Send for SpirvCompiler<T> {}
65unsafe impl<T: SpirvTarget> Sync for SpirvCompiler<T> {}
66
67impl<T: SpirvTarget> Clone for SpirvCompiler<T> {
68    fn clone(&self) -> Self {
69        Self {
70            target: self.target.clone(),
71            builder: Builder::new_from_module(self.module_ref().clone()),
72            cube_dim: self.cube_dim,
73            mode: self.mode,
74            addr_type: self.addr_type,
75            global_invocation_id: self.global_invocation_id,
76            num_workgroups: self.num_workgroups,
77            setup_block: self.setup_block,
78            opt: self.opt.clone(),
79            uniformity: self.uniformity.clone(),
80            shared_liveness: self.shared_liveness.clone(),
81            current_block: self.current_block,
82            capabilities: self.capabilities.clone(),
83            state: self.state.clone(),
84            debug_symbols: self.debug_symbols,
85            visited: self.visited.clone(),
86            metadata: self.metadata.clone(),
87            debug_info: self.debug_info.clone(),
88            ext_meta_pos: self.ext_meta_pos.clone(),
89            compilation_options: self.compilation_options.clone(),
90        }
91    }
92}
93
94fn debug_symbols_activated() -> bool {
95    matches!(
96        GlobalConfig::get().compilation.logger.level,
97        CompilationLogLevel::Full
98    )
99}
100
101impl<T: SpirvTarget> Default for SpirvCompiler<T> {
102    fn default() -> Self {
103        Self {
104            target: Default::default(),
105            builder: Builder::new(),
106            cube_dim: CubeDim::new_single(),
107            mode: Default::default(),
108            addr_type: ElemType::UInt(UIntKind::U32).into(),
109            global_invocation_id: Default::default(),
110            num_workgroups: Default::default(),
111            capabilities: Default::default(),
112            state: Default::default(),
113            setup_block: Default::default(),
114            opt: Default::default(),
115            uniformity: Default::default(),
116            shared_liveness: Default::default(),
117            current_block: Default::default(),
118            debug_symbols: debug_symbols_activated(),
119            visited: Default::default(),
120            metadata: Default::default(),
121            debug_info: Default::default(),
122            ext_meta_pos: Default::default(),
123            compilation_options: Default::default(),
124        }
125    }
126}
127
128impl<T: SpirvTarget> Deref for SpirvCompiler<T> {
129    type Target = Builder;
130
131    fn deref(&self) -> &Self::Target {
132        &self.builder
133    }
134}
135
136impl<T: SpirvTarget> DerefMut for SpirvCompiler<T> {
137    fn deref_mut(&mut self) -> &mut Self::Target {
138        &mut self.builder
139    }
140}
141
142impl<T: SpirvTarget> Compiler for SpirvCompiler<T> {
143    type Representation = SpirvKernel;
144    type CompilationOptions = WgpuCompilationOptions;
145
146    fn compile(
147        &mut self,
148        mut value: KernelDefinition,
149        compilation_options: &Self::CompilationOptions,
150        mode: ExecutionMode,
151        addr_type: StorageType,
152    ) -> Result<Self::Representation, CompilationError> {
153        let errors = value.body.pop_errors();
154        if !errors.is_empty() {
155            let mut reason = "Can't compile spirv kernel".to_string();
156            for error in errors {
157                reason += error.as_str();
158                reason += "\n";
159            }
160
161            return Err(CompilationError::Validation {
162                reason,
163                backtrace: BackTrace::capture(),
164            });
165        }
166
167        let bindings = value.buffers.clone();
168        let scalars = value
169            .scalars
170            .iter()
171            .map(|s| (self.compile_storage_type(s.ty), s.count))
172            .collect();
173        let mut ext_meta_pos = Vec::new();
174        let mut num_ext = 0;
175
176        let mut all_meta: Vec<_> = value
177            .buffers
178            .iter()
179            .chain(value.tensor_maps.iter())
180            .map(|buf| (buf.id, buf.has_extended_meta))
181            .collect();
182        all_meta.sort_by_key(|(id, _)| *id);
183
184        let num_meta = all_meta.len();
185
186        for (_, has_extended_meta) in all_meta.iter() {
187            ext_meta_pos.push(num_ext);
188            if *has_extended_meta {
189                num_ext += 1;
190            }
191        }
192
193        self.cube_dim = value.cube_dim;
194        self.mode = mode;
195        self.addr_type = addr_type;
196        self.metadata = Metadata::new(num_meta as u32, num_ext);
197        self.compilation_options = compilation_options.clone();
198        self.ext_meta_pos = ext_meta_pos;
199
200        let (module, optimizer) = self.compile_kernel(value);
201        Ok(SpirvKernel {
202            module,
203            optimizer,
204            bindings,
205            scalars,
206            has_metadata: self.metadata.static_len() > 0,
207        })
208    }
209
210    fn elem_size(&self, elem: core::ElemType) -> usize {
211        elem.size()
212    }
213
214    fn extension(&self) -> &'static str {
215        "spv"
216    }
217}
218
219impl<Target: SpirvTarget> Debug for SpirvCompiler<Target> {
220    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
221        write!(f, "spirv<{:?}>", self.target)
222    }
223}
224
225impl<Target: SpirvTarget> SpirvCompiler<Target> {
226    pub fn compile_kernel(&mut self, kernel: KernelDefinition) -> (Module, Optimizer) {
227        let options = kernel.options.clone();
228
229        self.debug_symbols = debug_symbols_activated() || options.debug_symbols;
230
231        self.set_version(1, 6);
232
233        let mut target = self.target.clone();
234
235        let mut opt = OptimizerBuilder::default()
236            .with_transformer(ErfTransform)
237            .with_transformer(BitwiseTransform)
238            .with_transformer(HypotTransform)
239            .with_transformer(RhypotTransform)
240            .with_processor(CheckedIoProcessor::new(self.mode))
241            .with_processor(UnrollProcessor::new(MAX_VECTORIZATION))
242            .with_processor(SaturatingArithmeticProcessor::new(true))
243            .optimize(kernel.body.clone(), kernel.cube_dim);
244
245        self.uniformity = opt.analysis::<Uniformity>();
246        self.shared_liveness = opt.analysis::<SharedLiveness>();
247        self.opt = Rc::new(opt);
248
249        self.init_state(kernel.clone());
250        self.init_debug();
251
252        let cube_dims = vec![kernel.cube_dim.x, kernel.cube_dim.y, kernel.cube_dim.z];
253
254        target.set_kernel_name(options.kernel_name.clone());
255
256        let (main, debug_setup) = self.declare_main(&options.kernel_name);
257
258        let setup = self.id();
259        self.debug_name(setup, "setup");
260
261        let entry = self.opt.entry();
262        let body = self.label(entry);
263        let setup_block = self.setup(setup, debug_setup);
264        self.setup_block = setup_block;
265        self.compile_block(entry);
266
267        let ret = self.opt.ret;
268        self.compile_block(ret);
269
270        if self.selected_block().is_some() {
271            let label = self.label(ret);
272            self.branch(label).unwrap();
273        }
274
275        self.select_block(Some(setup_block)).unwrap();
276        self.branch(body).unwrap();
277
278        self.end_function().unwrap();
279
280        self.declare_shared_memories();
281
282        let builtins = self
283            .state
284            .used_builtins
285            .clone()
286            .into_iter()
287            .map(|(builtin, (id, item))| {
288                let ty = Item::Pointer(StorageClass::Input, Box::new(item)).id(self);
289                self.variable(ty, Some(id), StorageClass::Input, None);
290                self.decorate(id, Decoration::BuiltIn, vec![builtin.into()]);
291                id
292            })
293            .collect::<Vec<_>>();
294
295        target.set_modes(self, main, builtins, cube_dims);
296
297        let module = take(&mut self.builder).module();
298        (module, self.opt.as_ref().clone())
299    }
300
301    fn setup(&mut self, label: Word, debug_setup: impl Fn(&mut Self)) -> usize {
302        self.begin_block(Some(label)).unwrap();
303
304        let opt = self.opt.clone();
305        for const_arr in opt.const_arrays() {
306            self.register_const_array(const_arr);
307        }
308
309        debug_setup(self);
310
311        let setup_block = self.selected_block().unwrap();
312        self.select_block(None).unwrap();
313        setup_block
314    }
315
316    #[track_caller]
317    pub fn current_block(&self) -> BasicBlock {
318        self.opt.block(self.current_block.unwrap()).clone()
319    }
320
321    pub fn builtin(&mut self, builtin: BuiltIn, item: Item) -> Word {
322        if let Some(existing) = self.state.used_builtins.get(&builtin) {
323            existing.0
324        } else {
325            let id = self.id();
326            self.state.used_builtins.insert(builtin, (id, item));
327            id
328        }
329    }
330
331    pub fn compile_block(&mut self, block: NodeIndex) {
332        if self.visited.contains(&block) {
333            return;
334        }
335        self.visited.insert(block);
336        self.current_block = Some(block);
337
338        let label = self.label(block);
339        self.begin_block(Some(label)).unwrap();
340        let block_id = self.selected_block().unwrap();
341
342        self.debug_start_block();
343
344        let operations = self.current_block().ops.borrow().clone();
345        for (_, operation) in operations {
346            self.compile_operation(operation);
347        }
348
349        let control_flow = self.current_block().control_flow.borrow().clone();
350        self.compile_control_flow(control_flow);
351
352        let current = self.selected_block();
353        self.select_block(Some(block_id)).unwrap();
354        let phi = { self.opt.block(block).phi_nodes.borrow().clone() };
355        for phi in phi {
356            let out = self.compile_variable(phi.out);
357            let ty = out.item().id(self);
358            let out_id = self.write_id(&out);
359            let entries: Vec<_> = phi
360                .entries
361                .into_iter()
362                .map(|it| {
363                    let label = self.end_label(it.block);
364                    let value = self.compile_variable(it.value);
365                    let value = self.read(&value);
366                    (value, label)
367                })
368                .collect();
369            self.insert_phi(InsertPoint::Begin, ty, Some(out_id), entries)
370                .unwrap();
371        }
372        self.select_block(current).unwrap();
373    }
374
375    // Declare variable in the first block of the function
376    pub fn declare_function_variable(&mut self, ty: Word) -> Word {
377        let setup = self.setup_block;
378        let id = self.id();
379        let var = Instruction::new(
380            Op::Variable,
381            Some(ty),
382            Some(id),
383            vec![Operand::StorageClass(StorageClass::Function)],
384        );
385        let current_block = self.selected_block();
386        self.select_block(Some(setup)).unwrap();
387        self.insert_into_block(InsertPoint::Begin, var).unwrap();
388        self.select_block(current_block).unwrap();
389        id
390    }
391
392    fn declare_shared_memories(&mut self) {
393        if self.compilation_options.supports_explicit_smem {
394            self.declare_shared_memories_explicit();
395        } else {
396            self.declare_shared_memories_implicit();
397        }
398    }
399
400    /// When using `VK_KHR_workgroup_memory_explicit_layout`, all shared memory is declared as a
401    /// `Block`. This means they are all pointers into the same chunk of memory, with different
402    /// offsets and sizes. Unlike C++, this shared block is declared implicitly, not explicitly.
403    /// Alignment and total size is calculated by the driver.
404    fn declare_shared_memories_explicit(&mut self) {
405        let shared_arrays = self.state.shared_arrays.clone();
406        let shared = self.state.shared.clone();
407        if shared_arrays.is_empty() && shared.is_empty() {
408            return;
409        }
410
411        self.capabilities
412            .insert(Capability::WorkgroupMemoryExplicitLayoutKHR);
413
414        for (index, memory) in shared_arrays {
415            let item_size = memory.item.size();
416
417            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
418            // explicit layout as well.
419            match item_size {
420                1 => {
421                    self.capabilities
422                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
423                }
424                2 => {
425                    self.capabilities
426                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
427                }
428                _ => {}
429            }
430
431            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
432            let arr_id = arr_ty.id(self);
433
434            if !self.state.decorated_types.contains(&arr_id) {
435                self.decorate(
436                    arr_id,
437                    Decoration::ArrayStride,
438                    [Operand::LiteralBit32(item_size)],
439                );
440                self.state.decorated_types.insert(arr_id);
441            }
442
443            let block_ty = Item::Struct(vec![arr_ty]);
444            let block_id = block_ty.id(self);
445
446            self.decorate(block_id, Decoration::Block, []);
447            self.member_decorate(
448                block_id,
449                0,
450                Decoration::Offset,
451                [Operand::LiteralBit32(memory.offset)],
452            );
453
454            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
455
456            self.debug_shared(memory.id, index);
457            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
458            self.decorate(memory.id, Decoration::Aliased, []);
459        }
460
461        for (index, memory) in shared {
462            let item_size = memory.item.size();
463
464            // It's safe to assume that if 8-bit/16-bit types are supported, they're supported for
465            // explicit layout as well.
466            match item_size {
467                1 => {
468                    self.capabilities
469                        .insert(Capability::WorkgroupMemoryExplicitLayout8BitAccessKHR);
470                }
471                2 => {
472                    self.capabilities
473                        .insert(Capability::WorkgroupMemoryExplicitLayout16BitAccessKHR);
474                }
475                _ => {}
476            }
477
478            let block_ty = Item::Struct(vec![memory.item]);
479            let block_id = block_ty.id(self);
480
481            self.decorate(block_id, Decoration::Block, []);
482            self.member_decorate(
483                block_id,
484                0,
485                Decoration::Offset,
486                [Operand::LiteralBit32(memory.offset)],
487            );
488
489            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(block_ty)).id(self);
490
491            self.debug_shared(memory.id, index);
492            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
493            self.decorate(memory.id, Decoration::Aliased, []);
494        }
495    }
496
497    fn declare_shared_memories_implicit(&mut self) {
498        let shared_memories = self.state.shared_arrays.clone();
499        for (index, memory) in shared_memories {
500            let arr_ty = Item::Array(Box::new(memory.item), memory.len);
501            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(arr_ty)).id(self);
502
503            self.debug_shared(memory.id, index);
504            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
505        }
506        let shared = self.state.shared.clone();
507        for (index, memory) in shared {
508            let ptr_ty = Item::Pointer(StorageClass::Workgroup, Box::new(memory.item)).id(self);
509
510            self.debug_shared(memory.id, index);
511            self.variable(ptr_ty, Some(memory.id), StorageClass::Workgroup, None);
512        }
513    }
514
515    pub fn declare_math_mode(&mut self, modes: InstructionModes, out_id: Word) {
516        if !self.compilation_options.supports_fp_fast_math || modes.fp_math_mode.is_empty() {
517            return;
518        }
519        let mode = convert_math_mode(modes.fp_math_mode);
520        self.capabilities.insert(Capability::FloatControls2);
521        self.decorate(
522            out_id,
523            Decoration::FPFastMathMode,
524            [Operand::FPFastMathMode(mode)],
525        );
526    }
527
528    pub fn is_uniform_block(&self) -> bool {
529        self.uniformity
530            .is_block_uniform(self.current_block.unwrap())
531    }
532}
533
534pub(crate) fn convert_math_mode(math_mode: EnumSet<FastMath>) -> FPFastMathMode {
535    let mut flags = FPFastMathMode::NONE;
536
537    for mode in math_mode.iter() {
538        match mode {
539            FastMath::NotNaN => flags |= FPFastMathMode::NOT_NAN,
540            FastMath::NotInf => flags |= FPFastMathMode::NOT_INF,
541            FastMath::UnsignedZero => flags |= FPFastMathMode::NSZ,
542            FastMath::AllowReciprocal => flags |= FPFastMathMode::ALLOW_RECIP,
543            FastMath::AllowContraction => flags |= FPFastMathMode::ALLOW_CONTRACT,
544            FastMath::AllowReassociation => flags |= FPFastMathMode::ALLOW_REASSOC,
545            FastMath::AllowTransform => {
546                flags |= FPFastMathMode::ALLOW_CONTRACT
547                    | FPFastMathMode::ALLOW_REASSOC
548                    | FPFastMathMode::ALLOW_TRANSFORM
549            }
550            _ => {}
551        }
552    }
553
554    flags
555}