1
2
3#[derive(Debug , Clone)]
14pub struct compute_kernel{
15 pub x : u32 ,
17 pub y : u32 ,
19 pub z : u32 ,
21 pub code : String,
24}
25
26impl compute_kernel{
27 fn new(code : String) -> Self{
29 compute_kernel{
30 x : 1,
31 y : 1,
32 z : 1,
33 code : code,
34 }
35 }
36}
37
38#[derive(Debug , Clone)]
50pub struct info<T>{
51 pub bind : u32,
53 pub group : u32,
55 pub data : T,
57}
58
59#[derive(Debug)]
66pub struct compute_config{
67 pub _wgpu_instance : wgpu::Instance,
69 pub _wgpu_adapter : wgpu::Adapter,
71 pub _wgpu_queue : wgpu::Queue,
73 pub _wgpu_device : wgpu::Device,
75 pub _entry_point : String,
77}
78
79impl Default for compute_config{
81 fn default() -> Self {
83 let instance = wgpu::Instance::default();
84 let adapter = pollster::block_on(instance
85 .request_adapter(&wgpu::RequestAdapterOptions::default()))
86 .expect("ERROR : failed to get adapter");
87 let (device, queue) = pollster::block_on(adapter
88 .request_device(
89 &wgpu::DeviceDescriptor {
90 label: None,
91 required_features: wgpu::Features::empty(),
92 required_limits: wgpu::Limits::downlevel_defaults(),
93 memory_hints: wgpu::MemoryHints::MemoryUsage,
94 },
95 None,
96 ))
97 .expect("ERROR : Adapter could not find the device");
98
99 Self {
100 _wgpu_instance : instance ,
101 _wgpu_adapter : adapter ,
102 _wgpu_queue : queue ,
103 _wgpu_device : device ,
104 _entry_point : "main".to_string() ,
105 }
106 }
107}
108
109#[macro_export]
113macro_rules! compute_ext {
114 ($config:expr , $kernel:expr, $($data:expr),*) => {
115 {
116 use wgpu::util::DeviceExt;
117 use std::collections::HashMap;
118
119
120
121 let instance = $config._wgpu_instance;
122
123 let adapter = $config._wgpu_adapter;
124 let device = $config._wgpu_device;
125 let queue = $config._wgpu_queue;
126
127
128 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
129 label: Some("Shader"),
130 source: wgpu::ShaderSource::Wgsl($kernel.code.into()),
131 });
132
133
134 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
135 label: None,
136 layout: None,
137 module: &shader,
138 entry_point: &$config._entry_point ,
139 compilation_options: Default::default(),
140 cache: None,
141 });
142
143
144
145 let mut encoder =
146 device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
147
148 let mut staging_buffers : Vec<wgpu::Buffer> = Vec::new();
149 let mut sizes : Vec<wgpu::BufferAddress> = Vec::new();
150 let mut storage_buffers : Vec<wgpu::Buffer> = Vec::new();
151
152 #[derive(Debug)]
153 struct buf_index {
154 index: usize ,
155 bind : u32 ,
156 }
157
158 let mut grouponized : HashMap<u32 , Vec<buf_index>> = HashMap::new();
159
160 {
161 let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
162 label: None,
163 timestamp_writes: None,
164 });
165
166
167
168
169 $(
170 if !grouponized.contains_key(&$data.group){
171 grouponized.insert($data.group , Vec::new());
172 }
173 let refr = $data.data.as_slice();
174 let size = std::mem::size_of_val(refr) as wgpu::BufferAddress;
175
176 sizes.push(size);
177
178 let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
179 label: None,
180 size : sizes[sizes.len() - 1],
181 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
182 mapped_at_creation: false,
183 });
184
185 staging_buffers.push(staging_buffer);
186
187
188 let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
189 label: Some("Storage Buffer"),
190 contents: bytemuck::cast_slice(refr),
191 usage: wgpu::BufferUsages::STORAGE
192 | wgpu::BufferUsages::COPY_DST
193 | wgpu::BufferUsages::COPY_SRC,
194 });
195 storage_buffers.push(storage_buffer);
196
197 grouponized.get_mut(&$data.group).expect("ERROR : smth went wrong !").push(buf_index{
198 index : sizes.len() - 1,
199 bind : $data.bind
200 });
201
202 )*
203
204
205 for group in grouponized.keys(){
206 let bind_group_layout = compute_pipeline.get_bind_group_layout(group.clone());
207
208 let mut entries : Vec<wgpu::BindGroupEntry> = Vec::new();
209 let data = grouponized.get(&group).expect("ERROR : smth went wrong !");
210 for GroupEntry in data {
211 entries.push(wgpu::BindGroupEntry{
212 binding : GroupEntry.bind ,
213 resource : storage_buffers[GroupEntry.index].as_entire_binding(),
214 });
215 }
216 let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
217 label: None,
218 layout: &bind_group_layout,
219 entries: entries.as_slice() ,
220 });
221
222 cpass.set_pipeline(&compute_pipeline);
223 cpass.set_bind_group(group.clone(), &bind_group, &[]);
224
225
226 }
227
228
229
230
231
232
233
234 cpass.insert_debug_marker("debug_marker");
235 cpass.dispatch_workgroups($kernel.x, $kernel.y, $kernel.z);
236 }
237
238 for (index, storage_buffer) in storage_buffers.iter().enumerate() {
239
240
241 encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffers[index], 0, sizes[index]);
242 }
243
244 queue.submit(Some(encoder.finish()));
245
246
247
248
249 let mut index = 0;
250 $(
251 let buffer_slice = staging_buffers[index].slice(..);
252 let (sender, receiver) = flume::bounded(1);
253 buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
254
255
256 device.poll(wgpu::Maintain::wait()).panic_on_timeout();
257
258 if let Ok(Ok(())) = pollster::block_on(receiver.recv_async()) {
259 let data = buffer_slice.get_mapped_range();
260 let casted_data = bytemuck::cast_slice(&data).to_vec();
261
262 for (i, &value) in casted_data.iter().enumerate() {
263 $data.data[i] = value;
264 }
265
266
267 drop(data);
268 staging_buffers[index].unmap();
269 } else {
270 panic!("failed to run compute on gpu!")
271 }
272
273 index += 1;
274 )*
275
276
277
278 }
279 };
280}
281
282#[macro_export]
289macro_rules! compute {
290 ($kernel:expr, $($data:expr),*) => {
291 {
292 let config = core_compute::compute_config::default();
293 core_compute::compute_ext!(config , $kernel, $($data),*);
294 }
295 };
296}