oximedia-gpu 0.1.8

GPU compute pipeline using WGPU for OxiMedia - cross-platform acceleration
Documentation
// GPU block-matching pass for motion estimation using SAD cost metric.
//
// Each workgroup handles ONE block. Within the workgroup the 16×16 = 256
// threads each evaluate one candidate motion vector (centred around the seed
// from the previous pyramid level). A workgroup-shared parallel tree reduction
// finds the minimum SAD candidate without any subgroup operations, so the
// shader compiles on all wgpu backends.
//
// Uniform layout (must match `BlockMatchUniforms` on the Rust side):
//   block_size   (u32) — block edge length in pixels
//   search_half  (u32) — half the search window (the window is 16×16, so 8)
//   frame_width  (u32)
//   frame_height (u32)
//   mv_seed_x    (i32) — integer MV seed from the coarser pyramid level
//   mv_seed_y    (i32)
//   blocks_x     (u32) — number of blocks along X
//   blocks_y     (u32)
//
// Output buffer: `vec4<i32>` per block — (dx, dy, sad, _pad).

struct BlockMatchUniforms {
    block_size:   u32,
    search_half:  u32,
    frame_width:  u32,
    frame_height: u32,
    mv_seed_x:    i32,
    mv_seed_y:    i32,
    blocks_x:     u32,
    blocks_y:     u32,
}

@group(0) @binding(0) var<uniform>              u:       BlockMatchUniforms;
@group(0) @binding(1) var<storage, read>        ref_buf: array<u32>;
@group(0) @binding(2) var<storage, read>        cur_buf: array<u32>;
@group(0) @binding(3) var<storage, read_write>  mv_out:  array<vec4<i32>>;

// Workgroup-shared SAD cache: one u32 per thread (16×16 = 256 slots).
var<workgroup> sad_cache: array<u32, 256>;
var<workgroup> idx_cache: array<u32, 256>;

fn load_ref(x: i32, y: i32) -> u32 {
    let cx = u32(clamp(x, 0, i32(u.frame_width)  - 1));
    let cy = u32(clamp(y, 0, i32(u.frame_height) - 1));
    return ref_buf[cy * u.frame_width + cx];
}

fn load_cur(x: i32, y: i32) -> u32 {
    let cx = u32(clamp(x, 0, i32(u.frame_width)  - 1));
    let cy = u32(clamp(y, 0, i32(u.frame_height) - 1));
    return cur_buf[cy * u.frame_width + cx];
}

@compute @workgroup_size(16, 16)
fn block_match(
    @builtin(workgroup_id)          wgid:  vec3<u32>,
    @builtin(local_invocation_id)   lid:   vec3<u32>,
    @builtin(local_invocation_index) lidx: u32,
) {
    // Skip out-of-range workgroups (can happen when blocks_x * blocks_y is not
    // a multiple of the dispatch grid).
    if wgid.x >= u.blocks_x || wgid.y >= u.blocks_y {
        sad_cache[lidx] = 0xFFFFFFFFu;
        idx_cache[lidx]  = lidx;
        workgroupBarrier();
        return;
    }

    // Block origin in the current frame (pixel coords).
    let bx = i32(wgid.x * u.block_size);
    let by = i32(wgid.y * u.block_size);

    // This thread's candidate displacement, centred on the seed.
    let half = i32(u.search_half);
    let dx = i32(lid.x) - half + u.mv_seed_x;
    let dy = i32(lid.y) - half + u.mv_seed_y;

    // Compute SAD between current block at (bx, by) and the reference block
    // shifted by (dx, dy).
    var sad: u32 = 0u;
    let bs = i32(u.block_size);
    for (var ry: i32 = 0; ry < bs; ry = ry + 1) {
        for (var rx: i32 = 0; rx < bs; rx = rx + 1) {
            let cur_val = load_cur(bx + rx,      by + ry);
            let ref_val = load_ref(bx + rx + dx, by + ry + dy);
            let diff    = i32(cur_val) - i32(ref_val);
            sad = sad + u32(abs(diff));
        }
    }

    sad_cache[lidx] = sad;
    idx_cache[lidx]  = lidx;
    workgroupBarrier();

    // Parallel tree reduction (128 → 64 → 32 → 16 → 8 → 4 → 2 → 1).
    var stride: u32 = 128u;
    loop {
        if stride == 0u { break; }
        if lidx < stride {
            if sad_cache[lidx + stride] < sad_cache[lidx] {
                sad_cache[lidx] = sad_cache[lidx + stride];
                idx_cache[lidx]  = idx_cache[lidx + stride];
            }
        }
        workgroupBarrier();
        stride = stride >> 1u;
    }

    if lidx == 0u {
        let best    = idx_cache[0];
        let best_dx = i32(best % 16u) - half + u.mv_seed_x;
        let best_dy = i32(best / 16u) - half + u.mv_seed_y;
        let block_idx = i32(wgid.y * u.blocks_x + wgid.x);
        mv_out[block_idx] = vec4<i32>(best_dx, best_dy, i32(sad_cache[0]), 0);
    }
}