from sympy import *
def to_limbs(value, num_limbs=4, limb_size=64):
limbs = []
for i in range(num_limbs):
limb = (value >> (i * limb_size)) & ((1 << limb_size) - 1)
limbs.append(limb)
return limbs
def from_limbs(limbs, limb_size=64):
value = 0
for i, limb in enumerate(limbs):
value |= limb << (i * limb_size)
return value
def debug_rust_montgomery():
L = 2**252 + 27742317777372353535851937790883648493 R = 2**256
R_SQUARED_MOD_L_LIMBS = [0xa40611e3449c0f01, 0xd00e1ba768859347, 0xceec73d217f5be65, 0x0399411b7c309a3d]
N_PRIME_L_LIMBS = [0xd2b51da312547e1b, 0xb1a206f2fdba84ff, 0x14e75438ffa36bea, 0x9db6c6f26fe91836]
print("=== Rust-style Montgomery Reduction Debug ===")
print(f"L limbs: {to_limbs(L)}")
print(f"R_SQUARED_MOD_L limbs: {R_SQUARED_MOD_L_LIMBS}")
print(f"N_PRIME_L limbs: {N_PRIME_L_LIMBS}")
test_value = 10000
print(f"\nTesting with value = {test_value}")
print(f"Value limbs: {to_limbs(test_value)}")
value_bigint = test_value
r_squared_bigint = from_limbs(R_SQUARED_MOD_L_LIMBS)
n_bigint = L
n_prime_bigint = from_limbs(N_PRIME_L_LIMBS)
print(f"value = {value_bigint}")
print(f"r_squared = {r_squared_bigint}")
print(f"n = {n_bigint}")
print(f"n_prime = {n_prime_bigint}")
product = value_bigint * r_squared_bigint
print(f"product = value * r_squared = {product}")
print(f"product limbs (low 256): {to_limbs(product & ((1<<256)-1))}")
print(f"product limbs (high 256): {to_limbs(product >> 256)}")
result = rust_montgomery_reduce_with_high(
to_limbs(product & ((1<<256)-1)), to_limbs(product >> 256), to_limbs(L), to_limbs(from_limbs(N_PRIME_L_LIMBS)) )
mont_value = from_limbs(result)
print(f"Montgomery value: {mont_value}")
print(f"Expected: {(test_value * R) % L}")
recovered = rust_montgomery_reduce_with_high(
result, [0, 0, 0, 0], to_limbs(L), to_limbs(from_limbs(N_PRIME_L_LIMBS)) )
recovered_value = from_limbs(recovered)
print(f"Recovered value: {recovered_value}")
print(f"Expected: {test_value}")
print(f"Success: {recovered_value == test_value}")
def rust_montgomery_reduce_with_high(t_low_limbs, t_high_limbs, n_limbs, n_prime_limbs):
print(f"\n--- Montgomery REDC ---")
print(f"t_low: {t_low_limbs}")
print(f"t_high: {t_high_limbs}")
print(f"n: {n_limbs}")
print(f"n_prime: {n_prime_limbs}")
t_low = from_limbs(t_low_limbs)
t_high = from_limbs(t_high_limbs)
n = from_limbs(n_limbs)
n_prime = from_limbs(n_prime_limbs)
print(f"t_low value: {t_low}")
print(f"t_high value: {t_high}")
print(f"n value: {n}")
print(f"n_prime value: {n_prime}")
R = 2**256
m_full = t_low * n_prime
m = m_full % R
print(f"m = (t_low * n_prime) % R = {m}")
print(f"m limbs: {to_limbs(m)}")
mn = m * n
mn_low = mn % (2**256)
mn_high = mn // (2**256)
print(f"m * n = {mn}")
print(f"mn_low: {to_limbs(mn_low)}")
print(f"mn_high: {to_limbs(mn_high)}")
T = t_high * (2**256) + t_low
print(f"T = t_high * 2^256 + t_low = {T}")
print(f"T + m*n = {T + mn}")
sum_low, carry1 = add_with_carry(t_low_limbs, to_limbs(mn_low))
print(f"sum_low = t_low + mn_low: {sum_low}, carry1: {carry1}")
sum_high, carry2 = add_with_carry(t_high_limbs, to_limbs(mn_high))
print(f"sum_high = t_high + mn_high: {sum_high}, carry2: {carry2}")
carry_to_high = carry1
if carry_to_high > 0:
sum_high_value = from_limbs(sum_high) + carry_to_high
sum_high = to_limbs(sum_high_value)
print(f"sum_high + carry1: {sum_high}")
if carry2:
sum_high_value = from_limbs(sum_high) + 1
sum_high = to_limbs(sum_high_value)
print(f"sum_high + carry2: {sum_high}")
total_sum = from_limbs(sum_high) * (2**256) + from_limbs(sum_low)
result_before_shift = total_sum // (2**256)
print(f"total_sum = {total_sum}")
print(f"result_before_shift = total_sum / R = {result_before_shift}")
print(f"result_before_shift limbs: {to_limbs(result_before_shift)}")
needs_subtraction = result_before_shift >= n
print(f"needs_subtraction = {result_before_shift} >= {n} = {needs_subtraction}")
if needs_subtraction:
final_result = result_before_shift - n
print(f"final_result = {result_before_shift} - {n} = {final_result}")
else:
final_result = result_before_shift
print(f"final_result = {result_before_shift}")
return to_limbs(final_result)
def add_with_carry(a_limbs, b_limbs):
result_limbs = [0] * 4
carry = 0
for i in range(4):
sum_val = a_limbs[i] + b_limbs[i] + carry
result_limbs[i] = sum_val & ((1 << 64) - 1) carry = sum_val >> 64
overflow = carry != 0
return (result_limbs, overflow)
if __name__ == "__main__":
debug_rust_montgomery()