use imgui::{
Context, DrawCmd::Elements, DrawData, DrawIdx, DrawList, DrawVert, TextureId, Textures,
};
use smallvec::SmallVec;
use std::mem::size_of;
use wgpu::util::{BufferInitDescriptor, DeviceExt};
use wgpu::*;
pub type RendererResult<T> = Result<T, RendererError>;
#[repr(transparent)]
#[derive(Debug, Copy, Clone)]
struct DrawVertPod(DrawVert);
unsafe impl bytemuck::Zeroable for DrawVertPod {}
unsafe impl bytemuck::Pod for DrawVertPod {}
#[derive(Clone, Debug)]
pub enum RendererError {
BadTexture(TextureId),
}
#[allow(dead_code)]
enum ShaderStage {
Vertex,
Fragment,
Compute,
}
#[cfg(feature = "glsl-to-spirv")]
struct Shaders;
#[cfg(feature = "glsl-to-spirv")]
impl Shaders {
fn compile_glsl(code: &str, stage: ShaderStage) -> ShaderModuleSource<'static> {
use std::io::Read as _;
let ty = match stage {
ShaderStage::Vertex => glsl_to_spirv::ShaderType::Vertex,
ShaderStage::Fragment => glsl_to_spirv::ShaderType::Fragment,
ShaderStage::Compute => glsl_to_spirv::ShaderType::Compute,
};
let mut data = Vec::new();
glsl_to_spirv::compile(&code, ty)
.unwrap()
.read_to_end(&mut data)
.unwrap();
let source = util::make_spirv(&data);
if let ShaderModuleSource::SpirV(cow) = source {
ShaderModuleSource::SpirV(std::borrow::Cow::Owned(cow.into()))
} else {
unreachable!()
}
}
fn get_program_code() -> (&'static str, &'static str) {
(include_str!("imgui.vert"), include_str!("imgui.frag"))
}
}
#[derive(Clone)]
pub struct TextureConfig<'a> {
pub size: Extent3d,
pub label: Option<&'a str>,
pub format: Option<TextureFormat>,
pub usage: TextureUsage,
pub mip_level_count: u32,
pub sample_count: u32,
pub dimension: TextureDimension,
}
impl<'a> TextureConfig<'a> {
pub fn new(width: u32, height: u32) -> TextureConfig<'static> {
TextureConfig {
size: Extent3d {
width,
height,
depth: 1,
},
label: None,
format: None,
usage: TextureUsage::SAMPLED | TextureUsage::COPY_DST,
mip_level_count: 1,
sample_count: 1,
dimension: TextureDimension::D2,
}
}
pub fn set_depth(mut self, depth: u32) -> Self {
self.size.depth = depth;
self
}
pub fn set_label<'b>(mut self, label: &'b str) -> TextureConfig<'b> {
self.label = None;
let mut result: TextureConfig<'b> = unsafe { std::mem::transmute(self) };
result.label = Some(label);
result
}
pub fn set_format(mut self, format: TextureFormat) -> Self {
self.format = Some(format);
self
}
pub fn set_usage(mut self, usage: TextureUsage) -> Self {
self.usage = usage;
self
}
pub fn set_mip_level_count(mut self, mip_level_count: u32) -> Self {
self.mip_level_count = mip_level_count;
self
}
pub fn set_sample_count(mut self, sample_count: u32) -> Self {
self.sample_count = sample_count;
self
}
pub fn set_dimension(mut self, dimension: TextureDimension) -> Self {
self.dimension = dimension;
self
}
pub fn build(self, device: &Device, renderer: &Renderer) -> Texture {
Texture::new(device, renderer, self)
}
}
pub struct Texture {
texture: wgpu::Texture,
view: wgpu::TextureView,
bind_group: BindGroup,
size: Extent3d,
}
impl Texture {
pub fn from_raw_parts(
texture: wgpu::Texture,
view: wgpu::TextureView,
bind_group: BindGroup,
size: Extent3d,
) -> Self {
Texture {
texture,
view,
bind_group,
size,
}
}
pub fn new(device: &Device, renderer: &Renderer, config: TextureConfig) -> Self {
let texture = device.create_texture(&TextureDescriptor {
label: config.label,
size: config.size,
mip_level_count: config.mip_level_count,
sample_count: config.sample_count,
dimension: config.dimension,
format: config.format.unwrap_or(renderer.config.texture_format),
usage: config.usage,
});
let view = texture.create_view(&TextureViewDescriptor::default());
let sampler = device.create_sampler(&SamplerDescriptor {
label: Some("imgui-wgpu sampler"),
address_mode_u: AddressMode::ClampToEdge,
address_mode_v: AddressMode::ClampToEdge,
address_mode_w: AddressMode::ClampToEdge,
mag_filter: FilterMode::Linear,
min_filter: FilterMode::Linear,
mipmap_filter: FilterMode::Linear,
lod_min_clamp: -100.0,
lod_max_clamp: 100.0,
compare: None,
anisotropy_clamp: None,
});
let bind_group = device.create_bind_group(&BindGroupDescriptor {
label: config.label,
layout: &renderer.texture_layout,
entries: &[
BindGroupEntry {
binding: 0,
resource: BindingResource::TextureView(&view),
},
BindGroupEntry {
binding: 1,
resource: BindingResource::Sampler(&sampler),
},
],
});
Texture {
texture,
view,
bind_group,
size: config.size,
}
}
pub fn write(&self, queue: &Queue, data: &[u8], width: u32, height: u32) {
queue.write_texture(
TextureCopyView {
texture: &self.texture,
mip_level: 0,
origin: Origin3d { x: 0, y: 0, z: 0 },
},
data,
TextureDataLayout {
offset: 0,
bytes_per_row: width * 4,
rows_per_image: height,
},
Extent3d {
width,
height,
depth: 1,
},
);
}
pub fn width(&self) -> u32 {
self.size.width
}
pub fn height(&self) -> u32 {
self.size.height
}
pub fn depth(&self) -> u32 {
self.size.depth
}
pub fn size(&self) -> Extent3d {
self.size
}
pub fn texture(&self) -> &wgpu::Texture {
&self.texture
}
pub fn view(&self) -> &wgpu::TextureView {
&self.view
}
}
pub struct RendererConfig<'vs, 'fs> {
texture_format: TextureFormat,
depth_format: Option<TextureFormat>,
sample_count: u32,
vertex_shader: Option<ShaderModuleSource<'vs>>,
fragment_shader: Option<ShaderModuleSource<'fs>>,
}
impl RendererConfig<'_, '_> {
pub fn with_shaders<'vs, 'fs>(
vertex_shader: ShaderModuleSource<'vs>,
fragment_shader: ShaderModuleSource<'fs>,
) -> RendererConfig<'vs, 'fs> {
RendererConfig {
texture_format: TextureFormat::Rgba8Unorm,
depth_format: None,
sample_count: 1,
vertex_shader: Some(vertex_shader),
fragment_shader: Some(fragment_shader),
}
}
pub fn new() -> RendererConfig<'static, 'static> {
Self::with_shaders(
include_spirv!("imgui.vert.spv"),
include_spirv!("imgui.frag.spv"),
)
}
#[cfg(feature = "glsl-to-spirv")]
pub fn new_glsl() -> RendererConfig<'static, 'static> {
let (vs_code, fs_code) = Shaders::get_program_code();
let vs_raw = Shaders::compile_glsl(vs_code, ShaderStage::Vertex);
let fs_raw = Shaders::compile_glsl(fs_code, ShaderStage::Fragment);
Self::with_shaders(vs_raw, fs_raw)
}
pub fn set_texture_format(mut self, texture_format: TextureFormat) -> Self {
self.texture_format = texture_format;
self
}
pub fn set_depth_format(mut self, depth_format: TextureFormat) -> Self {
self.depth_format = Some(depth_format);
self
}
pub fn set_sample_count(mut self, sample_count: u32) -> Self {
self.sample_count = sample_count;
self
}
pub fn build(self, imgui: &mut Context, device: &Device, queue: &Queue) -> Renderer {
Renderer::new(imgui, device, queue, self)
}
}
pub struct Renderer {
pipeline: RenderPipeline,
uniform_buffer: Buffer,
uniform_bind_group: BindGroup,
pub textures: Textures<Texture>,
texture_layout: BindGroupLayout,
index_buffers: SmallVec<[Buffer; 4]>,
vertex_buffers: SmallVec<[Buffer; 4]>,
config: RendererConfig<'static, 'static>,
}
impl Renderer {
pub fn new(
imgui: &mut Context,
device: &Device,
queue: &Queue,
config: RendererConfig,
) -> Renderer {
let RendererConfig {
texture_format,
depth_format,
sample_count,
vertex_shader,
fragment_shader,
} = config;
let vs_module = device.create_shader_module(vertex_shader.unwrap());
let fs_module = device.create_shader_module(fragment_shader.unwrap());
let size = 64;
let uniform_buffer = device.create_buffer(&BufferDescriptor {
label: Some("imgui-wgpu uniform buffer"),
size,
usage: BufferUsage::UNIFORM | BufferUsage::COPY_DST,
mapped_at_creation: false,
});
let uniform_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: None,
entries: &[BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStage::VERTEX,
ty: BindingType::UniformBuffer {
dynamic: false,
min_binding_size: None,
},
count: None,
}],
});
let uniform_bind_group = device.create_bind_group(&BindGroupDescriptor {
label: Some("imgui-wgpu bind group"),
layout: &uniform_layout,
entries: &[BindGroupEntry {
binding: 0,
resource: BindingResource::Buffer(uniform_buffer.slice(..)),
}],
});
let texture_layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
label: Some("imgui-wgpu bind group layout"),
entries: &[
BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStage::FRAGMENT,
ty: BindingType::SampledTexture {
multisampled: false,
component_type: TextureComponentType::Float,
dimension: TextureViewDimension::D2,
},
count: None,
},
BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStage::FRAGMENT,
ty: BindingType::Sampler { comparison: false },
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
label: Some("imgui-wgpu pipeline layout"),
bind_group_layouts: &[&uniform_layout, &texture_layout],
push_constant_ranges: &[],
});
let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
label: Some("imgui-wgpu pipeline"),
layout: Some(&pipeline_layout),
vertex_stage: ProgrammableStageDescriptor {
module: &vs_module,
entry_point: "main",
},
fragment_stage: Some(ProgrammableStageDescriptor {
module: &fs_module,
entry_point: "main",
}),
rasterization_state: Some(RasterizationStateDescriptor {
front_face: FrontFace::Cw,
cull_mode: CullMode::None,
clamp_depth: false,
depth_bias: 0,
depth_bias_slope_scale: 0.0,
depth_bias_clamp: 0.0,
}),
primitive_topology: PrimitiveTopology::TriangleList,
color_states: &[ColorStateDescriptor {
format: texture_format,
color_blend: BlendDescriptor {
src_factor: BlendFactor::SrcAlpha,
dst_factor: BlendFactor::OneMinusSrcAlpha,
operation: BlendOperation::Add,
},
alpha_blend: BlendDescriptor {
src_factor: BlendFactor::OneMinusDstAlpha,
dst_factor: BlendFactor::One,
operation: BlendOperation::Add,
},
write_mask: ColorWrite::ALL,
}],
depth_stencil_state: depth_format.map(|format| wgpu::DepthStencilStateDescriptor {
format,
depth_write_enabled: false,
depth_compare: wgpu::CompareFunction::Always,
stencil: wgpu::StencilStateDescriptor::default(),
}),
vertex_state: VertexStateDescriptor {
index_format: IndexFormat::Uint16,
vertex_buffers: &[VertexBufferDescriptor {
stride: size_of::<DrawVert>() as BufferAddress,
step_mode: InputStepMode::Vertex,
attributes: &vertex_attr_array![0 => Float2, 1 => Float2, 2 => Uint],
}],
},
sample_count,
sample_mask: !0,
alpha_to_coverage_enabled: false,
});
let mut renderer = Renderer {
pipeline,
uniform_buffer,
uniform_bind_group,
textures: Textures::new(),
texture_layout,
vertex_buffers: SmallVec::new(),
index_buffers: SmallVec::new(),
config: RendererConfig {
texture_format,
depth_format,
sample_count,
vertex_shader: None,
fragment_shader: None,
},
};
renderer.reload_font_texture(imgui, device, queue);
renderer
}
pub fn render<'r>(
&'r mut self,
draw_data: &DrawData,
queue: &Queue,
device: &Device,
rpass: &mut RenderPass<'r>,
) -> RendererResult<()> {
let fb_width = draw_data.display_size[0] * draw_data.framebuffer_scale[0];
let fb_height = draw_data.display_size[1] * draw_data.framebuffer_scale[1];
if !(fb_width > 0.0 && fb_height > 0.0) {
return Ok(());
}
let width = draw_data.display_size[0];
let height = draw_data.display_size[1];
let offset_x = draw_data.display_pos[0] / width;
let offset_y = draw_data.display_pos[1] / height;
let matrix = [
[2.0 / width, 0.0, 0.0, 0.0],
[0.0, 2.0 / -height as f32, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[-1.0 - offset_x * 2.0, 1.0 + offset_y * 2.0, 0.0, 1.0],
];
self.update_uniform_buffer(queue, &matrix);
rpass.set_pipeline(&self.pipeline);
rpass.set_bind_group(0, &self.uniform_bind_group, &[]);
self.vertex_buffers.clear();
self.index_buffers.clear();
for draw_list in draw_data.draw_lists() {
self.vertex_buffers
.push(self.upload_vertex_buffer(device, draw_list.vtx_buffer()));
self.index_buffers
.push(self.upload_index_buffer(device, draw_list.idx_buffer()));
}
for (draw_list_buffers_index, draw_list) in draw_data.draw_lists().enumerate() {
self.render_draw_list(
rpass,
&draw_list,
draw_data.display_pos,
draw_data.framebuffer_scale,
draw_list_buffers_index,
)?;
}
Ok(())
}
fn render_draw_list<'render>(
&'render self,
rpass: &mut RenderPass<'render>,
draw_list: &DrawList,
clip_off: [f32; 2],
clip_scale: [f32; 2],
draw_list_buffers_index: usize,
) -> RendererResult<()> {
let mut start = 0;
let index_buffer = &self.index_buffers[draw_list_buffers_index];
let vertex_buffer = &self.vertex_buffers[draw_list_buffers_index];
rpass.set_index_buffer(index_buffer.slice(..));
rpass.set_vertex_buffer(0, vertex_buffer.slice(..));
for cmd in draw_list.commands() {
match cmd {
Elements { count, cmd_params } => {
let clip_rect = [
(cmd_params.clip_rect[0] - clip_off[0]) * clip_scale[0],
(cmd_params.clip_rect[1] - clip_off[1]) * clip_scale[1],
(cmd_params.clip_rect[2] - clip_off[0]) * clip_scale[0],
(cmd_params.clip_rect[3] - clip_off[1]) * clip_scale[1],
];
let texture_id = cmd_params.texture_id.into();
let tex = self
.textures
.get(texture_id)
.ok_or_else(|| RendererError::BadTexture(texture_id))?;
rpass.set_bind_group(1, &tex.bind_group, &[]);
let scissors = (
clip_rect[0].max(0.0).floor() as u32,
clip_rect[1].max(0.0).floor() as u32,
(clip_rect[2] - clip_rect[0]).abs().ceil() as u32,
(clip_rect[3] - clip_rect[1]).abs().ceil() as u32,
);
rpass.set_scissor_rect(scissors.0, scissors.1, scissors.2, scissors.3);
let end = start + count as u32;
rpass.draw_indexed(start..end, 0, 0..1);
start = end;
}
_ => {}
}
}
Ok(())
}
fn update_uniform_buffer(&mut self, queue: &Queue, matrix: &[[f32; 4]; 4]) {
let data = bytemuck::bytes_of(matrix);
queue.write_buffer(&self.uniform_buffer, 0, data);
}
fn upload_vertex_buffer(&self, device: &Device, vertices: &[DrawVert]) -> Buffer {
let vertices = unsafe {
std::slice::from_raw_parts(vertices.as_ptr() as *mut DrawVertPod, vertices.len())
};
let data = bytemuck::cast_slice(&vertices);
device.create_buffer_init(&BufferInitDescriptor {
label: Some("imgui-wgpu vertex buffer"),
contents: data,
usage: BufferUsage::VERTEX,
})
}
fn upload_index_buffer(&self, device: &Device, indices: &[DrawIdx]) -> Buffer {
let data = bytemuck::cast_slice(&indices);
device.create_buffer_init(&BufferInitDescriptor {
label: Some("imgui-wgpu index buffer"),
contents: data,
usage: BufferUsage::INDEX,
})
}
pub fn reload_font_texture(&mut self, imgui: &mut Context, device: &Device, queue: &Queue) {
let mut fonts = imgui.fonts();
self.textures.remove(fonts.tex_id);
let handle = fonts.build_rgba32_texture();
let font_texture = TextureConfig::new(handle.width, handle.height)
.set_label("imgui-wgpu font atlas")
.build(&device, self);
font_texture.write(&queue, handle.data, handle.width, handle.height);
fonts.tex_id = self.textures.insert(font_texture);
fonts.clear_tex_data();
}
}