clock-curve-math 0.5.0

High-performance, constant-time, cryptography-grade number theory library for ClockCurve ecosystem
Documentation
#!/usr/bin/env python3
"""
Debug script that mimics the exact Rust Montgomery reduction implementation.
"""

from sympy import *

def to_limbs(value, num_limbs=4, limb_size=64):
    """Convert a value to little-endian limbs like Rust BigInt."""
    limbs = []
    for i in range(num_limbs):
        limb = (value >> (i * limb_size)) & ((1 << limb_size) - 1)
        limbs.append(limb)
    return limbs

def from_limbs(limbs, limb_size=64):
    """Convert little-endian limbs back to value."""
    value = 0
    for i, limb in enumerate(limbs):
        value |= limb << (i * limb_size)
    return value

def debug_rust_montgomery():
    # Define the parameters
    L = 2**252 + 27742317777372353535851937790883648493  # scalar modulus
    R = 2**256  # Montgomery radix

    # Precomputed constants as limbs (little-endian)
    R_SQUARED_MOD_L_LIMBS = [0xa40611e3449c0f01, 0xd00e1ba768859347, 0xceec73d217f5be65, 0x0399411b7c309a3d]
    N_PRIME_L_LIMBS = [0xd2b51da312547e1b, 0xb1a206f2fdba84ff, 0x14e75438ffa36bea, 0x9db6c6f26fe91836]

    print("=== Rust-style Montgomery Reduction Debug ===")
    print(f"L limbs: {to_limbs(L)}")
    print(f"R_SQUARED_MOD_L limbs: {R_SQUARED_MOD_L_LIMBS}")
    print(f"N_PRIME_L limbs: {N_PRIME_L_LIMBS}")

    # Test with failing value: 10000
    test_value = 10000
    print(f"\nTesting with value = {test_value}")
    print(f"Value limbs: {to_limbs(test_value)}")

    # Step 1: to_montgomery - compute (value * R²) * R^(-1) mod L
    # In Rust: mont_value = montgomery_reduce(value * r_squared, n, n_prime)

    # First compute value * r_squared (this gives a 512-bit result)
    value_bigint = test_value
    r_squared_bigint = from_limbs(R_SQUARED_MOD_L_LIMBS)
    n_bigint = L
    n_prime_bigint = from_limbs(N_PRIME_L_LIMBS)

    print(f"value = {value_bigint}")
    print(f"r_squared = {r_squared_bigint}")
    print(f"n = {n_bigint}")
    print(f"n_prime = {n_prime_bigint}")

    # Compute product = value * r_squared (512 bits)
    product = value_bigint * r_squared_bigint
    print(f"product = value * r_squared = {product}")
    print(f"product limbs (low 256): {to_limbs(product & ((1<<256)-1))}")
    print(f"product limbs (high 256): {to_limbs(product >> 256)}")

    # Apply Montgomery reduction to product
    result = rust_montgomery_reduce_with_high(
        to_limbs(product & ((1<<256)-1)),  # t_low
        to_limbs(product >> 256),          # t_high
        to_limbs(L),                       # n
        to_limbs(from_limbs(N_PRIME_L_LIMBS))  # n_prime
    )

    mont_value = from_limbs(result)
    print(f"Montgomery value: {mont_value}")
    print(f"Expected: {(test_value * R) % L}")

    # Step 2: from_montgomery - apply REDC again
    recovered = rust_montgomery_reduce_with_high(
        result,        # t_low (mont_value)
        [0, 0, 0, 0], # t_high (zero)
        to_limbs(L),   # n
        to_limbs(from_limbs(N_PRIME_L_LIMBS))  # n_prime
    )

    recovered_value = from_limbs(recovered)
    print(f"Recovered value: {recovered_value}")
    print(f"Expected: {test_value}")
    print(f"Success: {recovered_value == test_value}")

def rust_montgomery_reduce_with_high(t_low_limbs, t_high_limbs, n_limbs, n_prime_limbs):
    """
    Mimic the exact Rust montgomery_reduce_with_high implementation.
    """
    print(f"\n--- Montgomery REDC ---")
    print(f"t_low: {t_low_limbs}")
    print(f"t_high: {t_high_limbs}")
    print(f"n: {n_limbs}")
    print(f"n_prime: {n_prime_limbs}")

    # Convert to values for easier math
    t_low = from_limbs(t_low_limbs)
    t_high = from_limbs(t_high_limbs)
    n = from_limbs(n_limbs)
    n_prime = from_limbs(n_prime_limbs)

    print(f"t_low value: {t_low}")
    print(f"t_high value: {t_high}")
    print(f"n value: {n}")
    print(f"n_prime value: {n_prime}")

    # m = (t_low * n_prime) mod R, where R = 2^256
    # Since we're working with values, m = (t_low * n_prime) % (2^256)
    R = 2**256
    m_full = t_low * n_prime
    m = m_full % R
    print(f"m = (t_low * n_prime) % R = {m}")
    print(f"m limbs: {to_limbs(m)}")

    # Compute m * N (512 bits)
    mn = m * n
    mn_low = mn % (2**256)
    mn_high = mn // (2**256)
    print(f"m * n = {mn}")
    print(f"mn_low: {to_limbs(mn_low)}")
    print(f"mn_high: {to_limbs(mn_high)}")

    # Compute T + m*N
    # T = t_high * 2^256 + t_low
    T = t_high * (2**256) + t_low
    print(f"T = t_high * 2^256 + t_low = {T}")
    print(f"T + m*n = {T + mn}")

    # In Rust: add_with_carry(t_low, mn_low) -> (sum_low, carry1)
    sum_low, carry1 = add_with_carry(t_low_limbs, to_limbs(mn_low))
    print(f"sum_low = t_low + mn_low: {sum_low}, carry1: {carry1}")

    # add_with_carry(t_high, mn_high) -> (sum_high, carry2)
    sum_high, carry2 = add_with_carry(t_high_limbs, to_limbs(mn_high))
    print(f"sum_high = t_high + mn_high: {sum_high}, carry2: {carry2}")

    # Add carry from low addition
    carry_to_high = carry1
    if carry_to_high > 0:
        sum_high_value = from_limbs(sum_high) + carry_to_high
        sum_high = to_limbs(sum_high_value)
        print(f"sum_high + carry1: {sum_high}")

    # Add any carry from high addition
    if carry2:
        sum_high_value = from_limbs(sum_high) + 1
        sum_high = to_limbs(sum_high_value)
        print(f"sum_high + carry2: {sum_high}")

    # Now result = (T + m*N) / R
    total_sum = from_limbs(sum_high) * (2**256) + from_limbs(sum_low)
    result_before_shift = total_sum // (2**256)
    print(f"total_sum = {total_sum}")
    print(f"result_before_shift = total_sum / R = {result_before_shift}")
    print(f"result_before_shift limbs: {to_limbs(result_before_shift)}")

    # Apply conditional subtraction: if result >= N then result = result - N
    needs_subtraction = result_before_shift >= n
    print(f"needs_subtraction = {result_before_shift} >= {n} = {needs_subtraction}")

    if needs_subtraction:
        final_result = result_before_shift - n
        print(f"final_result = {result_before_shift} - {n} = {final_result}")
    else:
        final_result = result_before_shift
        print(f"final_result = {result_before_shift}")

    return to_limbs(final_result)

def add_with_carry(a_limbs, b_limbs):
    """Mimic the Rust add_with_carry function."""
    result_limbs = [0] * 4
    carry = 0

    for i in range(4):
        # sum = a[i] + b[i] + carry
        sum_val = a_limbs[i] + b_limbs[i] + carry
        result_limbs[i] = sum_val & ((1 << 64) - 1)  # Lower 64 bits
        carry = sum_val >> 64  # Upper bits become carry

    overflow = carry != 0
    return (result_limbs, overflow)

if __name__ == "__main__":
    debug_rust_montgomery()