#[derive(Debug , Clone)]
pub struct compute_kernel{
pub x : u32 ,
pub y : u32 ,
pub z : u32 ,
pub code : String,
}
impl compute_kernel{
fn new(code : String) -> Self{
compute_kernel{
x : 1,
y : 1,
z : 1,
code : code,
}
}
}
#[derive(Debug , Clone)]
pub struct info<T>{
pub bind : u32,
pub group : u32,
pub data : T,
}
#[derive(Debug)]
pub struct compute_config{
pub _wgpu_instance : wgpu::Instance,
pub _wgpu_adapter : wgpu::Adapter,
pub _wgpu_queue : wgpu::Queue,
pub _wgpu_device : wgpu::Device,
pub _entry_point : String,
}
impl Default for compute_config{
fn default() -> Self {
let instance = wgpu::Instance::default();
let adapter = pollster::block_on(instance
.request_adapter(&wgpu::RequestAdapterOptions::default()))
.expect("ERROR : failed to get adapter");
let (device, queue) = pollster::block_on(adapter
.request_device(
&wgpu::DeviceDescriptor {
label: None,
required_features: wgpu::Features::empty(),
required_limits: wgpu::Limits::downlevel_defaults(),
memory_hints: wgpu::MemoryHints::MemoryUsage,
},
None,
))
.expect("ERROR : Adapter could not find the device");
Self {
_wgpu_instance : instance ,
_wgpu_adapter : adapter ,
_wgpu_queue : queue ,
_wgpu_device : device ,
_entry_point : "main".to_string() ,
}
}
}
#[macro_export]
macro_rules! compute_ext {
($config:expr , $kernel:expr, $($data:expr),*) => {
{
use wgpu::util::DeviceExt;
use std::collections::HashMap;
let instance = $config._wgpu_instance;
let adapter = $config._wgpu_adapter;
let device = $config._wgpu_device;
let queue = $config._wgpu_queue;
let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Shader"),
source: wgpu::ShaderSource::Wgsl($kernel.code.into()),
});
let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: None,
layout: None,
module: &shader,
entry_point: &$config._entry_point ,
compilation_options: Default::default(),
cache: None,
});
let mut encoder =
device.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: None });
let mut staging_buffers : Vec<wgpu::Buffer> = Vec::new();
let mut sizes : Vec<wgpu::BufferAddress> = Vec::new();
let mut storage_buffers : Vec<wgpu::Buffer> = Vec::new();
#[derive(Debug)]
struct buf_index {
index: usize ,
bind : u32 ,
}
let mut grouponized : HashMap<u32 , Vec<buf_index>> = HashMap::new();
{
let mut cpass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: None,
timestamp_writes: None,
});
$(
if !grouponized.contains_key(&$data.group){
grouponized.insert($data.group , Vec::new());
}
let refr = $data.data.as_slice();
let size = std::mem::size_of_val(refr) as wgpu::BufferAddress;
sizes.push(size);
let staging_buffer = device.create_buffer(&wgpu::BufferDescriptor {
label: None,
size : sizes[sizes.len() - 1],
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
staging_buffers.push(staging_buffer);
let storage_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("Storage Buffer"),
contents: bytemuck::cast_slice(refr),
usage: wgpu::BufferUsages::STORAGE
| wgpu::BufferUsages::COPY_DST
| wgpu::BufferUsages::COPY_SRC,
});
storage_buffers.push(storage_buffer);
grouponized.get_mut(&$data.group).expect("ERROR : smth went wrong !").push(buf_index{
index : sizes.len() - 1,
bind : $data.bind
});
)*
for group in grouponized.keys(){
let bind_group_layout = compute_pipeline.get_bind_group_layout(group.clone());
let mut entries : Vec<wgpu::BindGroupEntry> = Vec::new();
let data = grouponized.get(&group).expect("ERROR : smth went wrong !");
for GroupEntry in data {
entries.push(wgpu::BindGroupEntry{
binding : GroupEntry.bind ,
resource : storage_buffers[GroupEntry.index].as_entire_binding(),
});
}
let bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
label: None,
layout: &bind_group_layout,
entries: entries.as_slice() ,
});
cpass.set_pipeline(&compute_pipeline);
cpass.set_bind_group(group.clone(), &bind_group, &[]);
}
cpass.insert_debug_marker("debug_marker");
cpass.dispatch_workgroups($kernel.x, $kernel.y, $kernel.z);
}
for (index, storage_buffer) in storage_buffers.iter().enumerate() {
encoder.copy_buffer_to_buffer(&storage_buffer, 0, &staging_buffers[index], 0, sizes[index]);
}
queue.submit(Some(encoder.finish()));
let mut index = 0;
$(
let buffer_slice = staging_buffers[index].slice(..);
let (sender, receiver) = flume::bounded(1);
buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
device.poll(wgpu::Maintain::wait()).panic_on_timeout();
if let Ok(Ok(())) = pollster::block_on(receiver.recv_async()) {
let data = buffer_slice.get_mapped_range();
let casted_data = bytemuck::cast_slice(&data).to_vec();
for (i, &value) in casted_data.iter().enumerate() {
$data.data[i] = value;
}
drop(data);
staging_buffers[index].unmap();
} else {
panic!("failed to run compute on gpu!")
}
index += 1;
)*
}
};
}
#[macro_export]
macro_rules! compute {
($kernel:expr, $($data:expr),*) => {
{
let config = core_compute::compute_config::default();
core_compute::compute_ext!(config , $kernel, $($data),*);
}
};
}