agpu 0.1.2

Abstract GPU Project
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
use tracing::warn;
pub use wgpu::CompareFunction;

use std::borrow::Cow;

use wgpu::ShaderModuleDescriptor;
use wgpu::ShaderSource;

use crate::Gpu;
use crate::GpuError;

use crate::RenderPipeline;

pub trait ColorTargetBuilderExt {
    fn blend_over(self) -> Self;
    fn blend_over_premult(self) -> Self;
    fn blend_add(self) -> Self;
    fn blend_subtract(self) -> Self;
    fn write_mask(self, mask: u32) -> Self;
}
impl ColorTargetBuilderExt for wgpu::ColorTargetState {
    fn blend_over(mut self) -> Self {
        self.blend = Some(wgpu::BlendState::ALPHA_BLENDING);
        self
    }
    fn blend_over_premult(mut self) -> Self {
        self.blend = Some(wgpu::BlendState::PREMULTIPLIED_ALPHA_BLENDING);
        self
    }

    fn blend_add(mut self) -> Self {
        self.blend = Some(wgpu::BlendState {
            color: wgpu::BlendComponent {
                src_factor: wgpu::BlendFactor::SrcAlpha,
                dst_factor: wgpu::BlendFactor::SrcAlpha,
                operation: wgpu::BlendOperation::Add,
            },
            alpha: wgpu::BlendComponent::OVER,
        });
        self
    }

    fn blend_subtract(mut self) -> Self {
        self.blend = Some(wgpu::BlendState {
            color: wgpu::BlendComponent {
                src_factor: wgpu::BlendFactor::SrcAlpha,
                dst_factor: wgpu::BlendFactor::SrcAlpha,
                operation: wgpu::BlendOperation::Subtract,
            },
            alpha: wgpu::BlendComponent::OVER,
        });
        self
    }

    fn write_mask(mut self, mask: u32) -> Self {
        self.write_mask = wgpu::ColorWrites::from_bits(mask).unwrap();
        self
    }
}

pub struct PipelineBuilder<'a> {
    /// Handle to the Gpu
    gpu: Gpu,
    label: Option<&'a str>,
    /// Data that is used to build the pipeline
    /// This is a seperate struct to take advantage of Default trait derivation
    desc: PipelineDescriptor<'a>,

    /// SPIR-V bytes for the vertex shader
    vertex: ShaderModuleDescriptor<'a>,
    /// SPIR-V bytes for the fragment shader.
    /// This is optional
    fragment: Option<ShaderModuleDescriptor<'a>>,
    vertex_entry: &'a str,
    fragment_entry: &'a str,
    fragment_targets: &'a [wgpu::ColorTargetState],
}

#[derive(Default)]
struct PipelineDescriptor<'a> {
    // PIPELINE LAYOUT
    /// Bind groups that this pipeline uses. The first entry will provide all the bindings for
    /// "set = 0", second entry will provide all the bindings for "set = 1" etc.
    bind_group_layouts: &'a [&'a wgpu::BindGroupLayout],
    /// Set of push constant ranges this pipeline uses. Each shader stage that uses push constants
    /// must define the range in push constant memory that corresponds to its single `layout(push_constant)`
    /// uniform block.
    /// Requires [`Features::PUSH_CONSTANTS`].
    push_constant_ranges: &'a [wgpu::PushConstantRange],
    // RENDER PIPELINE
    /// Primitive type the input mesh is composed of. Has Default.
    primitive: wgpu::PrimitiveState,
    /// Describes the depth/stencil state in a render pipeline. Optional.
    depth_stencil: Option<wgpu::DepthStencilState>,
    multisample: wgpu::MultisampleState,
    vertex_layouts: &'a [wgpu::VertexBufferLayout<'a>],
}
impl PipelineBuilder<'_> {
    pub fn make_spirv(bytes: &[u8]) -> Result<ShaderSource, GpuError> {
        // HACK: This is a workaround for wgpu's spirv parsing. It will panic if the bytes
        // are not valid SPIR-V instead of returning a Result.
        // But even using catch_unwind the panic will be logged in stdout. So we're
        // registering a custom panic hook to suppress the output for this function.

        // This is *potentially* dangerous since make_spirv() could panic for other reasons.
        // TODO: Check the data length and magic number here before calling make_spirv().
        // That will allow us to remove the panic code.

        // First we save the current hook
        let prev_hook = std::panic::take_hook();
        // Now we register our own hook which does nothing
        std::panic::set_hook(Box::new(|_| {}));
        // Now we try to parse the bytes, and if it panics, we return an error instead of panicking
        let result = std::panic::catch_unwind(|| wgpu::util::make_spirv(bytes))
            .map_err(|_| GpuError::ShaderParseError);
        // Now we restore the previous hook
        std::panic::set_hook(prev_hook);
        // Return the result
        result
    }

    // FIXME: This is so scuffed
    pub fn make_spirv_owned<'f>(mut vec8: Vec<u8>) -> Result<ShaderSource<'f>, GpuError> {
        // I copy-pasted this code from StackOverflow without reading the answer
        // surrounding it that told me to write a comment explaining why this code
        // is actually safe for my own use case.
        let vec32 = unsafe {
            let ratio = std::mem::size_of::<u8>() / std::mem::size_of::<u32>();

            let length = vec8.len() * ratio;
            let capacity = vec8.capacity() * ratio;
            let ptr = vec8.as_mut_ptr() as *mut u32;

            // Don't run the destructor for vec32
            std::mem::forget(vec8);

            // Construct new Vec
            Vec::from_raw_parts(ptr, length, capacity)
        };
        Ok(ShaderSource::SpirV(Cow::Owned(vec32)))
    }

    pub fn make_wgsl(wgsl: &str) -> Result<ShaderSource, GpuError> {
        Ok(ShaderSource::Wgsl(Cow::Borrowed(wgsl)))
    }

    pub fn make_wgsl_owned<'f>(wgsl: String) -> Result<ShaderSource<'f>, GpuError> {
        Ok(ShaderSource::Wgsl(Cow::Owned(wgsl)))
    }

    /// 'a: lifetime of the shader source
    /// 'b: lifetime of the input path
    pub fn shader_auto_load<'a, 'b>(path: &'b str) -> Result<ShaderSource<'a>, GpuError> {
        if let Ok(spirv) = Self::make_spirv_owned(std::fs::read(path).unwrap()) {
            Ok(spirv)
        } else if let Ok(wgsl) = Self::make_wgsl_owned(std::fs::read_to_string(path).unwrap()) {
            Ok(wgsl)
        } else {
            Err(GpuError::ShaderParseError)
        }
    }

    pub fn shader_auto(bytes: &[u8]) -> Result<ShaderSource, GpuError> {
        if let Ok(spirv) = Self::make_spirv(bytes) {
            Ok(spirv)
        } else if let Ok(wgsl) = Self::make_wgsl(Self::str_from_bytes(bytes)?) {
            Ok(wgsl)
        } else {
            Err(GpuError::ShaderParseError)
        }
    }
}
impl<'a> PipelineBuilder<'a> {
    pub fn new(gpu: Gpu, label: &'a str) -> Self {
        const DEFAULT_FRAGMENT_FORMAT: wgpu::TextureFormat = wgpu::TextureFormat::Bgra8UnormSrgb;

        let vertex = wgpu::util::make_spirv(include_bytes!("../../shader/screen.vert.spv"));
        let fragment = wgpu::util::make_spirv(include_bytes!("../../shader/uv.frag.spv"));

        let vertex = ShaderModuleDescriptor {
            label: Some("Default vertex shader"),
            source: vertex,
        };
        let fragment = Some(ShaderModuleDescriptor {
            label: Some("Default fragment shader"),
            source: fragment,
        });

        Self {
            gpu,
            label: Some(label),
            desc: PipelineDescriptor::default(),
            vertex,
            fragment,
            vertex_entry: "main",
            fragment_entry: "main",
            fragment_targets: &[wgpu::ColorTargetState {
                // TODO: Use gpu.preferred_format
                format: DEFAULT_FRAGMENT_FORMAT,
                blend: Some(wgpu::BlendState::ALPHA_BLENDING),
                write_mask: wgpu::ColorWrites::ALL,
            }],
        }
    }

    /// Set the label
    pub fn with_label(mut self, label: &'a str) -> Self {
        self.label = Some(label);
        self
    }

    /// Set the vertex buffer layouts
    pub fn with_vertex_layouts(mut self, layouts: &'a [wgpu::VertexBufferLayout<'a>]) -> Self {
        self.desc.vertex_layouts = layouts;
        self
    }

    /// Set the fragment layouts
    pub fn with_fragment_targets(mut self, targets: &'a [wgpu::ColorTargetState]) -> Self {
        self.fragment_targets = targets;
        self
    }

    /// Declare a depth state for the pipeline. MUST be called if the pipeline is
    /// set for a render pass with depth attachment
    pub fn with_depth(mut self) -> Self {
        self.desc.depth_stencil = Some(wgpu::DepthStencilState {
            depth_write_enabled: true,
            depth_compare: wgpu::CompareFunction::Less,
            stencil: wgpu::StencilState::default(),
            format: wgpu::TextureFormat::Depth32Float,
            bias: wgpu::DepthBiasState::default(),
        });
        self
    }

    fn do_depth<F>(&mut self, op: F)
    where
        F: FnOnce(&mut wgpu::DepthStencilState),
    {
        if let Some(desc) = self.desc.depth_stencil.as_mut() {
            op(desc);
        } else {
            warn!("Depth mod was called before with_depth() was called in pipeline builder");
        }
    }

    /// Add a constant depth biasing factor, in basic units of the depth format.
    /// Add a slope depth biasing factor.
    /// TODO: Clarify what this means??
    pub fn depth_bias(mut self, constant: i32, slope: f32) -> Self {
        self.do_depth(|desc| {
            desc.bias.constant = constant;
            desc.bias.slope_scale = slope;
        });
        self
    }
    /// Add a depth bias clamp value (absolute).
    pub fn depth_bias_clamp(mut self, clamp: f32) -> Self {
        self.do_depth(|desc| {
            desc.bias.clamp = clamp;
        });
        self
    }

    /// Set the depth comparison function
    /// Values testing `true` will pass the depth test
    pub fn depth_compare(mut self, compare: CompareFunction) -> Self {
        self.do_depth(|desc| {
            desc.depth_compare = compare;
        });
        self
    }

    pub fn with_depth_stencil(mut self) -> Self {
        self.desc.depth_stencil = Some(wgpu::DepthStencilState {
            depth_write_enabled: true,
            depth_compare: wgpu::CompareFunction::Less,
            // TODO: Actually need a stencil state to use stencil lol
            stencil: wgpu::StencilState::default(),
            format: wgpu::TextureFormat::Depth24PlusStencil8,
            bias: wgpu::DepthBiasState::default(),
        });
        self
    }

    fn str_from_bytes(bytes: &[u8]) -> Result<&str, GpuError> {
        std::str::from_utf8(bytes).map_err(|_| GpuError::ShaderParseError)
    }

    /// Load the vertex shader from file path.
    /// See `with_vertex()` for loading static bytes.
    pub fn load_vertex(mut self, path: &'a str) -> Self {
        self.vertex.source = Self::shader_auto_load(path).expect("Load vertex shader");
        self
    }
    /// Load the vertex shader from bytes.
    /// This is convenient for static bytes. If you want to load from a file, at
    /// runtime, see load_vertex()
    pub fn with_vertex(mut self, bytes: &'a [u8]) -> Self {
        self.vertex.source = Self::shader_auto(bytes).expect("Parse vertex shader");
        self
    }

    /// Load the fragment shader from bytes.
    /// This is convenient for static bytes. If you want to load from a file, at
    /// runtime, see load_fragment()
    pub fn with_fragment(mut self, bytes: &'static [u8]) -> Self {
        self.fragment = Some(ShaderModuleDescriptor {
            label: Some("Default fragment shader"),
            source: Self::shader_auto(bytes).expect("Parse fragment shader"),
        });
        self
    }

    pub const fn with_fragment_entry(mut self, entry: &'a str) -> Self {
        self.fragment_entry = entry;
        self
    }

    pub const fn with_vertex_entry(mut self, entry: &'a str) -> Self {
        self.vertex_entry = entry;
        self
    }

    /// Convenience method for with_vertex() + with_fragment()
    /// This also sets the entry points to vs_main and fs_main respectively.
    pub fn with_vertex_fragment(mut self, bytes: &'static [u8]) -> Self {
        self.vertex_entry = "vs_main";
        self.fragment_entry = "fs_main";
        self.with_vertex(bytes).with_fragment(bytes)
    }

    /// Optional version of with_fragment_bytes(), for use in macros
    /// This has no effect if None is provided. To remove the fragment shader,
    /// use no_fragment() instead.
    pub fn with_fragment_opt(self, fragment_bytes: Option<&'static [u8]>) -> Self {
        if let Some(bytes) = fragment_bytes {
            self.with_fragment(bytes)
        } else {
            self
        }
    }

    /// Load the fragment shader from file path at runtime.
    /// See `with_fragment()` for loading static bytes.
    pub fn load_fragment(mut self, fragment: &'a str) -> Self {
        self.fragment = Some(ShaderModuleDescriptor {
            label: Some("Default fragment shader"),
            source: Self::shader_auto_load(fragment).expect("Load fragment shader"),
        });
        self
    }

    pub const fn with_bind_groups(mut self, bind_groups: &'a [&wgpu::BindGroupLayout]) -> Self {
        self.desc.bind_group_layouts = bind_groups;
        self
    }

    /// Cull front faces.
    /// Front is CCW.
    pub const fn cull_front(mut self) -> Self {
        self.desc.primitive.cull_mode = Some(wgpu::Face::Front);
        self
    }

    /// Cull back faces.
    /// Back is CW.
    pub const fn cull_back(mut self) -> Self {
        self.desc.primitive.cull_mode = Some(wgpu::Face::Back);
        self
    }

    /// Draws lines instead of filling in triangles.
    pub const fn wireframe(mut self) -> Self {
        self.desc.primitive.polygon_mode = wgpu::PolygonMode::Line;
        self
    }

    pub const fn vertex_points(mut self) -> Self {
        self.desc.primitive.topology = wgpu::PrimitiveTopology::PointList;
        self
    }

    pub const fn vertex_lines(mut self, strip: bool) -> Self {
        self.desc.primitive.topology = if strip {
            wgpu::PrimitiveTopology::LineStrip
        } else {
            wgpu::PrimitiveTopology::LineList
        };
        self
    }

    pub const fn vertex_triangles(mut self, strip: bool) -> Self {
        self.desc.primitive.topology = if strip {
            wgpu::PrimitiveTopology::TriangleStrip
        } else {
            wgpu::PrimitiveTopology::TriangleList
        };
        self
    }

    #[must_use]
    pub fn create(&self) -> RenderPipeline {
        // Create vertex module
        let vertex_module = self.gpu.device.create_shader_module(&self.vertex);

        // Create shader module
        let fragment_module = self
            .fragment
            .as_ref()
            .map(|fragment| self.gpu.device.create_shader_module(fragment));

        // Map fragment state if Some() otherwise it is None
        let fragment = fragment_module
            .as_ref()
            .map(|fs_module| wgpu::FragmentState {
                module: fs_module,
                entry_point: self.fragment_entry,
                targets: self.fragment_targets,
            });

        // The pipeline layout
        let layout = self
            .gpu
            .device
            .create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
                label: self.label_suffix("pipeline layout").as_deref(),
                bind_group_layouts: self.desc.bind_group_layouts,
                push_constant_ranges: self.desc.push_constant_ranges,
            });

        let pipeline_desc = wgpu::RenderPipelineDescriptor {
            layout: Some(&layout),
            label: self.label,
            vertex: wgpu::VertexState {
                module: &vertex_module,
                entry_point: self.vertex_entry,
                buffers: self.desc.vertex_layouts,
            },
            primitive: self.desc.primitive,
            depth_stencil: self.desc.depth_stencil.clone(),
            multisample: self.desc.multisample,
            fragment,
            // TODO: Implement multiview interface
            multiview: None,
        };

        // Create the pipeline
        let pipeline = self.gpu.device.create_render_pipeline(&pipeline_desc);
        RenderPipeline {
            depth_stencil: self.desc.depth_stencil.clone(),
            gpu: self.gpu.clone(),
            inner: pipeline,
        }
    }

    /// Helper function to append a suffix to the label, if Some
    fn label_suffix(&self, suffix: &str) -> Option<String> {
        self.label.map(|label| format!("{} {}", label, suffix))
    }
}