1use crate::{
4 shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5 GpuDevice, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ScaleFilter {
16 Nearest,
18 Bilinear,
20 Bicubic,
22 Area,
24}
25
26impl ScaleFilter {
27 fn to_filter_id(self) -> u32 {
28 match self {
29 Self::Nearest => 0,
30 Self::Bilinear => 1,
31 Self::Bicubic => 2,
32 Self::Area => 3,
33 }
34 }
35}
36
37#[repr(C)]
38#[derive(Copy, Clone, Pod, Zeroable)]
39struct ScaleParams {
40 src_width: u32,
41 src_height: u32,
42 dst_width: u32,
43 dst_height: u32,
44 src_stride: u32,
45 dst_stride: u32,
46 filter_type: u32,
47 padding: u32,
48}
49
50pub struct ScaleOperation;
52
53impl ScaleOperation {
54 #[allow(clippy::too_many_arguments)]
71 pub fn scale(
72 device: &GpuDevice,
73 input: &[u8],
74 src_width: u32,
75 src_height: u32,
76 output: &mut [u8],
77 dst_width: u32,
78 dst_height: u32,
79 filter: ScaleFilter,
80 ) -> Result<()> {
81 utils::validate_dimensions(src_width, src_height)?;
82 utils::validate_dimensions(dst_width, dst_height)?;
83 utils::validate_buffer_size(input, src_width, src_height, 4)?;
84 utils::validate_buffer_size(output, dst_width, dst_height, 4)?;
85
86 let pipeline = if filter == ScaleFilter::Area {
87 Self::get_downscale_pipeline(device)?
88 } else {
89 Self::get_scale_pipeline(device)?
90 };
91
92 let layout = Self::get_bind_group_layout(device)?;
93
94 Self::execute_scale(
95 device, pipeline, layout, input, src_width, src_height, output, dst_width, dst_height,
96 filter,
97 )
98 }
99
100 #[allow(clippy::too_many_arguments)]
101 fn execute_scale(
102 device: &GpuDevice,
103 pipeline: &ComputePipeline,
104 layout: &BindGroupLayout,
105 input: &[u8],
106 src_width: u32,
107 src_height: u32,
108 output: &mut [u8],
109 dst_width: u32,
110 dst_height: u32,
111 filter: ScaleFilter,
112 ) -> Result<()> {
113 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
115 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
116
117 device.queue().write_buffer(input_buffer.buffer(), 0, input);
119
120 let params = ScaleParams {
122 src_width,
123 src_height,
124 dst_width,
125 dst_height,
126 src_stride: src_width,
127 dst_stride: dst_width,
128 filter_type: filter.to_filter_id(),
129 padding: 0,
130 };
131 let params_bytes = bytemuck::bytes_of(¶ms);
132 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
133
134 let compiler = ShaderCompiler::new(device);
136 let bind_group = compiler.create_bind_group(
137 "Scale Bind Group",
138 layout,
139 &[
140 wgpu::BindGroupEntry {
141 binding: 0,
142 resource: input_buffer.buffer().as_entire_binding(),
143 },
144 wgpu::BindGroupEntry {
145 binding: 1,
146 resource: output_buffer.buffer().as_entire_binding(),
147 },
148 wgpu::BindGroupEntry {
149 binding: 2,
150 resource: params_buffer.buffer().as_entire_binding(),
151 },
152 ],
153 );
154
155 Self::dispatch_compute(device, pipeline, &bind_group, dst_width, dst_height)?;
157
158 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
160 let mut encoder = device
161 .device()
162 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
163 label: Some("Scale Copy Encoder"),
164 });
165
166 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
167
168 device.queue().submit(Some(encoder.finish()));
169 device.wait();
170
171 let result = readback_buffer.read(device, 0, output.len() as u64)?;
172 output.copy_from_slice(&result);
173
174 Ok(())
175 }
176
177 fn dispatch_compute(
178 device: &GpuDevice,
179 pipeline: &ComputePipeline,
180 bind_group: &BindGroup,
181 width: u32,
182 height: u32,
183 ) -> Result<()> {
184 let mut encoder = device
185 .device()
186 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
187 label: Some("Scale Compute Encoder"),
188 });
189
190 {
191 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
192 label: Some("Scale Compute Pass"),
193 timestamp_writes: None,
194 });
195
196 compute_pass.set_pipeline(pipeline);
197 compute_pass.set_bind_group(0, bind_group, &[]);
198
199 let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
200 compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
201 }
202
203 device.queue().submit(Some(encoder.finish()));
204 Ok(())
205 }
206
207 fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
208 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
209
210 Ok(LAYOUT.get_or_init(|| {
211 let compiler = ShaderCompiler::new(device);
212 let entries = BindGroupLayoutBuilder::new()
213 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
217
218 compiler.create_bind_group_layout("Scale Bind Group Layout", &entries)
219 }))
220 }
221
222 fn get_scale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
223 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
224
225 Ok(PIPELINE.get_or_init(|| {
226 let compiler = ShaderCompiler::new(device);
227 let shader = compiler
228 .compile(
229 "Scale Shader",
230 ShaderSource::Embedded(crate::shader::embedded::SCALE_SHADER),
231 )
232 .expect("Failed to compile scale shader");
233
234 let layout =
235 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
236
237 compiler
238 .create_pipeline("Scale Pipeline", &shader, "scale_main", layout)
239 .expect("Failed to create pipeline")
240 }))
241 }
242
243 fn get_downscale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
244 static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
245
246 Ok(PIPELINE.get_or_init(|| {
247 let compiler = ShaderCompiler::new(device);
248 let shader = compiler
249 .compile(
250 "Scale Shader",
251 ShaderSource::Embedded(crate::shader::embedded::SCALE_SHADER),
252 )
253 .expect("Failed to compile scale shader");
254
255 let layout =
256 Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
257
258 compiler
259 .create_pipeline("Downscale Pipeline", &shader, "downscale_area", layout)
260 .expect("Failed to create pipeline")
261 }))
262 }
263}