Skip to main content

oxigdal_gpu/kernels/
resampling.rs

1//! GPU kernels for raster resampling operations.
2//!
3//! This module provides GPU-accelerated resampling operations including
4//! nearest neighbor, bilinear, and bicubic interpolation.
5
6use crate::buffer::GpuBuffer;
7use crate::context::GpuContext;
8use crate::error::{GpuError, GpuResult};
9use crate::shaders::{
10    ComputePipelineBuilder, WgslShader, create_compute_bind_group_layout, storage_buffer_layout,
11    uniform_buffer_layout,
12};
13use bytemuck::{Pod, Zeroable};
14use tracing::debug;
15use wgpu::{
16    BindGroupDescriptor, BindGroupEntry, BufferUsages, CommandEncoderDescriptor,
17    ComputePassDescriptor, ComputePipeline,
18};
19
20/// Resampling method.
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum ResamplingMethod {
23    /// Nearest neighbor (fast, blocky).
24    NearestNeighbor,
25    /// Bilinear interpolation (smooth, fast).
26    Bilinear,
27    /// Bicubic interpolation (highest quality, slower).
28    Bicubic,
29}
30
31impl ResamplingMethod {
32    /// Get the shader entry point name.
33    fn entry_point(&self) -> &'static str {
34        match self {
35            Self::NearestNeighbor => "nearest_neighbor",
36            Self::Bilinear => "bilinear",
37            Self::Bicubic => "bicubic",
38        }
39    }
40}
41
42/// Resampling parameters.
43#[derive(Debug, Clone, Copy, Pod, Zeroable)]
44#[repr(C)]
45pub struct ResamplingParams {
46    /// Source width.
47    pub src_width: u32,
48    /// Source height.
49    pub src_height: u32,
50    /// Destination width.
51    pub dst_width: u32,
52    /// Destination height.
53    pub dst_height: u32,
54}
55
56impl ResamplingParams {
57    /// Create new resampling parameters.
58    pub fn new(src_width: u32, src_height: u32, dst_width: u32, dst_height: u32) -> Self {
59        Self {
60            src_width,
61            src_height,
62            dst_width,
63            dst_height,
64        }
65    }
66
67    /// Calculate scale factors.
68    pub fn scale_factors(&self) -> (f32, f32) {
69        let scale_x = self.src_width as f32 / self.dst_width as f32;
70        let scale_y = self.src_height as f32 / self.dst_height as f32;
71        (scale_x, scale_y)
72    }
73}
74
75/// GPU kernel for resampling operations.
76pub struct ResamplingKernel {
77    context: GpuContext,
78    pipeline: ComputePipeline,
79    bind_group_layout: wgpu::BindGroupLayout,
80    workgroup_size: (u32, u32),
81    method: ResamplingMethod,
82}
83
84impl ResamplingKernel {
85    /// Create a new resampling kernel.
86    ///
87    /// # Errors
88    ///
89    /// Returns an error if shader compilation or pipeline creation fails.
90    pub fn new(context: &GpuContext, method: ResamplingMethod) -> GpuResult<Self> {
91        debug!("Creating resampling kernel: {:?}", method);
92
93        let shader_source = Self::resampling_shader(method);
94        let mut shader = WgslShader::new(shader_source, method.entry_point());
95        let shader_module = shader.compile(context.device())?;
96
97        let bind_group_layout = create_compute_bind_group_layout(
98            context.device(),
99            &[
100                storage_buffer_layout(0, true),  // input
101                uniform_buffer_layout(1),        // params
102                storage_buffer_layout(2, false), // output
103            ],
104            Some("ResamplingKernel BindGroupLayout"),
105        )?;
106
107        let pipeline =
108            ComputePipelineBuilder::new(context.device(), shader_module, method.entry_point())
109                .bind_group_layout(&bind_group_layout)
110                .label(format!("ResamplingKernel Pipeline: {:?}", method))
111                .build()?;
112
113        Ok(Self {
114            context: context.clone(),
115            pipeline,
116            bind_group_layout,
117            workgroup_size: (16, 16),
118            method,
119        })
120    }
121
122    /// Get shader source for resampling method.
123    fn resampling_shader(method: ResamplingMethod) -> String {
124        let common = r#"
125struct ResamplingParams {
126    src_width: u32,
127    src_height: u32,
128    dst_width: u32,
129    dst_height: u32,
130}
131
132@group(0) @binding(0) var<storage, read> input: array<f32>;
133@group(0) @binding(1) var<uniform> params: ResamplingParams;
134@group(0) @binding(2) var<storage, read_write> output: array<f32>;
135
136fn get_pixel(x: u32, y: u32) -> f32 {
137    if (x >= params.src_width || y >= params.src_height) {
138        return 0.0;
139    }
140    return input[y * params.src_width + x];
141}
142
143fn lerp(a: f32, b: f32, t: f32) -> f32 {
144    return a + (b - a) * t;
145}
146"#;
147
148        match method {
149            ResamplingMethod::NearestNeighbor => format!(
150                r#"
151{}
152
153@compute @workgroup_size(16, 16)
154fn nearest_neighbor(@builtin(global_invocation_id) global_id: vec3<u32>) {{
155    let dst_x = global_id.x;
156    let dst_y = global_id.y;
157
158    if (dst_x >= params.dst_width || dst_y >= params.dst_height) {{
159        return;
160    }}
161
162    let scale_x = f32(params.src_width) / f32(params.dst_width);
163    let scale_y = f32(params.src_height) / f32(params.dst_height);
164
165    let src_x = u32(f32(dst_x) * scale_x);
166    let src_y = u32(f32(dst_y) * scale_y);
167
168    let value = get_pixel(src_x, src_y);
169    output[dst_y * params.dst_width + dst_x] = value;
170}}
171"#,
172                common
173            ),
174
175            ResamplingMethod::Bilinear => format!(
176                r#"
177{}
178
179@compute @workgroup_size(16, 16)
180fn bilinear(@builtin(global_invocation_id) global_id: vec3<u32>) {{
181    let dst_x = global_id.x;
182    let dst_y = global_id.y;
183
184    if (dst_x >= params.dst_width || dst_y >= params.dst_height) {{
185        return;
186    }}
187
188    let scale_x = f32(params.src_width) / f32(params.dst_width);
189    let scale_y = f32(params.src_height) / f32(params.dst_height);
190
191    let src_x = f32(dst_x) * scale_x;
192    let src_y = f32(dst_y) * scale_y;
193
194    let x0 = u32(floor(src_x));
195    let y0 = u32(floor(src_y));
196    let x1 = min(x0 + 1u, params.src_width - 1u);
197    let y1 = min(y0 + 1u, params.src_height - 1u);
198
199    let tx = fract(src_x);
200    let ty = fract(src_y);
201
202    let v00 = get_pixel(x0, y0);
203    let v10 = get_pixel(x1, y0);
204    let v01 = get_pixel(x0, y1);
205    let v11 = get_pixel(x1, y1);
206
207    let v0 = lerp(v00, v10, tx);
208    let v1 = lerp(v01, v11, tx);
209    let value = lerp(v0, v1, ty);
210
211    output[dst_y * params.dst_width + dst_x] = value;
212}}
213"#,
214                common
215            ),
216
217            ResamplingMethod::Bicubic => format!(
218                r#"
219{}
220
221fn cubic_interpolate(p0: f32, p1: f32, p2: f32, p3: f32, t: f32) -> f32 {{
222    let a = -0.5 * p0 + 1.5 * p1 - 1.5 * p2 + 0.5 * p3;
223    let b = p0 - 2.5 * p1 + 2.0 * p2 - 0.5 * p3;
224    let c = -0.5 * p0 + 0.5 * p2;
225    let d = p1;
226    return a * t * t * t + b * t * t + c * t + d;
227}}
228
229@compute @workgroup_size(16, 16)
230fn bicubic(@builtin(global_invocation_id) global_id: vec3<u32>) {{
231    let dst_x = global_id.x;
232    let dst_y = global_id.y;
233
234    if (dst_x >= params.dst_width || dst_y >= params.dst_height) {{
235        return;
236    }}
237
238    let scale_x = f32(params.src_width) / f32(params.dst_width);
239    let scale_y = f32(params.src_height) / f32(params.dst_height);
240
241    let src_x = f32(dst_x) * scale_x;
242    let src_y = f32(dst_y) * scale_y;
243
244    let x_floor = floor(src_x);
245    let y_floor = floor(src_y);
246    let tx = fract(src_x);
247    let ty = fract(src_y);
248
249    // Sample 4x4 neighborhood
250    var cols: array<f32, 4>;
251    for (var j = 0; j < 4; j++) {{
252        let y = i32(y_floor) + j - 1;
253        var row: array<f32, 4>;
254        for (var i = 0; i < 4; i++) {{
255            let x = i32(x_floor) + i - 1;
256            if (x >= 0 && x < i32(params.src_width) && y >= 0 && y < i32(params.src_height)) {{
257                row[i] = get_pixel(u32(x), u32(y));
258            }} else {{
259                row[i] = 0.0;
260            }}
261        }}
262        cols[j] = cubic_interpolate(row[0], row[1], row[2], row[3], tx);
263    }}
264
265    let value = cubic_interpolate(cols[0], cols[1], cols[2], cols[3], ty);
266    output[dst_y * params.dst_width + dst_x] = value;
267}}
268"#,
269                common
270            ),
271        }
272    }
273
274    /// Execute resampling on GPU buffer.
275    ///
276    /// # Errors
277    ///
278    /// Returns an error if buffer sizes are invalid or execution fails.
279    pub fn execute<T: Pod>(
280        &self,
281        input: &GpuBuffer<T>,
282        params: ResamplingParams,
283    ) -> GpuResult<GpuBuffer<T>> {
284        // Validate input size
285        let expected_input_size = (params.src_width as usize) * (params.src_height as usize);
286        if input.len() != expected_input_size {
287            return Err(GpuError::invalid_kernel_params(format!(
288                "Input buffer size mismatch: expected {}, got {}",
289                expected_input_size,
290                input.len()
291            )));
292        }
293
294        // Create output buffer
295        let output_size = (params.dst_width as usize) * (params.dst_height as usize);
296        let output = GpuBuffer::new(
297            &self.context,
298            output_size,
299            BufferUsages::STORAGE | BufferUsages::COPY_SRC,
300        )?;
301
302        // Create params buffer
303        let params_buffer = GpuBuffer::from_data(
304            &self.context,
305            &[params],
306            BufferUsages::UNIFORM | BufferUsages::COPY_DST,
307        )?;
308
309        // Create bind group
310        let bind_group = self
311            .context
312            .device()
313            .create_bind_group(&BindGroupDescriptor {
314                label: Some("ResamplingKernel BindGroup"),
315                layout: &self.bind_group_layout,
316                entries: &[
317                    BindGroupEntry {
318                        binding: 0,
319                        resource: input.buffer().as_entire_binding(),
320                    },
321                    BindGroupEntry {
322                        binding: 1,
323                        resource: params_buffer.buffer().as_entire_binding(),
324                    },
325                    BindGroupEntry {
326                        binding: 2,
327                        resource: output.buffer().as_entire_binding(),
328                    },
329                ],
330            });
331
332        // Execute kernel
333        let mut encoder = self
334            .context
335            .device()
336            .create_command_encoder(&CommandEncoderDescriptor {
337                label: Some("ResamplingKernel Encoder"),
338            });
339
340        {
341            let mut compute_pass = encoder.begin_compute_pass(&ComputePassDescriptor {
342                label: Some("ResamplingKernel Pass"),
343                timestamp_writes: None,
344            });
345
346            compute_pass.set_pipeline(&self.pipeline);
347            compute_pass.set_bind_group(0, &bind_group, &[]);
348
349            let workgroups_x =
350                (params.dst_width + self.workgroup_size.0 - 1) / self.workgroup_size.0;
351            let workgroups_y =
352                (params.dst_height + self.workgroup_size.1 - 1) / self.workgroup_size.1;
353
354            compute_pass.dispatch_workgroups(workgroups_x, workgroups_y, 1);
355        }
356
357        self.context.queue().submit(Some(encoder.finish()));
358
359        debug!(
360            "Resampled {}x{} to {}x{} using {:?}",
361            params.src_width, params.src_height, params.dst_width, params.dst_height, self.method
362        );
363
364        Ok(output)
365    }
366}
367
368/// Resize raster using GPU acceleration.
369///
370/// # Errors
371///
372/// Returns an error if GPU operations fail.
373pub fn resize<T: Pod>(
374    context: &GpuContext,
375    input: &GpuBuffer<T>,
376    src_width: u32,
377    src_height: u32,
378    dst_width: u32,
379    dst_height: u32,
380    method: ResamplingMethod,
381) -> GpuResult<GpuBuffer<T>> {
382    let kernel = ResamplingKernel::new(context, method)?;
383    let params = ResamplingParams::new(src_width, src_height, dst_width, dst_height);
384    kernel.execute(input, params)
385}
386
387/// Downscale raster by factor of 2 (fast).
388///
389/// # Errors
390///
391/// Returns an error if GPU operations fail.
392pub fn downscale_2x<T: Pod>(
393    context: &GpuContext,
394    input: &GpuBuffer<T>,
395    width: u32,
396    height: u32,
397) -> GpuResult<GpuBuffer<T>> {
398    resize(
399        context,
400        input,
401        width,
402        height,
403        width / 2,
404        height / 2,
405        ResamplingMethod::Bilinear,
406    )
407}
408
409/// Upscale raster by factor of 2.
410///
411/// # Errors
412///
413/// Returns an error if GPU operations fail.
414pub fn upscale_2x<T: Pod>(
415    context: &GpuContext,
416    input: &GpuBuffer<T>,
417    width: u32,
418    height: u32,
419    method: ResamplingMethod,
420) -> GpuResult<GpuBuffer<T>> {
421    resize(context, input, width, height, width * 2, height * 2, method)
422}
423
424#[cfg(test)]
425mod tests {
426    use super::*;
427
428    #[test]
429    fn test_resampling_params() {
430        let params = ResamplingParams::new(1024, 768, 512, 384);
431        let (scale_x, scale_y) = params.scale_factors();
432        assert!((scale_x - 2.0).abs() < 1e-5);
433        assert!((scale_y - 2.0).abs() < 1e-5);
434    }
435
436    #[test]
437    fn test_resampling_shader() {
438        let shader = ResamplingKernel::resampling_shader(ResamplingMethod::Bilinear);
439        assert!(shader.contains("@compute"));
440        assert!(shader.contains("bilinear"));
441    }
442
443    #[tokio::test]
444    async fn test_resampling_kernel() {
445        if let Ok(context) = GpuContext::new().await {
446            if let Ok(_kernel) = ResamplingKernel::new(&context, ResamplingMethod::NearestNeighbor)
447            {
448                // Kernel created successfully
449            }
450        }
451    }
452
453    #[tokio::test]
454    async fn test_resize_operation() {
455        if let Ok(context) = GpuContext::new().await {
456            // Create a simple 4x4 input
457            let input_data: Vec<f32> = (0..16).map(|i| i as f32).collect();
458
459            if let Ok(input) = GpuBuffer::from_data(
460                &context,
461                &input_data,
462                BufferUsages::STORAGE | BufferUsages::COPY_SRC,
463            ) {
464                if let Ok(_output) = resize(
465                    &context,
466                    &input,
467                    4,
468                    4,
469                    2,
470                    2,
471                    ResamplingMethod::NearestNeighbor,
472                ) {
473                    // Successfully resized
474                }
475            }
476        }
477    }
478}