sift-wgpu 0.1.0

High-performance SIFT (Scale-Invariant Feature Transform) implementation in Rust with CPU and WebGPU backends.
Documentation
// extrema_detect.wgsl
// Detect local extrema (maxima/minima) in scale-space DoG pyramid
// Each extrema must be > or < all 26 neighbors (3x3x3 cube)

// ===== Bind Groups =====
// @group(0): metadata (read-only)
struct PyramidMetadata {
    octaves: u32,          // number of octaves
    dog_scales: u32,       // DoG layers per octave (scales-1 from Gaussian)
    usable_scales: u32,    // layers where we can detect extrema (dog_scales-2)
    base_width: u32,
    base_height: u32,
    sigma: f32,
    contrast_threshold: f32,
    edge_threshold: f32,
}

@group(0) @binding(0) var<storage, read> pyramid_meta: PyramidMetadata;
@group(0) @binding(1) var<storage, read> level_offsets: array<u32>;
@group(0) @binding(2) var<storage, read> level_widths: array<u32>;
@group(0) @binding(3) var<storage, read> level_heights: array<u32>;

// @group(1): heap (read-only)
@group(1) @binding(0) var<storage, read> heap: array<u32>;

// @group(2): output buffers
@group(2) @binding(0) var<storage, read_write> extrema_counter: atomic<u32>;
@group(2) @binding(1) var<storage, read_write> keypoints_staging: array<vec4<f32>>;
const MAX_STAGING_KEYPOINTS: u32 = 32768u;

// ===== Push Constants / Specialization =====
// For dispatch: each invocation processes (octave_idx, dog_layer_idx)
// Dispatch as: (ceil(width/16), ceil(height/16), octaves * dog_scales)

// ===== F16 Unpacking =====
fn read_pixel_f16(base_offset: u32, x: u32, y: u32, width: u32) -> f32 {
    let idx = y * width + x;
    let word_idx = idx >> 1u;
    let is_high = (idx & 1u) != 0u;
    let packed = heap[base_offset + word_idx];
    let unpacked = unpack2x16float(packed);
    return select(unpacked.x, unpacked.y, is_high);
}

// ===== Tile-Based Detection =====
// Workgroup: 16×16 threads process 16×16 center pixels
// Shared memory: 18×18 tile for 1-pixel halo around center 16×16
// Load 3 DoG layers (prev, current, next) into shared memory
var<workgroup> tile_prev: array<f32, 324>; // 18×18
var<workgroup> tile_curr: array<f32, 324>;
var<workgroup> tile_next: array<f32, 324>;
var<workgroup> tile_count: atomic<u32>;

@compute @workgroup_size(16, 16, 1)
fn detect_extrema(
    @builtin(global_invocation_id) global_id: vec3<u32>,
    @builtin(local_invocation_id) local_id: vec3<u32>,
    @builtin(workgroup_id) workgroup_id: vec3<u32>,
) {
    // workgroup_id.z encodes (octave, usable_layer)
    let octave_idx = workgroup_id.z / pyramid_meta.usable_scales;
    let usable_layer = workgroup_id.z % pyramid_meta.usable_scales;
    
    if (octave_idx >= pyramid_meta.octaves || usable_layer >= pyramid_meta.usable_scales) {
        return;
    }
    
    // Only first thread initializes tile counter
    if (local_id.x == 0u && local_id.y == 0u) {
        atomicStore(&tile_count, 0u);
    }
    workgroupBarrier();
    
    // Get dimensions for this octave (all DoG layers in an octave have same dimensions)
    let dog_scales = pyramid_meta.dog_scales;
    let level_w = level_widths[octave_idx * dog_scales];
    let level_h = level_heights[octave_idx * dog_scales];
    
    // Tile origin in output space (center 16×16 region)
    let tile_origin_x = workgroup_id.x * 16u;
    let tile_origin_y = workgroup_id.y * 16u;
    
    // Load 18×18 halo for 3 DoG layers
    // Each thread loads ~2 pixels (324 / 256 = 1.27)
    let thread_idx = local_id.y * 16u + local_id.x;
    for (var i = thread_idx; i < 324u; i += 256u) {
        let tile_y = i / 18u;
        let tile_x = i % 18u;
        
        // Map to global coordinates with halo (-1 to +16)
        let global_x = i32(tile_origin_x) + i32(tile_x) - 1;
        let global_y = i32(tile_origin_y) + i32(tile_y) - 1;
        
        // Clamp to valid range
        let clamped_x = clamp(global_x, 0, i32(level_w) - 1);
        let clamped_y = clamp(global_y, 0, i32(level_h) - 1);
        
        // Read from DoG pyramid (3 adjacent layers: usable_layer, usable_layer+1, usable_layer+2)
        // usable_layer is the MIDDLE layer we're checking for extrema
        let dog_prev_idx = octave_idx * dog_scales + usable_layer;
        let dog_curr_idx = dog_prev_idx + 1u;
        let dog_next_idx = dog_curr_idx + 1u;
        
        let offset_prev = level_offsets[dog_prev_idx];
        let offset_curr = level_offsets[dog_curr_idx];
        let offset_next = level_offsets[dog_next_idx];
        
        tile_prev[i] = read_pixel_f16(offset_prev, u32(clamped_x), u32(clamped_y), level_w);
        tile_curr[i] = read_pixel_f16(offset_curr, u32(clamped_x), u32(clamped_y), level_w);
        tile_next[i] = read_pixel_f16(offset_next, u32(clamped_x), u32(clamped_y), level_w);
    }
    workgroupBarrier();
    
    // Each thread checks 1 center pixel (16×16)
    let center_x = global_id.x;
    let center_y = global_id.y;
    
    if (center_x >= level_w || center_y >= level_h) {
        return;
    }
    
    // Map to tile coordinates (+1 offset for halo)
    let tile_x = local_id.x + 1u;
    let tile_y = local_id.y + 1u;
    let tile_idx = tile_y * 18u + tile_x;
    
    let center_val = tile_curr[tile_idx];
    
    // Quick contrast check
    if (abs(center_val) < pyramid_meta.contrast_threshold) {
        return;
    }
    
    // Check if extrema (max or min among 26 neighbors)
    var is_max = true;
    var is_min = true;
    
    for (var dz = -1; dz <= 1; dz++) {
        for (var dy = -1; dy <= 1; dy++) {
            for (var dx = -1; dx <= 1; dx++) {
                if (dx == 0 && dy == 0 && dz == 0) {
                    continue;
                }
                
                let neighbor_tile_x = i32(tile_x) + dx;
                let neighbor_tile_y = i32(tile_y) + dy;
                let neighbor_tile_idx = u32(neighbor_tile_y * 18 + neighbor_tile_x);
                
                var neighbor_val: f32;
                if (dz == -1) {
                    neighbor_val = tile_prev[neighbor_tile_idx];
                } else if (dz == 0) {
                    neighbor_val = tile_curr[neighbor_tile_idx];
                } else {
                    neighbor_val = tile_next[neighbor_tile_idx];
                }
                
                if (center_val <= neighbor_val) {
                    is_max = false;
                }
                if (center_val >= neighbor_val) {
                    is_min = false;
                }
            }
        }
    }
    
    if (!is_max && !is_min) {
        return;
    }
    
    // Edge suppression via Hessian ratio
    // Compute Hxx, Hyy, Hxy from current layer
    let idx_00 = tile_idx;
    let idx_m10 = tile_idx - 1u;
    let idx_p10 = tile_idx + 1u;
    let idx_0m1 = tile_idx - 18u;
    let idx_0p1 = tile_idx + 18u;
    
    let val_00 = tile_curr[idx_00];
    let val_m10 = tile_curr[idx_m10];
    let val_p10 = tile_curr[idx_p10];
    let val_0m1 = tile_curr[idx_0m1];
    let val_0p1 = tile_curr[idx_0p1];
    
    let Dxx = val_p10 + val_m10 - 2.0 * val_00;
    let Dyy = val_0p1 + val_0m1 - 2.0 * val_00;
    let Dxy = (tile_curr[idx_0p1 + 1u] - tile_curr[idx_0p1 - 1u] 
              - tile_curr[idx_0m1 + 1u] + tile_curr[idx_0m1 - 1u]) * 0.25;
    
    let trace = Dxx + Dyy;
    let det = Dxx * Dyy - Dxy * Dxy;
    
    if (det <= 0.0) {
        return; // not a well-defined extremum
    }
    
    let ratio = (trace * trace) / det;
    let r_threshold = (pyramid_meta.edge_threshold + 1.0) * (pyramid_meta.edge_threshold + 1.0) / pyramid_meta.edge_threshold;
    
    if (ratio > r_threshold) {
        return; // edge-like, reject
    }
    
    // Passed all tests, add to staging buffer
    let slot = atomicAdd(&extrema_counter, 1u);
    if (slot < MAX_STAGING_KEYPOINTS) {
        // Store (x, y, octave, scale_within_octave)
        // usable_layer + 1 is the middle layer index we're detecting extrema at
        let scale_in_octave = f32(usable_layer) + 1.0;
        keypoints_staging[slot] = vec4<f32>(
            f32(center_x),
            f32(center_y),
            f32(octave_idx),
            scale_in_octave
        );
    }
}