cubecl_spirv/
target.rs

1use cubecl_core::prelude::{Binding, Location, Visibility};
2use rspirv::spirv::{
3    self, AddressingModel, Capability, Decoration, ExecutionModel, MemoryModel, StorageClass, Word,
4};
5use std::{fmt::Debug, iter};
6
7use crate::{SpirvCompiler, extensions::TargetExtensions, item::Item};
8
9pub trait SpirvTarget:
10    TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
11{
12    fn set_modes(
13        &mut self,
14        b: &mut SpirvCompiler<Self>,
15        main: Word,
16        builtins: Vec<Word>,
17        cube_dims: Vec<u32>,
18    );
19    fn generate_binding(
20        &mut self,
21        b: &mut SpirvCompiler<Self>,
22        binding: Binding,
23        name: String,
24    ) -> Word;
25
26    fn set_kernel_name(&mut self, name: impl Into<String>);
27}
28
29#[derive(Clone)]
30pub struct GLCompute {
31    kernel_name: String,
32}
33
34impl Default for GLCompute {
35    fn default() -> Self {
36        Self {
37            kernel_name: "main".into(),
38        }
39    }
40}
41
42impl Debug for GLCompute {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        f.write_str("gl_compute")
45    }
46}
47
48impl SpirvTarget for GLCompute {
49    fn set_modes(
50        &mut self,
51        b: &mut SpirvCompiler<Self>,
52        main: Word,
53        builtins: Vec<Word>,
54        cube_dims: Vec<u32>,
55    ) {
56        let interface: Vec<u32> = builtins
57            .into_iter()
58            .chain(b.state.buffers.iter().copied())
59            .chain(iter::once(b.state.info))
60            .chain(b.state.scalar_bindings.values().copied())
61            .chain(b.state.shared_memories.values().map(|it| it.id))
62            .collect();
63
64        b.capability(Capability::Shader);
65        b.capability(Capability::VulkanMemoryModel);
66        b.capability(Capability::VulkanMemoryModelDeviceScope);
67        b.capability(Capability::GroupNonUniform);
68
69        if b.compilation_options.supports_explicit_smem {
70            b.extension("SPV_KHR_workgroup_memory_explicit_layout");
71        }
72
73        let caps: Vec<_> = b.capabilities.iter().copied().collect();
74        for cap in caps.iter() {
75            b.capability(*cap);
76        }
77
78        if caps.contains(&Capability::CooperativeMatrixKHR) {
79            b.extension("SPV_KHR_cooperative_matrix");
80        }
81
82        if caps.contains(&Capability::AtomicFloat16AddEXT) {
83            b.extension("SPV_EXT_shader_atomic_float16_add");
84        }
85
86        if caps.contains(&Capability::AtomicFloat32AddEXT)
87            | caps.contains(&Capability::AtomicFloat64AddEXT)
88        {
89            b.extension("SPV_EXT_shader_atomic_float_add");
90        }
91
92        if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
93            | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
94            | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
95        {
96            b.extension("SPV_EXT_shader_atomic_float_min_max");
97        }
98
99        if caps.contains(&Capability::BFloat16TypeKHR)
100            || caps.contains(&Capability::BFloat16CooperativeMatrixKHR)
101            || caps.contains(&Capability::BFloat16DotProductKHR)
102        {
103            b.extension("SPV_KHR_bfloat16");
104        }
105
106        if caps.contains(&Capability::Float8EXT)
107            || caps.contains(&Capability::Float8CooperativeMatrixEXT)
108        {
109            b.extension("SPV_EXT_float8");
110        }
111
112        if caps.contains(&Capability::FloatControls2) {
113            b.extension("SPV_KHR_float_controls2");
114        }
115
116        if b.debug_symbols {
117            b.extension("SPV_KHR_non_semantic_info");
118        }
119
120        b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
121        b.entry_point(
122            ExecutionModel::GLCompute,
123            main,
124            &self.kernel_name,
125            interface,
126        );
127        b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
128    }
129
130    fn generate_binding(
131        &mut self,
132        b: &mut SpirvCompiler<Self>,
133        binding: Binding,
134        name: String,
135    ) -> Word {
136        let index = binding.id;
137        let item = b.compile_type(binding.ty);
138        let item_size = item.size();
139        let item = match binding.size {
140            Some(size) => Item::Array(Box::new(item), size as u32),
141            None => Item::RuntimeArray(Box::new(item)),
142        };
143        let arr = item.id(b); // pre-generate type
144
145        if !b.state.array_types.contains(&arr) {
146            b.decorate(arr, Decoration::ArrayStride, [item_size.into()]);
147            b.state.array_types.insert(arr);
148        }
149
150        let struct_ty = b.id();
151        b.type_struct_id(Some(struct_ty), vec![arr]);
152
153        let location = match binding.location {
154            Location::Cube => StorageClass::Workgroup,
155            Location::Storage => StorageClass::StorageBuffer,
156        };
157        let ptr_ty = b.type_pointer(None, location, struct_ty);
158        let var = b.variable(ptr_ty, None, location, None);
159
160        b.debug_name(var, name);
161
162        if matches!(binding.visibility, Visibility::Read) {
163            b.decorate(var, Decoration::NonWritable, vec![]);
164        }
165
166        b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
167        b.decorate(var, Decoration::Binding, vec![index.into()]);
168        b.decorate(struct_ty, Decoration::Block, vec![]);
169        b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
170
171        var
172    }
173
174    fn set_kernel_name(&mut self, name: impl Into<String>) {
175        self.kernel_name = name.into();
176    }
177}