1use crate::bind_group_layout::BindGroupLayoutDesc;
2use smallvec::SmallVec;
4use std::borrow::Cow;
5use wgpu;
6
7pub const DEFAULT_VS_SHADER_ENTRY_POINT: &str = "vs_main";
8pub const DEFAULT_FS_SHADER_ENTRY_POINT: &str = "fs_main";
9
10#[derive(Clone, Hash, PartialEq, Eq)]
15pub struct RenderPipelineDesc {
16    pub label: Option<String>,
17    pub shader_code: Option<String>, pub shader_code_vert: Option<String>,
20    pub shader_code_frag: Option<String>,
21    pub shader_label: Option<String>,
22    pub disable_fragment_shader: bool, pub bind_group_layouts_desc: SmallVec<[BindGroupLayoutDesc; 4]>,
25    pub vertex_buffers_layouts: SmallVec<[wgpu::VertexBufferLayout<'static>; 4]>,
27    pub render_targets: SmallVec<[Option<wgpu::ColorTargetState>; 4]>,
29
30    pub primitive: wgpu::PrimitiveState,
34    pub depth_stencil: Option<wgpu::DepthStencilState>,
37    pub multisample: wgpu::MultisampleState,
39}
40
41impl Default for RenderPipelineDesc {
42    fn default() -> Self {
43        Self {
44            label: None,
45            shader_code: None,
47            shader_code_vert: None,
48            shader_code_frag: None,
49            shader_label: Some(String::from("Shader")),
50            disable_fragment_shader: false,
51            bind_group_layouts_desc: SmallVec::new(),
53            vertex_buffers_layouts: SmallVec::new(),
55            render_targets: SmallVec::new(),
57
58            primitive: wgpu::PrimitiveState {
61                topology: wgpu::PrimitiveTopology::TriangleList,
62                strip_index_format: None,
63                front_face: wgpu::FrontFace::Ccw,
64                cull_mode: None, polygon_mode: wgpu::PolygonMode::Fill,
69                unclipped_depth: false,
71                conservative: false,
73            },
74            depth_stencil: None,
76            multisample: wgpu::MultisampleState {
78                count: 1,
79                mask: !0,
80                alpha_to_coverage_enabled: false,
81            },
82        }
83    }
84}
85impl RenderPipelineDesc {
86    pub fn into_render_pipeline(
89        self,
90        device: &wgpu::Device,
91        ) -> wgpu::RenderPipeline {
93        let shader_monolithic: Option<wgpu::ShaderModule>;
97        let shader_vert: Option<wgpu::ShaderModule>;
98        let shader_frag: Option<wgpu::ShaderModule>;
99        let (shader_vert, shader_frag) = if let Some(shader_code_monlithic) = self.shader_code {
100            shader_monolithic = Some(device.create_shader_module(wgpu::ShaderModuleDescriptor {
101                label: self.shader_label.as_deref(),
102                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(&shader_code_monlithic)),
103            }));
104            (shader_monolithic.as_ref().unwrap(), shader_monolithic.as_ref().unwrap())
105        } else {
106            shader_vert = Some(device.create_shader_module(wgpu::ShaderModuleDescriptor {
107                label: self.shader_label.as_deref(),
108                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(self.shader_code_vert.as_ref().unwrap())),
109            }));
110            shader_frag = Some(device.create_shader_module(wgpu::ShaderModuleDescriptor {
111                label: self.shader_label.as_deref(),
112                source: wgpu::ShaderSource::Wgsl(Cow::Borrowed(self.shader_code_frag.as_ref().unwrap())),
113            }));
114            (shader_vert.as_ref().unwrap(), shader_frag.as_ref().unwrap())
115        };
116
117        let mut bind_group_layouts: Vec<wgpu::BindGroupLayout> = Vec::new();
124        for bgl_desc in self.bind_group_layouts_desc {
126            let bgl = bgl_desc.into_bind_group_layout(device);
130            bind_group_layouts.push(bgl);
131        }
132
133        let bind_group_layouts: Vec<&wgpu::BindGroupLayout> = bind_group_layouts.iter().collect();
135
136        let render_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
138            label: None,
139            bind_group_layouts: bind_group_layouts.as_slice(),
140            push_constant_ranges: &[],
141        });
142
143        let fragment_state: Option<wgpu::FragmentState> = if self.disable_fragment_shader {
145            None
146        } else {
147            Some(wgpu::FragmentState {
148                module: shader_frag,
149                entry_point: DEFAULT_FS_SHADER_ENTRY_POINT,
150                targets: self.render_targets.as_slice(),
151                compilation_options: wgpu::PipelineCompilationOptions::default(),
152            })
153        };
154
155        let pipeline_desc = wgpu::RenderPipelineDescriptor {
156            label: self.label.as_deref(),
157            layout: Some(&render_pipeline_layout),
159            vertex: wgpu::VertexState {
160                module: shader_vert,
161                entry_point: DEFAULT_VS_SHADER_ENTRY_POINT,
162                buffers: self.vertex_buffers_layouts.as_slice(),
163                compilation_options: wgpu::PipelineCompilationOptions::default(),
164            },
165            fragment: fragment_state,
166            primitive: self.primitive,
167            depth_stencil: self.depth_stencil,
168            multisample: self.multisample,
169            multiview: None,
170            cache: None,
171        };
172
173        device.create_render_pipeline(&pipeline_desc)
175    }
176}
177
178pub struct RenderPipelineDescBuilder {
182    pipeline_desc: Option<RenderPipelineDesc>,
183}
184impl Default for RenderPipelineDescBuilder {
185    fn default() -> Self {
186        Self::new()
187    }
188}
189impl RenderPipelineDescBuilder {
190    pub fn new() -> Self {
191        Self {
192            pipeline_desc: Some(RenderPipelineDesc::default()),
193        }
194    }
195
196    #[must_use]
199    pub fn label(mut self, label: &str) -> Self {
200        self.pipeline_desc.as_mut().unwrap().label = Some(String::from(label));
201        self
202    }
203
204    #[must_use]
207    pub fn shader_code(mut self, code: &str) -> Self {
208        self.pipeline_desc.as_mut().unwrap().shader_code = Some(String::from(code));
209        self
210    }
211
212    #[must_use]
215    pub fn shader_code_vert(mut self, code: &str) -> Self {
216        self.pipeline_desc.as_mut().unwrap().shader_code_vert = Some(String::from(code));
217        self
218    }
219
220    #[must_use]
223    pub fn shader_code_frag(mut self, code: &str) -> Self {
224        self.pipeline_desc.as_mut().unwrap().shader_code_frag = Some(String::from(code));
225        self
226    }
227
228    #[must_use]
231    pub fn shader_label(mut self, label: &str) -> Self {
232        self.pipeline_desc.as_mut().unwrap().shader_label = Some(String::from(label));
233        self
234    }
235
236    #[must_use]
239    pub fn disable_fragment_shader(mut self) -> Self {
240        self.pipeline_desc.as_mut().unwrap().disable_fragment_shader = true;
241        self
242    }
243
244    #[must_use]
247    pub fn add_bind_group_layout_desc(mut self, layout_desc: BindGroupLayoutDesc) -> Self {
248        self.pipeline_desc.as_mut().unwrap().bind_group_layouts_desc.push(layout_desc);
249        self
250    }
251
252    #[must_use]
255    pub fn add_vertex_buffer_layout(mut self, vertex_layout: wgpu::VertexBufferLayout<'static>) -> Self {
256        self.pipeline_desc.as_mut().unwrap().vertex_buffers_layouts.push(vertex_layout);
257        self
258    }
259
260    #[must_use]
263    pub fn add_render_target(mut self, render_target: wgpu::ColorTargetState) -> Self {
264        self.pipeline_desc.as_mut().unwrap().render_targets.push(Some(render_target));
265        self
266    }
267
268    #[must_use]
271    pub fn primitive(mut self, primitive: wgpu::PrimitiveState) -> Self {
272        self.pipeline_desc.as_mut().unwrap().primitive = primitive;
273        self
274    }
275
276    #[must_use]
279    pub fn depth_state(mut self, depth_state: Option<wgpu::DepthStencilState>) -> Self {
280        self.pipeline_desc.as_mut().unwrap().depth_stencil = depth_state;
281        self
282    }
283
284    #[must_use]
287    pub fn multisample(mut self, multisample: wgpu::MultisampleState) -> Self {
288        self.pipeline_desc.as_mut().unwrap().multisample = multisample;
289        self
290    }
291
292    pub fn build_desc(&mut self) -> RenderPipelineDesc {
295        self.pipeline_desc.take().unwrap() }
298
299    pub fn build_pipeline(&mut self, device: &wgpu::Device) -> wgpu::RenderPipeline {
302        let desc = self.pipeline_desc.take().unwrap();
303        desc.into_render_pipeline(device)
304    }
305}