kangaroo 0.1.0

Pollard's Kangaroo ECDLP solver for secp256k1 using Vulkan/Metal/DX12 compute
Documentation
// =============================================================================
// Pollard's Kangaroo Algorithm - GPU Kernel (Jacobian Coordinates)
// =============================================================================
// Uses Jacobian coordinates for fast point addition without modular inverse
// Optimized with Workgroup Batch Inversion (Montgomery)

// -----------------------------------------------------------------------------
// Configuration
// -----------------------------------------------------------------------------

struct Config {
    dp_mask_lo: vec4<u32>,
    dp_mask_hi: vec4<u32>,
    num_kangaroos: u32,
    steps_per_call: u32,
    jump_table_size: u32,
    _padding: u32
}

// Kangaroo state - uses Jacobian coordinates
// Must match Rust GpuKangaroo struct layout!
struct Kangaroo {
    x: array<u32, 8>,           // Jacobian X (32 bytes)
    y: array<u32, 8>,           // Jacobian Y (32 bytes)
    z: array<u32, 8>,           // Jacobian Z (32 bytes)
    dist: array<u32, 8>,        // Distance traveled (32 bytes)
    ktype: u32,                 // 0 = Tame, 1 = Wild
    is_active: u32,
    _padding: array<u32, 2>     // Align to 16 bytes
}

// AffinePoint is defined in curve_jacobian.wgsl

struct DistinguishedPoint {
    x: array<u32, 8>,           // Jacobian X
    z: array<u32, 8>,           // Jacobian Z (for affine X = X/Z²)
    dist: array<u32, 8>,
    ktype: u32,
    kangaroo_id: u32,
}

// -----------------------------------------------------------------------------
// Buffers
// -----------------------------------------------------------------------------

@group(0) @binding(0) var<uniform> config: Config;
@group(0) @binding(1) var<storage, read> jump_points: array<AffinePoint, 256>;
@group(0) @binding(2) var<storage, read> jump_distances: array<array<u32, 8>, 256>;
@group(0) @binding(3) var<storage, read_write> kangaroos: array<Kangaroo>;
@group(0) @binding(4) var<storage, read_write> dp_buffer: array<DistinguishedPoint>;
@group(0) @binding(5) var<storage, read_write> dp_count: atomic<u32>;

// Shared memory for batch inversion
var<workgroup> shared_z: array<array<u32, 8>, 64>;
var<workgroup> shared_prod: array<array<u32, 8>, 64>;

// -----------------------------------------------------------------------------
// Store distinguished point (convert Jacobian X to affine approximation)
// Real implementation would compute proper affine conversion
// -----------------------------------------------------------------------------

fn store_dp(k: Kangaroo, kangaroo_id: u32) {
    let idx = atomicAdd(&dp_count, 1u);

    if (idx < 65536u) {
        var dp: DistinguishedPoint;
        // Store Jacobian X and Z (CPU will verify using X1*Z2² == X2*Z1²)
        dp.x = k.x;
        dp.z = k.z;
        dp.dist = k.dist;
        dp.ktype = k.ktype;
        dp.kangaroo_id = kangaroo_id;
        dp_buffer[idx] = dp;
    }
}

// -----------------------------------------------------------------------------
// Main compute shader
// -----------------------------------------------------------------------------

@compute @workgroup_size(64)
fn main(@builtin(global_invocation_id) global_id: vec3<u32>, @builtin(local_invocation_id) local_id_vec: vec3<u32>) {
    let kid = global_id.x;
    let lid = local_id_vec.x;

    // Load kangaroo state (if valid)
    var k: Kangaroo;
    var valid = false;
    if (kid < config.num_kangaroos) {
        k = kangaroos[kid];
        if (k.is_active != 0u) {
            valid = true;
        }
    }

    // Build Jacobian point from state
    var p: JacobianPoint;
    if (valid) {
        p.x = k.x;
        p.y = k.y;
        p.z = k.z;
    } else {
        // Initialize dummy point for inactive threads to avoid messing up batch inversion
        p.x = fe_one();
        p.y = fe_one();
        p.z = fe_one(); // Z=1 ensures safe inversion and no-op
    }

    // Track if we already stored a DP this batch (to avoid flooding)
    var dp_stored = false;

    // Perform jumps
    for (var step = 0u; step < config.steps_per_call; step++) {
        // =====================================================================
        // BATCH NORMALIZATION (Simultaneous Inversion)
        // =====================================================================
        
        // 1. Load Z into shared memory
        shared_z[lid] = p.z;
        workgroupBarrier();

        // 2. Thread 0 computes prefix products and total inverse
        // Using serial implementation for simplicity and robustness
        if (lid == 0u) {
            var prod = fe_one();
            
            // Compute prefix products: shared_prod[i] = z_0 * ... * z_i
            for (var i = 0u; i < 64u; i++) {
                prod = fe_mul(prod, shared_z[i]);
                shared_prod[i] = prod;
            }

            // Invert total product
            var inv = fe_inv(prod);

            // Compute inverses backwards
            // inv_z[i] = inv_total * L[i-1] * R[i+1] (implicitly handled)
            // Algorithm:
            // inv_acc = inv_all
            // for i = N-1 to 1:
            //   inv_z[i] = inv_acc * L[i-1]
            //   inv_acc = inv_acc * z[i]
            // inv_z[0] = inv_acc
            
            var inv_acc = inv;
            for (var i = 63u; i > 0u; i--) {
                let prev_prod = shared_prod[i - 1u];
                let val_inv = fe_mul(inv_acc, prev_prod);
                let val_z = shared_z[i];
                
                // Store result back in shared_z to save space
                // We are done with z[i] after this step
                // CAREFUL: We need shared_z[i] for the update of inv_acc
                
                // We need a temporary for z[i] because we are overwriting shared_z[i]
                // Actually we can just write to shared_prod? No, shared_prod used L[i-1].
                // We can overwrite shared_prod[i] with the result inverse!
                shared_prod[i] = val_inv;
                
                inv_acc = fe_mul(inv_acc, val_z);
            }
            shared_prod[0] = inv_acc;
        }
        workgroupBarrier();

        // 3. Read inverse Z from shared memory
        let z_inv = shared_prod[lid];
        let z_inv2 = fe_square(z_inv);
        
        // 4. Normalize X (and Y if needed, but we only need X for jump)
        // Only normalize if not already normalized
        // For batch inversion, we always run it, but if Z was 1, z_inv is 1.
        
        var x_affine: array<u32, 8>;
        
        // Check if Z was 1 (optimization)
        // Actually, we can just do the mul.
        x_affine = fe_mul(p.x, z_inv2);
        
        // Update P to affine (Optional? We need it for DP check and Jump)
        // Yes, we must update P to affine to ensure paths merge!
        // Otherwise we have (X_af, Y_jac, 1) vs (X_jac, Y_jac, Z).
        // We need to reset Z to 1.
        
        let z_inv3 = fe_mul(z_inv2, z_inv);
        p.x = x_affine;
        p.y = fe_mul(p.y, z_inv3);
        p.z = fe_one(); // Z is now 1

        // =====================================================================
        // LOGIC
        // =====================================================================

        if (valid) {
            // Check for DP
            if (!dp_stored) {
                // Check first limb against mask
                if ((p.x[0] & config.dp_mask_lo.x) == 0u) {
                    k.x = p.x;
                    k.y = p.y;
                    k.z = p.z;
                    store_dp(k, kid);
                    dp_stored = true;
                }
            }

            // Select jump
            let jump_idx = p.x[0] & 0xFFu;
            let jump_point = jump_points[jump_idx];
            let jump_dist = jump_distances[jump_idx];

            // Add jump point (affine)
            p = jac_add_affine(p, jump_point.x, jump_point.y);

            // Update distance
            k.dist = scalar_add_256(k.dist, jump_dist);
        }
    }

    // Write back updated state
    if (valid) {
        k.x = p.x;
        k.y = p.y;
        k.z = p.z;
        kangaroos[kid] = k;
    }
}