oximedia-gpu 0.1.8

GPU compute pipeline using WGPU for OxiMedia - cross-platform acceleration
Documentation
// GPU sub-pixel refinement pass for motion estimation.
//
// Reads integer motion vectors from `mv_in` and performs a 3-point parabola
// fit at ±0.5 px to produce a refined floating-point MV in `mv_out`.
// One thread per block.

struct SubpixUniforms {
    frame_width:  u32,
    frame_height: u32,
    block_size:   u32,
    num_blocks:   u32,
}

@group(0) @binding(0) var<uniform>              u:       SubpixUniforms;
@group(0) @binding(1) var<storage, read>        ref_buf: array<u32>;
@group(0) @binding(2) var<storage, read>        cur_buf: array<u32>;
@group(0) @binding(3) var<storage, read>        mv_in:   array<vec4<i32>>;
@group(0) @binding(4) var<storage, read_write>  mv_out:  array<vec2<f32>>;

fn load_ref_f(xi: i32, yi: i32) -> f32 {
    let x = u32(clamp(xi, 0, i32(u.frame_width)  - 1));
    let y = u32(clamp(yi, 0, i32(u.frame_height) - 1));
    return f32(ref_buf[y * u.frame_width + x]);
}

fn load_cur_f(xi: i32, yi: i32) -> f32 {
    let x = u32(clamp(xi, 0, i32(u.frame_width)  - 1));
    let y = u32(clamp(yi, 0, i32(u.frame_height) - 1));
    return f32(cur_buf[y * u.frame_width + x]);
}

// Bilinear sample from a storage buffer at floating-point coords.
fn sample_bilinear_ref(xf: f32, yf: f32) -> f32 {
    let x0 = clamp(i32(floor(xf)), 0, i32(u.frame_width)  - 1);
    let y0 = clamp(i32(floor(yf)), 0, i32(u.frame_height) - 1);
    let x1 = clamp(x0 + 1, 0, i32(u.frame_width)  - 1);
    let y1 = clamp(y0 + 1, 0, i32(u.frame_height) - 1);
    let fx = xf - floor(xf);
    let fy = yf - floor(yf);
    let v00 = f32(ref_buf[u32(y0) * u.frame_width + u32(x0)]);
    let v10 = f32(ref_buf[u32(y0) * u.frame_width + u32(x1)]);
    let v01 = f32(ref_buf[u32(y1) * u.frame_width + u32(x0)]);
    let v11 = f32(ref_buf[u32(y1) * u.frame_width + u32(x1)]);
    return mix(mix(v00, v10, fx), mix(v01, v11, fx), fy);
}

// SAD between the current integer block and the reference at a floating-point
// displacement (dx, dy) using bilinear interpolation.
fn compute_sad_subpix(bx: i32, by: i32, dx: f32, dy: f32) -> f32 {
    let bs  = i32(u.block_size);
    var s: f32 = 0.0;
    for (var ry: i32 = 0; ry < bs; ry = ry + 1) {
        for (var rx: i32 = 0; rx < bs; rx = rx + 1) {
            let cur_val = load_cur_f(bx + rx, by + ry);
            let ref_val = sample_bilinear_ref(
                f32(bx + rx) + dx,
                f32(by + ry) + dy,
            );
            s = s + abs(cur_val - ref_val);
        }
    }
    return s;
}

@compute @workgroup_size(64)
fn subpixel_refine(@builtin(global_invocation_id) gid: vec3<u32>) {
    if gid.x >= u.num_blocks { return; }

    let block_idx = i32(gid.x);
    let mv        = mv_in[block_idx];

    let blocks_x = max(i32(u.frame_width)  / i32(u.block_size), 1);
    let bx        = (block_idx % blocks_x) * i32(u.block_size);
    let by        = (block_idx / blocks_x) * i32(u.block_size);

    let cx = f32(mv.x);
    let cy = f32(mv.y);

    // ── 3-point parabola fit along X ────────────────────────────────────────
    let s0x = compute_sad_subpix(bx, by, cx - 0.5, cy);
    let s1x = compute_sad_subpix(bx, by, cx,        cy);
    let s2x = compute_sad_subpix(bx, by, cx + 0.5, cy);

    var best_dx = cx;
    let denom_x = s0x - 2.0 * s1x + s2x;
    if s0x < s1x && s0x < s2x {
        best_dx = cx - 0.5;
    } else if s2x < s1x {
        best_dx = cx + 0.5;
    } else if abs(denom_x) > 1e-6 {
        best_dx = cx + 0.5 * (s0x - s2x) / denom_x;
    }

    // ── 3-point parabola fit along Y (using refined X) ─────────────────────
    let s0y = compute_sad_subpix(bx, by, best_dx, cy - 0.5);
    let s1y = compute_sad_subpix(bx, by, best_dx, cy);
    let s2y = compute_sad_subpix(bx, by, best_dx, cy + 0.5);

    var best_dy = cy;
    let denom_y = s0y - 2.0 * s1y + s2y;
    if s0y < s1y && s0y < s2y {
        best_dy = cy - 0.5;
    } else if s2y < s1y {
        best_dy = cy + 0.5;
    } else if abs(denom_y) > 1e-6 {
        best_dy = cy + 0.5 * (s0y - s2y) / denom_y;
    }

    mv_out[gid.x] = vec2<f32>(best_dx, best_dy);
}