use wgpu::{BackendOptions, Backends, InstanceFlags, Limits};
use crate::prelude::GpuBuffers;
#[derive(Debug, PartialEq, Eq)]
pub enum MemoryMetric{
GB,
MB,
KB,
B,
}
pub fn get_size_using_metric(size: u64, metric: &MemoryMetric) -> u64{
if metric == &MemoryMetric::GB{
return (size * 1024 * 1024 * 1024);
}
else if metric == &MemoryMetric::MB{
return (size * 1024 * 1024);
}
else if metric == &MemoryMetric::KB{
return (size * 1024);
}
else if metric == &MemoryMetric::B{
return size;
}
0
}
pub async fn gpu_init(max_buffer_size: u64, metric: &MemoryMetric) -> (wgpu::Device, wgpu::Queue){
let mut real_buffer_size: u64;
real_buffer_size = get_size_using_metric(max_buffer_size, metric);
if real_buffer_size > 2_147_483_648{
panic!("Buffer size too big");
}
if real_buffer_size == 0{
real_buffer_size = 1024*1024*256;
}
if real_buffer_size == 2_147_483_648{
real_buffer_size = 2_147_483_647;
}
let limits = Limits{
max_buffer_size: real_buffer_size,
max_storage_buffer_binding_size: real_buffer_size as u32,
..Limits::downlevel_defaults()
};
let instance = wgpu::Instance::new(&wgpu::InstanceDescriptor{
backends: Backends::PRIMARY,
flags: InstanceFlags::default(),
backend_options: BackendOptions::default(),
});
let adapter = instance.request_adapter(&wgpu::RequestAdapterOptions::default())
.await.expect("No adapter found");
let device_descriptor = wgpu::DeviceDescriptor{
label: Some("New Device"),
required_features: wgpu::Features::empty(),
required_limits: limits,
memory_hints: wgpu::MemoryHints::Performance,
trace: wgpu::Trace::Off,
};
adapter.request_device(&device_descriptor)
.await.expect("No device")
}
pub fn get_bind_group_layout(buffers: &GpuBuffers) -> wgpu::BindGroupLayout{
let mut bind_group_layout_entries = vec!{
wgpu::BindGroupLayoutEntry{
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None
},
count: None,
},
wgpu::BindGroupLayoutEntry{
binding: 3,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: false },
has_dynamic_offset: false,
min_binding_size: None
},
count: None,
}
};
if buffers.shapes_buffer.is_some(){
bind_group_layout_entries.push(
wgpu::BindGroupLayoutEntry{
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Storage { read_only: true },
has_dynamic_offset: false,
min_binding_size: None
},
count: None,
},
);
}
if buffers.params_buffer.is_some(){
bind_group_layout_entries.push(
wgpu::BindGroupLayoutEntry{
binding: 2,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Buffer {
ty: wgpu::BufferBindingType::Uniform,
has_dynamic_offset: false,
min_binding_size: None
},
count: None,
},
);
}
let bind_group_layout = buffers.device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor{
label: Some("Bing group layout"),
entries: &bind_group_layout_entries,
});
bind_group_layout
}
pub fn get_bind_group(buffers: &GpuBuffers) -> wgpu::BindGroup{
let mut bind_group_entries = vec!{
wgpu::BindGroupEntry{
binding: 0,
resource: buffers.inputs_buffer.as_entire_binding(),
},
wgpu::BindGroupEntry{
binding: 3,
resource: buffers.output_buffer.as_entire_binding(),
}
};
if buffers.shapes_buffer.is_some(){
bind_group_entries.push(
wgpu::BindGroupEntry{
binding: 1,
resource: buffers.shapes_buffer.as_ref().unwrap().as_entire_binding(),
}
);
}
if buffers.params_buffer.is_some(){
bind_group_entries.push(
wgpu::BindGroupEntry{
binding: 2,
resource: buffers.params_buffer.as_ref().unwrap().as_entire_binding(),
}
);
}
let bind_group = buffers.device.create_bind_group(&wgpu::BindGroupDescriptor{
label: Some("Bind group"),
layout: buffers.bind_group_layout.as_ref().unwrap(),
entries: &bind_group_entries,
});
bind_group
}
pub fn get_pipeline_layout(device: &wgpu::Device, bind_group_layout: &wgpu::BindGroupLayout) -> wgpu::PipelineLayout{
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor{
label: Some("Pipeline layout"),
bind_group_layouts: &[&bind_group_layout],
push_constant_ranges: &[],
});
pipeline_layout
}
pub fn get_pipeline(device: &wgpu::Device, shader: &wgpu::ShaderModule, pipeline_layout: &wgpu::PipelineLayout) -> wgpu::ComputePipeline{
let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor{
label: Some("Compute pipeline"),
layout: Some(&pipeline_layout),
module: &shader,
entry_point: Some("main"),
compilation_options: wgpu::PipelineCompilationOptions::default(),
cache: None,
});
pipeline
}
pub async fn dispatch_and_receive(device: &wgpu::Device, pipeline: &wgpu::ComputePipeline, bind_group: &wgpu::BindGroup, queue: &wgpu::Queue, input_data_len: usize, output_buffer: &wgpu::Buffer, output_len: usize) -> Vec<f32>{
let mut encoder = device.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("Encoder"),
});
{
let workgroup_size = 64;
let total_invocations = output_len as u32;
let total_workgroups = (total_invocations + workgroup_size - 1) / workgroup_size;
let x = total_workgroups.min(65535);
let y = ((total_workgroups / 65535) + 1).min(65535);
let z = 1;
let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor{
label: Some("Compute pass"),
timestamp_writes: None,
});
compute_pass.set_pipeline(pipeline);
compute_pass.set_bind_group(0, bind_group, &[]);
compute_pass.dispatch_workgroups(x, y.max(1), z);
}
let staging = device.create_buffer(&wgpu::BufferDescriptor {
label: Some("Staging"),
size: output_buffer.size(),
usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
mapped_at_creation: false,
});
encoder.copy_buffer_to_buffer(output_buffer, 0, &staging, 0, staging.size());
queue.submit(Some(encoder.finish()));
let slice = staging.slice(..);
slice.map_async(wgpu::MapMode::Read, |_| {});
device.poll(wgpu::MaintainBase::Wait);
let data = slice.get_mapped_range();
let result: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
result
}