Skip to main content

cubecl_spirv/
target.rs

1use cubecl_core::prelude::{Binding, Location, Visibility};
2use rspirv::spirv::{
3    self, AddressingModel, Capability, Decoration, ExecutionMode, ExecutionModel, MemoryModel,
4    StorageClass, Word,
5};
6use std::{fmt::Debug, iter};
7
8use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
9
10pub trait SpirvTarget:
11    TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
12{
13    fn set_modes(
14        &mut self,
15        b: &mut SpirvCompiler<Self>,
16        main: Word,
17        builtins: Vec<Word>,
18        cube_dims: Vec<u32>,
19    );
20    fn generate_binding(
21        &mut self,
22        b: &mut SpirvCompiler<Self>,
23        binding: Binding,
24        name: String,
25    ) -> Word;
26
27    fn set_kernel_name(&mut self, name: impl Into<String>);
28}
29
30#[derive(Clone)]
31pub struct GLCompute {
32    kernel_name: String,
33}
34
35impl Default for GLCompute {
36    fn default() -> Self {
37        Self {
38            kernel_name: "main".into(),
39        }
40    }
41}
42
43impl Debug for GLCompute {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.write_str("gl_compute")
46    }
47}
48
49impl SpirvTarget for GLCompute {
50    fn set_modes(
51        &mut self,
52        b: &mut SpirvCompiler<Self>,
53        main: Word,
54        builtins: Vec<Word>,
55        cube_dims: Vec<u32>,
56    ) {
57        let interface: Vec<u32> = builtins
58            .into_iter()
59            .chain(b.state.buffers.iter().copied())
60            .chain(iter::once(b.state.info))
61            .chain(b.state.scalar_bindings.values().copied())
62            .chain(b.state.shared_arrays.values().map(|it| it.id))
63            .chain(b.state.shared.values().map(|it| it.id))
64            .collect();
65
66        b.capability(Capability::Shader);
67        b.capability(Capability::VulkanMemoryModel);
68        b.capability(Capability::VulkanMemoryModelDeviceScope);
69        b.capability(Capability::GroupNonUniform);
70
71        if b.compilation_options.supports_explicit_smem {
72            b.extension("SPV_KHR_workgroup_memory_explicit_layout");
73        }
74
75        if b.addr_type.size_bits() == 64 {
76            b.extension("SPV_EXT_shader_64bit_indexing");
77            b.capability(Capability::Shader64BitIndexingEXT);
78            b.execution_mode(main, ExecutionMode::Shader64BitIndexingEXT, []);
79        }
80
81        let caps: Vec<_> = b.capabilities.iter().copied().collect();
82        for cap in caps.iter() {
83            b.capability(*cap);
84        }
85
86        if caps.contains(&Capability::CooperativeMatrixKHR) {
87            b.extension("SPV_KHR_cooperative_matrix");
88        }
89
90        if caps.contains(&Capability::AtomicFloat16AddEXT) {
91            b.extension("SPV_EXT_shader_atomic_float16_add");
92        }
93
94        if caps.contains(&Capability::AtomicFloat32AddEXT)
95            | caps.contains(&Capability::AtomicFloat64AddEXT)
96        {
97            b.extension("SPV_EXT_shader_atomic_float_add");
98        }
99
100        if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
101            | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
102            | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
103        {
104            b.extension("SPV_EXT_shader_atomic_float_min_max");
105        }
106
107        if caps.contains(&Capability::BFloat16TypeKHR)
108            || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
109            || caps.contains(&Capability::BFloat16DotProductKHR)
110        {
111            b.extension("SPV_KHR_bfloat16");
112        }
113
114        if caps.contains(&Capability::Float8EXT)
115            || caps.contains(&Capability::Float8CooperativeMatrixEXT)
116        {
117            b.extension("SPV_EXT_float8");
118        }
119
120        if caps.contains(&Capability::FloatControls2) {
121            b.extension("SPV_KHR_float_controls2");
122        }
123
124        if b.debug_symbols {
125            b.extension("SPV_KHR_non_semantic_info");
126        }
127
128        b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
129        b.entry_point(
130            ExecutionModel::GLCompute,
131            main,
132            &self.kernel_name,
133            interface,
134        );
135        b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
136    }
137
138    fn generate_binding(
139        &mut self,
140        b: &mut SpirvCompiler<Self>,
141        binding: Binding,
142        name: String,
143    ) -> Word {
144        let index = binding.id;
145        let item = b.compile_type(binding.ty);
146        let item_size = item.size();
147        let item = match binding.size {
148            Some(size) => Item::Array(Box::new(item), size as u32),
149            None => Item::RuntimeArray(Box::new(item)),
150        };
151        let arr = item.id(b); // pre-generate type
152
153        if !b.state.array_types.contains(&arr) {
154            b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
155            b.state.array_types.insert(arr);
156        }
157
158        let struct_ty = b.id();
159        b.type_struct_id(Some(struct_ty), vec![arr]);
160
161        let location = match binding.location {
162            Location::Cube => StorageClass::Workgroup,
163            Location::Storage => StorageClass::StorageBuffer,
164        };
165        let ptr_ty = b.type_pointer(None, location, struct_ty);
166        let var = b.variable(ptr_ty, None, location, None);
167
168        b.debug_name(var, name);
169
170        if matches!(binding.visibility, Visibility::Read) {
171            b.decorate(var, Decoration::NonWritable, vec![]);
172        }
173
174        b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
175        b.decorate(var, Decoration::Binding, vec![index.into()]);
176        b.decorate(struct_ty, Decoration::Block, vec![]);
177        b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
178
179        var
180    }
181
182    fn set_kernel_name(&mut self, name: impl Into<String>) {
183        self.kernel_name = name.into();
184    }
185}