Skip to main content

oximedia_gpu/ops/
scale.rs

1//! Image scaling operations with various interpolation methods
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13/// Scale filter type
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ScaleFilter {
16    /// Nearest neighbor (fastest, lowest quality)
17    Nearest,
18    /// Bilinear interpolation (balanced)
19    Bilinear,
20    /// Bicubic interpolation (highest quality)
21    Bicubic,
22    /// Area averaging for downscaling
23    Area,
24    /// Lanczos-3 interpolation (highest quality, ringing-free)
25    Lanczos3,
26}
27
28impl ScaleFilter {
29    fn to_filter_id(self) -> u32 {
30        match self {
31            Self::Nearest => 0,
32            Self::Bilinear => 1,
33            Self::Bicubic => 2,
34            Self::Area => 3,
35            Self::Lanczos3 => 4,
36        }
37    }
38}
39
40#[repr(C)]
41#[derive(Copy, Clone, Pod, Zeroable)]
42struct ScaleParams {
43    src_width: u32,
44    src_height: u32,
45    dst_width: u32,
46    dst_height: u32,
47    src_stride: u32,
48    dst_stride: u32,
49    filter_type: u32,
50    padding: u32,
51}
52
53/// Image scaling operations
54pub struct ScaleOperation;
55
56impl ScaleOperation {
57    /// Scale an image
58    ///
59    /// # Arguments
60    ///
61    /// * `device` - GPU device
62    /// * `input` - Input image buffer (packed RGBA format)
63    /// * `src_width` - Source image width
64    /// * `src_height` - Source image height
65    /// * `output` - Output image buffer (packed RGBA format)
66    /// * `dst_width` - Destination image width
67    /// * `dst_height` - Destination image height
68    /// * `filter` - Scaling filter type
69    ///
70    /// # Errors
71    ///
72    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
73    #[allow(clippy::too_many_arguments)]
74    pub fn scale(
75        device: &GpuDevice,
76        input: &[u8],
77        src_width: u32,
78        src_height: u32,
79        output: &mut [u8],
80        dst_width: u32,
81        dst_height: u32,
82        filter: ScaleFilter,
83    ) -> Result<()> {
84        utils::validate_dimensions(src_width, src_height)?;
85        utils::validate_dimensions(dst_width, dst_height)?;
86        utils::validate_buffer_size(input, src_width, src_height, 4)?;
87        utils::validate_buffer_size(output, dst_width, dst_height, 4)?;
88
89        // Lanczos uses a CPU path (high-quality resampling kernel).
90        if filter == ScaleFilter::Lanczos3 {
91            let _ = device; // suppress unused warning
92            return Self::lanczos3_cpu(input, src_width, src_height, output, dst_width, dst_height);
93        }
94
95        let pipeline = if filter == ScaleFilter::Area {
96            Self::get_downscale_pipeline(device)?
97        } else {
98            Self::get_scale_pipeline(device)?
99        };
100
101        let layout = Self::get_bind_group_layout(device)?;
102
103        Self::execute_scale(
104            device, pipeline, layout, input, src_width, src_height, output, dst_width, dst_height,
105            filter,
106        )
107    }
108
109    /// CPU Lanczos-3 resampling (a = 3, window = 6 taps).
110    ///
111    /// Uses separable 2-pass approach (horizontal then vertical) for efficiency.
112    /// The sinc-windowed-sinc kernel produces high-quality results with minimal
113    /// ringing artefacts.
114    #[allow(clippy::too_many_arguments)]
115    pub fn lanczos3_cpu(
116        input: &[u8],
117        src_width: u32,
118        src_height: u32,
119        output: &mut [u8],
120        dst_width: u32,
121        dst_height: u32,
122    ) -> Result<()> {
123        let sw = src_width as usize;
124        let sh = src_height as usize;
125        let dw = dst_width as usize;
126        let dh = dst_height as usize;
127
128        const LANCZOS_A: f64 = 3.0;
129
130        let lanczos_weight = |x: f64| -> f64 {
131            if x.abs() < 1e-10 {
132                return 1.0;
133            }
134            if x.abs() >= LANCZOS_A {
135                return 0.0;
136            }
137            let pi_x = std::f64::consts::PI * x;
138            let pi_x_a = pi_x / LANCZOS_A;
139            (pi_x.sin() / pi_x) * (pi_x_a.sin() / pi_x_a)
140        };
141
142        // --- Horizontal pass ---
143        let x_scale = sw as f64 / dw as f64;
144        let mut h_temp = vec![0.0_f64; dw * sh * 4]; // intermediate f64 buffer
145
146        for sy in 0..sh {
147            for dx in 0..dw {
148                let center = (dx as f64 + 0.5) * x_scale - 0.5;
149                let start = (center - LANCZOS_A + 1.0).floor().max(0.0) as usize;
150                let end = ((center + LANCZOS_A).ceil() as usize).min(sw);
151
152                let mut weights_sum = 0.0_f64;
153                let mut acc = [0.0_f64; 4];
154
155                for sx in start..end {
156                    let w = lanczos_weight(sx as f64 - center);
157                    weights_sum += w;
158                    let src_base = (sy * sw + sx) * 4;
159                    for c in 0..4 {
160                        acc[c] += w * input[src_base + c] as f64;
161                    }
162                }
163
164                let dst_base = (sy * dw + dx) * 4;
165                if weights_sum.abs() > 1e-10 {
166                    let inv = 1.0 / weights_sum;
167                    for c in 0..4 {
168                        h_temp[dst_base + c] = acc[c] * inv;
169                    }
170                }
171            }
172        }
173
174        // --- Vertical pass ---
175        let y_scale = sh as f64 / dh as f64;
176
177        for dy in 0..dh {
178            let center = (dy as f64 + 0.5) * y_scale - 0.5;
179            let start = (center - LANCZOS_A + 1.0).floor().max(0.0) as usize;
180            let end = ((center + LANCZOS_A).ceil() as usize).min(sh);
181
182            for dx in 0..dw {
183                let mut weights_sum = 0.0_f64;
184                let mut acc = [0.0_f64; 4];
185
186                for sy in start..end {
187                    let w = lanczos_weight(sy as f64 - center);
188                    weights_sum += w;
189                    let src_base = (sy * dw + dx) * 4;
190                    for c in 0..4 {
191                        acc[c] += w * h_temp[src_base + c];
192                    }
193                }
194
195                let dst_base = (dy * dw + dx) * 4;
196                if weights_sum.abs() > 1e-10 {
197                    let inv = 1.0 / weights_sum;
198                    for c in 0..4 {
199                        output[dst_base + c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
200                    }
201                }
202            }
203        }
204
205        Ok(())
206    }
207
208    #[allow(clippy::too_many_arguments)]
209    fn execute_scale(
210        device: &GpuDevice,
211        pipeline: &ComputePipeline,
212        layout: &BindGroupLayout,
213        input: &[u8],
214        src_width: u32,
215        src_height: u32,
216        output: &mut [u8],
217        dst_width: u32,
218        dst_height: u32,
219        filter: ScaleFilter,
220    ) -> Result<()> {
221        // Create buffers
222        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
223        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
224
225        // Upload input data
226        device.queue().write_buffer(input_buffer.buffer(), 0, input);
227
228        // Create uniform buffer for parameters
229        let params = ScaleParams {
230            src_width,
231            src_height,
232            dst_width,
233            dst_height,
234            src_stride: src_width,
235            dst_stride: dst_width,
236            filter_type: filter.to_filter_id(),
237            padding: 0,
238        };
239        let params_bytes = bytemuck::bytes_of(&params);
240        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
241
242        // Create bind group
243        let compiler = ShaderCompiler::new(device);
244        let bind_group = compiler.create_bind_group(
245            "Scale Bind Group",
246            layout,
247            &[
248                wgpu::BindGroupEntry {
249                    binding: 0,
250                    resource: input_buffer.buffer().as_entire_binding(),
251                },
252                wgpu::BindGroupEntry {
253                    binding: 1,
254                    resource: output_buffer.buffer().as_entire_binding(),
255                },
256                wgpu::BindGroupEntry {
257                    binding: 2,
258                    resource: params_buffer.buffer().as_entire_binding(),
259                },
260            ],
261        );
262
263        // Execute compute pass
264        Self::dispatch_compute(device, pipeline, &bind_group, dst_width, dst_height)?;
265
266        // Read back results
267        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
268        let mut encoder = device
269            .device()
270            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
271                label: Some("Scale Copy Encoder"),
272            });
273
274        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
275
276        device.queue().submit(Some(encoder.finish()));
277        device.wait();
278
279        let result = readback_buffer.read(device, 0, output.len() as u64)?;
280        output.copy_from_slice(&result);
281
282        Ok(())
283    }
284
285    fn dispatch_compute(
286        device: &GpuDevice,
287        pipeline: &ComputePipeline,
288        bind_group: &BindGroup,
289        width: u32,
290        height: u32,
291    ) -> Result<()> {
292        let mut encoder = device
293            .device()
294            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
295                label: Some("Scale Compute Encoder"),
296            });
297
298        {
299            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
300                label: Some("Scale Compute Pass"),
301                timestamp_writes: None,
302            });
303
304            compute_pass.set_pipeline(pipeline);
305            compute_pass.set_bind_group(0, bind_group, &[]);
306
307            let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
308            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
309        }
310
311        device.queue().submit(Some(encoder.finish()));
312        Ok(())
313    }
314
315    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
316        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
317
318        Ok(LAYOUT.get_or_init(|| {
319            let compiler = ShaderCompiler::new(device);
320            let entries = BindGroupLayoutBuilder::new()
321                .add_storage_buffer_read_only(0) // input
322                .add_storage_buffer(1) // output
323                .add_uniform_buffer(2) // params
324                .build();
325
326            compiler.create_bind_group_layout("Scale Bind Group Layout", &entries)
327        }))
328    }
329
330    fn init_pipeline(
331        device: &GpuDevice,
332        name: &str,
333        entry_point: &str,
334    ) -> std::result::Result<ComputePipeline, String> {
335        let compiler = ShaderCompiler::new(device);
336        let shader = compiler
337            .compile(
338                "Scale Shader",
339                ShaderSource::Embedded(crate::shader::embedded::SCALE_SHADER),
340            )
341            .map_err(|e| format!("Failed to compile scale shader: {e}"))?;
342
343        let layout = Self::get_bind_group_layout(device)
344            .map_err(|e| format!("Failed to create bind group layout: {e}"))?;
345
346        compiler
347            .create_pipeline(name, &shader, entry_point, layout)
348            .map_err(|e| format!("Failed to create pipeline: {e}"))
349    }
350
351    fn get_scale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
352        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
353
354        PIPELINE
355            .get_or_init(|| ScaleOperation::init_pipeline(device, "Scale Pipeline", "scale_main"))
356            .as_ref()
357            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
358    }
359
360    fn get_downscale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
361        static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
362
363        PIPELINE
364            .get_or_init(|| {
365                ScaleOperation::init_pipeline(device, "Downscale Pipeline", "downscale_area")
366            })
367            .as_ref()
368            .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
369    }
370}
371
372#[cfg(test)]
373mod tests {
374    use super::*;
375
376    /// Build a solid-colour RGBA image of size `w × h` with value `(r,g,b,a)`.
377    fn solid(w: u32, h: u32, r: u8, g: u8, b: u8, a: u8) -> Vec<u8> {
378        let n = (w * h * 4) as usize;
379        let mut v = vec![0u8; n];
380        for px in v.chunks_mut(4) {
381            px[0] = r;
382            px[1] = g;
383            px[2] = b;
384            px[3] = a;
385        }
386        v
387    }
388
389    // --- Lanczos-3 CPU path ---
390
391    #[test]
392    fn test_lanczos3_uniform_downscale_preserves_colour() {
393        // A uniform image: after scaling, every pixel should stay the same colour.
394        let src = solid(8, 8, 100, 150, 200, 255);
395        let mut dst = vec![0u8; 4 * 4 * 4];
396        ScaleOperation::lanczos3_cpu(&src, 8, 8, &mut dst, 4, 4)
397            .expect("lanczos3 downscale should succeed");
398        for px in dst.chunks(4) {
399            assert!(
400                (px[0] as i32 - 100).unsigned_abs() <= 1,
401                "R mismatch: {}",
402                px[0]
403            );
404            assert!(
405                (px[1] as i32 - 150).unsigned_abs() <= 1,
406                "G mismatch: {}",
407                px[1]
408            );
409            assert!(
410                (px[2] as i32 - 200).unsigned_abs() <= 1,
411                "B mismatch: {}",
412                px[2]
413            );
414        }
415    }
416
417    #[test]
418    fn test_lanczos3_uniform_upscale_preserves_colour() {
419        let src = solid(4, 4, 80, 160, 240, 255);
420        let mut dst = vec![0u8; 8 * 8 * 4];
421        ScaleOperation::lanczos3_cpu(&src, 4, 4, &mut dst, 8, 8)
422            .expect("lanczos3 upscale should succeed");
423        for px in dst.chunks(4) {
424            assert!(
425                (px[0] as i32 - 80).unsigned_abs() <= 2,
426                "R mismatch: {}",
427                px[0]
428            );
429            assert!(
430                (px[1] as i32 - 160).unsigned_abs() <= 2,
431                "G mismatch: {}",
432                px[1]
433            );
434            assert!(
435                (px[2] as i32 - 240).unsigned_abs() <= 2,
436                "B mismatch: {}",
437                px[2]
438            );
439        }
440    }
441
442    #[test]
443    fn test_lanczos3_1x1_identity() {
444        let src = solid(1, 1, 42, 84, 126, 255);
445        let mut dst = vec![0u8; 4];
446        ScaleOperation::lanczos3_cpu(&src, 1, 1, &mut dst, 1, 1)
447            .expect("1×1 lanczos3 should succeed");
448        assert_eq!(dst[0], 42);
449        assert_eq!(dst[1], 84);
450        assert_eq!(dst[2], 126);
451        assert_eq!(dst[3], 255);
452    }
453
454    #[test]
455    fn test_lanczos3_output_size_correct() {
456        let src = solid(16, 16, 200, 200, 200, 255);
457        let mut dst = vec![0u8; 8 * 4 * 4]; // 8 wide × 4 tall
458        ScaleOperation::lanczos3_cpu(&src, 16, 16, &mut dst, 8, 4)
459            .expect("lanczos3 non-square downscale should succeed");
460        assert_eq!(dst.len(), 8 * 4 * 4);
461    }
462
463    #[test]
464    fn test_lanczos3_gradient_downscale_monotone() {
465        // A left-to-right gradient: after downscaling, pixel X values should
466        // still be monotonically non-decreasing across columns.
467        let sw = 16u32;
468        let sh = 4u32;
469        let mut src = vec![0u8; (sw * sh * 4) as usize];
470        for row in 0..sh as usize {
471            for col in 0..sw as usize {
472                let v = (col * 255 / (sw as usize - 1)) as u8;
473                let base = (row * sw as usize + col) * 4;
474                src[base] = v;
475                src[base + 1] = v;
476                src[base + 2] = v;
477                src[base + 3] = 255;
478            }
479        }
480        let dw = 8u32;
481        let dh = 4u32;
482        let mut dst = vec![0u8; (dw * dh * 4) as usize];
483        ScaleOperation::lanczos3_cpu(&src, sw, sh, &mut dst, dw, dh)
484            .expect("lanczos3 gradient downscale should succeed");
485        // Check that each row is non-decreasing in the R channel
486        for row in 0..dh as usize {
487            let mut prev = 0u8;
488            for col in 0..dw as usize {
489                let r = dst[(row * dw as usize + col) * 4];
490                // Allow ±2 due to Lanczos ringing
491                assert!(
492                    r as i32 >= prev as i32 - 2,
493                    "gradient not monotone: row={row} col={col} r={r} prev={prev}"
494                );
495                prev = r;
496            }
497        }
498    }
499
500    #[test]
501    fn test_lanczos3_black_white_border() {
502        // Image split: left half black, right half white.  After downscale the
503        // left-most pixel should be near black and right-most near white.
504        let sw = 8u32;
505        let sh = 4u32;
506        let mut src = vec![0u8; (sw * sh * 4) as usize];
507        for row in 0..sh as usize {
508            for col in 0..sw as usize {
509                let v = if col < sw as usize / 2 { 0u8 } else { 255u8 };
510                let base = (row * sw as usize + col) * 4;
511                src[base] = v;
512                src[base + 1] = v;
513                src[base + 2] = v;
514                src[base + 3] = 255;
515            }
516        }
517        let dw = 4u32;
518        let dh = 2u32;
519        let mut dst = vec![0u8; (dw * dh * 4) as usize];
520        ScaleOperation::lanczos3_cpu(&src, sw, sh, &mut dst, dw, dh)
521            .expect("lanczos3 should succeed");
522        let left = dst[0]; // first pixel, R channel
523        let right = dst[((dw - 1) * 4) as usize]; // last pixel on first row
524        assert!(left < 128, "left pixel should be dark: {left}");
525        assert!(right > 128, "right pixel should be bright: {right}");
526    }
527}