Skip to main content

any_gpu/ops/
upsample.rs

1// Unlicense — cochranblock.org
2// Contributors: GotEmCoach, KOVA, Claude Opus 4.6
3//
4// Nearest-neighbor upsampling for UNet decoder path.
5
6use crate::device::{GpuBuffer, GpuDevice};
7use anyhow::{ensure, Result};
8
9#[repr(C)]
10#[derive(Copy, Clone, bytemuck::Pod, bytemuck::Zeroable)]
11struct UpsampleParams {
12    batch: u32,
13    channels: u32,
14    in_h: u32,
15    in_w: u32,
16    out_h: u32,
17    out_w: u32,
18    _pad: [u32; 2],
19}
20
21const SHADER_UPSAMPLE_NEAREST: &str = "
22struct P { batch: u32, channels: u32, in_h: u32, in_w: u32, out_h: u32, out_w: u32, _p0: u32, _p1: u32, }
23@group(0) @binding(0) var<uniform> p: P;
24@group(0) @binding(1) var<storage, read> input: array<f32>;
25@group(0) @binding(2) var<storage, read_write> out: array<f32>;
26@compute @workgroup_size(256)
27fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
28    let idx = gid.x + gid.y * 65535u * 256u;
29    let total = p.batch * p.channels * p.out_h * p.out_w;
30    if idx >= total { return; }
31
32    let ow = idx % p.out_w;
33    let oh = (idx / p.out_w) % p.out_h;
34    let c  = (idx / (p.out_w * p.out_h)) % p.channels;
35    let n  = idx / (p.out_w * p.out_h * p.channels);
36
37    let ih = oh * p.in_h / p.out_h;
38    let iw = ow * p.in_w / p.out_w;
39
40    let in_idx = n * (p.channels * p.in_h * p.in_w)
41               + c * (p.in_h * p.in_w)
42               + ih * p.in_w + iw;
43    out[idx] = input[in_idx];
44}
45";
46
47impl GpuDevice {
48    /// Nearest-neighbor 2D upsample. Input: [N,C,H,W], output: [N,C,H*scale_h,W*scale_w].
49    pub fn upsample_nearest2d(
50        &self,
51        input: &GpuBuffer,
52        batch: u32, channels: u32, in_h: u32, in_w: u32,
53        scale_h: u32, scale_w: u32,
54    ) -> Result<GpuBuffer> {
55        ensure!(input.len == (batch * channels * in_h * in_w) as usize);
56        let out_h = in_h * scale_h;
57        let out_w = in_w * scale_w;
58        let total = batch * channels * out_h * out_w;
59        let out = self.alloc(total as usize);
60        let params = UpsampleParams { batch, channels, in_h, in_w, out_h, out_w, _pad: [0; 2] };
61        self.dispatch_shader(
62            SHADER_UPSAMPLE_NEAREST, Some("upsample"),
63            &params, &[input], &out,
64            super::dispatch_1d(total),
65        );
66        Ok(out)
67    }
68}
69
70#[cfg(test)]
71mod tests {
72    use super::*;
73    fn dev() -> &'static GpuDevice { &crate::ops::TEST_DEV }
74
75    // CPU reference upsample
76    fn cpu_upsample(input: &[f32], batch: usize, ch: usize, h: usize, w: usize, sh: usize, sw: usize) -> Vec<f32> {
77        let oh = h * sh; let ow = w * sw;
78        let mut out = vec![0.0f32; batch * ch * oh * ow];
79        for n in 0..batch {
80            for c in 0..ch {
81                for y in 0..oh {
82                    for x in 0..ow {
83                        let iy = y * h / oh; let ix = x * w / ow;
84                        out[n*ch*oh*ow + c*oh*ow + y*ow + x] = input[n*ch*h*w + c*h*w + iy*w + ix];
85                    }
86                }
87            }
88        }
89        out
90    }
91
92    #[test]
93    fn test_upsample_2x() {
94        let input = dev().upload(&[1.0, 2.0, 3.0, 4.0]);
95        let result = dev().read(&dev().upsample_nearest2d(&input, 1, 1, 2, 2, 2, 2).unwrap()).unwrap();
96        assert_eq!(result, vec![
97            1.0, 1.0, 2.0, 2.0,
98            1.0, 1.0, 2.0, 2.0,
99            3.0, 3.0, 4.0, 4.0,
100            3.0, 3.0, 4.0, 4.0,
101        ]);
102    }
103
104    #[test]
105    fn test_upsample_3x_vs_cpu() {
106        // Non-power-of-2 scale
107        let data: Vec<f32> = (1..=6).map(|x| x as f32).collect(); // 1x1x2x3
108        let expected = cpu_upsample(&data, 1, 1, 2, 3, 3, 3);
109        let result = dev().read(&dev().upsample_nearest2d(&dev().upload(&data), 1, 1, 2, 3, 3, 3).unwrap()).unwrap();
110        assert_eq!(result, expected);
111    }
112
113    #[test]
114    fn test_upsample_batched_multichannel_vs_cpu() {
115        // batch=2, channels=3, 2x2 spatial, scale 2x
116        let data: Vec<f32> = (0..24).map(|i| i as f32).collect();
117        let expected = cpu_upsample(&data, 2, 3, 2, 2, 2, 2);
118        let result = dev().read(&dev().upsample_nearest2d(&dev().upload(&data), 2, 3, 2, 2, 2, 2).unwrap()).unwrap();
119        assert_eq!(result, expected);
120    }
121
122    #[test]
123    fn test_upsample_1x1() {
124        // 1x1 spatial -> 3x3 spatial
125        let result = dev().read(&dev().upsample_nearest2d(&dev().upload(&[7.0]), 1, 1, 1, 1, 3, 3).unwrap()).unwrap();
126        assert_eq!(result, vec![7.0; 9]);
127    }
128}