Skip to main content

oximedia_gpu/ops/
scale.rs

1//! Image scaling operations with various interpolation methods
2
3use crate::{
4    shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5    GpuDevice, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13/// Scale filter type
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ScaleFilter {
16    /// Nearest neighbor (fastest, lowest quality)
17    Nearest,
18    /// Bilinear interpolation (balanced)
19    Bilinear,
20    /// Bicubic interpolation (highest quality)
21    Bicubic,
22    /// Area averaging for downscaling
23    Area,
24}
25
26impl ScaleFilter {
27    fn to_filter_id(self) -> u32 {
28        match self {
29            Self::Nearest => 0,
30            Self::Bilinear => 1,
31            Self::Bicubic => 2,
32            Self::Area => 3,
33        }
34    }
35}
36
37#[repr(C)]
38#[derive(Copy, Clone, Pod, Zeroable)]
39struct ScaleParams {
40    src_width: u32,
41    src_height: u32,
42    dst_width: u32,
43    dst_height: u32,
44    src_stride: u32,
45    dst_stride: u32,
46    filter_type: u32,
47    padding: u32,
48}
49
50/// Image scaling operations
51pub struct ScaleOperation;
52
53impl ScaleOperation {
54    /// Scale an image
55    ///
56    /// # Arguments
57    ///
58    /// * `device` - GPU device
59    /// * `input` - Input image buffer (packed RGBA format)
60    /// * `src_width` - Source image width
61    /// * `src_height` - Source image height
62    /// * `output` - Output image buffer (packed RGBA format)
63    /// * `dst_width` - Destination image width
64    /// * `dst_height` - Destination image height
65    /// * `filter` - Scaling filter type
66    ///
67    /// # Errors
68    ///
69    /// Returns an error if buffer sizes are invalid or if the GPU operation fails.
70    #[allow(clippy::too_many_arguments)]
71    pub fn scale(
72        device: &GpuDevice,
73        input: &[u8],
74        src_width: u32,
75        src_height: u32,
76        output: &mut [u8],
77        dst_width: u32,
78        dst_height: u32,
79        filter: ScaleFilter,
80    ) -> Result<()> {
81        utils::validate_dimensions(src_width, src_height)?;
82        utils::validate_dimensions(dst_width, dst_height)?;
83        utils::validate_buffer_size(input, src_width, src_height, 4)?;
84        utils::validate_buffer_size(output, dst_width, dst_height, 4)?;
85
86        let pipeline = if filter == ScaleFilter::Area {
87            Self::get_downscale_pipeline(device)?
88        } else {
89            Self::get_scale_pipeline(device)?
90        };
91
92        let layout = Self::get_bind_group_layout(device)?;
93
94        Self::execute_scale(
95            device, pipeline, layout, input, src_width, src_height, output, dst_width, dst_height,
96            filter,
97        )
98    }
99
100    #[allow(clippy::too_many_arguments)]
101    fn execute_scale(
102        device: &GpuDevice,
103        pipeline: &ComputePipeline,
104        layout: &BindGroupLayout,
105        input: &[u8],
106        src_width: u32,
107        src_height: u32,
108        output: &mut [u8],
109        dst_width: u32,
110        dst_height: u32,
111        filter: ScaleFilter,
112    ) -> Result<()> {
113        // Create buffers
114        let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
115        let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
116
117        // Upload input data
118        device.queue().write_buffer(input_buffer.buffer(), 0, input);
119
120        // Create uniform buffer for parameters
121        let params = ScaleParams {
122            src_width,
123            src_height,
124            dst_width,
125            dst_height,
126            src_stride: src_width,
127            dst_stride: dst_width,
128            filter_type: filter.to_filter_id(),
129            padding: 0,
130        };
131        let params_bytes = bytemuck::bytes_of(&params);
132        let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
133
134        // Create bind group
135        let compiler = ShaderCompiler::new(device);
136        let bind_group = compiler.create_bind_group(
137            "Scale Bind Group",
138            layout,
139            &[
140                wgpu::BindGroupEntry {
141                    binding: 0,
142                    resource: input_buffer.buffer().as_entire_binding(),
143                },
144                wgpu::BindGroupEntry {
145                    binding: 1,
146                    resource: output_buffer.buffer().as_entire_binding(),
147                },
148                wgpu::BindGroupEntry {
149                    binding: 2,
150                    resource: params_buffer.buffer().as_entire_binding(),
151                },
152            ],
153        );
154
155        // Execute compute pass
156        Self::dispatch_compute(device, pipeline, &bind_group, dst_width, dst_height)?;
157
158        // Read back results
159        let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
160        let mut encoder = device
161            .device()
162            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
163                label: Some("Scale Copy Encoder"),
164            });
165
166        output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
167
168        device.queue().submit(Some(encoder.finish()));
169        device.wait();
170
171        let result = readback_buffer.read(device, 0, output.len() as u64)?;
172        output.copy_from_slice(&result);
173
174        Ok(())
175    }
176
177    fn dispatch_compute(
178        device: &GpuDevice,
179        pipeline: &ComputePipeline,
180        bind_group: &BindGroup,
181        width: u32,
182        height: u32,
183    ) -> Result<()> {
184        let mut encoder = device
185            .device()
186            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
187                label: Some("Scale Compute Encoder"),
188            });
189
190        {
191            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
192                label: Some("Scale Compute Pass"),
193                timestamp_writes: None,
194            });
195
196            compute_pass.set_pipeline(pipeline);
197            compute_pass.set_bind_group(0, bind_group, &[]);
198
199            let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
200            compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
201        }
202
203        device.queue().submit(Some(encoder.finish()));
204        Ok(())
205    }
206
207    fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
208        static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
209
210        Ok(LAYOUT.get_or_init(|| {
211            let compiler = ShaderCompiler::new(device);
212            let entries = BindGroupLayoutBuilder::new()
213                .add_storage_buffer_read_only(0) // input
214                .add_storage_buffer(1) // output
215                .add_uniform_buffer(2) // params
216                .build();
217
218            compiler.create_bind_group_layout("Scale Bind Group Layout", &entries)
219        }))
220    }
221
222    fn get_scale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
223        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
224
225        Ok(PIPELINE.get_or_init(|| {
226            let compiler = ShaderCompiler::new(device);
227            let shader = compiler
228                .compile(
229                    "Scale Shader",
230                    ShaderSource::Embedded(crate::shader::embedded::SCALE_SHADER),
231                )
232                .expect("Failed to compile scale shader");
233
234            let layout =
235                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
236
237            compiler
238                .create_pipeline("Scale Pipeline", &shader, "scale_main", layout)
239                .expect("Failed to create pipeline")
240        }))
241    }
242
243    fn get_downscale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
244        static PIPELINE: OnceCell<ComputePipeline> = OnceCell::new();
245
246        Ok(PIPELINE.get_or_init(|| {
247            let compiler = ShaderCompiler::new(device);
248            let shader = compiler
249                .compile(
250                    "Scale Shader",
251                    ShaderSource::Embedded(crate::shader::embedded::SCALE_SHADER),
252                )
253                .expect("Failed to compile scale shader");
254
255            let layout =
256                Self::get_bind_group_layout(device).expect("Failed to create bind group layout");
257
258            compiler
259                .create_pipeline("Downscale Pipeline", &shader, "downscale_area", layout)
260                .expect("Failed to create pipeline")
261        }))
262    }
263}