kangaroo 0.1.0

Pollard's Kangaroo ECDLP solver for secp256k1 using Vulkan/Metal/DX12 compute
Documentation
#!/usr/bin/env python3
"""Analyze public key patterns for private keys 0-10000."""

from collections import Counter
from dataclasses import dataclass

# secp256k1 parameters
P = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEFFFFFC2F
N = 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFEBAAEDCE6AF48A03BBFD25E8CD0364141
Gx = 0x79BE667EF9DCBBAC55A06295CE870B07029BFCDB2DCE28D959F2815B16F81798
Gy = 0x483ADA7726A3C4655DA4FBFC0E1108A8FD17B448A68554199C47D08FFB10D4B8


def mod_inverse(a: int, m: int) -> int:
    """Extended Euclidean Algorithm for modular inverse."""
    if a < 0:
        a = a % m
    g, x, _ = extended_gcd(a, m)
    if g != 1:
        raise ValueError("No inverse")
    return x % m


def extended_gcd(a: int, b: int):
    if a == 0:
        return b, 0, 1
    gcd, x1, y1 = extended_gcd(b % a, a)
    x = y1 - (b // a) * x1
    y = x1
    return gcd, x, y


def point_add(p1, p2):
    """Add two points on secp256k1."""
    if p1 is None:
        return p2
    if p2 is None:
        return p1

    x1, y1 = p1
    x2, y2 = p2

    if x1 == x2:
        if y1 != y2:
            return None  # Point at infinity
        # Point doubling
        s = (3 * x1 * x1 * mod_inverse(2 * y1, P)) % P
    else:
        s = ((y2 - y1) * mod_inverse(x2 - x1, P)) % P

    x3 = (s * s - x1 - x2) % P
    y3 = (s * (x1 - x3) - y1) % P
    return (x3, y3)


def scalar_mult(k: int, point):
    """Multiply point by scalar k."""
    if k == 0:
        return None
    result = None
    addend = point
    while k:
        if k & 1:
            result = point_add(result, addend)
        addend = point_add(addend, addend)
        k >>= 1
    return result


def get_pubkey(privkey: int) -> tuple:
    """Get public key point for private key."""
    if privkey == 0:
        return None
    return scalar_mult(privkey, (Gx, Gy))


def pubkey_to_compressed(point) -> str:
    """Convert point to compressed public key hex."""
    if point is None:
        return "00" + "00" * 32
    x, y = point
    prefix = "02" if y % 2 == 0 else "03"
    return prefix + format(x, "064x")


def pubkey_to_uncompressed(point) -> str:
    """Convert point to uncompressed public key hex."""
    if point is None:
        return "04" + "00" * 64
    x, y = point
    return "04" + format(x, "064x") + format(y, "064x")


@dataclass
class KeyData:
    privkey: int
    x: int
    y: int
    compressed: str
    uncompressed: str


def generate_keys(start: int, end: int) -> list[KeyData]:
    """Generate keys for range."""
    keys = []
    G = (Gx, Gy)

    # Use incremental addition for speed
    current_point = None if start == 0 else scalar_mult(start, G)

    for k in range(start, end + 1):
        if k == 0:
            keys.append(KeyData(0, 0, 0, "00" * 33, "00" * 65))
            current_point = None
        elif k == 1:
            current_point = G
            keys.append(
                KeyData(
                    k,
                    Gx,
                    Gy,
                    pubkey_to_compressed(current_point),
                    pubkey_to_uncompressed(current_point),
                )
            )
        else:
            current_point = point_add(current_point, G)
            keys.append(
                KeyData(
                    k,
                    current_point[0],
                    current_point[1],
                    pubkey_to_compressed(current_point),
                    pubkey_to_uncompressed(current_point),
                )
            )

        if k % 1000 == 0:
            print(f"Generated {k} keys...")

    return keys


def analyze_prefix_patterns(keys: list[KeyData]):
    """Analyze prefix patterns in compressed keys."""
    print("\n" + "=" * 60)
    print("PREFIX ANALYSIS (Compressed Keys)")
    print("=" * 60)

    # Count 02 vs 03 prefixes
    prefix_02 = sum(1 for k in keys if k.compressed.startswith("02"))
    prefix_03 = sum(1 for k in keys if k.compressed.startswith("03"))
    print("\nPrefix distribution:")
    print(f"  02 (even Y): {prefix_02} ({100 * prefix_02 / len(keys):.1f}%)")
    print(f"  03 (odd Y):  {prefix_03} ({100 * prefix_03 / len(keys):.1f}%)")

    # Analyze first bytes after prefix
    first_byte = Counter()
    for k in keys[1:]:  # Skip 0
        first_byte[k.compressed[2:4]] += 1

    print("\nFirst byte after prefix (top 10):")
    for byte, count in first_byte.most_common(10):
        print(f"  0x{byte}: {count} ({100 * count / len(keys):.2f}%)")

    # Check for repeating prefixes
    prefix_4 = Counter(k.compressed[:4] for k in keys[1:])
    prefix_6 = Counter(k.compressed[:6] for k in keys[1:])
    prefix_8 = Counter(k.compressed[:8] for k in keys[1:])

    print("\nUnique prefixes:")
    print(f"  4 chars (prefix+1byte): {len(prefix_4)} unique")
    print(f"  6 chars (prefix+2bytes): {len(prefix_6)} unique")
    print(f"  8 chars (prefix+3bytes): {len(prefix_8)} unique")


def analyze_x_coordinate(keys: list[KeyData]):
    """Analyze X coordinate patterns."""
    print("\n" + "=" * 60)
    print("X COORDINATE ANALYSIS")
    print("=" * 60)

    # Distribution of high bits
    high_byte = Counter()
    for k in keys[1:]:
        high_byte[k.x >> 248] += 1  # Top 8 bits

    print("\nTop byte of X distribution (showing non-zero):")
    for byte, count in sorted(high_byte.items()):
        if count > 0:
            print(f"  0x{byte:02x}: {count}")

    # Check differences between consecutive keys
    print("\nDifferences between consecutive X values (sample):")
    for i in range(1, min(11, len(keys))):
        if keys[i].x and keys[i - 1].x:
            diff = keys[i].x - keys[i - 1].x
            print(f"  key[{i}] - key[{i - 1}] = {diff:+d}")

    # Look for X coordinates that appear multiple times (collisions)
    x_values = [k.x for k in keys[1:]]
    x_counts = Counter(x_values)
    collisions = [(x, c) for x, c in x_counts.items() if c > 1]
    print(f"\nX coordinate collisions: {len(collisions)}")
    if collisions:
        for x, c in collisions[:5]:
            matching = [k.privkey for k in keys if k.x == x]
            print(f"  X=0x{x:064x}... appears {c} times: keys {matching}")


def analyze_y_parity(keys: list[KeyData]):
    """Analyze Y coordinate parity patterns."""
    print("\n" + "=" * 60)
    print("Y PARITY ANALYSIS")
    print("=" * 60)

    parities = [k.y % 2 for k in keys[1:]]

    # Look for runs
    runs = []
    current_run = 1
    for i in range(1, len(parities)):
        if parities[i] == parities[i - 1]:
            current_run += 1
        else:
            runs.append(current_run)
            current_run = 1
    runs.append(current_run)

    print("Parity run statistics:")
    print(f"  Max run length: {max(runs)}")
    print(f"  Average run length: {sum(runs) / len(runs):.2f}")
    print(f"  Total runs: {len(runs)}")

    # Look for patterns in parity sequence
    print("\nFirst 100 parities (0=even, 1=odd):")
    parity_str = "".join(str(p) for p in parities[:100])
    print(f"  {parity_str}")

    # Check for period
    print("\nChecking for periodicity...")
    for period in [2, 3, 4, 5, 6, 7, 8, 16, 32, 64, 128, 256]:
        matches = sum(
            1
            for i in range(len(parities) - period)
            if parities[i] == parities[i + period]
        )
        expected = len(parities) - period
        ratio = matches / expected
        if abs(ratio - 0.5) > 0.1:  # Significant deviation from random
            print(f"  Period {period}: {ratio:.3f} match ratio (0.5 = random)")


def analyze_bit_distribution(keys: list[KeyData]):
    """Analyze bit distribution in X coordinates."""
    print("\n" + "=" * 60)
    print("BIT DISTRIBUTION ANALYSIS")
    print("=" * 60)

    # Count how often each bit is 1
    bit_counts = [0] * 256
    for k in keys[1:]:
        x = k.x
        for bit in range(256):
            if x & (1 << bit):
                bit_counts[bit] += 1

    total = len(keys) - 1

    print("\nBit frequencies (showing deviations from 50%):")
    significant = []
    for bit in range(256):
        freq = bit_counts[bit] / total
        if abs(freq - 0.5) > 0.05:  # More than 5% deviation
            significant.append((bit, freq))

    if significant:
        for bit, freq in sorted(significant, key=lambda x: -abs(x[1] - 0.5))[:20]:
            print(
                f"  Bit {bit:3d}: {freq:.3f} ({'+' if freq > 0.5 else ''}{(freq - 0.5) * 100:.1f}%)"
            )
    else:
        print("  No significant deviations found (all bits within 5% of 50%)")


def analyze_modular_patterns(keys: list[KeyData]):
    """Analyze modular arithmetic patterns."""
    print("\n" + "=" * 60)
    print("MODULAR PATTERNS")
    print("=" * 60)

    # Check X mod small primes
    small_primes = [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31]

    for p in small_primes:
        residues = Counter(k.x % p for k in keys[1:])
        expected = len(keys) // p
        max_dev = max(abs(c - expected) for c in residues.values())
        if max_dev > expected * 0.2:  # 20% deviation
            print(f"\n  X mod {p} distribution (non-uniform!):")
            for r in range(p):
                print(f"    {r}: {residues[r]}")


def analyze_consecutive_relationships(keys: list[KeyData]):
    """Analyze relationships between consecutive keys."""
    print("\n" + "=" * 60)
    print("CONSECUTIVE KEY RELATIONSHIPS")
    print("=" * 60)

    # For consecutive private keys k and k+1:
    # P(k+1) = P(k) + G
    # So X(k+1) is related to X(k) through point addition

    # Check XOR patterns
    xor_high_byte = Counter()
    for i in range(1, len(keys) - 1):
        xor = keys[i].x ^ keys[i + 1].x
        xor_high_byte[xor >> 248] += 1

    print("\nXOR of consecutive X coordinates (top byte):")
    for byte, count in xor_high_byte.most_common(10):
        print(f"  0x{byte:02x}: {count}")

    # Check if any private key patterns map to public key patterns
    print("\nPrivate key → Public key prefix mapping (sample):")
    for k in keys[1:20]:
        print(f"  {k.privkey:5d}{k.compressed[:10]}...")


def main():
    print("=" * 60)
    print("BITCOIN PUBLIC KEY PATTERN ANALYSIS")
    print("Private keys: 0 to 10000")
    print("=" * 60)

    print("\nGenerating keys...")
    keys = generate_keys(0, 10000)

    print(f"\nGenerated {len(keys)} keys")
    print("\nSample keys:")
    for k in keys[:5]:
        print(f"  {k.privkey}: {k.compressed[:20]}...")
    print("  ...")
    for k in keys[-3:]:
        print(f"  {k.privkey}: {k.compressed[:20]}...")

    # Run analyses
    analyze_prefix_patterns(keys)
    analyze_x_coordinate(keys)
    analyze_y_parity(keys)
    analyze_bit_distribution(keys)
    analyze_modular_patterns(keys)
    analyze_consecutive_relationships(keys)

    print("\n" + "=" * 60)
    print("ANALYSIS COMPLETE")
    print("=" * 60)


if __name__ == "__main__":
    main()