image_convolution/
gpu_device.rs

1use std::borrow::Cow;
2use wgpu::util::DeviceExt;
3
4pub struct GpuDevice {
5    pub(crate) device: wgpu::Device,
6    pub(crate) queue: wgpu::Queue,
7}
8
9pub fn create_gpu_device() -> GpuDevice {
10    let (device, queue) = futures::executor::block_on(create_device_queue());
11    GpuDevice { device, queue }
12}
13
14pub async fn create_device_queue() -> (wgpu::Device, wgpu::Queue) {
15    let instance = wgpu::Instance::new(wgpu::BackendBit::PRIMARY);
16    let adapter = instance
17        .request_adapter(&wgpu::RequestAdapterOptions {
18            power_preference: wgpu::PowerPreference::HighPerformance,
19            compatible_surface: None,
20        })
21        .await
22        .expect("Failed to find an appropriate adapter");
23
24    adapter
25        .request_device(
26            &wgpu::DeviceDescriptor {
27                label: None,
28                features: wgpu::Features::empty(),
29                limits: wgpu::Limits::default(),
30            },
31            None,
32        )
33        .await
34        .expect("Failed to create device")
35}
36
37impl GpuDevice {
38    pub fn create_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
39        self.device.create_buffer(&wgpu::BufferDescriptor {
40            label: Some(label),
41            size,
42            usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_SRC,
43            mapped_at_creation: false,
44        })
45    }
46
47    pub fn create_data_buffer(&self, label: &str, contents: &[u8]) -> wgpu::Buffer {
48        self.device
49            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
50                label: Some(label),
51                contents,
52                usage: wgpu::BufferUsage::STORAGE | wgpu::BufferUsage::COPY_SRC,
53            })
54    }
55
56    pub fn create_uniform_buffer(&self, label: &str, contents: &[u8]) -> wgpu::Buffer {
57        self.device
58            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
59                label: Some(label),
60                contents,
61                usage: wgpu::BufferUsage::UNIFORM,
62            })
63    }
64
65    pub fn create_output_buffer(&self, label: &str, size: u64) -> wgpu::Buffer {
66        self.device.create_buffer(&wgpu::BufferDescriptor {
67            label: Some(label),
68            size,
69            usage: wgpu::BufferUsage::MAP_READ | wgpu::BufferUsage::COPY_DST,
70            mapped_at_creation: false,
71        })
72    }
73
74    pub fn create_bind_group(
75        &self,
76        buffers: &[(&wgpu::Buffer, u64, wgpu::BufferBindingType)],
77    ) -> (wgpu::BindGroupLayout, wgpu::BindGroup) {
78        let layout_entries = buffers
79            .iter()
80            .enumerate()
81            .map(|(index, (_, size, ty))| wgpu::BindGroupLayoutEntry {
82                binding: index as u32,
83                visibility: wgpu::ShaderStage::COMPUTE,
84                ty: wgpu::BindingType::Buffer {
85                    ty: *ty,
86                    has_dynamic_offset: false,
87                    min_binding_size: wgpu::BufferSize::new(*size),
88                },
89                count: None,
90            })
91            .collect::<Vec<_>>();
92        let bind_group_layout =
93            self.device
94                .create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
95                    label: None,
96                    entries: &layout_entries,
97                });
98        let group_entries = buffers
99            .iter()
100            .enumerate()
101            .map(|(index, (buffer, _, _))| wgpu::BindGroupEntry {
102                binding: index as u32,
103                resource: buffer.as_entire_binding(),
104            })
105            .collect::<Vec<_>>();
106        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
107            label: None,
108            layout: &bind_group_layout,
109            entries: &group_entries,
110        });
111        (bind_group_layout, bind_group)
112    }
113
114    pub fn create_compute_pipeline(
115        &self,
116        buffers: &[(&wgpu::Buffer, u64, wgpu::BufferBindingType)],
117        shader: &str,
118    ) -> (wgpu::BindGroup, wgpu::ComputePipeline) {
119        let (bind_group_layout, bind_group) = self.create_bind_group(buffers);
120
121        // create shader module
122        let cs_module = self
123            .device
124            .create_shader_module(&wgpu::ShaderModuleDescriptor {
125                label: None,
126                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(shader)),
127                flags: wgpu::ShaderFlags::VALIDATION,
128            });
129
130        // create pipeline for shader
131        let pipeline_layout = self
132            .device
133            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
134                label: None,
135                bind_group_layouts: &[&bind_group_layout],
136                push_constant_ranges: &[],
137            });
138        let compute_pipeline =
139            self.device
140                .create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
141                    label: None,
142                    layout: Some(&pipeline_layout),
143                    module: &cs_module,
144                    entry_point: "main",
145                });
146
147        (bind_group, compute_pipeline)
148    }
149}