use crate::ecs::world::World;
use crate::render::wgpu::rendergraph::{PassExecutionContext, PassNode};
use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
const GRAYSCALE_SHADER: &str = "
@group(0) @binding(0)
var input_texture: texture_2d<f32>;
@group(0) @binding(1)
var output_texture: texture_storage_2d<rgba8unorm, write>;
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let texture_size = textureDimensions(input_texture);
if (global_id.x >= texture_size.x || global_id.y >= texture_size.y) {
return;
}
let coords = vec2<i32>(i32(global_id.x), i32(global_id.y));
let color = textureLoad(input_texture, coords, 0);
let luminance = dot(color.rgb, vec3<f32>(0.299, 0.587, 0.114));
let output_color = vec4<f32>(luminance, luminance, luminance, color.a);
textureStore(output_texture, vec2<i32>(i32(global_id.x), i32(global_id.y)), output_color);
}
";
const COPY_SHADER: &str = "
@group(0) @binding(0)
var input_texture: texture_2d<f32>;
@group(0) @binding(1)
var output_texture: texture_storage_2d<rgba8unorm, write>;
@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
let texture_size = textureDimensions(input_texture);
if (global_id.x >= texture_size.x || global_id.y >= texture_size.y) {
return;
}
let coords = vec2<i32>(i32(global_id.x), i32(global_id.y));
let color = textureLoad(input_texture, coords, 0);
textureStore(output_texture, vec2<i32>(i32(global_id.x), i32(global_id.y)), color);
}
";
pub struct ComputeGrayscalePass {
grayscale_pipeline: ComputePipeline,
copy_pipeline: ComputePipeline,
bind_group_layout: BindGroupLayout,
cached_bind_group: Option<BindGroup>,
}
impl ComputeGrayscalePass {
pub fn new(device: &wgpu::Device) -> Self {
let grayscale_shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Grayscale Shader"),
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(GRAYSCALE_SHADER)),
});
let copy_shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
label: Some("Copy Shader"),
source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(COPY_SHADER)),
});
let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
label: Some("Compute Grayscale Bind Group Layout"),
entries: &[
wgpu::BindGroupLayoutEntry {
binding: 0,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::Texture {
sample_type: wgpu::TextureSampleType::Float { filterable: false },
view_dimension: wgpu::TextureViewDimension::D2,
multisampled: false,
},
count: None,
},
wgpu::BindGroupLayoutEntry {
binding: 1,
visibility: wgpu::ShaderStages::COMPUTE,
ty: wgpu::BindingType::StorageTexture {
access: wgpu::StorageTextureAccess::WriteOnly,
format: wgpu::TextureFormat::Rgba8Unorm,
view_dimension: wgpu::TextureViewDimension::D2,
},
count: None,
},
],
});
let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
label: Some("Compute Grayscale Pipeline Layout"),
bind_group_layouts: &[Some(&bind_group_layout)],
immediate_size: 0,
});
let grayscale_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Grayscale Pipeline"),
layout: Some(&pipeline_layout),
module: &grayscale_shader_module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
let copy_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
label: Some("Copy Pipeline"),
layout: Some(&pipeline_layout),
module: ©_shader_module,
entry_point: Some("main"),
compilation_options: Default::default(),
cache: None,
});
Self {
grayscale_pipeline,
copy_pipeline,
bind_group_layout,
cached_bind_group: None,
}
}
}
impl PassNode<World> for ComputeGrayscalePass {
fn name(&self) -> &str {
"compute_grayscale"
}
fn reads(&self) -> Vec<&str> {
vec!["input"]
}
fn writes(&self) -> Vec<&str> {
vec!["output"]
}
fn invalidate_bind_groups(&mut self) {
self.cached_bind_group = None;
}
fn execute<'r, 'e>(
&mut self,
context: PassExecutionContext<'r, 'e, World>,
) -> crate::render::wgpu::rendergraph::Result<
Vec<crate::render::wgpu::rendergraph::SubGraphRunCommand<'r>>,
> {
if self.cached_bind_group.is_none() {
let input_view = context.get_texture_view("input")?;
let output_view = context.get_texture_view("output")?;
self.cached_bind_group = Some(context.device.create_bind_group(
&wgpu::BindGroupDescriptor {
label: Some("Compute Grayscale Bind Group"),
layout: &self.bind_group_layout,
entries: &[
wgpu::BindGroupEntry {
binding: 0,
resource: wgpu::BindingResource::TextureView(input_view),
},
wgpu::BindGroupEntry {
binding: 1,
resource: wgpu::BindingResource::TextureView(output_view),
},
],
},
));
}
let texture_size = context.get_texture_size("input")?;
let is_enabled = context.is_pass_enabled();
let mut compute_pass = context
.encoder
.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: Some("Compute Grayscale Pass"),
timestamp_writes: None,
});
let pipeline = if is_enabled {
&self.grayscale_pipeline
} else {
&self.copy_pipeline
};
compute_pass.set_pipeline(pipeline);
compute_pass.set_bind_group(0, self.cached_bind_group.as_ref().unwrap(), &[]);
let workgroup_count_x = texture_size.0.div_ceil(8);
let workgroup_count_y = texture_size.1.div_ceil(8);
compute_pass.dispatch_workgroups(workgroup_count_x, workgroup_count_y, 1);
drop(compute_pass);
Ok(context.into_sub_graph_commands())
}
}