nightshade 0.13.3

A cross-platform data-oriented game engine.
Documentation
use crate::ecs::world::World;
use crate::render::wgpu::rendergraph::{PassExecutionContext, PassNode};
use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};

const GRAYSCALE_SHADER: &str = "
@group(0) @binding(0)
var input_texture: texture_2d<f32>;

@group(0) @binding(1)
var output_texture: texture_storage_2d<rgba8unorm, write>;

@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let texture_size = textureDimensions(input_texture);

    if (global_id.x >= texture_size.x || global_id.y >= texture_size.y) {
        return;
    }

    let coords = vec2<i32>(i32(global_id.x), i32(global_id.y));
    let color = textureLoad(input_texture, coords, 0);

    let luminance = dot(color.rgb, vec3<f32>(0.299, 0.587, 0.114));
    let output_color = vec4<f32>(luminance, luminance, luminance, color.a);

    textureStore(output_texture, vec2<i32>(i32(global_id.x), i32(global_id.y)), output_color);
}
";

const COPY_SHADER: &str = "
@group(0) @binding(0)
var input_texture: texture_2d<f32>;

@group(0) @binding(1)
var output_texture: texture_storage_2d<rgba8unorm, write>;

@compute @workgroup_size(8, 8)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
    let texture_size = textureDimensions(input_texture);

    if (global_id.x >= texture_size.x || global_id.y >= texture_size.y) {
        return;
    }

    let coords = vec2<i32>(i32(global_id.x), i32(global_id.y));
    let color = textureLoad(input_texture, coords, 0);

    textureStore(output_texture, vec2<i32>(i32(global_id.x), i32(global_id.y)), color);
}
";

pub struct ComputeGrayscalePass {
    grayscale_pipeline: ComputePipeline,
    copy_pipeline: ComputePipeline,
    bind_group_layout: BindGroupLayout,
    cached_bind_group: Option<BindGroup>,
}

impl ComputeGrayscalePass {
    pub fn new(device: &wgpu::Device) -> Self {
        let grayscale_shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("Grayscale Shader"),
            source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(GRAYSCALE_SHADER)),
        });

        let copy_shader_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("Copy Shader"),
            source: wgpu::ShaderSource::Wgsl(std::borrow::Cow::Borrowed(COPY_SHADER)),
        });

        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
            label: Some("Compute Grayscale Bind Group Layout"),
            entries: &[
                wgpu::BindGroupLayoutEntry {
                    binding: 0,
                    visibility: wgpu::ShaderStages::COMPUTE,
                    ty: wgpu::BindingType::Texture {
                        sample_type: wgpu::TextureSampleType::Float { filterable: false },
                        view_dimension: wgpu::TextureViewDimension::D2,
                        multisampled: false,
                    },
                    count: None,
                },
                wgpu::BindGroupLayoutEntry {
                    binding: 1,
                    visibility: wgpu::ShaderStages::COMPUTE,
                    ty: wgpu::BindingType::StorageTexture {
                        access: wgpu::StorageTextureAccess::WriteOnly,
                        format: wgpu::TextureFormat::Rgba8Unorm,
                        view_dimension: wgpu::TextureViewDimension::D2,
                    },
                    count: None,
                },
            ],
        });

        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
            label: Some("Compute Grayscale Pipeline Layout"),
            bind_group_layouts: &[Some(&bind_group_layout)],
            immediate_size: 0,
        });

        let grayscale_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
            label: Some("Grayscale Pipeline"),
            layout: Some(&pipeline_layout),
            module: &grayscale_shader_module,
            entry_point: Some("main"),
            compilation_options: Default::default(),
            cache: None,
        });

        let copy_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
            label: Some("Copy Pipeline"),
            layout: Some(&pipeline_layout),
            module: &copy_shader_module,
            entry_point: Some("main"),
            compilation_options: Default::default(),
            cache: None,
        });

        Self {
            grayscale_pipeline,
            copy_pipeline,
            bind_group_layout,
            cached_bind_group: None,
        }
    }
}

impl PassNode<World> for ComputeGrayscalePass {
    fn name(&self) -> &str {
        "compute_grayscale"
    }

    fn reads(&self) -> Vec<&str> {
        vec!["input"]
    }

    fn writes(&self) -> Vec<&str> {
        vec!["output"]
    }

    fn invalidate_bind_groups(&mut self) {
        self.cached_bind_group = None;
    }

    fn execute<'r, 'e>(
        &mut self,
        context: PassExecutionContext<'r, 'e, World>,
    ) -> crate::render::wgpu::rendergraph::Result<
        Vec<crate::render::wgpu::rendergraph::SubGraphRunCommand<'r>>,
    > {
        if self.cached_bind_group.is_none() {
            let input_view = context.get_texture_view("input")?;
            let output_view = context.get_texture_view("output")?;

            self.cached_bind_group = Some(context.device.create_bind_group(
                &wgpu::BindGroupDescriptor {
                    label: Some("Compute Grayscale Bind Group"),
                    layout: &self.bind_group_layout,
                    entries: &[
                        wgpu::BindGroupEntry {
                            binding: 0,
                            resource: wgpu::BindingResource::TextureView(input_view),
                        },
                        wgpu::BindGroupEntry {
                            binding: 1,
                            resource: wgpu::BindingResource::TextureView(output_view),
                        },
                    ],
                },
            ));
        }

        let texture_size = context.get_texture_size("input")?;
        let is_enabled = context.is_pass_enabled();

        let mut compute_pass = context
            .encoder
            .begin_compute_pass(&wgpu::ComputePassDescriptor {
                label: Some("Compute Grayscale Pass"),
                timestamp_writes: None,
            });

        let pipeline = if is_enabled {
            &self.grayscale_pipeline
        } else {
            &self.copy_pipeline
        };

        compute_pass.set_pipeline(pipeline);
        compute_pass.set_bind_group(0, self.cached_bind_group.as_ref().unwrap(), &[]);

        let workgroup_count_x = texture_size.0.div_ceil(8);
        let workgroup_count_y = texture_size.1.div_ceil(8);
        compute_pass.dispatch_workgroups(workgroup_count_x, workgroup_count_y, 1);

        drop(compute_pass);

        Ok(context.into_sub_graph_commands())
    }
}