use crate::{RendererError, RendererResult};
use dear_imgui_rs::render::DrawVert;
use std::mem::size_of;
use wgpu::*;
pub const VS_ENTRY_POINT: &str = "vs_main";
pub const FS_ENTRY_POINT: &str = "fs_main";
pub const SHADER_SOURCE: &str = r#"
// Dear ImGui WGSL Shader
// Vertex and fragment shaders for rendering Dear ImGui draw data
struct VertexInput {
@location(0) position: vec2<f32>,
@location(1) uv: vec2<f32>,
@location(2) color: vec4<f32>,
}
struct VertexOutput {
@builtin(position) position: vec4<f32>,
@location(0) color: vec4<f32>,
@location(1) uv: vec2<f32>,
}
struct Uniforms {
mvp: mat4x4<f32>,
gamma: f32,
}
@group(0) @binding(0)
var<uniform> uniforms: Uniforms;
@group(0) @binding(1)
var u_sampler: sampler;
@group(1) @binding(0)
var u_texture: texture_2d<f32>;
@vertex
fn vs_main(in: VertexInput) -> VertexOutput {
var out: VertexOutput;
out.position = uniforms.mvp * vec4<f32>(in.position, 0.0, 1.0);
out.color = in.color;
out.uv = in.uv;
return out;
}
@fragment
fn fs_main(in: VertexOutput) -> @location(0) vec4<f32> {
let color = in.color * textureSample(u_texture, u_sampler, in.uv);
// Apply gamma curve if uniforms.gamma != 1.0. With gamma=1.0 this is a no-op.
let corrected = pow(color.rgb, vec3<f32>(uniforms.gamma));
return vec4<f32>(corrected, color.a);
}
"#;
pub struct ShaderManager {
shader_module: Option<ShaderModule>,
}
impl ShaderManager {
pub fn new() -> Self {
Self {
shader_module: None,
}
}
pub fn initialize(&mut self, device: &Device) -> RendererResult<()> {
let shader_module = device.create_shader_module(ShaderModuleDescriptor {
label: Some("Dear ImGui Shader"),
source: ShaderSource::Wgsl(SHADER_SOURCE.into()),
});
self.shader_module = Some(shader_module);
Ok(())
}
pub fn shader_module(&self) -> Option<&ShaderModule> {
self.shader_module.as_ref()
}
pub fn get_shader_module(&self) -> RendererResult<&ShaderModule> {
self.shader_module.as_ref().ok_or_else(|| {
RendererError::ShaderCompilationFailed("Shader module not initialized".to_string())
})
}
pub fn is_initialized(&self) -> bool {
self.shader_module.is_some()
}
}
impl Default for ShaderManager {
fn default() -> Self {
Self::new()
}
}
pub fn create_vertex_buffer_layout() -> VertexBufferLayout<'static> {
const VERTEX_ATTRIBUTES: &[VertexAttribute] = &vertex_attr_array![
0 => Float32x2, 1 => Float32x2, 2 => Unorm8x4 ];
VertexBufferLayout {
array_stride: size_of::<DrawVert>() as BufferAddress,
step_mode: VertexStepMode::Vertex,
attributes: VERTEX_ATTRIBUTES,
}
}
pub fn create_vertex_state<'a>(
shader_module: &'a ShaderModule,
buffers: &'a [VertexBufferLayout],
) -> VertexState<'a> {
VertexState {
module: shader_module,
entry_point: Some(VS_ENTRY_POINT),
compilation_options: Default::default(),
buffers,
}
}
pub fn create_fragment_state<'a>(
shader_module: &'a ShaderModule,
targets: &'a [Option<ColorTargetState>],
_use_gamma_correction: bool,
) -> FragmentState<'a> {
let entry_point = FS_ENTRY_POINT;
FragmentState {
module: shader_module,
entry_point: Some(entry_point),
compilation_options: Default::default(),
targets,
}
}