#pragma once
#include "risc0/zkp/core/devs.h"
#define DIRECT 0
namespace risc0 {
class Fp {
public:
static CONSTSCALAR uint32_t P = 15 * (uint32_t(1) << 27) + 1;
static CONSTSCALAR uint32_t M = 0x88000001;
static CONSTSCALAR uint32_t R2 = 1172168163;
private:
uint32_t val;
static DEVSPEC constexpr uint32_t add(uint32_t a, uint32_t b) {
uint32_t r = a + b;
return (r >= P ? r - P : r);
}
static DEVSPEC constexpr uint32_t sub(uint32_t a, uint32_t b) {
uint32_t r = a - b;
return (r > P ? r + P : r);
}
static DEVSPEC constexpr uint32_t mul(uint32_t a, uint32_t b) {
#if DIRECT
return (uint64_t(a) * uint64_t(b)) % P;
#else
uint64_t o64 = uint64_t(a) * uint64_t(b);
uint32_t low = -uint32_t(o64);
uint32_t red = M * low;
o64 += uint64_t(red) * uint64_t(P);
uint32_t ret = o64 >> 32;
return (ret >= P ? ret - P : ret);
#endif
}
static DEVSPEC constexpr uint32_t encode(uint32_t a) {
#if DIRECT
return a;
#else
return mul(R2, a);
#endif
}
static DEVSPEC constexpr uint32_t decode(uint32_t a) {
#if DIRECT
return a;
#else
return mul(1, a);
#endif
}
DEVSPEC constexpr Fp(uint32_t val, bool ignore) : val(val) {}
public:
DEVSPEC constexpr Fp() : val(0) {}
DEVSPEC constexpr Fp(uint32_t val) : val(encode(val)) {}
DEVSPEC constexpr uint32_t asUInt32() const { return decode(val); }
#ifdef METAL
constexpr uint32_t asUInt32() device const { return decode(val); }
#endif
DEVSPEC static constexpr Fp maxVal() { return P - 1; }
DEVSPEC static constexpr Fp invalid() { return Fp(0xfffffffful, true); }
template <typename Rng> static Fp random(DEVADDR Rng& rng) {
uint64_t val = uint64_t(rng.generate()) << 32 | rng.generate();
while (val + P < val) { val = rng.generate(); }
return val % uint64_t(P);
}
DEVSPEC constexpr Fp operator+(Fp rhs) const { return Fp(add(val, rhs.val), true); }
DEVSPEC constexpr Fp operator-() const { return Fp(sub(0, val), true); }
DEVSPEC constexpr Fp operator-(Fp rhs) const { return Fp(sub(val, rhs.val), true); }
DEVSPEC constexpr Fp operator*(Fp rhs) const { return Fp(mul(val, rhs.val), true); }
#ifdef METAL
constexpr Fp operator+(Fp rhs) device const { return Fp(add(val, rhs.val), true); }
constexpr Fp operator-() device const { return Fp(sub(0, val), true); }
constexpr Fp operator-(Fp rhs) device const { return Fp(sub(val, rhs.val), true); }
constexpr Fp operator*(Fp rhs) device const { return Fp(mul(val, rhs.val), true); }
#endif
DEVSPEC constexpr Fp operator+=(Fp rhs) {
val = add(val, rhs.val);
return *this;
}
DEVSPEC constexpr Fp operator-=(Fp rhs) {
val = sub(val, rhs.val);
return *this;
}
DEVSPEC constexpr Fp operator*=(Fp rhs) {
val = mul(val, rhs.val);
return *this;
}
DEVSPEC constexpr bool operator==(Fp rhs) const { return val == rhs.val; }
DEVSPEC constexpr bool operator!=(Fp rhs) const { return val != rhs.val; }
DEVSPEC constexpr bool operator<(Fp rhs) const { return decode(val) < decode(rhs.val); }
DEVSPEC constexpr bool operator<=(Fp rhs) const { return decode(val) <= decode(rhs.val); }
DEVSPEC constexpr bool operator>(Fp rhs) const { return decode(val) > decode(rhs.val); }
DEVSPEC constexpr bool operator>=(Fp rhs) const { return decode(val) >= decode(rhs.val); }
#ifdef METAL
constexpr bool operator==(Fp rhs) device const { return val == rhs.val; }
constexpr bool operator!=(Fp rhs) device const { return val != rhs.val; }
constexpr bool operator<(Fp rhs) device const { return decode(val) < decode(rhs.val); }
constexpr bool operator<=(Fp rhs) device const { return decode(val) <= decode(rhs.val); }
constexpr bool operator>(Fp rhs) device const { return decode(val) > decode(rhs.val); }
constexpr bool operator>=(Fp rhs) device const { return decode(val) >= decode(rhs.val); }
#endif
DEVSPEC constexpr Fp operator++(int) {
Fp r = *this;
val = add(val, encode(1));
return r;
}
DEVSPEC constexpr Fp operator--(int) {
Fp r = *this;
val = sub(val, encode(1));
return r;
}
DEVSPEC constexpr Fp operator++() {
val = add(val, encode(1));
return *this;
}
DEVSPEC constexpr Fp operator--() {
val = sub(val, encode(1));
return *this;
}
#ifdef CPU
std::string str() { return std::to_string(decode(val)); }
#endif
};
#ifdef CPU
inline std::ostream& operator<<(std::ostream& os, const Fp& x) {
os << x.asUInt32();
return os;
}
#endif
DEVSPEC constexpr inline Fp pow(Fp x, size_t n) {
Fp tot = 1;
while (n != 0) {
if (n % 2 == 1) {
tot *= x;
}
n = n / 2;
x *= x;
}
return tot;
}
DEVSPEC constexpr inline Fp inv(Fp x) {
return pow(x, Fp::P - 2);
}
}