Skip to main content

dear_imgui_wgpu/
shaders.rs

1//! Shader management for the WGPU renderer
2//!
3//! This module handles shader creation and management, including the WGSL shaders
4//! and pipeline creation logic.
5
6use crate::{RendererError, RendererResult};
7use dear_imgui_rs::render::DrawVert;
8use std::mem::size_of;
9use wgpu::*;
10
11/// Vertex shader entry point
12pub const VS_ENTRY_POINT: &str = "vs_main";
13/// Fragment shader entry point
14pub const FS_ENTRY_POINT: &str = "fs_main";
15
16/// WGSL shader source
17///
18/// This includes both vertex and fragment shaders with optional gamma correction,
19/// similar in spirit to imgui_impl_wgpu.cpp
20pub const SHADER_SOURCE: &str = r#"
21// Dear ImGui WGSL Shader
22// Vertex and fragment shaders for rendering Dear ImGui draw data
23
24struct VertexInput {
25    @location(0) position: vec2<f32>,
26    @location(1) uv: vec2<f32>,
27    @location(2) color: vec4<f32>,
28}
29
30struct VertexOutput {
31    @builtin(position) position: vec4<f32>,
32    @location(0) color: vec4<f32>,
33    @location(1) uv: vec2<f32>,
34}
35
36struct Uniforms {
37    mvp: mat4x4<f32>,
38    gamma: f32,
39}
40
41@group(0) @binding(0)
42var<uniform> uniforms: Uniforms;
43
44@group(0) @binding(1)
45var u_sampler: sampler;
46
47@group(1) @binding(0)
48var u_texture: texture_2d<f32>;
49
50@vertex
51fn vs_main(in: VertexInput) -> VertexOutput {
52    var out: VertexOutput;
53    out.position = uniforms.mvp * vec4<f32>(in.position, 0.0, 1.0);
54    out.color = in.color;
55    out.uv = in.uv;
56    return out;
57}
58
59@fragment
60fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
61    let color = in.color * textureSample(u_texture, u_sampler, in.uv);
62    // Apply gamma curve if uniforms.gamma != 1.0. With gamma=1.0 this is a no-op.
63    let corrected = pow(color.rgb, vec3<f32>(uniforms.gamma));
64    return vec4<f32>(corrected, color.a);
65}
66"#;
67
68/// Shader manager
69pub struct ShaderManager {
70    shader_module: Option<ShaderModule>,
71}
72
73impl ShaderManager {
74    /// Create a new shader manager
75    pub fn new() -> Self {
76        Self {
77            shader_module: None,
78        }
79    }
80
81    /// Initialize shaders
82    pub fn initialize(&mut self, device: &Device) -> RendererResult<()> {
83        let shader_module = device.create_shader_module(ShaderModuleDescriptor {
84            label: Some("Dear ImGui Shader"),
85            source: ShaderSource::Wgsl(SHADER_SOURCE.into()),
86        });
87
88        self.shader_module = Some(shader_module);
89        Ok(())
90    }
91
92    /// Get the shader module
93    pub fn shader_module(&self) -> Option<&ShaderModule> {
94        self.shader_module.as_ref()
95    }
96
97    /// Get the shader module reference (with error handling)
98    pub fn get_shader_module(&self) -> RendererResult<&ShaderModule> {
99        self.shader_module.as_ref().ok_or_else(|| {
100            RendererError::ShaderCompilationFailed("Shader module not initialized".to_string())
101        })
102    }
103
104    /// Check if shaders are initialized
105    pub fn is_initialized(&self) -> bool {
106        self.shader_module.is_some()
107    }
108}
109
110impl Default for ShaderManager {
111    fn default() -> Self {
112        Self::new()
113    }
114}
115
116/// Create vertex buffer layout for Dear ImGui vertices
117pub fn create_vertex_buffer_layout() -> VertexBufferLayout<'static> {
118    const VERTEX_ATTRIBUTES: &[VertexAttribute] = &vertex_attr_array![
119        0 => Float32x2,  // position
120        1 => Float32x2,  // uv
121        2 => Unorm8x4    // color
122    ];
123
124    VertexBufferLayout {
125        array_stride: size_of::<DrawVert>() as BufferAddress,
126        step_mode: VertexStepMode::Vertex,
127        attributes: VERTEX_ATTRIBUTES,
128    }
129}
130
131/// Create vertex state for render pipeline
132pub fn create_vertex_state<'a>(
133    shader_module: &'a ShaderModule,
134    buffers: &'a [VertexBufferLayout],
135) -> VertexState<'a> {
136    VertexState {
137        module: shader_module,
138        entry_point: Some(VS_ENTRY_POINT),
139        compilation_options: Default::default(),
140        buffers,
141    }
142}
143
144/// Create fragment state for render pipeline
145pub fn create_fragment_state<'a>(
146    shader_module: &'a ShaderModule,
147    targets: &'a [Option<ColorTargetState>],
148    _use_gamma_correction: bool,
149) -> FragmentState<'a> {
150    let entry_point = FS_ENTRY_POINT;
151
152    FragmentState {
153        module: shader_module,
154        entry_point: Some(entry_point),
155        compilation_options: Default::default(),
156        targets,
157    }
158}