rootvg_image/
pipeline.rs

1use rustc_hash::FxHashMap;
2use smallvec::SmallVec;
3use std::{cell::RefCell, ops::Range, rc::Rc};
4use wgpu::PipelineCompilationOptions;
5
6use rootvg_core::{
7    buffer::Buffer,
8    math::{PhysicalSizeI32, ScaleFactor},
9    pipeline::DefaultConstantUniforms,
10};
11
12use crate::{
13    primitive::{ImagePrimitive, ImageVertex},
14    texture::TextureInner,
15    RcTexture,
16};
17
18const INITIAL_INSTANCES: usize = 16;
19const INITIAL_SUB_BATCHES: usize = 16;
20
21struct Batch {
22    range_in_buffer: Range<u32>,
23    texture: RcTexture,
24}
25
26pub struct ImageBatchBuffer {
27    buffer: Buffer<ImageVertex>,
28    sub_batches: Vec<Batch>,
29    num_instances: usize,
30
31    prev_primitives: Vec<ImagePrimitive>,
32}
33
34impl ImageBatchBuffer {
35    fn new(device: &wgpu::Device) -> Self {
36        Self {
37            buffer: Buffer::new(
38                device,
39                "rootvg-image instance buffer",
40                INITIAL_INSTANCES,
41                wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
42            ),
43            sub_batches: Vec::with_capacity(INITIAL_SUB_BATCHES),
44            num_instances: 0,
45            prev_primitives: Vec::new(),
46        }
47    }
48
49    fn prepare(
50        &mut self,
51        primitives: &[ImagePrimitive],
52        device: &wgpu::Device,
53        queue: &wgpu::Queue,
54        texture_bind_group_layout: &wgpu::BindGroupLayout,
55    ) {
56        // Don't prepare if primitives have not changed since the last
57        // prepare.
58        if primitives == &self.prev_primitives {
59            return;
60        }
61        self.prev_primitives = primitives.into();
62
63        self.sub_batches.clear();
64        self.num_instances = primitives.len();
65
66        self.buffer.expand_to_fit_new_size(device, primitives.len());
67
68        struct TempSubBatchEntry {
69            vertices: SmallVec<[ImageVertex; INITIAL_INSTANCES]>,
70            texture: RcTexture,
71        }
72
73        // TODO: reuse the allocation of this hash map?
74        let mut sub_batches_map: FxHashMap<*const RefCell<TextureInner>, TempSubBatchEntry> =
75            FxHashMap::default();
76        sub_batches_map.reserve(INITIAL_SUB_BATCHES);
77
78        for image in primitives.iter() {
79            image
80                .texture
81                .upload_if_needed(device, queue, texture_bind_group_layout);
82
83            let sub_batch = sub_batches_map
84                .entry(Rc::as_ptr(&image.texture.inner))
85                .or_insert_with(|| TempSubBatchEntry {
86                    vertices: SmallVec::new(),
87                    texture: image.texture.clone(),
88                });
89
90            sub_batch.vertices.push(image.vertex);
91        }
92
93        let mut range_start = 0;
94        for sub_batch in sub_batches_map.values() {
95            self.buffer.write(queue, range_start, &sub_batch.vertices);
96
97            self.sub_batches.push(Batch {
98                range_in_buffer: range_start as u32
99                    ..(range_start + sub_batch.vertices.len()) as u32,
100                texture: sub_batch.texture.clone(),
101            });
102
103            range_start += sub_batch.vertices.len();
104        }
105    }
106}
107
108pub struct ImagePipeline {
109    pipeline: wgpu::RenderPipeline,
110
111    constants_buffer: wgpu::Buffer,
112    constants_bind_group: wgpu::BindGroup,
113    texture_layout: wgpu::BindGroupLayout,
114
115    screen_size: PhysicalSizeI32,
116    scale_factor: ScaleFactor,
117}
118
119impl ImagePipeline {
120    pub fn new(
121        device: &wgpu::Device,
122        format: wgpu::TextureFormat,
123        multisample: wgpu::MultisampleState,
124    ) -> Self {
125        let constants_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
126            label: Some("rootvg-image constants layout"),
127            entries: &[
128                DefaultConstantUniforms::entry(0),
129                wgpu::BindGroupLayoutEntry {
130                    binding: 1,
131                    visibility: wgpu::ShaderStages::FRAGMENT,
132                    // This should match the filterable field of the
133                    // Texture entry.
134                    ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
135                    count: None,
136                },
137            ],
138        });
139
140        let constants_buffer = device.create_buffer(&wgpu::BufferDescriptor {
141            label: Some("rootvg-image constants buffer"),
142            size: std::mem::size_of::<Self>() as wgpu::BufferAddress,
143            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
144            mapped_at_creation: false,
145        });
146
147        let sampler = device.create_sampler(&wgpu::SamplerDescriptor {
148            address_mode_u: wgpu::AddressMode::ClampToEdge,
149            address_mode_v: wgpu::AddressMode::ClampToEdge,
150            address_mode_w: wgpu::AddressMode::ClampToEdge,
151            mag_filter: wgpu::FilterMode::Linear,
152            min_filter: wgpu::FilterMode::Linear,
153            mipmap_filter: wgpu::FilterMode::Nearest,
154            ..Default::default()
155        });
156
157        let constants_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
158            label: Some("rootvg-image constants bind group"),
159            layout: &constants_layout,
160            entries: &[
161                wgpu::BindGroupEntry {
162                    binding: 0,
163                    resource: constants_buffer.as_entire_binding(),
164                },
165                wgpu::BindGroupEntry {
166                    binding: 1,
167                    resource: wgpu::BindingResource::Sampler(&sampler),
168                },
169            ],
170        });
171
172        let texture_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
173            label: Some("rootvg-image texture layout"),
174            entries: &[wgpu::BindGroupLayoutEntry {
175                binding: 0,
176                visibility: wgpu::ShaderStages::FRAGMENT,
177                ty: wgpu::BindingType::Texture {
178                    multisampled: false,
179                    view_dimension: wgpu::TextureViewDimension::D2,
180                    sample_type: wgpu::TextureSampleType::Float { filterable: true },
181                },
182                count: None,
183            }],
184        });
185
186        let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
187            label: Some("rootvg-image pipeline layout"),
188            push_constant_ranges: &[],
189            bind_group_layouts: &[&constants_layout, &texture_layout],
190        });
191
192        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
193            label: Some("rootvg-image shader"),
194            source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(concat!(include_str!(
195                "shader/image.wgsl"
196            ),))),
197        });
198
199        let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
200            label: Some("rootvg-image pipeline"),
201            layout: Some(&layout),
202            vertex: wgpu::VertexState {
203                module: &shader,
204                entry_point: "vs_main",
205                buffers: &[wgpu::VertexBufferLayout {
206                    array_stride: std::mem::size_of::<ImageVertex>() as u64,
207                    step_mode: wgpu::VertexStepMode::Instance,
208                    attributes: &wgpu::vertex_attr_array!(
209                        // Position
210                        0 => Float32x2,
211                        // Size
212                        1 => Float32x2,
213                        // Normalized UV position
214                        2 => Float32x2,
215                        // Normalized UV size
216                        3 => Float32x2,
217                        // Transform Matrix 3x2
218                        4 => Float32x2,
219                        5 => Float32x2,
220                        6 => Float32x2,
221                        // Has Transformation
222                        7 => Uint32,
223                    ),
224                }],
225                compilation_options: PipelineCompilationOptions::default(),
226            },
227            fragment: Some(wgpu::FragmentState {
228                module: &shader,
229                entry_point: "fs_main",
230                targets: &[Some(wgpu::ColorTargetState {
231                    format,
232                    blend: Some(wgpu::BlendState {
233                        color: wgpu::BlendComponent {
234                            src_factor: wgpu::BlendFactor::SrcAlpha,
235                            dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
236                            operation: wgpu::BlendOperation::Add,
237                        },
238                        alpha: wgpu::BlendComponent {
239                            src_factor: wgpu::BlendFactor::One,
240                            dst_factor: wgpu::BlendFactor::OneMinusSrcAlpha,
241                            operation: wgpu::BlendOperation::Add,
242                        },
243                    }),
244                    write_mask: wgpu::ColorWrites::ALL,
245                })],
246                compilation_options: PipelineCompilationOptions::default(),
247            }),
248            primitive: wgpu::PrimitiveState {
249                topology: wgpu::PrimitiveTopology::TriangleList,
250                front_face: wgpu::FrontFace::Cw,
251                ..Default::default()
252            },
253            depth_stencil: None,
254            multisample,
255            multiview: None,
256        });
257
258        Self {
259            pipeline,
260            constants_buffer,
261            constants_bind_group,
262            texture_layout,
263            screen_size: PhysicalSizeI32::default(),
264            scale_factor: ScaleFactor::default(),
265        }
266    }
267
268    pub fn create_batch(&mut self, device: &wgpu::Device) -> ImageBatchBuffer {
269        ImageBatchBuffer::new(device)
270    }
271
272    pub fn start_preparations(
273        &mut self,
274        _device: &wgpu::Device,
275        queue: &wgpu::Queue,
276        screen_size: PhysicalSizeI32,
277        scale_factor: ScaleFactor,
278    ) {
279        if self.screen_size == screen_size && self.scale_factor == scale_factor {
280            return;
281        }
282
283        self.screen_size = screen_size;
284        self.scale_factor = scale_factor;
285
286        DefaultConstantUniforms::prepare_buffer(
287            &self.constants_buffer,
288            screen_size,
289            scale_factor,
290            queue,
291        );
292    }
293
294    pub fn prepare_batch(
295        &mut self,
296        batch: &mut ImageBatchBuffer,
297        primitives: &[ImagePrimitive],
298        device: &wgpu::Device,
299        queue: &wgpu::Queue,
300    ) {
301        batch.prepare(primitives, device, queue, &self.texture_layout);
302    }
303
304    pub fn render_batch<'pass>(
305        &'pass self,
306        batch: &'pass ImageBatchBuffer,
307        render_pass: &mut wgpu::RenderPass<'pass>,
308    ) {
309        if batch.num_instances == 0 {
310            return;
311        }
312
313        render_pass.set_pipeline(&self.pipeline);
314        render_pass.set_bind_group(0, &self.constants_bind_group, &[]);
315
316        render_pass.set_vertex_buffer(0, batch.buffer.slice(0..batch.num_instances));
317
318        for sub_batch in batch.sub_batches.iter() {
319            // # SAFETY:
320            //
321            // Because wgpu requires the bind group to be borrowed for `'pass`, we
322            // are not able to use the safe option that returns a `std::cell::Ref`.
323            //
324            // By design, data is only mutated during the prepare stage, not during
325            // the render pass stage. So there is no chance for the
326            // `RefCell<TextureInner>` to be borrowed mutably during the render
327            // pass.
328            let texture_bind_group = unsafe {
329                &RefCell::try_borrow_unguarded(&sub_batch.texture.inner)
330                    .unwrap()
331                    .bind_group
332                    .as_ref()
333                    .unwrap()
334            };
335
336            render_pass.set_bind_group(1, texture_bind_group, &[]);
337
338            render_pass.draw(0..6, sub_batch.range_in_buffer.clone());
339        }
340    }
341}