roast2d_internal 0.4.0

Roast2D internal crate
Documentation
use std::cell::RefCell;

use bytemuck;
use encase::ShaderType;
use glam::Vec2;

use crate::engine::Engine;
use crate::prelude::renderer_types::{TextureResource, Vertex};
use crate::prelude::wgpu::util::DeviceExt;
use crate::prelude::wgpu::{self, Device, Queue, RenderPass};

use crate::render::BackendState;
use crate::utils::bytes::to_wgsl_bytes;

use crate::renderer::traits::PostShader;

/// Settings for the retro CRT shader.
#[derive(Debug, Copy, Clone, ShaderType)]
pub struct RetroSettings {
    /// The resolution of the screen. This is updated automatically by the renderer.
    pub resolution: Vec2,
    /// The total elapsed time. This is updated automatically by the renderer.
    pub time: f32,
    /// The amount of barrel distortion (fisheye effect).
    /// A value of 0.0 means no distortion. Default is 0.15.
    pub barrel_dist_amount: f32,
    /// The amount of color bleed, simulating chromatic aberration.
    /// Controls how much the red and blue channels are shifted horizontally.
    /// Default is 0.25.
    pub color_bleed_amount: f32,
    /// The size of the scrolling scanlines. Larger values mean thicker lines.
    /// Default is 1.5.
    pub scanline_size: f32,
    /// The speed of the scrolling scanlines.
    /// Default is 5.0.
    pub scanline_speed: f32,
    /// The brightness of scanlines. Higher values make scanlines darker.
    /// Default is 0.5.
    pub scanline_brightness: f32,
    /// The brightness of the vignette effect, which darkens the corners of the screen.
    /// Default is 0.5.
    pub vignette_brightness: f32,
    /// The strength of the vignette effect. Higher values mean a more pronounced effect.
    /// Default is 0.75.
    pub vignette_amount: f32,
}

impl Default for RetroSettings {
    fn default() -> Self {
        Self {
            resolution: Vec2::ONE,
            time: 0.0,
            barrel_dist_amount: 0.085,
            color_bleed_amount: 0.25,
            scanline_size: 2.5,
            scanline_speed: 1.0,
            scanline_brightness: 0.8, // Increase brightness for lighter scanlines
            vignette_brightness: 0.5,
            vignette_amount: 0.75,
        }
    }
}

struct State {
    device: Device,
    queue: Queue,
    pipeline: wgpu::RenderPipeline,
    bind_group_layout: wgpu::BindGroupLayout,
    vertex_buffer: wgpu::Buffer,
    index_buffer: wgpu::Buffer,
    uniform_buffer: wgpu::Buffer,
}

impl State {
    pub fn new(param: BackendState) -> Self {
        let BackendState {
            device,
            queue,
            surface_config: _,
            surface_view_format,
        } = param;
        let format = surface_view_format;
        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
            label: Some("retro post-processing vertex module"),
            source: wgpu::ShaderSource::Wgsl(
                include_str!("../../../assets/shaders/crt-retro.wgsl").into(),
            ),
        });
        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
            label: Some("Retro CRT Bind Group Layout"),
            entries: &[
                wgpu::BindGroupLayoutEntry {
                    binding: 0,
                    visibility: wgpu::ShaderStages::FRAGMENT,
                    ty: wgpu::BindingType::Texture {
                        sample_type: wgpu::TextureSampleType::Float { filterable: true },
                        view_dimension: wgpu::TextureViewDimension::D2,
                        multisampled: false,
                    },
                    count: None,
                },
                wgpu::BindGroupLayoutEntry {
                    binding: 1,
                    visibility: wgpu::ShaderStages::FRAGMENT,
                    ty: wgpu::BindingType::Sampler(wgpu::SamplerBindingType::Filtering),
                    count: None,
                },
                wgpu::BindGroupLayoutEntry {
                    binding: 2,
                    visibility: wgpu::ShaderStages::FRAGMENT,
                    ty: wgpu::BindingType::Buffer {
                        ty: wgpu::BufferBindingType::Uniform,
                        has_dynamic_offset: false,
                        min_binding_size: None,
                    },
                    count: None,
                },
            ],
        });
        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
            label: Some("retro post-processing layout"),
            bind_group_layouts: &[&bind_group_layout],
            push_constant_ranges: &[],
        });
        let pipeline = device.create_render_pipeline(&wgpu::RenderPipelineDescriptor {
            cache: None,
            label: Some("retro post-processing pipeline"),
            layout: Some(&pipeline_layout),
            vertex: wgpu::VertexState {
                module: &shader,
                entry_point: Some("vs_main"),
                buffers: &[Vertex::desc()],
                compilation_options: Default::default(),
            },
            fragment: Some(wgpu::FragmentState {
                module: &shader,
                entry_point: Some("fs_main"),
                targets: &[Some(wgpu::ColorTargetState {
                    format,
                    blend: Some(wgpu::BlendState::ALPHA_BLENDING),
                    write_mask: wgpu::ColorWrites::ALL,
                })],
                compilation_options: Default::default(),
            }),
            primitive: wgpu::PrimitiveState {
                topology: wgpu::PrimitiveTopology::TriangleList,
                strip_index_format: None,
                front_face: wgpu::FrontFace::Ccw,
                cull_mode: Some(wgpu::Face::Back),
                polygon_mode: wgpu::PolygonMode::Fill,
                unclipped_depth: false,
                conservative: false,
            },
            depth_stencil: None,
            multisample: wgpu::MultisampleState::default(),
            multiview: None,
        });
        let vertex_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("retro crt vertex buffer"),
            contents: bytemuck::cast_slice(&[
                Vertex {
                    pos: [-1.0, -1.0, 0.0],
                    color: [1.0, 1.0, 1.0, 1.0],
                    tex_coord: [0.0, 1.0],
                },
                Vertex {
                    pos: [1.0, -1.0, 0.0],
                    color: [1.0, 1.0, 1.0, 1.0],
                    tex_coord: [1.0, 1.0],
                },
                Vertex {
                    pos: [1.0, 1.0, 0.0],
                    color: [1.0, 1.0, 1.0, 1.0],
                    tex_coord: [1.0, 0.0],
                },
                Vertex {
                    pos: [-1.0, 1.0, 0.0],
                    color: [1.0, 1.0, 1.0, 1.0],
                    tex_coord: [0.0, 0.0],
                },
            ]),
            usage: wgpu::BufferUsages::VERTEX,
        });

        let index_buffer = device.create_buffer_init(&wgpu::util::BufferInitDescriptor {
            label: Some("retro crt index buffer"),
            contents: bytemuck::cast_slice(&[0u16, 1, 2, 0, 2, 3]),
            usage: wgpu::BufferUsages::INDEX,
        });
        let uniform_buffer = device.create_buffer(&wgpu::BufferDescriptor {
            label: Some("Retro CRT Settings Buffer"),
            size: std::mem::size_of::<RetroSettings>() as u64,
            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
            mapped_at_creation: false,
        });
        Self {
            device,
            queue,
            pipeline,
            bind_group_layout,
            vertex_buffer,
            index_buffer,
            uniform_buffer,
        }
    }
}

pub struct RetroShader {
    state: State,
    settings: RefCell<RetroSettings>,
    bind_group: RefCell<Option<wgpu::BindGroup>>,
}

impl RetroShader {
    pub fn setup(settings: RetroSettings, param: BackendState) -> Self {
        let state = State::new(param);
        Self {
            state,
            settings: RefCell::new(settings),
            bind_group: RefCell::new(None),
        }
    }
}

impl PostShader for RetroShader {
    fn resize(&self, g: &Engine) {
        let mut settings = self.settings.borrow_mut();
        settings.resolution = g.render.screen_size();
        // Invalidate cached bind group on size changes (render target view changes)
        self.bind_group.borrow_mut().take();
    }

    fn draw(&self, texture: &TextureResource, pass: &mut RenderPass) {
        if self.bind_group.borrow().is_none() {
            let bg = self
                .state
                .device
                .create_bind_group(&wgpu::BindGroupDescriptor {
                    layout: &self.state.bind_group_layout,
                    entries: &[
                        wgpu::BindGroupEntry {
                            binding: 0,
                            resource: wgpu::BindingResource::TextureView(&texture.view),
                        },
                        wgpu::BindGroupEntry {
                            binding: 1,
                            resource: wgpu::BindingResource::Sampler(&texture.sampler),
                        },
                        wgpu::BindGroupEntry {
                            binding: 2,
                            resource: self.state.uniform_buffer.as_entire_binding(),
                        },
                    ],
                    label: Some("Retro CRT bind group"),
                });
            self.bind_group.borrow_mut().replace(bg);
        }
        let bind_group = self.bind_group.borrow();
        let bind_group = bind_group.as_ref().unwrap();
        pass.set_pipeline(&self.state.pipeline);
        pass.set_bind_group(0, bind_group, &[]);
        pass.set_vertex_buffer(0, self.state.vertex_buffer.slice(..));
        pass.set_index_buffer(self.state.index_buffer.slice(..), wgpu::IndexFormat::Uint16);
        pass.draw_indexed(0..6, 0, 0..1);
    }

    fn frame_pass(&self, g: &Engine, _pass: &mut RenderPass) {
        let mut settings = self.settings.borrow_mut();
        settings.time = g.time;
        self.state
            .queue
            .write_buffer(&self.state.uniform_buffer, 0, &to_wgsl_bytes(&*settings));
    }

    fn frame_end(&self, _pass: &mut RenderPass) {}
}