image_convolution/
gpu_device.rs1use std::borrow::Cow;
2use wgpu::util::DeviceExt;
3
4pub struct GpuDevice {
5 pub(crate) device: wgpu::Device,
6 pub(crate) queue: wgpu::Queue,
7}
8
9pub fn create_gpu_device() -> GpuDevice {
10 let (device, queue) = futures::executor::block_on(create_device_queue());
11 GpuDevice { device, queue }
12}
13
14pub async fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
15 let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
16 let adapter = instance
17 .request_adapter(&wgpu::RequestAdapterOptions {
18 power_preference: wgpu::PowerPreference::HighPerformance,
19 compatible_surface: None,
20 })
21 .await
22 .expect("Failed to find an appropriate adapter");
23
24 adapter
25 .request_device(
26 &wgpu::DeviceDescriptor {
27 label: None,
28 features: wgpu::Features::empty(),
29 limits: wgpu::Limits::default(),
30 },
31 None,
32 )
33 .await
34 .expect("Failed to create device")
35}
36
37impl GpuDevice {
38 pub fn create_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
39 self.device.create_buffer(&wgpu::BufferDescriptor {
40 label: Some(label),
41 size,
42 usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_SRC,
43 mapped_at_creation: false,
44 })
45 }
46
47 pub fn create_data_buffer(&self, label: &str, contents: &[u8]) -> wgpu::Buffer {
48 self.device
49 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
50 label: Some(label),
51 contents,
52 usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_SRC,
53 })
54 }
55
56 pub fn create_uniform_buffer(&self, label: &str, contents: &[u8]) -> wgpu::Buffer {
57 self.device
58 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
59 label: Some(label),
60 contents,
61 usage: wgpu::BufferUsage::UNIFORM,
62 })
63 }
64
65 pub fn create_output_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
66 self.device.create_buffer(&wgpu::BufferDescriptor {
67 label: Some(label),
68 size,
69 usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
70 mapped_at_creation: false,
71 })
72 }
73
74 pub fn create_bind_group(
75 &self,
76 buffers: &[(&wgpu::Buffer, u64, wgpu::BufferBindingType)],
77 ) -> (wgpu::BindGroupLayout, wgpu::BindGroup) {
78 let layout_entries = buffers
79 .iter()
80 .enumerate()
81 .map(|(index, (_, size, ty))| wgpu::BindGroupLayoutEntry {
82 binding: index as u32,
83 visibility: wgpu::ShaderStage::COMPUTE,
84 ty: wgpu::BindingType::Buffer {
85 ty: *ty,
86 has_dynamic_offset: false,
87 min_binding_size: wgpu::BufferSize::new(*size),
88 },
89 count: None,
90 })
91 .collect::<Vec<_>>();
92 let bind_group_layout =
93 self.device
94 .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
95 label: None,
96 entries: &layout_entries,
97 });
98 let group_entries = buffers
99 .iter()
100 .enumerate()
101 .map(|(index, (buffer, _, _))| wgpu::BindGroupEntry {
102 binding: index as u32,
103 resource: buffer.as_entire_binding(),
104 })
105 .collect::<Vec<_>>();
106 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
107 label: None,
108 layout: &bind_group_layout,
109 entries: &group_entries,
110 });
111 (bind_group_layout, bind_group)
112 }
113
114 pub fn create_compute_pipeline(
115 &self,
116 buffers: &[(&wgpu::Buffer, u64, wgpu::BufferBindingType)],
117 shader: &str,
118 ) -> (wgpu::BindGroup, wgpu::ComputePipeline) {
119 let (bind_group_layout, bind_group) = self.create_bind_group(buffers);
120
121 let cs_module = self
123 .device
124 .create_shader_module(&wgpu::ShaderModuleDescriptor {
125 label: None,
126 source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(shader)),
127 flags: wgpu::ShaderFlags::VALIDATION,
128 });
129
130 let pipeline_layout = self
132 .device
133 .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
134 label: None,
135 bind_group_layouts: &[&bind_group_layout],
136 push_constant_ranges: &[],
137 });
138 let compute_pipeline =
139 self.device
140 .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
141 label: None,
142 layout: Some(&pipeline_layout),
143 module: &cs_module,
144 entry_point: "main",
145 });
146
147 (bind_group, compute_pipeline)
148 }
149}