import argparse
import sys
from typing import List, Tuple
KNOWN_PRIMES = {
'fp256': {
'prime': 65000549695646603732796438742359905742825358107623003571877145026864184071783,
'bits': 256,
'module': 'fp_256',
'classname': 'Fp256',
},
'fp480': {
'prime': 3121577065842246806003085452055281276803074876175537384188619957989004527066410274868798956582915008874704066849018213144375771284425395508176023,
'bits': 480,
'module': 'fp_480',
'classname': 'Fp480',
},
}
def to_limbs(n: int, limb_size: int, num_limbs: int) -> List[int]:
mask = (1 << limb_size) - 1
limbs = []
for _ in range(num_limbs):
limbs.append(n & mask)
n >>= limb_size
return limbs
def from_limbs(limbs: List[int], limb_size: int) -> int:
result = 0
for i, limb in enumerate(limbs):
result |= limb << (limb_size * i)
return result
def mod_inverse(a: int, m: int) -> int:
def extended_gcd(a, b):
if a == 0:
return b, 0, 1
gcd, x1, y1 = extended_gcd(b % a, a)
x = y1 - (b // a) * x1
y = x1
return gcd, x, y
_, x, _ = extended_gcd(a % m, m)
return (x % m + m) % m
def compute_num_limbs(bits: int, limb_size: int) -> int:
return (bits + limb_size - 1) // limb_size
def compute_constants(prime: int, limb_size: int, bits: int) -> Tuple[List[int], List[int], List[int], List[int], int, int]:
num_limbs = compute_num_limbs(bits, limb_size)
prime_limbs = to_limbs(prime, limb_size, num_limbs)
modulus = 1 << limb_size
montm0inv = mod_inverse(-prime_limbs[0], modulus)
R = 1 << (limb_size * num_limbs)
montgomery_one = R % prime
montgomery_one_limbs = to_limbs(montgomery_one, limb_size, num_limbs)
montgomery_r2 = (R * R) % prime
montgomery_r2_limbs = to_limbs(montgomery_r2, limb_size, num_limbs)
reduction_const = (1 << (limb_size * (2 * num_limbs - 1))) % prime
reduction_limbs = to_limbs(reduction_const, limb_size, num_limbs)
return prime_limbs, reduction_limbs, montgomery_one_limbs, montgomery_r2_limbs, montm0inv, num_limbs
def format_limb_array(limbs: List[int], indent: int = 8) -> str:
indent_str = ' ' * indent
lines = []
current_line = []
current_length = 0
for i, limb in enumerate(limbs):
limb_str = str(limb)
if i < len(limbs) - 1:
limb_str += ','
test_length = current_length + len(limb_str) + (2 if current_line else 0)
if current_line and test_length > 100:
lines.append(indent_str + ' '.join(current_line))
current_line = [limb_str]
current_length = len(limb_str)
else:
current_line.append(limb_str)
current_length = test_length
if current_line:
lines.append(indent_str + ' '.join(current_line))
return '\n'.join(lines)
def generate_macro_invocation(module: str, classname: str, bits: int, limb_size: int,
prime_limbs: List[int], reduction_limbs: List[int],
montgomery_one_limbs: List[int], montgomery_r2_limbs: List[int],
montm0inv: int, num_limbs: int) -> str:
macro_name = f"fp{limb_size}"
return f"""{macro_name}!(
{module}, // Name of mod
{classname}, // Name of class
{bits}, // Number of bits for prime
{num_limbs}, // Number of limbs (ceil(bits/{limb_size}))
[
// prime number in limbs, least significant first
// get this from sage with p.digits(2^{limb_size})
{format_limb_array(prime_limbs)}
],
[
// Barrett reduction constant for reducing values up to twice
// the number of prime bits (double limbs):
// 2^({limb_size}*{num_limbs}*2 - {limb_size}) mod p = 2^{limb_size * num_limbs * 2 - limb_size} mod p
{format_limb_array(reduction_limbs)}
],
[
// Montgomery R = 2^(W*N) where W = word size and N = limbs
// R = 2^({num_limbs}*{limb_size}) = 2^{num_limbs * limb_size}
// Montgomery R mod p
{format_limb_array(montgomery_one_limbs)}
],
[
// Montgomery R^2 mod p
{format_limb_array(montgomery_r2_limbs)}
],
// -p[0]^-1 mod 2^{limb_size}
// in sage: m = p.digits(2^{limb_size})[0]
// (-m).inverse_mod(2^{limb_size})
{montm0inv}
);"""
def main():
parser = argparse.ArgumentParser(
description='Compute Montgomery constants for gridiron finite fields',
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="""
Examples:
# Generate 31-bit limb constants for Fp256
%(prog)s --limb-size 31 --field fp256
# Generate 62-bit limb constants for Fp480
%(prog)s --limb-size 62 --field fp480
# Use custom prime
%(prog)s --limb-size 31 --prime 12345... --module my_fp --classname MyFp --bits 256
# Quiet mode (only output Rust macro)
%(prog)s --limb-size 62 --field fp256 --quiet
"""
)
parser.add_argument('--limb-size', type=int, choices=[31, 62], required=True,
help='Limb size in bits (31 or 62)')
parser.add_argument('--field', type=str, choices=['fp256', 'fp480'],
help='Use known field (fp256 or fp480)')
parser.add_argument('--prime', type=str,
help='Prime number (decimal or hex with 0x prefix)')
parser.add_argument('--module', type=str,
help='Module name (e.g., fp_256)')
parser.add_argument('--classname', type=str,
help='Class name (e.g., Fp256)')
parser.add_argument('--bits', type=int,
help='Number of bits in prime')
parser.add_argument('--quiet', action='store_true',
help='Only output Rust macro (suppress explanations)')
args = parser.parse_args()
if args.field:
field_info = KNOWN_PRIMES[args.field]
prime = field_info['prime']
bits = field_info['bits']
module = field_info['module']
classname = field_info['classname']
elif args.prime:
if args.prime.startswith('0x') or args.prime.startswith('0X'):
prime = int(args.prime, 16)
else:
prime = int(args.prime)
if not all([args.module, args.classname, args.bits]):
parser.error("--prime requires --module, --classname, and --bits")
module = args.module
classname = args.classname
bits = args.bits
else:
parser.error("Must specify either --field or --prime")
prime_limbs, reduction_limbs, montgomery_one_limbs, montgomery_r2_limbs, montm0inv, num_limbs = \
compute_constants(prime, args.limb_size, bits)
if not args.quiet:
print("=" * 80)
print(f"{classname} ({bits}-bit prime, {num_limbs} limbs @ {args.limb_size}-bit)")
print("=" * 80)
print(f"\nPrime (decimal): {prime}")
print(f"\nPrime in {args.limb_size}-bit limbs:")
print(f" {prime_limbs}")
print(f"\nMontgomery m0_inv ((-p[0])^-1 mod 2^{args.limb_size}):")
print(f" {montm0inv}")
print(f"\nR = 2^({args.limb_size} * {num_limbs}) = 2^{args.limb_size * num_limbs}")
print(f"\nMontgomery One (R mod p):")
print(f" {montgomery_one_limbs}")
print(f"\nMontgomery R^2 (R^2 mod p):")
print(f" {montgomery_r2_limbs}")
print(f"\nReduction constant (2^{args.limb_size * (2 * num_limbs - 1)} mod p):")
print(f" {reduction_limbs}")
print("\n" + "=" * 80)
print("Rust Macro Invocation:")
print("=" * 80)
print()
macro_code = generate_macro_invocation(
module, classname, bits, args.limb_size,
prime_limbs, reduction_limbs, montgomery_one_limbs, montgomery_r2_limbs,
montm0inv, num_limbs
)
print(macro_code)
if not args.quiet:
print("\n" + "=" * 80)
print("Copy the above macro invocation into your Rust code.")
print("=" * 80)
if __name__ == '__main__':
main()