rootvg_quad/pipeline/
gradient.rs

1// The following code is taken and modified from
2// https://github.com/iced-rs/iced/blob/master/wgpu/src/quad/gradient.rs
3//
4// Iced MIT license: https://github.com/iced-rs/iced/blob/master/LICENSE
5
6use rootvg_core::{
7    buffer::Buffer,
8    math::{PhysicalSizeI32, ScaleFactor},
9    pipeline::DefaultConstantUniforms,
10};
11use wgpu::PipelineCompilationOptions;
12
13use super::INITIAL_INSTANCES;
14
15use crate::GradientQuadPrimitive;
16
17pub struct GradientQuadBatchBuffer {
18    buffer: Buffer<GradientQuadPrimitive>,
19    num_primitives: usize,
20}
21
22pub struct GradientQuadPipeline {
23    pipeline: wgpu::RenderPipeline,
24
25    constants_buffer: wgpu::Buffer,
26    constants_bind_group: wgpu::BindGroup,
27
28    screen_size: PhysicalSizeI32,
29    scale_factor: ScaleFactor,
30}
31
32impl GradientQuadPipeline {
33    pub fn new(
34        device: &wgpu::Device,
35        format: wgpu::TextureFormat,
36        multisample: wgpu::MultisampleState,
37    ) -> Self {
38        let (constants_layout, constants_buffer, constants_bind_group) =
39            DefaultConstantUniforms::layout_buffer_and_bind_group(device);
40
41        let layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
42            label: Some("rootvg-quad gradient pipeline layout"),
43            push_constant_ranges: &[],
44            bind_group_layouts: &[&constants_layout],
45        });
46
47        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
48            label: Some("rootvg-quad gradient shader"),
49            source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(concat!(
50                include_str!("../shader/quad.wgsl"),
51                "\n",
52                include_str!("../shader/gradient.wgsl"),
53                "\n",
54                include_str!("../shader/oklab.wgsl")
55            ))),
56        });
57
58        let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
59            label: Some("rootvg-quad gradient pipeline"),
60            layout: Some(&layout),
61            vertex: wgpu::VertexState {
62                module: &shader,
63                entry_point: "gradient_vs_main",
64                buffers: &[wgpu::VertexBufferLayout {
65                    array_stride: std::mem::size_of::<GradientQuadPrimitive>() as u64,
66                    step_mode: wgpu::VertexStepMode::Instance,
67                    attributes: &wgpu::vertex_attr_array!(
68                        // Colors 1-2
69                        0 => Uint32x4,
70                        // Colors 3-4
71                        1 => Uint32x4,
72                        // Colors 5-6
73                        2 => Uint32x4,
74                        // Colors 7-8
75                        3 => Uint32x4,
76                        // Offsets 1-8
77                        4 => Uint32x4,
78                        // Direction
79                        5 => Float32x4,
80                        // Position
81                        6 => Float32x2,
82                        // Size
83                        7 => Float32x2,
84                        // Border color
85                        8 => Float32x4,
86                        // Border radius
87                        9 => Float32x4,
88                        // Border width
89                        10 => Float32
90                    ),
91                }],
92                compilation_options: PipelineCompilationOptions::default(),
93            },
94            fragment: Some(wgpu::FragmentState {
95                module: &shader,
96                entry_point: "gradient_fs_main",
97                targets: &super::color_target_state(format),
98                compilation_options: PipelineCompilationOptions::default(),
99            }),
100            primitive: wgpu::PrimitiveState {
101                topology: wgpu::PrimitiveTopology::TriangleList,
102                front_face: wgpu::FrontFace::Cw,
103                ..Default::default()
104            },
105            depth_stencil: None,
106            multisample,
107            multiview: None,
108        });
109
110        Self {
111            constants_buffer,
112            constants_bind_group,
113            pipeline,
114            screen_size: PhysicalSizeI32::default(),
115            scale_factor: ScaleFactor::default(),
116        }
117    }
118
119    pub fn create_batch(&mut self, device: &wgpu::Device) -> GradientQuadBatchBuffer {
120        GradientQuadBatchBuffer {
121            buffer: Buffer::new(
122                device,
123                "rootvg-quad gradient buffer",
124                INITIAL_INSTANCES,
125                wgpu::BufferUsages::VERTEX | wgpu::BufferUsages::COPY_DST,
126            ),
127            num_primitives: 0,
128        }
129    }
130
131    pub fn start_preparations(
132        &mut self,
133        _device: &wgpu::Device,
134        queue: &wgpu::Queue,
135        screen_size: PhysicalSizeI32,
136        scale_factor: ScaleFactor,
137    ) {
138        if self.screen_size == screen_size && self.scale_factor == scale_factor {
139            return;
140        }
141
142        self.screen_size = screen_size;
143        self.scale_factor = scale_factor;
144
145        DefaultConstantUniforms::prepare_buffer(
146            &self.constants_buffer,
147            screen_size,
148            scale_factor,
149            queue,
150        );
151    }
152
153    pub fn prepare_batch(
154        &mut self,
155        batch: &mut GradientQuadBatchBuffer,
156        primitives: &[GradientQuadPrimitive],
157        device: &wgpu::Device,
158        queue: &wgpu::Queue,
159    ) {
160        let _ = batch
161            .buffer
162            .expand_to_fit_new_size(device, primitives.len());
163        let _ = batch.buffer.write(queue, 0, primitives);
164
165        batch.num_primitives = primitives.len();
166    }
167
168    pub fn render_batch<'pass>(
169        &'pass self,
170        batch: &'pass GradientQuadBatchBuffer,
171        render_pass: &mut wgpu::RenderPass<'pass>,
172    ) {
173        if batch.num_primitives == 0 {
174            return;
175        }
176
177        render_pass.set_pipeline(&self.pipeline);
178        render_pass.set_bind_group(0, &self.constants_bind_group, &[]);
179
180        render_pass.set_vertex_buffer(0, batch.buffer.slice(0..batch.num_primitives));
181
182        render_pass.draw(0..6, 0..batch.num_primitives as u32);
183    }
184}