import argparse
import json
import sys
from math import ceil, floor, gcd, log, log2
FIELDS = {
"babybear": {
"prime": 2013265921,
"valid_widths": [16, 24],
},
"koalabear": {
"prime": 2130706433,
"valid_widths": [16, 24],
},
"goldilocks": {
"prime": (1 << 64) - (1 << 32) + 1,
"valid_widths": [8, 12, 16, 20],
},
"mersenne31": {
"prime": (1 << 31) - 1,
"valid_widths": [16, 24, 32],
},
}
def compute_alpha(p):
for alpha in range(3, p):
if gcd(alpha, p - 1) == 1:
return alpha
raise ValueError(f"No valid alpha found for p={p}")
def extended_gcd(a, b):
if a == 0:
return b, 0, 1
g, x, y = extended_gcd(b % a, a)
return g, y - (b // a) * x, x
def mod_inv(a, p):
if a == 0:
raise ZeroDivisionError("Cannot invert zero")
g, x, _ = extended_gcd(a % p, p)
if g != 1:
raise ValueError(f"{a} has no inverse mod {p}")
return x % p
def mat_identity(t, p):
M = [[0] * t for _ in range(t)]
for i in range(t):
M[i][i] = 1
return M
def mat_mul(A, B, p):
t = len(A)
C = [[0] * t for _ in range(t)]
for i in range(t):
for j in range(t):
s = 0
for k in range(t):
s += A[i][k] * B[k][j]
C[i][j] = s % p
return C
def mat_vec_mul(M, v, p):
t = len(M)
result = [0] * t
for i in range(t):
s = 0
for j in range(t):
s += M[i][j] * v[j]
result[i] = s % p
return result
def mat_pow(M, n, p):
t = len(M)
result = mat_identity(t, p)
base = [row[:] for row in M]
while n > 0:
if n & 1:
result = mat_mul(result, base, p)
base = mat_mul(base, base, p)
n >>= 1
return result
def mat_add_scalar_diag(M, c, p):
t = len(M)
R = [row[:] for row in M]
for i in range(t):
R[i][i] = (R[i][i] + c) % p
return R
def mat_trace(M, p):
return sum(M[i][i] for i in range(len(M))) % p
def char_poly(M, p):
n = len(M)
coeffs = [0] * (n + 1)
coeffs[n] = 1
C = [row[:] for row in M]
coeffs[n - 1] = (-mat_trace(C, p)) % p
for k in range(2, n + 1):
temp = mat_add_scalar_diag(C, coeffs[n - k + 1], p)
C = mat_mul(M, temp, p)
coeffs[n - k] = (-(mod_inv(k, p) * mat_trace(C, p)) % p) % p
return coeffs
def poly_strip(f):
f = list(f)
while len(f) > 1 and f[-1] == 0:
f.pop()
return f
def poly_mul(f, g, p):
if not f or not g:
return [0]
n, m = len(f), len(g)
result = [0] * (n + m - 1)
for i in range(n):
if f[i] == 0:
continue
for j in range(m):
result[i + j] = (result[i + j] + f[i] * g[j]) % p
return poly_strip(result)
def poly_divmod(f, g, p):
f = list(f)
g = poly_strip(g)
if g == [0]:
raise ZeroDivisionError("Division by zero polynomial")
dg = len(g) - 1
inv_lc = mod_inv(g[-1], p)
q = [0] * max(len(f) - dg, 1)
while len(f) >= len(g) and f != [0]:
f = poly_strip(f)
if len(f) < len(g):
break
coeff = (f[-1] * inv_lc) % p
shift = len(f) - len(g)
q[shift] = coeff
for i in range(len(g)):
f[shift + i] = (f[shift + i] - coeff * g[i]) % p
f = poly_strip(f)
return poly_strip(q), poly_strip(f) if f else [0]
def poly_mod(f, g, p):
_, r = poly_divmod(f, g, p)
return r
def poly_gcd(f, g, p):
f = poly_strip(list(f))
g = poly_strip(list(g))
while g != [0]:
f, g = g, poly_mod(f, g, p)
if len(f) > 0 and f[-1] != 0:
inv_lc = mod_inv(f[-1], p)
f = [(c * inv_lc) % p for c in f]
return poly_strip(f)
def poly_pow_mod(base, exp, modulus, p):
if exp == 0:
return [1]
result = [1]
base = poly_mod(base, modulus, p)
while exp > 0:
if exp & 1:
result = poly_mod(poly_mul(result, base, p), modulus, p)
base = poly_mod(poly_mul(base, base, p), modulus, p)
exp >>= 1
return result
def poly_sub(f, g, p):
n = max(len(f), len(g))
result = [0] * n
for i in range(len(f)):
result[i] = f[i]
for i in range(len(g)):
result[i] = (result[i] - g[i]) % p
return poly_strip(result)
def prime_factors(n):
factors = set()
d = 2
while d * d <= n:
while n % d == 0:
factors.add(d)
n //= d
d += 1
if n > 1:
factors.add(n)
return factors
def is_irreducible(f, p):
n = len(f) - 1 if n <= 0:
return False
if n == 1:
return True
x = [0, 1]
x_pi = [None] * (n + 1)
x_pi[0] = x
for i in range(1, n + 1):
x_pi[i] = poly_pow_mod(x_pi[i - 1], p, f, p)
diff = poly_sub(x_pi[n], x, p)
if poly_mod(diff, f, p) != [0]:
return False
for q in prime_factors(n):
k = n // q
diff = poly_sub(x_pi[k], x, p)
g = poly_gcd(diff, f, p)
if len(g) > 1: return False
return True
class GrainLFSR:
TAPS = [0, 13, 23, 38, 51, 62]
def __init__(self, n, t, R_F, R_P):
field_type = 1 sbox = 0
bits = []
bits += self._to_bits(field_type, 2)
bits += self._to_bits(sbox, 4)
bits += self._to_bits(n, 12)
bits += self._to_bits(t, 12)
bits += self._to_bits(R_F, 10)
bits += self._to_bits(R_P, 10)
bits += [1] * 30
assert len(bits) == 80
self.state = bits
for _ in range(160):
self._clock()
@staticmethod
def _to_bits(value, width):
return [int(b) for b in bin(value)[2:].zfill(width)]
def _clock(self):
new_bit = 0
for tap in self.TAPS:
new_bit ^= self.state[tap]
self.state.pop(0)
self.state.append(new_bit)
return new_bit
def next_bit(self):
while True:
a = self._clock()
b = self._clock()
if a == 1:
return b
def random_field_element(self, n, p):
while True:
bits = [self.next_bit() for _ in range(n)]
value = int("".join(str(b) for b in bits), 2)
if value < p:
return value
def sat_inequalities(p, t, R_F, R_P, alpha, M, n):
threshold = (floor(log(p, 2) - ((alpha - 1) / 2.0))) * (t + 1)
R_F_1 = 6 if M <= threshold else 10
R_F_2 = 1 + ceil(log(2, alpha) * min(M, n)) + ceil(log(t, alpha)) - R_P
R_F_3 = log(2, alpha) * min(M, log(p, 2)) - R_P
R_F_4 = t - 1 + log(2, alpha) * min(M / float(t + 1), log(p, 2) / 2.0) - R_P
R_F_5 = (t - 2 + (M / (2.0 * log(alpha, 2))) - R_P) / float(t - 1)
R_F_max = max(ceil(R_F_1), ceil(R_F_2), ceil(R_F_3), ceil(R_F_4), ceil(R_F_5))
r_temp = floor(t / 3.0)
over = (R_F - 1) * t + R_P + r_temp + r_temp * (R_F / 2.0) + R_P + alpha
under = r_temp * (R_F / 2.0) + R_P + alpha
try:
from math import comb
binom_val = comb(int(over), int(under))
if binom_val == 0:
binom_log = 0
else:
binom_log = log2(binom_val)
except (ValueError, OverflowError):
binom_log = M + 1
cost_gb4 = ceil(2 * binom_log)
return (R_F >= R_F_max) and (cost_gb4 >= M)
def compute_round_numbers(p, t, alpha, M=128):
n = p.bit_length()
best_R_F = 0
best_R_P = 0
min_cost = float("inf")
max_cost_rf = 0
for R_P_t in range(1, 500):
for R_F_t in range(4, 100):
if R_F_t % 2 != 0:
continue
if sat_inequalities(p, t, R_F_t, R_P_t, alpha, M, n):
R_F_m = R_F_t + 2
R_P_m = int(ceil(R_P_t * 1.075))
cost = t * R_F_m + R_P_m
if (cost < min_cost) or (cost == min_cost and R_F_m < max_cost_rf):
best_R_P = R_P_m
best_R_F = R_F_m
min_cost = cost
max_cost_rf = best_R_F
if best_R_F == 0:
raise ValueError(
f"No valid round numbers found for p={p}, t={t}, alpha={alpha}"
)
return (best_R_F, best_R_P)
def generate_round_constants_poseidon2(grain, p, n, t, R_F, R_P):
R_f = R_F // 2
num_constants = R_F * t + R_P
raw = []
for _ in range(num_constants):
raw.append(grain.random_field_element(n, p))
idx = 0
external_initial = []
for _ in range(R_f):
external_initial.append(raw[idx : idx + t])
idx += t
internal = raw[idx : idx + R_P]
idx += R_P
external_final = []
for _ in range(R_f):
external_final.append(raw[idx : idx + t])
idx += t
assert idx == num_constants
return external_initial, internal, external_final
M4 = [
[5, 7, 1, 3],
[4, 6, 1, 1],
[1, 3, 5, 7],
[1, 1, 4, 6],
]
def generate_external_matrix(t, p):
if t == 2:
return [[2 % p, 1], [1, 2 % p]]
elif t == 3:
return [[2 % p, 1, 1], [1, 2 % p, 1], [1, 1, 2 % p]]
elif t == 4:
return [[x % p for x in row] for row in M4]
elif t % 4 == 0:
M = [[0] * t for _ in range(t)]
num_blocks = t // 4
for i in range(num_blocks):
for j in range(num_blocks):
factor = 2 if i == j else 1
for r in range(4):
for c in range(4):
M[i * 4 + r][j * 4 + c] = (factor * M4[r][c]) % p
return M
else:
raise ValueError(
f"Unsupported width t={t} for external matrix (must be 2, 3, or divisible by 4)"
)
def check_minpoly_condition(M, t, p):
M_pow = [row[:] for row in M] for i in range(1, 2 * t + 1):
cp = char_poly(M_pow, p)
if not is_irreducible(cp, p):
return False
M_pow = mat_mul(M, M_pow, p)
return True
def generate_internal_matrix(grain, t, n, p, verbose=False):
attempt = 0
while True:
attempt += 1
diag = [grain.random_field_element(n, p) for _ in range(t)]
M = [[0] * t for _ in range(t)]
for i in range(t):
for j in range(t):
if i == j:
M[i][j] = diag[i] % p
else:
M[i][j] = 1
if check_minpoly_condition(M, t, p):
if verbose:
print(f" Internal matrix found after {attempt} attempt(s)")
diag_minus_1 = [(diag[i] - 1) % p for i in range(t)]
return diag_minus_1
def poseidon2_permutation(
state,
external_matrix,
internal_matrix_diag_m1,
external_initial,
internal_constants,
external_final,
alpha,
p,
t,
):
state = list(state)
state = mat_vec_mul(external_matrix, state, p)
for rc in external_initial:
for i in range(t):
state[i] = (state[i] + rc[i]) % p
for i in range(t):
state[i] = pow(state[i], alpha, p)
state = mat_vec_mul(external_matrix, state, p)
for rc in internal_constants:
state[0] = (state[0] + rc) % p
state[0] = pow(state[0], alpha, p)
total = sum(state) % p
new_state = [0] * t
for i in range(t):
new_state[i] = (total + internal_matrix_diag_m1[i] * state[i]) % p
state = new_state
for rc in external_final:
for i in range(t):
state[i] = (state[i] + rc[i]) % p
for i in range(t):
state[i] = pow(state[i], alpha, p)
state = mat_vec_mul(external_matrix, state, p)
return state
def format_hex(value, n):
hex_width = (n + 3) // 4 return f"0x{value:0{hex_width}x}"
def _wrap_hex_row(values, n, indent=4, max_width=100):
items = [format_hex(v, n) for v in values]
prefix = " " * indent
lines = []
current = prefix
for i, item in enumerate(items):
sep = " " if i > 0 else ""
if len(current) + len(sep) + len(item) > max_width and current.strip():
lines.append(current)
current = prefix + item
else:
current += sep + item
if current.strip():
lines.append(current)
return lines
def format_default_poseidon2(
field_name, width, external_initial, internal, external_final,
diag_m1, p, n, alpha, R_F, R_P, skip_matrix,
):
R_f = R_F // 2
sep = "─" * 72
lines = []
lines.append(sep)
lines.append(f" Poseidon2 Constants — {field_name} (width {width})")
lines.append(sep)
lines.append("")
lines.append(f" Field {field_name}")
lines.append(f" Prime (p) {p}")
lines.append(f" Bit length {n}")
lines.append(f" S-box (α) x^{alpha}")
lines.append(f" Width (t) {width}")
lines.append(f" Full rounds {R_F} ({R_f} initial + {R_f} final)")
lines.append(f" Partial {R_P}")
lines.append(f" Constants {R_F * width + R_P} ({R_F}×{width} + {R_P})")
lines.append("")
lines.append(sep)
lines.append(f" External Round Constants — Initial ({R_f} rounds × {width})")
lines.append(sep)
for i, rnd in enumerate(external_initial):
lines.append("")
lines.append(f" round {i}:")
lines.extend(_wrap_hex_row(rnd, n))
lines.append("")
lines.append(sep)
lines.append(f" Internal Round Constants ({R_P} scalars)")
lines.append(sep)
lines.append("")
lines.extend(_wrap_hex_row(internal, n))
lines.append("")
lines.append(sep)
lines.append(f" External Round Constants — Final ({R_f} rounds × {width})")
lines.append(sep)
for i, rnd in enumerate(external_final):
lines.append("")
lines.append(f" round {R_f + R_P + i}:")
lines.extend(_wrap_hex_row(rnd, n))
if not skip_matrix:
lines.append("")
lines.append(sep)
lines.append(f" Internal Matrix Diagonal − 1 (Grain-generated, {width} entries)")
lines.append(f" note: production builds use hand-optimized diagonals")
lines.append(sep)
lines.append("")
lines.extend(_wrap_hex_row(diag_m1, n))
lines.append("")
lines.append(sep)
return "\n".join(lines)
def format_json_poseidon2(
field_name, width, external_initial, internal, external_final,
diag_m1, p, n, alpha, R_F, R_P, skip_matrix,
):
data = {
"field": field_name,
"prime": str(p),
"width": width,
"alpha": alpha,
"R_F": R_F,
"R_P": R_P,
"external_initial": [
[format_hex(v, n) for v in rnd] for rnd in external_initial
],
"internal": [format_hex(v, n) for v in internal],
"external_final": [[format_hex(v, n) for v in rnd] for rnd in external_final],
}
if not skip_matrix:
data["matrix_diag_minus_1_grain"] = [format_hex(v, n) for v in diag_m1]
return json.dumps(data, indent=2)
def main():
parser = argparse.ArgumentParser(
description="Generate Poseidon2 round constants for various prime fields.",
formatter_class=argparse.RawDescriptionHelpFormatter,
epilog="Examples:\n"
" python poseidon2/generate_constants.py --field babybear --width 16\n"
" python poseidon2/generate_constants.py --field goldilocks --width 8 --format json\n"
" python poseidon2/generate_constants.py --field koalabear --width 24 -v\n",
)
parser.add_argument(
"--field",
required=True,
choices=list(FIELDS.keys()),
help="Target field",
)
parser.add_argument(
"--width",
required=True,
type=int,
help="State width (t)",
)
parser.add_argument(
"--format",
default="default",
choices=["default", "json"],
help="Output format (default: human-readable summary)",
)
parser.add_argument(
"--security-level",
default=128,
type=int,
help="Security level in bits (default: 128)",
)
parser.add_argument(
"-v",
"--verbose",
action="store_true",
help="Print verbose progress information",
)
parser.add_argument(
"--skip-matrix",
action="store_true",
help="Skip internal matrix generation (faster; only generates round constants)",
)
parser.add_argument(
"--test-vector",
action="store_true",
help="Compute and print a test vector using the reference permutation",
)
args = parser.parse_args()
field_info = FIELDS[args.field]
p = field_info["prime"]
t = args.width
if t not in field_info["valid_widths"]:
print(
f"Error: width {t} not valid for {args.field}. "
f"Valid widths: {field_info['valid_widths']}",
file=sys.stderr,
)
sys.exit(1)
n = p.bit_length()
alpha = compute_alpha(p)
if args.verbose:
print(f"Field: {args.field}")
print(f"Prime: {p} ({n} bits)")
print(f"Width: {t}")
print(f"Alpha: {alpha}")
print()
if args.verbose:
print("Computing round numbers...", flush=True)
R_F, R_P = compute_round_numbers(p, t, alpha, args.security_level)
if args.verbose:
print(f"R_F = {R_F}, R_P = {R_P}")
print(f"Total raw constants: R_F*t + R_P = {R_F * t + R_P}")
print()
grain = GrainLFSR(n, t, R_F, R_P)
if args.verbose:
print("Generating round constants...", flush=True)
external_initial, internal, external_final = generate_round_constants_poseidon2(
grain, p, n, t, R_F, R_P
)
if not args.skip_matrix:
if args.verbose:
print("Generating internal matrix (this may take a moment)...", flush=True)
diag_m1 = generate_internal_matrix(grain, t, n, p, verbose=args.verbose)
else:
diag_m1 = [0] * t
if args.verbose:
print("Skipping internal matrix generation")
if args.verbose:
print()
fmt_args = (
args.field, t, external_initial, internal, external_final,
diag_m1, p, n, alpha, R_F, R_P, args.skip_matrix,
)
if args.format == "default":
print(format_default_poseidon2(*fmt_args))
elif args.format == "json":
print(format_json_poseidon2(*fmt_args))
if args.test_vector:
if args.skip_matrix:
print(
"\nWarning: test vector uses zero diagonal (--skip-matrix). "
"Output will NOT match production.",
file=sys.stderr,
)
ext_mat = generate_external_matrix(t, p)
state_in = list(range(t))
state_out = poseidon2_permutation(
state_in,
ext_mat,
diag_m1,
external_initial,
internal,
external_final,
alpha,
p,
t,
)
print()
print(f"Test vector (input = [0, 1, ..., {t - 1}]):")
print(f" Input: [{', '.join(format_hex(v, n) for v in state_in)}]")
print(f" Output: [{', '.join(format_hex(v, n) for v in state_out)}]")
if __name__ == "__main__":
main()