1use std::borrow::Cow;
13
14use rend3::{
15 graph::{DataHandle, RenderGraph, RenderPassTarget, RenderPassTargets, RenderTargetHandle},
16 util::bind_merge::{BindGroupBuilder, BindGroupLayoutBuilder},
17 Renderer,
18};
19use wgpu::{
20 BindGroup, BindGroupLayout, BindingType, Color, ColorTargetState, ColorWrites, Device, FragmentState, FrontFace,
21 MultisampleState, PipelineLayoutDescriptor, PolygonMode, PrimitiveState, PrimitiveTopology, RenderPipeline,
22 RenderPipelineDescriptor, ShaderModuleDescriptor, ShaderSource, ShaderStages, TextureFormat, TextureSampleType,
23 TextureViewDimension, VertexState,
24};
25
26use crate::{common::WholeFrameInterfaces, shaders::WGSL_SHADERS};
27
28fn create_pipeline(
29 device: &Device,
30 interfaces: &WholeFrameInterfaces,
31 bgl: &BindGroupLayout,
32 output_format: TextureFormat,
33) -> RenderPipeline {
34 profiling::scope!("TonemappingPass::new");
35 let blit_vert = device.create_shader_module(&ShaderModuleDescriptor {
36 label: Some("tonemapping vert"),
37 source: ShaderSource::Wgsl(Cow::Borrowed(
38 WGSL_SHADERS
39 .get_file("blit.vert.wgsl")
40 .unwrap()
41 .contents_utf8()
42 .unwrap(),
43 )),
44 });
45
46 let blit_frag = device.create_shader_module(&ShaderModuleDescriptor {
47 label: Some("tonemapping frag"),
48 source: ShaderSource::Wgsl(Cow::Borrowed(
49 WGSL_SHADERS
50 .get_file(match output_format.describe().srgb {
51 true => "blit-linear.frag.wgsl",
52 false => "blit-srgb.frag.wgsl",
53 })
54 .unwrap()
55 .contents_utf8()
56 .unwrap(),
57 )),
58 });
59
60 let pll = device.create_pipeline_layout(&PipelineLayoutDescriptor {
61 label: Some("tonemapping pass"),
62 bind_group_layouts: &[&interfaces.forward_uniform_bgl, bgl],
63 push_constant_ranges: &[],
64 });
65
66 device.create_render_pipeline(&RenderPipelineDescriptor {
67 label: Some("tonemapping pass"),
68 layout: Some(&pll),
69 vertex: VertexState {
70 module: &blit_vert,
71 entry_point: "main",
72 buffers: &[],
73 },
74 primitive: PrimitiveState {
75 topology: PrimitiveTopology::TriangleList,
76 strip_index_format: None,
77 front_face: FrontFace::Cw,
78 cull_mode: None,
79 unclipped_depth: false,
80 polygon_mode: PolygonMode::Fill,
81 conservative: false,
82 },
83 depth_stencil: None,
84 multisample: MultisampleState::default(),
85 fragment: Some(FragmentState {
86 module: &blit_frag,
87 entry_point: "main",
88 targets: &[ColorTargetState {
89 format: output_format,
90 blend: None,
91 write_mask: ColorWrites::all(),
92 }],
93 }),
94 multiview: None,
95 })
96}
97
98pub struct TonemappingRoutine {
102 bgl: BindGroupLayout,
103 pipeline: RenderPipeline,
104}
105
106impl TonemappingRoutine {
107 pub fn new(renderer: &Renderer, interfaces: &WholeFrameInterfaces, output_format: TextureFormat) -> Self {
108 let bgl = BindGroupLayoutBuilder::new()
109 .append(
110 ShaderStages::FRAGMENT,
111 BindingType::Texture {
112 sample_type: TextureSampleType::Float { filterable: true },
113 view_dimension: TextureViewDimension::D2,
114 multisampled: false,
115 },
116 None,
117 )
118 .build(&renderer.device, Some("bind bgl"));
119
120 let pipeline = create_pipeline(&renderer.device, interfaces, &bgl, output_format);
121
122 Self { bgl, pipeline }
123 }
124
125 pub fn add_to_graph<'node>(
126 &'node self,
127 graph: &mut RenderGraph<'node>,
128 src: RenderTargetHandle,
129 dst: RenderTargetHandle,
130 forward_uniform_bg: DataHandle<BindGroup>,
131 ) {
132 let mut builder = graph.add_node("Tonemapping");
133
134 let input_handle = builder.add_render_target_input(src);
135 let output_handle = builder.add_render_target_output(dst);
136
137 let rpass_handle = builder.add_renderpass(RenderPassTargets {
138 targets: vec![RenderPassTarget {
139 color: output_handle,
140 clear: Color::BLACK,
141 resolve: None,
142 }],
143 depth_stencil: None,
144 });
145
146 let forward_uniform_handle = builder.add_data_input(forward_uniform_bg);
147
148 let pt_handle = builder.passthrough_ref(self);
149
150 builder.build(move |pt, renderer, encoder_or_pass, temps, _ready, graph_data| {
151 let this = pt.get(pt_handle);
152 let rpass = encoder_or_pass.get_rpass(rpass_handle);
153 let forward_uniform_bg = graph_data.get_data(temps, forward_uniform_handle).unwrap();
154 let hdr_color = graph_data.get_render_target(input_handle);
155
156 profiling::scope!("tonemapping");
157
158 let blit_src_bg = temps.add(BindGroupBuilder::new().append_texture_view(hdr_color).build(
159 &renderer.device,
160 Some("blit src bg"),
161 &this.bgl,
162 ));
163
164 rpass.set_pipeline(&this.pipeline);
165 rpass.set_bind_group(0, forward_uniform_bg, &[]);
166 rpass.set_bind_group(1, blit_src_bg, &[]);
167 rpass.draw(0..3, 0..1);
168 });
169 }
170}