cubecl_spirv/
target.rs

1use cubecl_core::compute::{Binding, Location, Visibility};
2use hashbrown::HashMap;
3use rspirv::{
4    dr,
5    spirv::{
6        self, AddressingModel, Capability, Decoration, ExecutionModel, MemoryModel, StorageClass,
7        Word,
8    },
9};
10use std::{fmt::Debug, iter};
11
12use crate::{
13    SpirvCompiler, debug,
14    extensions::{TargetExtensions, glcompute},
15    item::Item,
16};
17
18pub trait SpirvTarget:
19    TargetExtensions<Self> + Debug + Clone + Default + Send + Sync + 'static
20{
21    fn extensions(&mut self, b: &mut SpirvCompiler<Self>) -> HashMap<String, Word>;
22    fn set_modes(
23        &mut self,
24        b: &mut SpirvCompiler<Self>,
25        main: Word,
26        builtins: Vec<Word>,
27        cube_dims: Vec<u32>,
28    );
29    fn generate_binding(
30        &mut self,
31        b: &mut SpirvCompiler<Self>,
32        binding: Binding,
33        name: String,
34    ) -> Word;
35
36    fn set_kernel_name(&mut self, name: impl Into<String>);
37}
38
39#[derive(Clone)]
40pub struct GLCompute {
41    kernel_name: String,
42}
43
44impl Default for GLCompute {
45    fn default() -> Self {
46        Self {
47            kernel_name: "main".into(),
48        }
49    }
50}
51
52impl Debug for GLCompute {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        f.write_str("gl_compute")
55    }
56}
57
58impl SpirvTarget for GLCompute {
59    fn set_modes(
60        &mut self,
61        b: &mut SpirvCompiler<Self>,
62        main: Word,
63        builtins: Vec<Word>,
64        cube_dims: Vec<u32>,
65    ) {
66        let interface: Vec<u32> = builtins
67            .into_iter()
68            .chain(b.state.buffers.iter().copied())
69            .chain(iter::once(b.state.info))
70            .chain(b.state.scalar_bindings.values().copied())
71            .chain(b.state.shared_memories.values().map(|it| it.id))
72            .collect();
73
74        b.capability(Capability::Shader);
75        b.capability(Capability::VulkanMemoryModel);
76        b.capability(Capability::VulkanMemoryModelDeviceScope);
77
78        let caps: Vec<_> = b.capabilities.iter().copied().collect();
79        for cap in caps.iter() {
80            b.capability(*cap);
81        }
82        if b.float_controls {
83            let inst = dr::Instruction::new(
84                spirv::Op::Capability,
85                None,
86                None,
87                vec![dr::Operand::LiteralBit32(6029)],
88            );
89            b.module_mut().capabilities.push(inst);
90        }
91
92        if caps.contains(&Capability::CooperativeMatrixKHR) {
93            b.extension("SPV_KHR_cooperative_matrix");
94        }
95
96        if caps.contains(&Capability::AtomicFloat16AddEXT) {
97            b.extension("SPV_EXT_shader_atomic_float16_add");
98        }
99
100        if caps.contains(&Capability::AtomicFloat32AddEXT)
101            | caps.contains(&Capability::AtomicFloat64AddEXT)
102        {
103            b.extension("SPV_EXT_shader_atomic_float_add");
104        }
105
106        if caps.contains(&Capability::AtomicFloat16MinMaxEXT)
107            | caps.contains(&Capability::AtomicFloat32MinMaxEXT)
108            | caps.contains(&Capability::AtomicFloat64MinMaxEXT)
109        {
110            b.extension("SPV_EXT_shader_atomic_float_min_max");
111        }
112
113        if b.float_controls {
114            b.extension("SPV_KHR_float_controls2");
115        }
116
117        if b.debug_symbols {
118            b.extension("SPV_KHR_non_semantic_info");
119        }
120
121        b.memory_model(AddressingModel::Logical, MemoryModel::Vulkan);
122        b.entry_point(
123            ExecutionModel::GLCompute,
124            main,
125            &self.kernel_name,
126            interface,
127        );
128        b.execution_mode(main, spirv::ExecutionMode::LocalSize, cube_dims);
129
130        if b.float_controls {
131            b.declare_float_execution_modes(main);
132        }
133    }
134
135    fn generate_binding(
136        &mut self,
137        b: &mut SpirvCompiler<Self>,
138        binding: Binding,
139        name: String,
140    ) -> Word {
141        let index = binding.id;
142        let item = b.compile_item(binding.item);
143        let item = match binding.size {
144            Some(size) => Item::Array(Box::new(item), size as u32),
145            None => Item::RuntimeArray(Box::new(item)),
146        };
147        let arr = item.id(b); // pre-generate type
148        let struct_ty = b.id();
149        b.type_struct_id(Some(struct_ty), vec![arr]);
150
151        let location = match binding.location {
152            Location::Cube => StorageClass::Workgroup,
153            Location::Storage => StorageClass::StorageBuffer,
154        };
155        let ptr_ty = b.type_pointer(None, location, struct_ty);
156        let var = b.variable(ptr_ty, None, location, None);
157
158        b.debug_name(var, name);
159
160        if matches!(binding.visibility, Visibility::Read) {
161            b.decorate(var, Decoration::NonWritable, vec![]);
162        }
163
164        b.decorate(var, Decoration::DescriptorSet, vec![0u32.into()]);
165        b.decorate(var, Decoration::Binding, vec![index.into()]);
166        b.decorate(struct_ty, Decoration::Block, vec![]);
167        b.member_decorate(struct_ty, 0, Decoration::Offset, vec![0u32.into()]);
168
169        var
170    }
171
172    fn extensions(&mut self, b: &mut SpirvCompiler<Self>) -> HashMap<String, Word> {
173        let mut extensions = HashMap::new();
174        extensions.insert(
175            glcompute::STD_NAME.to_string(),
176            b.ext_inst_import(glcompute::STD_NAME),
177        );
178        if b.debug_symbols {
179            extensions.insert(
180                debug::DEBUG_EXT_NAME.to_string(),
181                b.ext_inst_import(debug::DEBUG_EXT_NAME),
182            );
183            extensions.insert(
184                debug::PRINT_EXT_NAME.to_string(),
185                b.ext_inst_import(debug::PRINT_EXT_NAME),
186            );
187        }
188        extensions
189    }
190
191    fn set_kernel_name(&mut self, name: impl Into<String>) {
192        self.kernel_name = name.into();
193    }
194}