import sys
from math import ceil, log, gcd
NB_PARALLEL_LOOKUPS = 4
DEBUG = "--debug" in sys.argv
VALID_PARAMETERS = "valid"
NOT_ENOUGH_MODULI = "not_enough_moduli"
INVALID_AUX_MODULUS = "invalid_aux_modulus"
def lcm(l):
acc = 1
for x in l:
acc = abs(acc * x) // gcd(acc, x)
return acc
def log2(n):
k = len(bin(n-1)) - 2
assert(2**k == n)
return k
def signed_mod(x, m):
r = x % m
if r > m // 2:
r -= m
return r
def next_cheapest_power_of_2(MAX_BIT_LEN, x):
best_log = len(bin(x-1)) - 2
best = cost_range_check(MAX_BIT_LEN, best_log)
for i in range(1, 128):
cost = cost_range_check(MAX_BIT_LEN, best_log + i)
if cost < best:
best = cost
best_log = best_log + i
return 2**best_log
def mul_expr_bounds(q, n, B, base_powers, double_base_powers, **_):
max_sum_xy = (B-1)**2 * sum(double_base_powers)
max_sum_x = (B-1) * sum(base_powers)
max_sum_y = max_sum_x
max_sum_z = max_sum_x
expr_min = -max_sum_z
expr_max = max_sum_xy + max_sum_x + max_sum_y
return (expr_min, expr_max)
def tangent_expr_bounds(q, n, B, base_powers, double_base_powers, a_plus_one=1, **_):
max_sum_px = (B-1) * sum(base_powers)
max_sum_py = max_sum_px
max_sum_lambda = max_sum_px
max_sum_px2 = (B-1)**2 * sum(double_base_powers)
max_sum_lpy = max_sum_px2
expr_min = - 2 * (max_sum_py + max_sum_lambda + max_sum_lpy) + a_plus_one
expr_max = 3 * (max_sum_px + max_sum_px + max_sum_px2) + a_plus_one
return (expr_min, expr_max)
def on_curve_expr_bounds(q, n, B, base_powers, double_base_powers, a_plus_one=1, a_plus_b=0, **_):
max_sum_x = (B-1) * sum(base_powers)
max_sum_y = max_sum_x
max_sum_z = max_sum_x
max_sum_xz = (B-1)**2 * sum(double_base_powers)
max_sum_y2 = max_sum_xz
expr_min = -(max_sum_xz + max_sum_z + max(a_plus_one * max_sum_x, 0)) - a_plus_b
expr_max = 2 * max_sum_y + max_sum_y2 - min(a_plus_one * max_sum_x, 0) - a_plus_b
return (expr_min, expr_max)
def lambda2_expr_bounds(q, n, B, base_powers, double_base_powers, **_):
max_sum_px = (B-1) * sum(base_powers)
max_sum_qx = max_sum_px
max_sum_rx = max_sum_px
max_sum_lambda = max_sum_px
max_sum_lambda2 = (B-1)**2 * sum(double_base_powers)
expr_min = 2 - (2 * max_sum_lambda + max_sum_lambda2)
expr_max = 2 + max_sum_px + max_sum_qx + max_sum_rx
return (expr_min, expr_max)
def slope_expr_bounds(q, n, B, base_powers, double_base_powers, **_):
max_sum_px = (B-1) * sum(base_powers)
max_sum_py = max_sum_px
max_sum_qx = max_sum_px
max_sum_qy = max_sum_px
max_sum_lpx = (B-1)**2 * sum(double_base_powers)
max_sum_lqx = max_sum_lpx
expr_min = -(2 + max_sum_qy + max_sum_py + max_sum_qx + max_sum_lqx)
expr_max = max_sum_qy + max_sum_px + max_sum_lpx
return (expr_min, expr_max)
class Params:
def __init__(self, p, q, B, auxiliary_moduli, RC_len,
expr_bounds, double_base_powers = None, curve_constants = None):
if curve_constants is None:
curve_constants = {}
n = ceil(log(q) / log(B))
self.p = p
self.q = q
self.B = B
self.n = n
self.auxiliary_moduli = auxiliary_moduli
self.validity = VALID_PARAMETERS
base_powers = [B**i % q for i in range(n)]
if double_base_powers == None:
double_base_powers = [B**(i+j) % q for i in range(n) for j in range(n)]
(expr_min, expr_max) = expr_bounds(q, n, B, base_powers, double_base_powers,
**curve_constants)
assert (expr_min < 0)
k_min = - (abs(expr_min) // q)
k_max = expr_max // q
u_max = next_cheapest_power_of_2(RC_len, k_max - k_min + 1)
self.k_min = k_min
self.u_max = u_max
lcm_lower_bound = expr_min - (u_max + k_min) * q
lcm_upper_bound = expr_max - k_min * q
lcm_threshold = max(-lcm_lower_bound, lcm_upper_bound)
if lcm(auxiliary_moduli) <= lcm_threshold:
if DEBUG:
print("You must consider more auxiliary moduli:")
print(" lcm_threshold:", lcm_threshold)
print(" lcm(auxiliary_moduli): ", lcm(auxiliary_moduli))
print("About another %d bits to go" % int(log(lcm_threshold/lcm(auxiliary_moduli)) / log(2)))
self.validity = NOT_ENOUGH_MODULI
self.ls_min = []
self.vs_max = []
for mj in auxiliary_moduli:
if mj == p:
continue
bi_mod_q_mod_mj = [(B**i % q) % mj for i in range(n)]
bij_mod_q_mod_mj = [b % mj for b in double_base_powers]
cc_mj = {k: signed_mod(v, mj) for k, v in curve_constants.items()}
(expr_mj_min, expr_mj_max) = expr_bounds(q, n, B, bi_mod_q_mod_mj,
bij_mod_q_mod_mj, **cc_mj)
lj_min = - (abs(expr_mj_min - u_max * (q % mj) - (k_min * q) % mj ) // mj)
lj_max = (expr_mj_max - (k_min * q) % mj ) // mj
vj_max = next_cheapest_power_of_2(RC_len, lj_max - lj_min + 1)
self.ls_min.append(lj_min)
self.vs_max.append(vj_max)
lower_bound = expr_mj_min - u_max * (q % mj) - (k_min * q) % mj - (vj_max + lj_min) * mj
upper_bound = expr_mj_max - (k_min * q) % mj - lj_min * mj
p_threshold = max(-lower_bound, upper_bound)
if p <= p_threshold:
self.validity = INVALID_AUX_MODULUS
if DEBUG:
print("Auxiliary modulus %d is not valid (there will be wrap-around)" % mj)
print(" bij_q_mj:", bij_mod_q_mod_mj)
print(" bi_q_mj:", bi_mod_q_mod_mj)
print(" u_max:", u_max)
print(" l: [%d, %d]" % (lj_min, lj_max))
print(" vj_max:", vj_max)
print(" lower_bound:", lower_bound)
print(" upper_bound:", upper_bound)
print(" p_threshold:", p_threshold)
print("Threshold violated by about %d bits" % int(log(p_threshold/p) / log(2)))
Tables = {}
def cost_range_check(MAX_BIT_LEN, n):
global Tables
if Tables.get(MAX_BIT_LEN) == None:
Tables[MAX_BIT_LEN] = {}
if n == 0:
return 0
T = Tables.get(MAX_BIT_LEN)
if T.get(n) != None:
return T.get(n)
best = n for nb_cols in range(1, NB_PARALLEL_LOOKUPS+1):
for bit_len in range(1, MAX_BIT_LEN+1):
next_n = n - nb_cols * bit_len
if next_n < 0:
continue
sol = cost_range_check(MAX_BIT_LEN, next_n)
if sol < best:
best = sol
T[n] = best + 1
return best + 1
def cost_mul(RC_len, params):
cost = 2
cost += params.n * cost_range_check(RC_len, log2(params.B))
cost += cost_range_check(RC_len, log2(params.u_max))
cost += sum([cost_range_check(RC_len, log2(vj)) for vj in params.vs_max])
return cost
def cost_tangent(RC_len, params):
cost = 2
cost += cost_range_check(RC_len, log2(params.u_max))
cost += sum([cost_range_check(RC_len, log2(vj)) for vj in params.vs_max])
return cost
def cost_on_curve(RC_len, params, mul_cost):
cost = mul_cost + 2
cost += cost_range_check(RC_len, log2(params.u_max))
cost += sum([cost_range_check(RC_len, log2(vj)) for vj in params.vs_max])
return cost
def cost_lambda2(RC_len, params):
cost = 3
cost += cost_range_check(RC_len, log2(params.u_max))
cost += sum([cost_range_check(RC_len, log2(vj)) for vj in params.vs_max])
return cost
def cost_slope(RC_len, params):
cost = 3
cost += cost_range_check(RC_len, log2(params.u_max))
cost += sum([cost_range_check(RC_len, log2(vj)) for vj in params.vs_max])
return cost
def cost_incomplete_point_add(B, n, RC_len, lambda2_cost, slope_cost):
cost_assign = n * cost_range_check(RC_len, log2(B))
cost_assign_point = 1 + 2 * cost_assign
cost_incomplete_add = cost_assign_point + cost_assign + lambda2_cost + 2 * slope_cost
return cost_incomplete_add
def cost_scalar_mul(B, n, WS, RC_len, norm_cost, lambda2_cost, slope_cost, tangent_cost):
nb_assign = 1
nb_incomplete_add = 1
nb_double = 1
nb_negate = 1
nb_incomplete_assert_different_x = 2**WS - 1
nb_incomplete_add += 2**WS - 1
nb_multi_select = 2**WS
nb_iterations = ceil(256 / WS)
nb_double += 256
nb_incomplete_add += nb_iterations
nb_incomplete_assert_different_x += nb_iterations
nb_multi_select += nb_iterations
nb_negate += 1
nb_incomplete_add += 1
cost_assign = n * cost_range_check(RC_len, log2(B))
cost_assign_point = 1 + 2 * cost_assign
cost_negate = n + norm_cost
cost_incomplete_add = cost_assign_point + cost_assign + lambda2_cost + 2 * slope_cost
cost_double = cost_assign_point + 1 + cost_assign + tangent_cost + lambda2_cost + slope_cost
cost_incomplete_assert_different_x = 4
cost_multi_select = 1
cost = 0
cost += cost_assign * nb_assign
cost += cost_negate * nb_negate
cost += cost_incomplete_add * nb_incomplete_add
cost += cost_double * nb_double
cost += cost_incomplete_assert_different_x * nb_incomplete_assert_different_x
cost += cost_multi_select * nb_multi_select
return cost
def optimization_round(p, q, RC_len, nb_limbs, expr_bounds, curve_constants=None):
nb_bits = ceil(log(q) / log(2))
log2_B = min([k for k in range(nb_bits) if 2**(k * nb_limbs) >= q])
B = next_cheapest_power_of_2(RC_len, 2**log2_B)
auxiliary_moduli = [p]
params = Params(p, q, B, auxiliary_moduli, RC_len, expr_bounds,
curve_constants=curve_constants)
if params.validity == VALID_PARAMETERS:
return (params, ['native'])
log2_m = 0
for k in range(1, nb_bits):
params = Params(p, q, B, [2**k], RC_len, expr_bounds,
curve_constants=curve_constants)
if params.validity == INVALID_AUX_MODULUS:
break
log2_m = k
if log2_m == 0:
return None
m = 2**log2_m
auxiliary_moduli = [p, m]
auxiliary_moduli_str = ['native', '2^' + str(log2_m)]
params = Params(p, q, B, auxiliary_moduli, RC_len, expr_bounds,
curve_constants=curve_constants)
i = 1
while params.validity != VALID_PARAMETERS:
params = Params(p, q, B, auxiliary_moduli + [m - i],
RC_len, expr_bounds, curve_constants=curve_constants)
if params.validity != INVALID_AUX_MODULUS:
auxiliary_moduli += [m - i]
auxiliary_moduli_str += ['2^' + str(log2_m) + '-' + str(i)]
i += 1
return (params, auxiliary_moduli_str)
def pp_params(params, RC_len, auxiliary_moduli_str):
assert (params.validity == VALID_PARAMETERS)
log2_B = int(log(params.B) / log(2))
moduli = ", ".join(auxiliary_moduli_str)
info = "B = 2^%d, nb_limbs = %d, moduli = {%s}" % (log2_B, params.n, moduli)
info += ", u_max = {%d}" % log2(params.u_max)
info += ", vs_max = {[%s]}" % str([log2(v) for v in params.vs_max])
info += ", MAX_BIT_LEN = %d" % RC_len
return info
def optimize(p, q, a=0, b=0):
a_plus_one = a + 1
a_plus_b = a + b
tangent_cc = {'a_plus_one': a_plus_one}
on_curve_cc = {'a_plus_one': a_plus_one, 'a_plus_b': a_plus_b}
expr_bounds = tangent_expr_bounds
for nb_limbs in range(2, 8):
best_cost = 2**31 print()
for RC_len in range(8, 20+1):
opt = optimization_round(p, q, RC_len, nb_limbs, expr_bounds,
curve_constants=tangent_cc)
if opt == None:
continue
(params, auxiliary_moduli_str) = opt
tangent_cost = cost_tangent(RC_len, params)
params_mul = Params(p, q, params.B, params.auxiliary_moduli, RC_len,
mul_expr_bounds)
mul_cost = cost_mul(RC_len, params_mul)
params_lambda2 = Params(p, q, params.B, params.auxiliary_moduli, RC_len,
lambda2_expr_bounds)
lambda2_cost = cost_lambda2(RC_len, params_lambda2)
params_slope = Params(p, q, params.B, params.auxiliary_moduli, RC_len,
slope_expr_bounds)
slope_cost = cost_slope(RC_len, params_slope)
params_on_curve = Params(p, q, params.B, params.auxiliary_moduli, RC_len,
on_curve_expr_bounds, curve_constants=on_curve_cc)
on_curve_cost = cost_on_curve(RC_len, params_on_curve, mul_cost)
cost = cost_incomplete_point_add(params.B, params.n, RC_len, lambda2_cost, slope_cost)
if cost <= best_cost:
best_cost = cost
info = pp_params(params, RC_len, auxiliary_moduli_str)
print("%d (incomplete_add) | %d (mul) | %d (slope) | %d (λ²) | %d (tangent) | %d (on_curve) \t%s" %
(cost, mul_cost, slope_cost, lambda2_cost, tangent_cost, on_curve_cost, info))
PLUTO_SCALAR = 0x24000000000024000130e0000d7f70e4a803ca76f439266f443f9a5c7a8a6c7be4a775fe8e177fd69ca7e85d60050af41ffffcd300000001
ERIS_SCALAR = 0x24000000000024000130e0000d7f70e4a803ca76f439266f443f9a5cda8a6c7be4a7a5fe8fadffd6a2a7e8c30006b9459ffffcd300000001
CURVES = {
'secp256k1': (0, 7),
'secp256r1': (-3, 0x5ac635d8aa3a93e7b3ebbd55769886bc651d06b0cc53b0f63bce3c3e27d2604b),
'bls12-381': (0, 4),
'bn254': (0, 3),
}
ORDERS = {
'secp256k1-base' : 2**256 - 2**32 - 2**9 - 2**8 - 2**7 - 2**6 - 2**4 - 1,
'secp256k1-scalar' : 0xfffffffffffffffffffffffffffffffebaaedce6af48a03bbfd25e8cd0364141,
'secp256r1-base' : 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff,
'secp256r1-scalar' : 0xffffffff00000000ffffffffffffffffbce6faada7179e84f3b9cac2fc632551,
'pluto-base' : ERIS_SCALAR,
'pluto-scalar' : PLUTO_SCALAR,
'eris-scalar' : ERIS_SCALAR,
'bn254-base' : 0x30644e72e131a029b85045b68181585d97816a916871ca8d3c208c16d87cfd47,
'bn254-scalar' : 0x30644e72e131a029b85045b68181585d2833e84879b9709143e1f593f0000001,
'bls12-381-base': 0x1a0111ea397fe69a4b1ba7b6434bacd764774b84f38512bf6730d2a0f6b0f6241eabfffeb153ffffb9feffffffffaaab,
'bls12-381-scalar': 0x73eda753299d7d483339d80809a1d80553bda402fffe5bfeffffffff00000001,
'curve25519-base': 0x7fffffffffffffffffffffffffffffffffffffffffffffffffffffffffffffed,
'curve25519-scalar': 0x1000000000000000000000000000000014def9dea2f79cd65812631a5cf5d3ed,
}
def parse_modulus(m):
fetched = ORDERS.get(m)
if fetched != None:
return fetched
return eval(m)
def parse_curve(name):
fetched = CURVES.get(name)
if fetched != None:
return fetched
return None
if __name__ == '__main__':
if len(sys.argv) < 3:
keys = "\n".join([" - " + k for k in ORDERS.keys()])
curves = "\n".join([" - " + k for k in CURVES.keys()])
sys.exit("Usage: python3 foreign_params_gen.py NATIVE EMULATED [CURVE | a b]\n"\
"Where NATIVE and EMULATED must be replaced by concrete constants or "\
"one of the following supported values:\n" + keys + "\n\n" \
"Weierstrass coefficients (y^2 = x^3 + ax + b) can be specified as a\n"\
"curve name or explicit a b values (default 0). Supported curves:\n" + curves)
p = parse_modulus(sys.argv[1])
q = parse_modulus(sys.argv[2])
a, b = 0, 0
if len(sys.argv) >= 4:
curve = parse_curve(sys.argv[3])
if curve != None:
a, b = curve
else:
a = int(sys.argv[3], 0)
b = int(sys.argv[4], 0) if len(sys.argv) >= 5 else 0
print("Optimizing parameters for:\n Native modulus: %d\n Emulated modulus: %d" % (p, q))
if a != 0 or b != 0:
print(" Weierstrass coefficient a: %d\n Weierstrass coefficient b: %d" % (a, b))
optimize(p, q, a=a, b=b)