lambdaworks-math 0.10.0

Modular math library for cryptography
Documentation
// https://github.com/andrewmilson/ministark/blob/main/gpu-poly/src/metal/felt_u256.h.metal

#ifndef felt_u256_h
#define felt_u256_h

#include "../unsigned_integer/u256.h.metal"

template <
    /* =N **/ unsigned long N_0, unsigned long N_1, unsigned long N_2, unsigned long N_3,
    /* =R_SQUARED **/ unsigned long R_SQUARED_0, unsigned long R_SQUARED_1, unsigned long R_SQUARED_2, unsigned long R_SQUARED_3,
    /* =N_PRIME **/ unsigned long N_PRIME_0, unsigned long N_PRIME_1, unsigned long N_PRIME_2, unsigned long N_PRIME_3>
class Fp256 {
public:
    Fp256() = default;
    constexpr Fp256(unsigned long v) : inner(v) {}
    constexpr Fp256(u256 v) : inner(v) {}

    constexpr explicit operator u256() const
    {
        return inner;
    }

    constexpr Fp256 operator+(const Fp256 rhs) const
    {
        return Fp256(add(inner, rhs.inner));
    }

    constexpr Fp256 operator-(const Fp256 rhs) const
    {
        return Fp256(sub(inner, rhs.inner));
    }

    Fp256 operator*(const Fp256 rhs) const
    {
        return Fp256(mul(inner, rhs.inner));
    }

    // TODO: make method for all fields
    Fp256 pow(unsigned exp)
    {
        // TODO find a way to generate on compile time
        Fp256 const ONE = mul(u256(1), R_SQUARED);
        Fp256 res = ONE;

        while (exp > 0)
        {
            if (exp & 1)
            {
                res = res * *this;
            }
            exp >>= 1;
            *this = *this * *this;
        }

        return res;
    }

    Fp256 inverse() 
    {
        // used addchain
        // https://github.com/mmcloughlin/addchain
        u256 _10 = mul(inner, inner);
        u256 _11 = mul(_10, inner);
        u256 _1100 = sqn<2>(_11);
        u256 _1101 = mul(inner, _1100);
        u256 _1111 = mul(_10, _1101);
        u256 _11001 = mul(_1100, _1101);
        u256 _110010 = mul(_11001, _11001);
        u256 _110011 = mul(inner, _110010);
        u256 _1000010 = mul(_1111, _110011);
        u256 _1001110 = mul(_1100, _1000010);
        u256 _10000001 = mul(_110011, _1001110);
        u256 _11001111 = mul(_1001110, _10000001);
        u256 i14 = mul(_11001111, _11001111);
        u256 i15 = mul(_10000001, i14);
        u256 i16 = mul(i14, i15);
        u256 x10 = mul(_1000010, i16);
        u256 i27 = sqn<10>(x10);
        u256 i28 = mul(i16, i27);
        u256 i38 = sqn<10>(i27);
        u256 i39 = mul(i28, i38);
        u256 i49 = sqn<10>(i38);
        u256 i50 = mul(i39, i49);
        u256 i60 = sqn<10>(i49);
        u256 i61 = mul(i50, i60);
        u256 i72 = mul(sqn<10>(i60), i61);
        u256 x60 = mul(_1000010, i72);
        u256 i76 = sqn<2>(mul(i72, x60));
        u256 x64 = mul(mul(i15, i76), i76);
        u256 i208 = mul(sqn<64>(mul(sqn<63>(mul(i15, x64)), x64)), x64);
        return Fp256(mul(sqn<60>(i208), x60));
    }

    Fp256 neg()
    {
        // TODO: can improve
        return Fp256(sub(0, inner));
    }

private:
    u256 inner;

    constexpr static const constant u256 N = u256(N_0, N_1, N_2, N_3);
    constexpr static const constant u256 R_SQUARED = u256(R_SQUARED_0, R_SQUARED_1, R_SQUARED_2, R_SQUARED_3);
    constexpr static const constant u256 N_PRIME = u256(N_PRIME_0, N_PRIME_1, N_PRIME_2, N_PRIME_3);

    // Equates to `(1 << 256) - N`
    constexpr static const constant u256 R_SUB_N =
        u256(0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF, 0xFFFFFFFFFFFFFFFF) - N + u256(1);

    template<unsigned N_ACC>
    u256 sqn(u256 base) const {
        u256 result = base;
#pragma unroll
        for (unsigned i = 0; i < N_ACC; i++) {
            result = mul(result, result);
        }
        return result;
    }

    // Computes `lhs + rhs mod N`
    // Returns value in range [0,N)
    inline u256 add(const u256 lhs, const u256 rhs) const
    {
        u256 addition = (lhs + rhs);
        u256 res = addition;
        // TODO: determine if an if statement here are more optimal
        return res - u256(addition >= N) * N + u256(addition < lhs) * R_SUB_N;
    }

    // Computes `lhs - rhs mod N`
    // Assumes `rhs` value in range [0,N)
    inline u256 sub(const u256 lhs, const u256 rhs) const
    {
        // TODO: figure what goes on here with "constant" scope variables
        return add(lhs, ((u256)N) - rhs);
    }

    // Computes `lhs * rhs mod M`
    //
    // Essential that inputs are already in the range [0,N) and are in montgomery
    // form. Multiplication performs single round of montgomery reduction.
    //
    // Reference:
    // - https://en.wikipedia.org/wiki/Montgomery_modular_multiplication (REDC)
    // - https://www.youtube.com/watch?v=2UmQDKcelBQ
    u256 mul(const u256 lhs, const u256 rhs) const
    {
        u256 lhs_low = lhs.low;
        u256 lhs_high = lhs.high;
        u256 rhs_low = rhs.low;
        u256 rhs_high = rhs.high;

        u256 partial_t_high = lhs_high * rhs_high;
        u256 partial_t_mid_a = lhs_high * rhs_low;
        u256 partial_t_mid_a_low = partial_t_mid_a.low;
        u256 partial_t_mid_a_high = partial_t_mid_a.high;
        u256 partial_t_mid_b = rhs_high * lhs_low;
        u256 partial_t_mid_b_low = partial_t_mid_b.low;
        u256 partial_t_mid_b_high = partial_t_mid_b.high;
        u256 partial_t_low = lhs_low * rhs_low;

        u256 tmp = partial_t_mid_a_low +
                   partial_t_mid_b_low + partial_t_low.high;
        u256 carry = tmp.high;
        u256 t_low = u256(tmp.low, partial_t_low.low);
        u256 t_high = partial_t_high + partial_t_mid_a_high + partial_t_mid_b_high + carry;

        // Compute `m = T * N' mod R`
        u256 m = t_low * N_PRIME;

        // Compute `t = (T + m * N) / R`
        u256 n = N;
        u256 n_low = n.low;
        u256 n_high = n.high;
        u256 m_low = m.low;
        u256 m_high = m.high;

        u256 partial_mn_high = m_high * n_high;
        u256 partial_mn_mid_a = m_high * n_low;
        u256 partial_mn_mid_a_low = partial_mn_mid_a.low;
        u256 partial_mn_mid_a_high = partial_mn_mid_a.high;
        u256 partial_mn_mid_b = n_high * m_low;
        u256 partial_mn_mid_b_low = partial_mn_mid_b.low;
        u256 partial_mn_mid_b_high = partial_mn_mid_b.high;
        u256 partial_mn_low = m_low * n_low;

        tmp = partial_mn_mid_a_low + partial_mn_mid_b_low + u256(partial_mn_low.high);
        carry = tmp.high;
        u256 mn_low = u256(tmp.low, partial_mn_low.low);
        u256 mn_high = partial_mn_high + partial_mn_mid_a_high + partial_mn_mid_b_high + carry;

        u256 overflow = mn_low + t_low < mn_low;
        u256 t_tmp = t_high + overflow;
        u256 t = t_tmp + mn_high;
        u256 overflows_r = t < t_tmp;
        u256 overflows_modulus = t >= N;

        return t + overflows_r * R_SUB_N - overflows_modulus * N;
    }
};



#endif