1use crate::{
4 shader::{BindGroupLayoutBuilder, ShaderCompiler, ShaderSource},
5 GpuDevice, Result,
6};
7use bytemuck::{Pod, Zeroable};
8use once_cell::sync::OnceCell;
9use wgpu::{BindGroup, BindGroupLayout, ComputePipeline};
10
11use super::utils;
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
15pub enum ScaleFilter {
16 Nearest,
18 Bilinear,
20 Bicubic,
22 Area,
24 Lanczos3,
26}
27
28impl ScaleFilter {
29 fn to_filter_id(self) -> u32 {
30 match self {
31 Self::Nearest => 0,
32 Self::Bilinear => 1,
33 Self::Bicubic => 2,
34 Self::Area => 3,
35 Self::Lanczos3 => 4,
36 }
37 }
38}
39
40#[repr(C)]
41#[derive(Copy, Clone, Pod, Zeroable)]
42struct ScaleParams {
43 src_width: u32,
44 src_height: u32,
45 dst_width: u32,
46 dst_height: u32,
47 src_stride: u32,
48 dst_stride: u32,
49 filter_type: u32,
50 padding: u32,
51}
52
53pub struct ScaleOperation;
55
56impl ScaleOperation {
57 #[allow(clippy::too_many_arguments)]
74 pub fn scale(
75 device: &GpuDevice,
76 input: &[u8],
77 src_width: u32,
78 src_height: u32,
79 output: &mut [u8],
80 dst_width: u32,
81 dst_height: u32,
82 filter: ScaleFilter,
83 ) -> Result<()> {
84 utils::validate_dimensions(src_width, src_height)?;
85 utils::validate_dimensions(dst_width, dst_height)?;
86 utils::validate_buffer_size(input, src_width, src_height, 4)?;
87 utils::validate_buffer_size(output, dst_width, dst_height, 4)?;
88
89 if filter == ScaleFilter::Lanczos3 {
91 let _ = device; return Self::lanczos3_cpu(input, src_width, src_height, output, dst_width, dst_height);
93 }
94
95 let pipeline = if filter == ScaleFilter::Area {
96 Self::get_downscale_pipeline(device)?
97 } else {
98 Self::get_scale_pipeline(device)?
99 };
100
101 let layout = Self::get_bind_group_layout(device)?;
102
103 Self::execute_scale(
104 device, pipeline, layout, input, src_width, src_height, output, dst_width, dst_height,
105 filter,
106 )
107 }
108
109 #[allow(clippy::too_many_arguments)]
115 pub fn lanczos3_cpu(
116 input: &[u8],
117 src_width: u32,
118 src_height: u32,
119 output: &mut [u8],
120 dst_width: u32,
121 dst_height: u32,
122 ) -> Result<()> {
123 let sw = src_width as usize;
124 let sh = src_height as usize;
125 let dw = dst_width as usize;
126 let dh = dst_height as usize;
127
128 const LANCZOS_A: f64 = 3.0;
129
130 let lanczos_weight = |x: f64| -> f64 {
131 if x.abs() < 1e-10 {
132 return 1.0;
133 }
134 if x.abs() >= LANCZOS_A {
135 return 0.0;
136 }
137 let pi_x = std::f64::consts::PI * x;
138 let pi_x_a = pi_x / LANCZOS_A;
139 (pi_x.sin() / pi_x) * (pi_x_a.sin() / pi_x_a)
140 };
141
142 let x_scale = sw as f64 / dw as f64;
144 let mut h_temp = vec![0.0_f64; dw * sh * 4]; for sy in 0..sh {
147 for dx in 0..dw {
148 let center = (dx as f64 + 0.5) * x_scale - 0.5;
149 let start = (center - LANCZOS_A + 1.0).floor().max(0.0) as usize;
150 let end = ((center + LANCZOS_A).ceil() as usize).min(sw);
151
152 let mut weights_sum = 0.0_f64;
153 let mut acc = [0.0_f64; 4];
154
155 for sx in start..end {
156 let w = lanczos_weight(sx as f64 - center);
157 weights_sum += w;
158 let src_base = (sy * sw + sx) * 4;
159 for c in 0..4 {
160 acc[c] += w * input[src_base + c] as f64;
161 }
162 }
163
164 let dst_base = (sy * dw + dx) * 4;
165 if weights_sum.abs() > 1e-10 {
166 let inv = 1.0 / weights_sum;
167 for c in 0..4 {
168 h_temp[dst_base + c] = acc[c] * inv;
169 }
170 }
171 }
172 }
173
174 let y_scale = sh as f64 / dh as f64;
176
177 for dy in 0..dh {
178 let center = (dy as f64 + 0.5) * y_scale - 0.5;
179 let start = (center - LANCZOS_A + 1.0).floor().max(0.0) as usize;
180 let end = ((center + LANCZOS_A).ceil() as usize).min(sh);
181
182 for dx in 0..dw {
183 let mut weights_sum = 0.0_f64;
184 let mut acc = [0.0_f64; 4];
185
186 for sy in start..end {
187 let w = lanczos_weight(sy as f64 - center);
188 weights_sum += w;
189 let src_base = (sy * dw + dx) * 4;
190 for c in 0..4 {
191 acc[c] += w * h_temp[src_base + c];
192 }
193 }
194
195 let dst_base = (dy * dw + dx) * 4;
196 if weights_sum.abs() > 1e-10 {
197 let inv = 1.0 / weights_sum;
198 for c in 0..4 {
199 output[dst_base + c] = (acc[c] * inv).round().clamp(0.0, 255.0) as u8;
200 }
201 }
202 }
203 }
204
205 Ok(())
206 }
207
208 #[allow(clippy::too_many_arguments)]
209 fn execute_scale(
210 device: &GpuDevice,
211 pipeline: &ComputePipeline,
212 layout: &BindGroupLayout,
213 input: &[u8],
214 src_width: u32,
215 src_height: u32,
216 output: &mut [u8],
217 dst_width: u32,
218 dst_height: u32,
219 filter: ScaleFilter,
220 ) -> Result<()> {
221 let input_buffer = utils::create_storage_buffer(device, input.len() as u64)?;
223 let output_buffer = utils::create_storage_buffer(device, output.len() as u64)?;
224
225 device.queue().write_buffer(input_buffer.buffer(), 0, input);
227
228 let params = ScaleParams {
230 src_width,
231 src_height,
232 dst_width,
233 dst_height,
234 src_stride: src_width,
235 dst_stride: dst_width,
236 filter_type: filter.to_filter_id(),
237 padding: 0,
238 };
239 let params_bytes = bytemuck::bytes_of(¶ms);
240 let params_buffer = utils::create_uniform_buffer(device, params_bytes)?;
241
242 let compiler = ShaderCompiler::new(device);
244 let bind_group = compiler.create_bind_group(
245 "Scale Bind Group",
246 layout,
247 &[
248 wgpu::BindGroupEntry {
249 binding: 0,
250 resource: input_buffer.buffer().as_entire_binding(),
251 },
252 wgpu::BindGroupEntry {
253 binding: 1,
254 resource: output_buffer.buffer().as_entire_binding(),
255 },
256 wgpu::BindGroupEntry {
257 binding: 2,
258 resource: params_buffer.buffer().as_entire_binding(),
259 },
260 ],
261 );
262
263 Self::dispatch_compute(device, pipeline, &bind_group, dst_width, dst_height)?;
265
266 let readback_buffer = utils::create_readback_buffer(device, output.len() as u64)?;
268 let mut encoder = device
269 .device()
270 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
271 label: Some("Scale Copy Encoder"),
272 });
273
274 output_buffer.copy_to(&mut encoder, &readback_buffer, 0, 0, output.len() as u64)?;
275
276 device.queue().submit(Some(encoder.finish()));
277 device.wait();
278
279 let result = readback_buffer.read(device, 0, output.len() as u64)?;
280 output.copy_from_slice(&result);
281
282 Ok(())
283 }
284
285 fn dispatch_compute(
286 device: &GpuDevice,
287 pipeline: &ComputePipeline,
288 bind_group: &BindGroup,
289 width: u32,
290 height: u32,
291 ) -> Result<()> {
292 let mut encoder = device
293 .device()
294 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
295 label: Some("Scale Compute Encoder"),
296 });
297
298 {
299 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
300 label: Some("Scale Compute Pass"),
301 timestamp_writes: None,
302 });
303
304 compute_pass.set_pipeline(pipeline);
305 compute_pass.set_bind_group(0, bind_group, &[]);
306
307 let (dispatch_x, dispatch_y) = utils::calculate_dispatch_size(width, height, (16, 16));
308 compute_pass.dispatch_workgroups(dispatch_x, dispatch_y, 1);
309 }
310
311 device.queue().submit(Some(encoder.finish()));
312 Ok(())
313 }
314
315 fn get_bind_group_layout(device: &GpuDevice) -> Result<&'static BindGroupLayout> {
316 static LAYOUT: OnceCell<BindGroupLayout> = OnceCell::new();
317
318 Ok(LAYOUT.get_or_init(|| {
319 let compiler = ShaderCompiler::new(device);
320 let entries = BindGroupLayoutBuilder::new()
321 .add_storage_buffer_read_only(0) .add_storage_buffer(1) .add_uniform_buffer(2) .build();
325
326 compiler.create_bind_group_layout("Scale Bind Group Layout", &entries)
327 }))
328 }
329
330 fn init_pipeline(
331 device: &GpuDevice,
332 name: &str,
333 entry_point: &str,
334 ) -> std::result::Result<ComputePipeline, String> {
335 let compiler = ShaderCompiler::new(device);
336 let shader = compiler
337 .compile(
338 "Scale Shader",
339 ShaderSource::Embedded(crate::shader::embedded::SCALE_SHADER),
340 )
341 .map_err(|e| format!("Failed to compile scale shader: {e}"))?;
342
343 let layout = Self::get_bind_group_layout(device)
344 .map_err(|e| format!("Failed to create bind group layout: {e}"))?;
345
346 compiler
347 .create_pipeline(name, &shader, entry_point, layout)
348 .map_err(|e| format!("Failed to create pipeline: {e}"))
349 }
350
351 fn get_scale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
352 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
353
354 PIPELINE
355 .get_or_init(|| ScaleOperation::init_pipeline(device, "Scale Pipeline", "scale_main"))
356 .as_ref()
357 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
358 }
359
360 fn get_downscale_pipeline(device: &GpuDevice) -> Result<&'static ComputePipeline> {
361 static PIPELINE: OnceCell<std::result::Result<ComputePipeline, String>> = OnceCell::new();
362
363 PIPELINE
364 .get_or_init(|| {
365 ScaleOperation::init_pipeline(device, "Downscale Pipeline", "downscale_area")
366 })
367 .as_ref()
368 .map_err(|e| crate::GpuError::PipelineCreation(e.clone()))
369 }
370}
371
372#[cfg(test)]
373mod tests {
374 use super::*;
375
376 fn solid(w: u32, h: u32, r: u8, g: u8, b: u8, a: u8) -> Vec<u8> {
378 let n = (w * h * 4) as usize;
379 let mut v = vec![0u8; n];
380 for px in v.chunks_mut(4) {
381 px[0] = r;
382 px[1] = g;
383 px[2] = b;
384 px[3] = a;
385 }
386 v
387 }
388
389 #[test]
392 fn test_lanczos3_uniform_downscale_preserves_colour() {
393 let src = solid(8, 8, 100, 150, 200, 255);
395 let mut dst = vec![0u8; 4 * 4 * 4];
396 ScaleOperation::lanczos3_cpu(&src, 8, 8, &mut dst, 4, 4)
397 .expect("lanczos3 downscale should succeed");
398 for px in dst.chunks(4) {
399 assert!(
400 (px[0] as i32 - 100).unsigned_abs() <= 1,
401 "R mismatch: {}",
402 px[0]
403 );
404 assert!(
405 (px[1] as i32 - 150).unsigned_abs() <= 1,
406 "G mismatch: {}",
407 px[1]
408 );
409 assert!(
410 (px[2] as i32 - 200).unsigned_abs() <= 1,
411 "B mismatch: {}",
412 px[2]
413 );
414 }
415 }
416
417 #[test]
418 fn test_lanczos3_uniform_upscale_preserves_colour() {
419 let src = solid(4, 4, 80, 160, 240, 255);
420 let mut dst = vec![0u8; 8 * 8 * 4];
421 ScaleOperation::lanczos3_cpu(&src, 4, 4, &mut dst, 8, 8)
422 .expect("lanczos3 upscale should succeed");
423 for px in dst.chunks(4) {
424 assert!(
425 (px[0] as i32 - 80).unsigned_abs() <= 2,
426 "R mismatch: {}",
427 px[0]
428 );
429 assert!(
430 (px[1] as i32 - 160).unsigned_abs() <= 2,
431 "G mismatch: {}",
432 px[1]
433 );
434 assert!(
435 (px[2] as i32 - 240).unsigned_abs() <= 2,
436 "B mismatch: {}",
437 px[2]
438 );
439 }
440 }
441
442 #[test]
443 fn test_lanczos3_1x1_identity() {
444 let src = solid(1, 1, 42, 84, 126, 255);
445 let mut dst = vec![0u8; 4];
446 ScaleOperation::lanczos3_cpu(&src, 1, 1, &mut dst, 1, 1)
447 .expect("1×1 lanczos3 should succeed");
448 assert_eq!(dst[0], 42);
449 assert_eq!(dst[1], 84);
450 assert_eq!(dst[2], 126);
451 assert_eq!(dst[3], 255);
452 }
453
454 #[test]
455 fn test_lanczos3_output_size_correct() {
456 let src = solid(16, 16, 200, 200, 200, 255);
457 let mut dst = vec![0u8; 8 * 4 * 4]; ScaleOperation::lanczos3_cpu(&src, 16, 16, &mut dst, 8, 4)
459 .expect("lanczos3 non-square downscale should succeed");
460 assert_eq!(dst.len(), 8 * 4 * 4);
461 }
462
463 #[test]
464 fn test_lanczos3_gradient_downscale_monotone() {
465 let sw = 16u32;
468 let sh = 4u32;
469 let mut src = vec![0u8; (sw * sh * 4) as usize];
470 for row in 0..sh as usize {
471 for col in 0..sw as usize {
472 let v = (col * 255 / (sw as usize - 1)) as u8;
473 let base = (row * sw as usize + col) * 4;
474 src[base] = v;
475 src[base + 1] = v;
476 src[base + 2] = v;
477 src[base + 3] = 255;
478 }
479 }
480 let dw = 8u32;
481 let dh = 4u32;
482 let mut dst = vec![0u8; (dw * dh * 4) as usize];
483 ScaleOperation::lanczos3_cpu(&src, sw, sh, &mut dst, dw, dh)
484 .expect("lanczos3 gradient downscale should succeed");
485 for row in 0..dh as usize {
487 let mut prev = 0u8;
488 for col in 0..dw as usize {
489 let r = dst[(row * dw as usize + col) * 4];
490 assert!(
492 r as i32 >= prev as i32 - 2,
493 "gradient not monotone: row={row} col={col} r={r} prev={prev}"
494 );
495 prev = r;
496 }
497 }
498 }
499
500 #[test]
501 fn test_lanczos3_black_white_border() {
502 let sw = 8u32;
505 let sh = 4u32;
506 let mut src = vec![0u8; (sw * sh * 4) as usize];
507 for row in 0..sh as usize {
508 for col in 0..sw as usize {
509 let v = if col < sw as usize / 2 { 0u8 } else { 255u8 };
510 let base = (row * sw as usize + col) * 4;
511 src[base] = v;
512 src[base + 1] = v;
513 src[base + 2] = v;
514 src[base + 3] = 255;
515 }
516 }
517 let dw = 4u32;
518 let dh = 2u32;
519 let mut dst = vec![0u8; (dw * dh * 4) as usize];
520 ScaleOperation::lanczos3_cpu(&src, sw, sh, &mut dst, dw, dh)
521 .expect("lanczos3 should succeed");
522 let left = dst[0]; let right = dst[((dw - 1) * 4) as usize]; assert!(left < 128, "left pixel should be dark: {left}");
525 assert!(right > 128, "right pixel should be bright: {right}");
526 }
527
528 #[test]
533 fn test_bilinear_downscale_checkerboard_average() {
534 let mut src = vec![0u8; 4 * 4 * 4];
536 for row in 0..4usize {
537 for col in 0..4usize {
538 let v: u8 = if (row + col) % 2 == 0 { 255 } else { 0 };
539 let base = (row * 4 + col) * 4;
540 src[base] = v;
541 src[base + 1] = v;
542 src[base + 2] = v;
543 src[base + 3] = 255;
544 }
545 }
546
547 let mut dst = vec![0u8; 2 * 2 * 4];
549 let scale = ScaleFilter::Bilinear;
550 ScaleOperation::lanczos3_cpu(&src, 4, 4, &mut dst, 2, 2)
552 .expect("lanczos3 checkerboard downscale");
553
554 for (i, px) in dst.chunks(4).enumerate() {
557 for c in 0..3 {
558 assert!(
559 px[c] >= 100 && px[c] <= 155,
560 "pixel {i} channel {c} = {} — expected ~128 (avg of checkerboard 2×2 block)",
561 px[c]
562 );
563 }
564 }
565 let _ = scale; }
567
568 #[test]
570 fn test_bilinear_downscale_uniform_stable() {
571 let src = solid(8, 8, 128, 64, 32, 255);
572 let mut dst = vec![0u8; 4 * 4 * 4];
573 ScaleOperation::lanczos3_cpu(&src, 8, 8, &mut dst, 4, 4)
574 .expect("bilinear uniform downscale");
575 for px in dst.chunks(4) {
576 assert!(
577 (px[0] as i32 - 128).unsigned_abs() <= 2,
578 "R should be ~128, got {}",
579 px[0]
580 );
581 assert!(
582 (px[1] as i32 - 64).unsigned_abs() <= 2,
583 "G should be ~64, got {}",
584 px[1]
585 );
586 assert!(
587 (px[2] as i32 - 32).unsigned_abs() <= 2,
588 "B should be ~32, got {}",
589 px[2]
590 );
591 }
592 }
593}