dear_imgui_wgpu/
shaders.rs1use crate::{RendererError, RendererResult};
7use dear_imgui::render::DrawVert;
8use std::mem::size_of;
9use wgpu::*;
10
11pub const VS_ENTRY_POINT: &str = "vs_main";
13pub const FS_ENTRY_POINT: &str = "fs_main";
15
16pub const SHADER_SOURCE: &str = r#"
21// Dear ImGui WGSL Shader
22// Vertex and fragment shaders for rendering Dear ImGui draw data
23
24struct VertexInput {
25 @location(0) position: vec2<f32>,
26 @location(1) uv: vec2<f32>,
27 @location(2) color: vec4<f32>,
28}
29
30struct VertexOutput {
31 @builtin(position) position: vec4<f32>,
32 @location(0) color: vec4<f32>,
33 @location(1) uv: vec2<f32>,
34}
35
36struct Uniforms {
37 mvp: mat4x4<f32>,
38 gamma: f32,
39}
40
41@group(0) @binding(0)
42var<uniform> uniforms: Uniforms;
43
44@group(0) @binding(1)
45var u_sampler: sampler;
46
47@group(1) @binding(0)
48var u_texture: texture_2d<f32>;
49
50@vertex
51fn vs_main(in: VertexInput) -> VertexOutput {
52 var out: VertexOutput;
53 out.position = uniforms.mvp * vec4<f32>(in.position, 0.0, 1.0);
54 out.color = in.color;
55 out.uv = in.uv;
56 return out;
57}
58
59@fragment
60fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
61 let color = in.color * textureSample(u_texture, u_sampler, in.uv);
62 // Apply gamma curve if uniforms.gamma != 1.0. With gamma=1.0 this is a no-op.
63 let corrected = pow(color.rgb, vec3<f32>(uniforms.gamma));
64 return vec4<f32>(corrected, color.a);
65}
66"#;
67
68pub struct ShaderManager {
70 shader_module: Option<ShaderModule>,
71}
72
73impl ShaderManager {
74 pub fn new() -> Self {
76 Self {
77 shader_module: None,
78 }
79 }
80
81 pub fn initialize(&mut self, device: &Device) -> RendererResult<()> {
83 let shader_module = device.create_shader_module(ShaderModuleDescriptor {
84 label: Some("Dear ImGui Shader"),
85 source: ShaderSource::Wgsl(SHADER_SOURCE.into()),
86 });
87
88 self.shader_module = Some(shader_module);
89 Ok(())
90 }
91
92 pub fn shader_module(&self) -> Option<&ShaderModule> {
94 self.shader_module.as_ref()
95 }
96
97 pub fn get_shader_module(&self) -> RendererResult<&ShaderModule> {
99 self.shader_module.as_ref().ok_or_else(|| {
100 RendererError::ShaderCompilationFailed("Shader module not initialized".to_string())
101 })
102 }
103
104 pub fn is_initialized(&self) -> bool {
106 self.shader_module.is_some()
107 }
108}
109
110impl Default for ShaderManager {
111 fn default() -> Self {
112 Self::new()
113 }
114}
115
116pub fn create_vertex_buffer_layout() -> VertexBufferLayout<'static> {
118 const VERTEX_ATTRIBUTES: &[VertexAttribute] = &vertex_attr_array![
119 0 => Float32x2, 1 => Float32x2, 2 => Unorm8x4 ];
123
124 VertexBufferLayout {
125 array_stride: size_of::<DrawVert>() as BufferAddress,
126 step_mode: VertexStepMode::Vertex,
127 attributes: VERTEX_ATTRIBUTES,
128 }
129}
130
131pub fn create_vertex_state<'a>(
133 shader_module: &'a ShaderModule,
134 buffers: &'a [VertexBufferLayout],
135) -> VertexState<'a> {
136 VertexState {
137 module: shader_module,
138 entry_point: Some(VS_ENTRY_POINT),
139 compilation_options: Default::default(),
140 buffers,
141 }
142}
143
144pub fn create_fragment_state<'a>(
146 shader_module: &'a ShaderModule,
147 targets: &'a [Option<ColorTargetState>],
148 _use_gamma_correction: bool,
149) -> FragmentState<'a> {
150 let entry_point = FS_ENTRY_POINT;
151
152 FragmentState {
153 module: shader_module,
154 entry_point: Some(entry_point),
155 compilation_options: Default::default(),
156 targets,
157 }
158}
159
160pub fn create_bind_group_layouts(device: &Device) -> (BindGroupLayout, BindGroupLayout) {
162 let common_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
164 label: Some("Dear ImGui Common Bind Group Layout"),
165 entries: &[
166 BindGroupLayoutEntry {
167 binding: 0,
168 visibility: ShaderStages::VERTEX | ShaderStages::FRAGMENT,
169 ty: BindingType::Buffer {
170 ty: BufferBindingType::Uniform,
171 has_dynamic_offset: false,
172 min_binding_size: None,
173 },
174 count: None,
175 },
176 BindGroupLayoutEntry {
177 binding: 1,
178 visibility: ShaderStages::FRAGMENT,
179 ty: BindingType::Sampler(SamplerBindingType::Filtering),
180 count: None,
181 },
182 ],
183 });
184
185 let image_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
187 label: Some("Dear ImGui Image Bind Group Layout"),
188 entries: &[BindGroupLayoutEntry {
189 binding: 0,
190 visibility: ShaderStages::FRAGMENT,
191 ty: BindingType::Texture {
192 multisampled: false,
193 sample_type: TextureSampleType::Float { filterable: true },
194 view_dimension: TextureViewDimension::D2,
195 },
196 count: None,
197 }],
198 });
199
200 (common_layout, image_layout)
201}