kangaroo 0.1.0

Pollard's Kangaroo ECDLP solver for secp256k1 using Vulkan/Metal/DX12 compute
Documentation
import csv
import math
import os
import random
import time
from dataclasses import dataclass

import yaml

# Import lib for secp256k1 if needed, but for simulation we might mock it
# to ensure the benchmark completes in reasonable time for testing logic.
# For "real" benchmark, we'd use the actual curve.


# Mocking secp256k1 for speed in this simulation if 'simulated' is chosen
# or using a simple integer group.
class Group:
    def add(self, p1, p2):
        raise NotImplementedError

    def sub(self, p1, p2):
        raise NotImplementedError

    def mul(self, p, s):
        raise NotImplementedError

    def to_int(self, p):
        raise NotImplementedError

    def from_int(self, i):
        raise NotImplementedError


class IntegerGroup(Group):
    def __init__(self, modulus):
        self.modulus = modulus

    def add(self, p1, p2):
        return (p1 + p2) % self.modulus

    def sub(self, p1, p2):
        return (p1 - p2) % self.modulus

    def mul(self, p, s):
        return (p * s) % self.modulus

    def to_int(self, p):
        return p

    def from_int(self, i):
        return i % self.modulus


# Simple secp256k1 implementation for correctness (slow in Python)
# In a real benchmark, we'd use 'coincurve' or 'secp256k1' package.
try:
    import coincurve

    HAS_COINCURVE = True
except ImportError:
    HAS_COINCURVE = False


@dataclass
class RunResult:
    ts: float
    strategy: str
    N: int
    t_dp: int
    m: int
    offset_mode: str
    seed: int
    steps: int
    time_ms: float
    dpoints: int
    success: bool
    reason: str


class KangarooSimulator:
    def __init__(self, strategy_config: dict, range_config: dict, run_seed: int):
        self.config = strategy_config
        self.range_config = range_config
        self.seed = run_seed
        random.seed(run_seed)

        self.N = range_config["N"]
        self.buckets = strategy_config["buckets"]

        # Auto-adjust dp_bits for test range if needed
        # We want approx sqrt(N)/20 DPs per run? Or 1 DP every 32 steps for small N?
        if self.N < 10000000:  # Small range
            # Scale dp_bits so that prob is approx 1/32
            self.dp_mask = (1 << 5) - 1
        else:
            self.dp_mask = (1 << strategy_config["dp_bits"]) - 1

        # Setup group
        if range_config.get("generator") == "simulated":
            # Use a cyclic group of size roughly N to simulate collision probability
            # of birthday paradox in space N.
            # We add a small prime offset to N to avoid power-of-2 resonance if N is 2^k
            self.modulus = self.N + 37
            self.group = IntegerGroup(self.modulus)
            self.base = 1
            self.target_key = random.randint(0, self.N)  # Key in range [0, N]
            self.target_pub = self.group.mul(self.base, self.target_key)
        else:
            # Fallback to integer group if no crypto lib
            self.modulus = self.N + 37
            self.group = IntegerGroup(self.modulus)
            self.base = 1
            self.target_key = random.randint(0, self.N)
            self.target_pub = self.group.mul(self.base, self.target_key)

        self.setup_steps()

    def setup_steps(self):
        dist = self.config["step_distribution"]
        self.step_table = []

        # Optimal mean step size is roughly sqrt(N) * constant?
        # Actually for Kangaroo, we want steps roughly proportional to sqrt(W) usually,
        # but here let's target a mean step size of roughly sqrt(N) / 2 to ensure we don't jump too fast
        # but also don't crawl.
        target_mean = max(1, int(math.sqrt(self.N) * 0.5))

        if dist == "power_of_2":
            # Generate powers of 2 but scaled to match target_mean
            # S1: 2^0 ... 2^k
            # We want mean(2^0...2^k) ~= target_mean
            # 2^k ~= target_mean * (k+1) roughly?
            # Let's just clamp the powers to be reasonable for N
            max_pow = int(math.log2(self.N)) - 2  # Don't exceed N/4 per step
            max_pow = max(4, max_pow)

            # Use buckets to pick from 2^0 to 2^max_pow
            for i in range(self.buckets):
                p = i % max_pow
                step = 1 << p
                self.step_table.append(step)

        elif dist == "fnv1a":
            for i in range(self.buckets):
                h = 0x811C9DC5
                h = (h ^ i) * 0x01000193
                h = h & 0xFFFFFFFF
                # Map hash to a step size around target_mean
                # Uniform distribution [1, 2*target_mean]
                step = (h % (target_mean * 2)) + 1
                self.step_table.append(step)

        elif dist == "large_drift":
            # Large steps: 1.5 * target_mean
            for i in range(self.buckets):
                self.step_table.append(int(target_mean * 1.5))

        elif dist == "heavy_tail":
            # 80% small (0.5 * mean), 20% large (3 * mean)
            for i in range(self.buckets):
                if i < self.buckets * 0.8:
                    self.step_table.append(max(1, int(target_mean * 0.5)))
                else:
                    self.step_table.append(int(target_mean * 3))

        elif dist == "rotating":
            # Similar to FNV but we will rotate this table in the run loop?
            # For setup, just init like FNV
            for i in range(self.buckets):
                h = 0x811C9DC5
                h = (h ^ i) * 0x01000193
                h = h & 0xFFFFFFFF
                step = (h % (target_mean * 2)) + 1
                self.step_table.append(step)
        else:
            # Default fallback
            for i in range(self.buckets):
                self.step_table.append(1 << (i % 10))

        # Debug: print stats of table for first run
        if self.seed == 0:
            avg = sum(self.step_table) / len(self.step_table)
            print(
                f"DEBUG [{self.config['name']}] Avg Step: {avg:.1f} (Target: {target_mean})"
            )

    def get_start_offset(self, kangaroo_idx, total_kangaroos):
        mode = self.config["offset_mode"]
        if mode == "random":
            return random.randint(0, self.N)
        elif mode == "grid":
            delta = int(math.sqrt(self.N) / self.config.get("grid_delta_divisor", 64))
            return (kangaroo_idx * delta) % self.N
        elif mode == "centered":
            # Start near middle with some noise
            mid = self.N // 2
            spread = self.N // 10
            return mid + random.randint(-spread, spread)
        return random.randint(0, self.N)

    def run(self, max_steps=1000000) -> RunResult:
        start_time = time.time()

        # Tame kangaroo (starts at known distance from base + offset?)
        # Standard Kangaroo:
        # Tame: starts at b + w_i... knows discrete log relative to b?
        # Actually: Tame starts at 'trap' (known dlog), Wild starts at 'x' (unknown)
        # We need two herds or 1 tame 1 wild.
        # Let's simulate 1 tame 1 wild for simplicity in this proto.

        tame_start = self.N // 2  # Fixed known point for Tame
        wild_start = self.target_key  # Unknown x

        tame_pos = self.group.mul(self.base, tame_start)
        wild_pos = self.target_pub

        tame_dist = tame_start
        wild_dist = 0  # Relative to x

        dp_map = {}  # Map DP value -> (type, distance)

        steps = 0
        dpoints = 0
        success = False
        reason = "timeout"

        # Simple loop
        while steps < max_steps:
            steps += 1

            # Tame Step
            t_int = self.group.to_int(tame_pos)
            t_idx = t_int % self.buckets
            t_step = self.step_table[t_idx]
            tame_pos = self.group.add(tame_pos, self.group.from_int(t_step))
            tame_dist = tame_dist + t_step  # % N? Order usually unknown

            # Check Tame DP
            if (self.group.to_int(tame_pos) & self.dp_mask) == 0:
                dpoints += 1
                key = self.group.to_int(tame_pos)
                if key in dp_map:
                    entry_type, entry_dist = dp_map[key]
                    if entry_type == "wild":
                        # Collision!
                        # tame_dist = x + wild_dist
                        # x = tame_dist - wild_dist
                        found_x = (tame_dist - entry_dist) % self.modulus
                        if found_x == self.target_key:
                            success = True
                            reason = "collision_tame_hits_wild"
                            break
                else:
                    dp_map[key] = ("tame", tame_dist)

            # Wild Step
            w_int = self.group.to_int(wild_pos)
            w_idx = w_int % self.buckets
            w_step = self.step_table[w_idx]
            wild_pos = self.group.add(wild_pos, self.group.from_int(w_step))
            wild_dist = wild_dist + w_step

            # Check Wild DP
            if (self.group.to_int(wild_pos) & self.dp_mask) == 0:
                dpoints += 1
                key = self.group.to_int(wild_pos)
                if key in dp_map:
                    entry_type, entry_dist = dp_map[key]
                    if entry_type == "tame":
                        # Collision!
                        # tame_dist = x + wild_dist
                        found_x = (entry_dist - wild_dist) % self.modulus
                        if found_x == self.target_key:
                            success = True
                            reason = "collision_wild_hits_tame"
                            break
                else:
                    dp_map[key] = ("wild", wild_dist)

        end_time = time.time()

        return RunResult(
            ts=start_time,
            strategy=self.config.get("name", "custom"),
            N=self.N,
            t_dp=self.config["dp_bits"],
            m=self.buckets,
            offset_mode=self.config["offset_mode"],
            seed=self.seed,
            steps=steps,
            time_ms=(end_time - start_time) * 1000,
            dpoints=dpoints,
            success=success,
            reason=reason,
        )


def load_config():
    base_dir = os.path.dirname(os.path.abspath(__file__))
    with open(os.path.join(base_dir, "presets.yml")) as f:
        presets = yaml.safe_load(f)
    with open(os.path.join(base_dir, "ranges.yml")) as f:
        ranges = yaml.safe_load(f)
    return presets, ranges


def main():
    presets, ranges = load_config()

    # Check output file
    base_dir = os.path.dirname(os.path.abspath(__file__))
    output_file = os.path.join(base_dir, "benchmark_results.csv")
    file_exists = os.path.isfile(output_file)

    with open(output_file, "a", newline="") as csvfile:
        fieldnames = [
            "ts",
            "strategy",
            "N",
            "t_dp",
            "m",
            "offset_mode",
            "seed",
            "steps",
            "time_ms",
            "dpoints",
            "success",
            "reason",
        ]
        writer = csv.DictWriter(csvfile, fieldnames=fieldnames)

        if not file_exists:
            writer.writeheader()

        print(f"Starting benchmark... Output to {output_file}")

        # Run loop
        # Use 'test' range if available for quick check, otherwise small
        range_conf = next(
            (r for r in ranges["ranges"] if r["name"] == "test"), ranges["ranges"][0]
        )
        print(f"Selected range: {range_conf['name']} (N={range_conf['N']})")

        limit_steps = int(4 * math.sqrt(range_conf["N"]))

        for strategy_name, strategy_conf in presets["strategies"].items():
            strategy_conf["name"] = strategy_name
            print(f"Running strategy {strategy_name} on {range_conf['name']} range...")

            for i in range(20):  # 20 runs per strategy for better stats
                sim = KangarooSimulator(strategy_conf, range_conf, run_seed=i)
                result = sim.run(max_steps=limit_steps)

                row = {
                    "ts": result.ts,
                    "strategy": result.strategy,
                    "N": result.N,
                    "t_dp": result.t_dp,
                    "m": result.m,
                    "offset_mode": result.offset_mode,
                    "seed": result.seed,
                    "steps": result.steps,
                    "time_ms": result.time_ms,
                    "dpoints": result.dpoints,
                    "success": result.success,
                    "reason": result.reason,
                }
                writer.writerow(row)
                csvfile.flush()

                status = "SUCCESS" if result.success else "FAIL"
                print(
                    f"  Run {i}: {status} in {result.steps} steps ({result.time_ms:.2f}ms)"
                )


if __name__ == "__main__":
    main()