// https://github.com/andrewmilson/ministark/blob/main/gpu-poly/src/metal/felt_u256.h.metal
#ifndef felt_u256_h
#define felt_u256_h
#include "../unsigned_integer/u256.cuh"
template <
/* =N **/ unsigned long N_0, unsigned long N_1, unsigned long N_2,
unsigned long N_3,
/* =R_SQUARED **/ unsigned long R_SQUARED_0, unsigned long R_SQUARED_1,
unsigned long R_SQUARED_2, unsigned long R_SQUARED_3,
/* =N_PRIME **/ unsigned long N_PRIME_0, unsigned long N_PRIME_1,
unsigned long N_PRIME_2, unsigned long N_PRIME_3>
class Fp256 {
public:
Fp256() = default;
__device__ constexpr Fp256(unsigned long v) : inner(v) {}
__device__ constexpr Fp256(u256 v) : inner(v) {}
__device__ constexpr explicit operator u256() const { return inner; }
__device__ constexpr Fp256 operator+(const Fp256 rhs) const {
return Fp256(add(inner, rhs.inner));
}
__device__ constexpr Fp256 operator-(const Fp256 rhs) const {
return Fp256(sub(inner, rhs.inner));
}
__device__ Fp256 operator*(const Fp256 rhs) const {
return Fp256(mul(inner, rhs.inner));
}
// TODO: make method for all fields
__device__ Fp256 pow(unsigned exp) {
// TODO find a way to generate on compile time
Fp256 const ONE = mul(u256(1), R_SQUARED);
Fp256 res = ONE;
while (exp > 0) {
if (exp & 1) {
res = res * *this;
}
exp >>= 1;
*this = *this * *this;
}
return res;
}
__device__ Fp256 inverse() {
// used addchain
// https://github.com/mmcloughlin/addchain
u256 _10 = mul(inner, inner);
u256 _11 = mul(_10, inner);
u256 _1100 = sqn<2>(_11);
u256 _1101 = mul(inner, _1100);
u256 _1111 = mul(_10, _1101);
u256 _11001 = mul(_1100, _1101);
u256 _110010 = mul(_11001, _11001);
u256 _110011 = mul(inner, _110010);
u256 _1000010 = mul(_1111, _110011);
u256 _1001110 = mul(_1100, _1000010);
u256 _10000001 = mul(_110011, _1001110);
u256 _11001111 = mul(_1001110, _10000001);
u256 i14 = mul(_11001111, _11001111);
u256 i15 = mul(_10000001, i14);
u256 i16 = mul(i14, i15);
u256 x10 = mul(_1000010, i16);
u256 i27 = sqn<10>(x10);
u256 i28 = mul(i16, i27);
u256 i38 = sqn<10>(i27);
u256 i39 = mul(i28, i38);
u256 i49 = sqn<10>(i38);
u256 i50 = mul(i39, i49);
u256 i60 = sqn<10>(i49);
u256 i61 = mul(i50, i60);
u256 i72 = mul(sqn<10>(i60), i61);
u256 x60 = mul(_1000010, i72);
u256 i76 = sqn<2>(mul(i72, x60));
u256 x64 = mul(mul(i15, i76), i76);
u256 i208 = mul(sqn<64>(mul(sqn<63>(mul(i15, x64)), x64)), x64);
return Fp256(mul(sqn<60>(i208), x60));
}
__device__ Fp256 neg() {
// TODO: can improve
return Fp256(sub(0, inner));
}
private:
u256 inner;
constexpr static const u256 N = u256(N_0, N_1, N_2, N_3);
constexpr static const u256 R_SQUARED =
u256(R_SQUARED_0, R_SQUARED_1, R_SQUARED_2, R_SQUARED_3);
constexpr static const u256 N_PRIME =
u256(N_PRIME_0, N_PRIME_1, N_PRIME_2, N_PRIME_3);
// Equates to `(1 << 256) - N`
constexpr static const u256 R_SUB_N =
u256(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF,
0xFFFFFFFFFFFFFFFF) -
N + u256(1);
template <unsigned N_ACC> __device__ u256 sqn(u256 base) const {
u256 result = base;
#pragma unroll
for (unsigned i = 0; i < N_ACC; i++) {
result = mul(result, result);
}
return result;
}
// Computes `lhs + rhs mod N`
// Returns value in range [0,N)
__device__ inline u256 add(const u256 lhs, const u256 rhs) const {
u256 addition = (lhs + rhs);
u256 res = addition;
// TODO: determine if an if statement here are more optimal
return res - u256(addition >= N) * N + u256(addition < lhs) * R_SUB_N;
}
// Computes `lhs - rhs mod N`
// Assumes `rhs` value in range [0,N)
__device__ inline u256 sub(const u256 lhs, const u256 rhs) const {
// TODO: figure what goes on here with "constant" scope variables
return add(lhs, ((u256)N) - rhs);
}
// Computes `lhs * rhs mod M`
//
// Essential that inputs are already in the range [0,N) and are in montgomery
// form. Multiplication performs single round of montgomery reduction.
//
// Reference:
// - https://en.wikipedia.org/wiki/Montgomery_modular_multiplication (REDC)
// - https://www.youtube.com/watch?v=2UmQDKcelBQ
__device__ u256 mul(const u256 lhs, const u256 rhs) const {
u256 lhs_low = lhs.low;
u256 lhs_high = lhs.high;
u256 rhs_low = rhs.low;
u256 rhs_high = rhs.high;
u256 partial_t_high = lhs_high * rhs_high;
u256 partial_t_mid_a = lhs_high * rhs_low;
u256 partial_t_mid_a_low = partial_t_mid_a.low;
u256 partial_t_mid_a_high = partial_t_mid_a.high;
u256 partial_t_mid_b = rhs_high * lhs_low;
u256 partial_t_mid_b_low = partial_t_mid_b.low;
u256 partial_t_mid_b_high = partial_t_mid_b.high;
u256 partial_t_low = lhs_low * rhs_low;
u256 tmp = partial_t_mid_a_low + partial_t_mid_b_low + partial_t_low.high;
u256 carry = tmp.high;
u256 t_low = u256(tmp.low, partial_t_low.low);
u256 t_high =
partial_t_high + partial_t_mid_a_high + partial_t_mid_b_high + carry;
// Compute `m = T * N' mod R`
u256 m = t_low * N_PRIME;
// Compute `t = (T + m * N) / R`
u256 n = N;
u256 n_low = n.low;
u256 n_high = n.high;
u256 m_low = m.low;
u256 m_high = m.high;
u256 partial_mn_high = m_high * n_high;
u256 partial_mn_mid_a = m_high * n_low;
u256 partial_mn_mid_a_low = partial_mn_mid_a.low;
u256 partial_mn_mid_a_high = partial_mn_mid_a.high;
u256 partial_mn_mid_b = n_high * m_low;
u256 partial_mn_mid_b_low = partial_mn_mid_b.low;
u256 partial_mn_mid_b_high = partial_mn_mid_b.high;
u256 partial_mn_low = m_low * n_low;
tmp =
partial_mn_mid_a_low + partial_mn_mid_b_low + u256(partial_mn_low.high);
carry = tmp.high;
u256 mn_low = u256(tmp.low, partial_mn_low.low);
u256 mn_high =
partial_mn_high + partial_mn_mid_a_high + partial_mn_mid_b_high + carry;
u256 overflow = mn_low + t_low < mn_low;
u256 t_tmp = t_high + overflow;
u256 t = t_tmp + mn_high;
u256 overflows_r = t < t_tmp;
u256 overflows_modulus = t >= N;
return t + overflows_r * R_SUB_N - overflows_modulus * N;
}
};
namespace p256 {
// StarkWare field for Cairo
// P =
// 3618502788666131213697322783095070105623107215331596699973092056135872020481
using Fp = Fp256<
/* =N **/ /*u256(*/ 576460752303423505, 0, 0, 1 /*)*/,
/* =R_SQUARED **/ /*u256(*/ 576413109808302096, 18446744073700081664,
5151653887, 18446741271209837569 /*)*/,
/* =N_PRIME **/ /*u256(*/ 576460752303423504, 18446744073709551615,
18446744073709551615, 18446744073709551615 /*)*/
>;
} // namespace p256
#endif