ratatui_wgpu/
shaders.rs

1use std::{
2    mem::size_of,
3    num::NonZeroU64,
4};
5
6use wgpu::{
7    self,
8    include_wgsl,
9    AddressMode,
10    BindGroupDescriptor,
11    BindGroupEntry,
12    BindGroupLayout,
13    BindGroupLayoutDescriptor,
14    BindGroupLayoutEntry,
15    BindingResource,
16    BindingType,
17    Buffer,
18    BufferBindingType,
19    BufferDescriptor,
20    BufferUsages,
21    Color,
22    ColorTargetState,
23    ColorWrites,
24    FilterMode,
25    FragmentState,
26    LoadOp,
27    MultisampleState,
28    Operations,
29    PipelineCompilationOptions,
30    PipelineLayoutDescriptor,
31    PrimitiveState,
32    PrimitiveTopology,
33    RenderBundle,
34    RenderBundleDescriptor,
35    RenderBundleEncoderDescriptor,
36    RenderPassColorAttachment,
37    RenderPassDescriptor,
38    RenderPipeline,
39    RenderPipelineDescriptor,
40    Sampler,
41    SamplerBindingType,
42    SamplerDescriptor,
43    ShaderStages,
44    StoreOp,
45    TextureSampleType,
46    TextureViewDimension,
47    VertexState,
48};
49
50use crate::backend::PostProcessor;
51
52#[repr(C)]
53#[derive(bytemuck::Pod, bytemuck::Zeroable, Debug, Clone, Copy)]
54struct Uniforms {
55    screen_size: [f32; 2],
56    use_srgb: u32,
57    _pad0: [u32; 5],
58}
59
60/// The default post-processor. Used when you don't want to perform any custom
61/// shading on the output. This just blits the composited text to the surface.
62pub struct DefaultPostProcessor {
63    uniforms: Buffer,
64    bindings: BindGroupLayout,
65    sampler: Sampler,
66    pipeline: RenderPipeline,
67
68    blitter: RenderBundle,
69}
70
71impl PostProcessor for DefaultPostProcessor {
72    type UserData = ();
73
74    fn compile(
75        device: &wgpu::Device,
76        text_view: &wgpu::TextureView,
77        surface_config: &wgpu::SurfaceConfiguration,
78        _user_data: Self::UserData,
79    ) -> Self {
80        let uniforms = device.create_buffer(&BufferDescriptor {
81            label: Some("Text Blit Uniforms"),
82            size: size_of::<Uniforms>() as u64,
83            usage: BufferUsages::COPY_DST | BufferUsages::UNIFORM,
84            mapped_at_creation: false,
85        });
86
87        let sampler = device.create_sampler(&SamplerDescriptor {
88            address_mode_u: AddressMode::ClampToEdge,
89            address_mode_v: AddressMode::ClampToEdge,
90            address_mode_w: AddressMode::ClampToEdge,
91            mag_filter: FilterMode::Nearest,
92            min_filter: FilterMode::Nearest,
93            mipmap_filter: FilterMode::Nearest,
94            ..Default::default()
95        });
96
97        let layout = device.create_bind_group_layout(&BindGroupLayoutDescriptor {
98            label: Some("Text Blit Bindings Layout"),
99            entries: &[
100                BindGroupLayoutEntry {
101                    binding: 0,
102                    visibility: ShaderStages::FRAGMENT,
103                    ty: BindingType::Texture {
104                        sample_type: TextureSampleType::Float { filterable: true },
105                        view_dimension: TextureViewDimension::D2,
106                        multisampled: false,
107                    },
108                    count: None,
109                },
110                BindGroupLayoutEntry {
111                    binding: 1,
112                    visibility: ShaderStages::FRAGMENT,
113                    ty: BindingType::Sampler(SamplerBindingType::Filtering),
114                    count: None,
115                },
116                BindGroupLayoutEntry {
117                    binding: 2,
118                    visibility: ShaderStages::FRAGMENT,
119                    ty: BindingType::Buffer {
120                        ty: BufferBindingType::Uniform,
121                        has_dynamic_offset: false,
122                        min_binding_size: NonZeroU64::new(size_of::<Uniforms>() as u64),
123                    },
124                    count: None,
125                },
126            ],
127        });
128
129        let shader = device.create_shader_module(include_wgsl!("shaders/blit.wgsl"));
130
131        let pipeline_layout = device.create_pipeline_layout(&PipelineLayoutDescriptor {
132            label: Some("Text Blit Layout"),
133            bind_group_layouts: &[&layout],
134            push_constant_ranges: &[],
135        });
136
137        let pipeline = device.create_render_pipeline(&RenderPipelineDescriptor {
138            label: Some("Text Blitter Pipeline"),
139            layout: Some(&pipeline_layout),
140            vertex: VertexState {
141                module: &shader,
142                entry_point: Some("vs_main"),
143                compilation_options: PipelineCompilationOptions::default(),
144                buffers: &[],
145            },
146            primitive: PrimitiveState {
147                topology: PrimitiveTopology::TriangleStrip,
148                ..Default::default()
149            },
150            depth_stencil: None,
151            multisample: MultisampleState::default(),
152            fragment: Some(FragmentState {
153                module: &shader,
154                entry_point: Some("fs_main"),
155                compilation_options: PipelineCompilationOptions::default(),
156                targets: &[Some(ColorTargetState {
157                    format: surface_config.format,
158                    blend: None,
159                    write_mask: ColorWrites::ALL,
160                })],
161            }),
162            multiview: None,
163            cache: None,
164        });
165
166        let blitter = build_blitter(
167            device,
168            &layout,
169            text_view,
170            &sampler,
171            &uniforms,
172            surface_config,
173            &pipeline,
174        );
175
176        Self {
177            uniforms,
178            bindings: layout,
179            sampler,
180            pipeline,
181            blitter,
182        }
183    }
184
185    fn resize(
186        &mut self,
187        device: &wgpu::Device,
188        text_view: &wgpu::TextureView,
189        surface_config: &wgpu::SurfaceConfiguration,
190    ) {
191        self.blitter = build_blitter(
192            device,
193            &self.bindings,
194            text_view,
195            &self.sampler,
196            &self.uniforms,
197            surface_config,
198            &self.pipeline,
199        );
200    }
201
202    fn process(
203        &mut self,
204        encoder: &mut wgpu::CommandEncoder,
205        queue: &wgpu::Queue,
206        _text_view: &wgpu::TextureView,
207        surface_config: &wgpu::SurfaceConfiguration,
208        surface_view: &wgpu::TextureView,
209    ) {
210        {
211            let mut uniforms = queue
212                .write_buffer_with(
213                    &self.uniforms,
214                    0,
215                    NonZeroU64::new(size_of::<Uniforms>() as u64).unwrap(),
216                )
217                .unwrap();
218            uniforms.copy_from_slice(bytemuck::bytes_of(&Uniforms {
219                screen_size: [surface_config.width as f32, surface_config.height as f32],
220                use_srgb: u32::from(surface_config.format.is_srgb()),
221                _pad0: [0; 5],
222            }));
223        }
224
225        let mut pass = encoder.begin_render_pass(&RenderPassDescriptor {
226            label: Some("Text Blit Pass"),
227            color_attachments: &[Some(RenderPassColorAttachment {
228                view: surface_view,
229                resolve_target: None,
230                ops: Operations {
231                    load: LoadOp::Clear(Color::TRANSPARENT),
232                    store: StoreOp::Store,
233                },
234            })],
235            ..Default::default()
236        });
237
238        pass.execute_bundles(Some(&self.blitter));
239    }
240}
241
242fn build_blitter(
243    device: &wgpu::Device,
244    layout: &BindGroupLayout,
245    text_view: &wgpu::TextureView,
246    sampler: &Sampler,
247    uniforms: &Buffer,
248    surface_config: &wgpu::SurfaceConfiguration,
249    pipeline: &RenderPipeline,
250) -> RenderBundle {
251    let bindings = device.create_bind_group(&BindGroupDescriptor {
252        label: Some("Text Blit Bindings"),
253        layout,
254        entries: &[
255            BindGroupEntry {
256                binding: 0,
257                resource: BindingResource::TextureView(text_view),
258            },
259            BindGroupEntry {
260                binding: 1,
261                resource: BindingResource::Sampler(sampler),
262            },
263            BindGroupEntry {
264                binding: 2,
265                resource: uniforms.as_entire_binding(),
266            },
267        ],
268    });
269
270    let mut encoder = device.create_render_bundle_encoder(&RenderBundleEncoderDescriptor {
271        label: Some("Text Blit Pass Encoder"),
272        color_formats: &[Some(surface_config.format)],
273        depth_stencil: None,
274        sample_count: 1,
275        multiview: None,
276    });
277
278    encoder.set_pipeline(pipeline);
279
280    encoder.set_bind_group(0, &bindings, &[]);
281    encoder.draw(0..3, 0..1);
282
283    encoder.finish(&RenderBundleDescriptor {
284        label: Some("Text Blit Pass Bundle"),
285    })
286}