1use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
10#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
11struct UpsampleParams {
12 batch: u32,
13 channels: u32,
14 in_h: u32,
15 in_w: u32,
16 out_h: u32,
17 out_w: u32,
18 _pad: [u32; 2],
19}
20
21const SHADER_UPSAMPLE_NEAREST: &str = "
22struct P { batch: u32, channels: u32, in_h: u32, in_w: u32, out_h: u32, out_w: u32, _p0: u32, _p1: u32, }
23@group(0) @binding(0) var<uniform> p: P;
24@group(0) @binding(1) var<storage, read> input: array<f32>;
25@group(0) @binding(2) var<storage, read_write> out: array<f32>;
26@compute @workgroup_size(256)
27fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
28 let idx = gid.x + gid.y * 65535u * 256u;
29 let total = p.batch * p.channels * p.out_h * p.out_w;
30 if idx >= total { return; }
31
32 let ow = idx % p.out_w;
33 let oh = (idx / p.out_w) % p.out_h;
34 let c = (idx / (p.out_w * p.out_h)) % p.channels;
35 let n = idx / (p.out_w * p.out_h * p.channels);
36
37 let ih = oh * p.in_h / p.out_h;
38 let iw = ow * p.in_w / p.out_w;
39
40 let in_idx = n * (p.channels * p.in_h * p.in_w)
41 + c * (p.in_h * p.in_w)
42 + ih * p.in_w + iw;
43 out[idx] = input[in_idx];
44}
45";
46
47impl GpuDevice {
48 pub fn upsample_nearest2d(
50 &self,
51 input: &GpuBuffer,
52 batch: u32, channels: u32, in_h: u32, in_w: u32,
53 scale_h: u32, scale_w: u32,
54 ) -> Result<GpuBuffer> {
55 ensure!(input.len == (batch * channels * in_h * in_w) as usize);
56 let out_h = in_h * scale_h;
57 let out_w = in_w * scale_w;
58 let total = batch * channels * out_h * out_w;
59 let out = self.alloc(total as usize);
60 let params = UpsampleParams { batch, channels, in_h, in_w, out_h, out_w, _pad: [0; 2] };
61 self.dispatch_shader(
62 SHADER_UPSAMPLE_NEAREST, Some("upsample"),
63 ¶ms, &[input], &out,
64 super::dispatch_1d(total),
65 );
66 Ok(out)
67 }
68}
69
70#[cfg(test)]
71mod tests {
72 use super::*;
73 fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
74
75 fn cpu_upsample(input: &[f32], batch: usize, ch: usize, h: usize, w: usize, sh: usize, sw: usize) -> Vec<f32> {
77 let oh = h * sh; let ow = w * sw;
78 let mut out = vec![0.0f32; batch * ch * oh * ow];
79 for n in 0..batch {
80 for c in 0..ch {
81 for y in 0..oh {
82 for x in 0..ow {
83 let iy = y * h / oh; let ix = x * w / ow;
84 out[n*ch*oh*ow + c*oh*ow + y*ow + x] = input[n*ch*h*w + c*h*w + iy*w + ix];
85 }
86 }
87 }
88 }
89 out
90 }
91
92 #[test]
93 fn test_upsample_2x() {
94 let input = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
95 let result = dev().read(&dev().upsample_nearest2d(&input, 1, 1, 2, 2, 2, 2).unwrap()).unwrap();
96 assert_eq!(result, vec![
97 1.0, 1.0, 2.0, 2.0,
98 1.0, 1.0, 2.0, 2.0,
99 3.0, 3.0, 4.0, 4.0,
100 3.0, 3.0, 4.0, 4.0,
101 ]);
102 }
103
104 #[test]
105 fn test_upsample_3x_vs_cpu() {
106 let data: Vec<f32> = (1..=6).map(|x| x as f32).collect(); let expected = cpu_upsample(&data, 1, 1, 2, 3, 3, 3);
109 let result = dev().read(&dev().upsample_nearest2d(&dev().upload(&data), 1, 1, 2, 3, 3, 3).unwrap()).unwrap();
110 assert_eq!(result, expected);
111 }
112
113 #[test]
114 fn test_upsample_batched_multichannel_vs_cpu() {
115 let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
117 let expected = cpu_upsample(&data, 2, 3, 2, 2, 2, 2);
118 let result = dev().read(&dev().upsample_nearest2d(&dev().upload(&data), 2, 3, 2, 2, 2, 2).unwrap()).unwrap();
119 assert_eq!(result, expected);
120 }
121
122 #[test]
123 fn test_upsample_1x1() {
124 let result = dev().read(&dev().upsample_nearest2d(&dev().upload(&[7.0]), 1, 1, 1, 1, 3, 3).unwrap()).unwrap();
126 assert_eq!(result, vec![7.0; 9]);
127 }
128}