cubecl_spirv/
target.rs

1use cubecl_core::prelude::{Binding, Location, Visibility};
2use rspirv::spirv::{
3    self, AddressingModel, Capability, Decoration, ExecutionModel, MemoryModel, StorageClass, Word,
4};
5use std::{fmt::Debug, iter};
6
7use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
8
9pub trait SpirvTarget:
10    TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
11{
12    fn set_modes(
13        &mut self,
14        b: &mut SpirvCompiler<Self>,
15        main: Word,
16        builtins: Vec<Word>,
17        cube_dims: Vec<u32>,
18    );
19    fn generate_binding(
20        &mut self,
21        b: &mut SpirvCompiler<Self>,
22        binding: Binding,
23        name: String,
24    ) -> Word;
25
26    fn set_kernel_name(&mut self, name: impl Into<String>);
27}
28
29#[derive(Clone)]
30pub struct GLCompute {
31    kernel_name: String,
32}
33
34impl Default for GLCompute {
35    fn default() -> Self {
36        Self {
37            kernel_name: "main".into(),
38        }
39    }
40}
41
42impl Debug for GLCompute {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.write_str("gl_compute")
45    }
46}
47
48impl SpirvTarget for GLCompute {
49    fn set_modes(
50        &mut self,
51        b: &mut SpirvCompiler<Self>,
52        main: Word,
53        builtins: Vec<Word>,
54        cube_dims: Vec<u32>,
55    ) {
56        let interface: Vec<u32> = builtins
57            .into_iter()
58            .chain(b.state.buffers.iter().copied())
59            .chain(iter::once(b.state.info))
60            .chain(b.state.scalar_bindings.values().copied())
61            .chain(b.state.shared_arrays.values().map(|it| it.id))
62            .chain(b.state.shared.values().map(|it| it.id))
63            .collect();
64
65        b.capability(Capability::Shader);
66        b.capability(Capability::VulkanMemoryModel);
67        b.capability(Capability::VulkanMemoryModelDeviceScope);
68        b.capability(Capability::GroupNonUniform);
69
70        if b.compilation_options.supports_explicit_smem {
71            b.extension("SPV_KHR_workgroup_memory_explicit_layout");
72        }
73
74        let caps: Vec<_> = b.capabilities.iter().copied().collect();
75        for cap in caps.iter() {
76            b.capability(*cap);
77        }
78
79        if caps.contains(&Capability::CooperativeMatrixKHR) {
80            b.extension("SPV_KHR_cooperative_matrix");
81        }
82
83        if caps.contains(&Capability::AtomicFloat16AddEXT) {
84            b.extension("SPV_EXT_shader_atomic_float16_add");
85        }
86
87        if caps.contains(&Capability::AtomicFloat32AddEXT)
88            | caps.contains(&Capability::AtomicFloat64AddEXT)
89        {
90            b.extension("SPV_EXT_shader_atomic_float_add");
91        }
92
93        if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
94            | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
95            | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
96        {
97            b.extension("SPV_EXT_shader_atomic_float_min_max");
98        }
99
100        if caps.contains(&Capability::BFloat16TypeKHR)
101            || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
102            || caps.contains(&Capability::BFloat16DotProductKHR)
103        {
104            b.extension("SPV_KHR_bfloat16");
105        }
106
107        if caps.contains(&Capability::Float8EXT)
108            || caps.contains(&Capability::Float8CooperativeMatrixEXT)
109        {
110            b.extension("SPV_EXT_float8");
111        }
112
113        if caps.contains(&Capability::FloatControls2) {
114            b.extension("SPV_KHR_float_controls2");
115        }
116
117        if b.debug_symbols {
118            b.extension("SPV_KHR_non_semantic_info");
119        }
120
121        b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
122        b.entry_point(
123            ExecutionModel::GLCompute,
124            main,
125            &self.kernel_name,
126            interface,
127        );
128        b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
129    }
130
131    fn generate_binding(
132        &mut self,
133        b: &mut SpirvCompiler<Self>,
134        binding: Binding,
135        name: String,
136    ) -> Word {
137        let index = binding.id;
138        let item = b.compile_type(binding.ty);
139        let item_size = item.size();
140        let item = match binding.size {
141            Some(size) => Item::Array(Box::new(item), size as u32),
142            None => Item::RuntimeArray(Box::new(item)),
143        };
144        let arr = item.id(b); // pre-generate type
145
146        if !b.state.array_types.contains(&arr) {
147            b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
148            b.state.array_types.insert(arr);
149        }
150
151        let struct_ty = b.id();
152        b.type_struct_id(Some(struct_ty), vec![arr]);
153
154        let location = match binding.location {
155            Location::Cube => StorageClass::Workgroup,
156            Location::Storage => StorageClass::StorageBuffer,
157        };
158        let ptr_ty = b.type_pointer(None, location, struct_ty);
159        let var = b.variable(ptr_ty, None, location, None);
160
161        b.debug_name(var, name);
162
163        if matches!(binding.visibility, Visibility::Read) {
164            b.decorate(var, Decoration::NonWritable, vec![]);
165        }
166
167        b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
168        b.decorate(var, Decoration::Binding, vec![index.into()]);
169        b.decorate(struct_ty, Decoration::Block, vec![]);
170        b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
171
172        var
173    }
174
175    fn set_kernel_name(&mut self, name: impl Into<String>) {
176        self.kernel_name = name.into();
177    }
178}