1use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10 ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11 uniform_buffer_layout,
12};
13use bytemuck::{Pod, Zeroable};
14use tracing::debug;
15use wgpu::{
16 BindGroupDescriptor, BindGroupEntry, BufferUsages, CommandEncoderDescriptor,
17 ComputePassDescriptor, ComputePipeline,
18};
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ResamplingMethod {
23 NearestNeighbor,
25 Bilinear,
27 Bicubic,
29}
30
31impl ResamplingMethod {
32 fn entry_point(&self) -> &'static str {
34 match self {
35 Self::NearestNeighbor => "nearest_neighbor",
36 Self::Bilinear => "bilinear",
37 Self::Bicubic => "bicubic",
38 }
39 }
40}
41
42#[derive(Debug, Clone, Copy, Pod, Zeroable)]
44#[repr(C)]
45pub struct ResamplingParams {
46 pub src_width: u32,
48 pub src_height: u32,
50 pub dst_width: u32,
52 pub dst_height: u32,
54}
55
56impl ResamplingParams {
57 pub fn new(src_width: u32, src_height: u32, dst_width: u32, dst_height: u32) -> Self {
59 Self {
60 src_width,
61 src_height,
62 dst_width,
63 dst_height,
64 }
65 }
66
67 pub fn scale_factors(&self) -> (f32, f32) {
69 let scale_x = self.src_width as f32 / self.dst_width as f32;
70 let scale_y = self.src_height as f32 / self.dst_height as f32;
71 (scale_x, scale_y)
72 }
73}
74
75pub struct ResamplingKernel {
77 context: GpuContext,
78 pipeline: ComputePipeline,
79 bind_group_layout: wgpu::BindGroupLayout,
80 workgroup_size: (u32, u32),
81 method: ResamplingMethod,
82}
83
84impl ResamplingKernel {
85 pub fn new(context: &GpuContext, method: ResamplingMethod) -> GpuResult<Self> {
91 debug!("Creating resampling kernel: {:?}", method);
92
93 let shader_source = Self::resampling_shader(method);
94 let mut shader = WgslShader::new(shader_source, method.entry_point());
95 let shader_module = shader.compile(context.device())?;
96
97 let bind_group_layout = create_compute_bind_group_layout(
98 context.device(),
99 &[
100 storage_buffer_layout(0, true), uniform_buffer_layout(1), storage_buffer_layout(2, false), ],
104 Some("ResamplingKernel BindGroupLayout"),
105 )?;
106
107 let pipeline =
108 ComputePipelineBuilder::new(context.device(), shader_module, method.entry_point())
109 .bind_group_layout(&bind_group_layout)
110 .label(format!("ResamplingKernel Pipeline: {:?}", method))
111 .build()?;
112
113 Ok(Self {
114 context: context.clone(),
115 pipeline,
116 bind_group_layout,
117 workgroup_size: (16, 16),
118 method,
119 })
120 }
121
122 fn resampling_shader(method: ResamplingMethod) -> String {
124 let common = r#"
125struct ResamplingParams {
126 src_width: u32,
127 src_height: u32,
128 dst_width: u32,
129 dst_height: u32,
130}
131
132@group(0) @binding(0) var<storage, read> input: array<f32>;
133@group(0) @binding(1) var<uniform> params: ResamplingParams;
134@group(0) @binding(2) var<storage, read_write> output: array<f32>;
135
136fn get_pixel(x: u32, y: u32) -> f32 {
137 if (x >= params.src_width || y >= params.src_height) {
138 return 0.0;
139 }
140 return input[y * params.src_width + x];
141}
142
143fn lerp(a: f32, b: f32, t: f32) -> f32 {
144 return a + (b - a) * t;
145}
146"#;
147
148 match method {
149 ResamplingMethod::NearestNeighbor => format!(
150 r#"
151{}
152
153@compute @workgroup_size(16, 16)
154fn nearest_neighbor(@builtin(global_invocation_id) global_id: vec3<u32>) {{
155 let dst_x = global_id.x;
156 let dst_y = global_id.y;
157
158 if (dst_x >= params.dst_width || dst_y >= params.dst_height) {{
159 return;
160 }}
161
162 let scale_x = f32(params.src_width) / f32(params.dst_width);
163 let scale_y = f32(params.src_height) / f32(params.dst_height);
164
165 let src_x = u32(f32(dst_x) * scale_x);
166 let src_y = u32(f32(dst_y) * scale_y);
167
168 let value = get_pixel(src_x, src_y);
169 output[dst_y * params.dst_width + dst_x] = value;
170}}
171"#,
172 common
173 ),
174
175 ResamplingMethod::Bilinear => format!(
176 r#"
177{}
178
179@compute @workgroup_size(16, 16)
180fn bilinear(@builtin(global_invocation_id) global_id: vec3<u32>) {{
181 let dst_x = global_id.x;
182 let dst_y = global_id.y;
183
184 if (dst_x >= params.dst_width || dst_y >= params.dst_height) {{
185 return;
186 }}
187
188 let scale_x = f32(params.src_width) / f32(params.dst_width);
189 let scale_y = f32(params.src_height) / f32(params.dst_height);
190
191 let src_x = f32(dst_x) * scale_x;
192 let src_y = f32(dst_y) * scale_y;
193
194 let x0 = u32(floor(src_x));
195 let y0 = u32(floor(src_y));
196 let x1 = min(x0 + 1u, params.src_width - 1u);
197 let y1 = min(y0 + 1u, params.src_height - 1u);
198
199 let tx = fract(src_x);
200 let ty = fract(src_y);
201
202 let v00 = get_pixel(x0, y0);
203 let v10 = get_pixel(x1, y0);
204 let v01 = get_pixel(x0, y1);
205 let v11 = get_pixel(x1, y1);
206
207 let v0 = lerp(v00, v10, tx);
208 let v1 = lerp(v01, v11, tx);
209 let value = lerp(v0, v1, ty);
210
211 output[dst_y * params.dst_width + dst_x] = value;
212}}
213"#,
214 common
215 ),
216
217 ResamplingMethod::Bicubic => format!(
218 r#"
219{}
220
221fn cubic_interpolate(p0: f32, p1: f32, p2: f32, p3: f32, t: f32) -> f32 {{
222 let a = -0.5 * p0 + 1.5 * p1 - 1.5 * p2 + 0.5 * p3;
223 let b = p0 - 2.5 * p1 + 2.0 * p2 - 0.5 * p3;
224 let c = -0.5 * p0 + 0.5 * p2;
225 let d = p1;
226 return a * t * t * t + b * t * t + c * t + d;
227}}
228
229@compute @workgroup_size(16, 16)
230fn bicubic(@builtin(global_invocation_id) global_id: vec3<u32>) {{
231 let dst_x = global_id.x;
232 let dst_y = global_id.y;
233
234 if (dst_x >= params.dst_width || dst_y >= params.dst_height) {{
235 return;
236 }}
237
238 let scale_x = f32(params.src_width) / f32(params.dst_width);
239 let scale_y = f32(params.src_height) / f32(params.dst_height);
240
241 let src_x = f32(dst_x) * scale_x;
242 let src_y = f32(dst_y) * scale_y;
243
244 let x_floor = floor(src_x);
245 let y_floor = floor(src_y);
246 let tx = fract(src_x);
247 let ty = fract(src_y);
248
249 // Sample 4x4 neighborhood
250 var cols: array<f32, 4>;
251 for (var j = 0; j < 4; j++) {{
252 let y = i32(y_floor) + j - 1;
253 var row: array<f32, 4>;
254 for (var i = 0; i < 4; i++) {{
255 let x = i32(x_floor) + i - 1;
256 if (x >= 0 && x < i32(params.src_width) && y >= 0 && y < i32(params.src_height)) {{
257 row[i] = get_pixel(u32(x), u32(y));
258 }} else {{
259 row[i] = 0.0;
260 }}
261 }}
262 cols[j] = cubic_interpolate(row[0], row[1], row[2], row[3], tx);
263 }}
264
265 let value = cubic_interpolate(cols[0], cols[1], cols[2], cols[3], ty);
266 output[dst_y * params.dst_width + dst_x] = value;
267}}
268"#,
269 common
270 ),
271 }
272 }
273
274 pub fn execute<T: Pod>(
280 &self,
281 input: &GpuBuffer<T>,
282 params: ResamplingParams,
283 ) -> GpuResult<GpuBuffer<T>> {
284 let expected_input_size = (params.src_width as usize) * (params.src_height as usize);
286 if input.len() != expected_input_size {
287 return Err(GpuError::invalid_kernel_params(format!(
288 "Input buffer size mismatch: expected {}, got {}",
289 expected_input_size,
290 input.len()
291 )));
292 }
293
294 let output_size = (params.dst_width as usize) * (params.dst_height as usize);
296 let output = GpuBuffer::new(
297 &self.context,
298 output_size,
299 BufferUsages::STORAGE | BufferUsages::COPY_SRC,
300 )?;
301
302 let params_buffer = GpuBuffer::from_data(
304 &self.context,
305 &[params],
306 BufferUsages::UNIFORM | BufferUsages::COPY_DST,
307 )?;
308
309 let bind_group = self
311 .context
312 .device()
313 .create_bind_group(&BindGroupDescriptor {
314 label: Some("ResamplingKernel BindGroup"),
315 layout: &self.bind_group_layout,
316 entries: &[
317 BindGroupEntry {
318 binding: 0,
319 resource: input.buffer().as_entire_binding(),
320 },
321 BindGroupEntry {
322 binding: 1,
323 resource: params_buffer.buffer().as_entire_binding(),
324 },
325 BindGroupEntry {
326 binding: 2,
327 resource: output.buffer().as_entire_binding(),
328 },
329 ],
330 });
331
332 let mut encoder = self
334 .context
335 .device()
336 .create_command_encoder(&CommandEncoderDescriptor {
337 label: Some("ResamplingKernel Encoder"),
338 });
339
340 {
341 let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
342 label: Some("ResamplingKernel Pass"),
343 timestamp_writes: None,
344 });
345
346 compute_pass.set_pipeline(&self.pipeline);
347 compute_pass.set_bind_group(0, &bind_group, &[]);
348
349 let workgroups_x =
350 (params.dst_width + self.workgroup_size.0 - 1) / self.workgroup_size.0;
351 let workgroups_y =
352 (params.dst_height + self.workgroup_size.1 - 1) / self.workgroup_size.1;
353
354 compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
355 }
356
357 self.context.queue().submit(Some(encoder.finish()));
358
359 debug!(
360 "Resampled {}x{} to {}x{} using {:?}",
361 params.src_width, params.src_height, params.dst_width, params.dst_height, self.method
362 );
363
364 Ok(output)
365 }
366}
367
368pub fn resize<T: Pod>(
374 context: &GpuContext,
375 input: &GpuBuffer<T>,
376 src_width: u32,
377 src_height: u32,
378 dst_width: u32,
379 dst_height: u32,
380 method: ResamplingMethod,
381) -> GpuResult<GpuBuffer<T>> {
382 let kernel = ResamplingKernel::new(context, method)?;
383 let params = ResamplingParams::new(src_width, src_height, dst_width, dst_height);
384 kernel.execute(input, params)
385}
386
387pub fn downscale_2x<T: Pod>(
393 context: &GpuContext,
394 input: &GpuBuffer<T>,
395 width: u32,
396 height: u32,
397) -> GpuResult<GpuBuffer<T>> {
398 resize(
399 context,
400 input,
401 width,
402 height,
403 width / 2,
404 height / 2,
405 ResamplingMethod::Bilinear,
406 )
407}
408
409pub fn upscale_2x<T: Pod>(
415 context: &GpuContext,
416 input: &GpuBuffer<T>,
417 width: u32,
418 height: u32,
419 method: ResamplingMethod,
420) -> GpuResult<GpuBuffer<T>> {
421 resize(context, input, width, height, width * 2, height * 2, method)
422}
423
424#[cfg(test)]
425mod tests {
426 use super::*;
427
428 #[test]
429 fn test_resampling_params() {
430 let params = ResamplingParams::new(1024, 768, 512, 384);
431 let (scale_x, scale_y) = params.scale_factors();
432 assert!((scale_x - 2.0).abs() < 1e-5);
433 assert!((scale_y - 2.0).abs() < 1e-5);
434 }
435
436 #[test]
437 fn test_resampling_shader() {
438 let shader = ResamplingKernel::resampling_shader(ResamplingMethod::Bilinear);
439 assert!(shader.contains("@compute"));
440 assert!(shader.contains("bilinear"));
441 }
442
443 #[tokio::test]
444 async fn test_resampling_kernel() {
445 if let Ok(context) = GpuContext::new().await {
446 if let Ok(_kernel) = ResamplingKernel::new(&context, ResamplingMethod::NearestNeighbor)
447 {
448 }
450 }
451 }
452
453 #[tokio::test]
454 async fn test_resize_operation() {
455 if let Ok(context) = GpuContext::new().await {
456 let input_data: Vec<f32> = (0..16).map(|i| i as f32).collect();
458
459 if let Ok(input) = GpuBuffer::from_data(
460 &context,
461 &input_data,
462 BufferUsages::STORAGE | BufferUsages::COPY_SRC,
463 ) {
464 if let Ok(_output) = resize(
465 &context,
466 &input,
467 4,
468 4,
469 2,
470 2,
471 ResamplingMethod::NearestNeighbor,
472 ) {
473 }
475 }
476 }
477 }
478}