risc0-zkp-core 0.4.0

RISC Zero zero-knowledge proof system core crate
Documentation
// Copyright 2022 Risc0, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

/// \file
/// Defines the core finite field data type, Fp, and some free functions on the type.

#include "risc0/zkp/core/devs.h"
#include <limits>

// Determine whether to use montgomery representation or direct representation.  Currently
// montgomry seems faster in most cases, might want to revisit with more testing
#define DIRECT 0

namespace risc0 {

/// The Fp class is an element of the finite field F_p, where P is the prime number 15*2^27 + 1.
/// Put another way, Fp is basically integer arithmetic modulo P.
///
/// The 'Fp' datatype is the core type of all of the operations done within the zero knowledge
/// proofs, and is smallest 'addressable' datatype, and the base type of which all composite types
/// are built.  In many ways, one can imagine it as the word size of a very strange architecture.
///
/// This specific prime P was chosen to:
/// - Be less than 2^31 so that it fits within a 32 bit word and doesn't overflow on addition.
/// - Otherwise have as large a power of 2 in the factors of P-1 as possible.
///
/// This last property is useful for number theoretical transforms (the fast fourier transform
/// equivelant on finite fields).  See NTT.h for details.
///
/// The Fp class wraps all the standard arithmatic operations to make the finite field elements look
/// basically like ordinary numbers (which they mostly are).
class Fp {
public:
  /// The value of P, the modulus of Fp.
  static CONSTSCALAR uint32_t P = 15 * (uint32_t(1) << 27) + 1;
  static CONSTSCALAR uint32_t M = 0x88000001;
  static CONSTSCALAR uint32_t R2 = 1172168163;

private:
  // The actual value, always < P.
  uint32_t val;

  // We make 'impls' of the core ops which all the other uses call.  This is done to allow for
  // tweaking of the implementation later, for example switching to montgomery representation or
  // doing inline assembly or some crazy CUDA stuff.

  // Add two numbers
  static DEVSPEC constexpr uint32_t add(uint32_t a, uint32_t b) {
    uint32_t r = a + b;
    return (r >= P ? r - P : r);
  }

  // Subtract two numbers
  static DEVSPEC constexpr uint32_t sub(uint32_t a, uint32_t b) {
    uint32_t r = a - b;
    return (r > P ? r + P : r);
  }

  // Multiply two numbers
  static DEVSPEC constexpr uint32_t mul(uint32_t a, uint32_t b) {
#if DIRECT
    return (uint64_t(a) * uint64_t(b)) % P;
#else
    uint64_t o64 = uint64_t(a) * uint64_t(b);
    uint32_t low = -uint32_t(o64);
    uint32_t red = M * low;
    o64 += uint64_t(red) * uint64_t(P);
    uint32_t ret = o64 >> 32;
    return (ret >= P ? ret - P : ret);
#endif
  }

  // Encode / Decode
  static DEVSPEC constexpr uint32_t encode(uint32_t a) {
#if DIRECT
    return a;
#else
    return mul(R2, a);
#endif
  }
  static DEVSPEC constexpr uint32_t decode(uint32_t a) {
#if DIRECT
    return a;
#else
    return mul(1, a);
#endif
  }

  // A private constructor that take the 'interal' form.
  DEVSPEC constexpr Fp(uint32_t val, bool ignore) : val(val) {}

public:
  /// Default constructor, sets value to 0.
  DEVSPEC constexpr Fp() : val(0) {}
  /// Construct an FP from a uint32_t, wrap if needed
  DEVSPEC constexpr Fp(uint32_t val) : val(encode(val)) {}

  /// Convert to a unit32_t
  DEVSPEC constexpr uint32_t asUInt32() const { return decode(val); }
#ifdef METAL
  constexpr uint32_t asUInt32() device const { return decode(val); }
#endif
  /// Get the largest value, basically P - 1.
  DEVSPEC static constexpr Fp maxVal() { return P - 1; }
  /// Get an 'invalid' Fp value
  DEVSPEC static constexpr Fp invalid() { return Fp(0xfffffffful, true); }
  /// Generate a uniform random value.
  template <typename Rng> static Fp random(DEVADDR Rng& rng) {
    // Reject the last modulo-P region of possible uint32_t values, since it's uneven
    // and will only return random values less than (2^32 % P).
    constexpr uint32_t reject_cutoff = (std::numeric_limits<uint32_t>::max() / P) * P;

    uint32_t val;
    do {
      val = rng.generate();
    } while (val >= reject_cutoff);
    return val % P;
  }

  // Implement all the various overloads
  DEVSPEC constexpr Fp operator+(Fp rhs) const { return Fp(add(val, rhs.val), true); }
  DEVSPEC constexpr Fp operator-() const { return Fp(sub(0, val), true); }
  DEVSPEC constexpr Fp operator-(Fp rhs) const { return Fp(sub(val, rhs.val), true); }
  DEVSPEC constexpr Fp operator*(Fp rhs) const { return Fp(mul(val, rhs.val), true); }
#ifdef METAL
  constexpr Fp operator+(Fp rhs) device const { return Fp(add(val, rhs.val), true); }
  constexpr Fp operator-() device const { return Fp(sub(0, val), true); }
  constexpr Fp operator-(Fp rhs) device const { return Fp(sub(val, rhs.val), true); }
  constexpr Fp operator*(Fp rhs) device const { return Fp(mul(val, rhs.val), true); }
#endif

  DEVSPEC constexpr Fp operator+=(Fp rhs) {
    val = add(val, rhs.val);
    return *this;
  }
  DEVSPEC constexpr Fp operator-=(Fp rhs) {
    val = sub(val, rhs.val);
    return *this;
  }
  DEVSPEC constexpr Fp operator*=(Fp rhs) {
    val = mul(val, rhs.val);
    return *this;
  }
  DEVSPEC constexpr bool operator==(Fp rhs) const { return val == rhs.val; }
  DEVSPEC constexpr bool operator!=(Fp rhs) const { return val != rhs.val; }
  DEVSPEC constexpr bool operator<(Fp rhs) const { return decode(val) < decode(rhs.val); }
  DEVSPEC constexpr bool operator<=(Fp rhs) const { return decode(val) <= decode(rhs.val); }
  DEVSPEC constexpr bool operator>(Fp rhs) const { return decode(val) > decode(rhs.val); }
  DEVSPEC constexpr bool operator>=(Fp rhs) const { return decode(val) >= decode(rhs.val); }
#ifdef METAL
  constexpr bool operator==(Fp rhs) device const { return val == rhs.val; }
  constexpr bool operator!=(Fp rhs) device const { return val != rhs.val; }
  constexpr bool operator<(Fp rhs) device const { return decode(val) < decode(rhs.val); }
  constexpr bool operator<=(Fp rhs) device const { return decode(val) <= decode(rhs.val); }
  constexpr bool operator>(Fp rhs) device const { return decode(val) > decode(rhs.val); }
  constexpr bool operator>=(Fp rhs) device const { return decode(val) >= decode(rhs.val); }
#endif

  // Post-inc/dec
  DEVSPEC constexpr Fp operator++(int) {
    Fp r = *this;
    val = add(val, encode(1));
    return r;
  }
  DEVSPEC constexpr Fp operator--(int) {
    Fp r = *this;
    val = sub(val, encode(1));
    return r;
  }

  // Pre-inc/dec
  DEVSPEC constexpr Fp operator++() {
    val = add(val, encode(1));
    return *this;
  }
  DEVSPEC constexpr Fp operator--() {
    val = sub(val, encode(1));
    return *this;
  }

#ifdef CPU
  std::string str() { return std::to_string(decode(val)); }
#endif
};

/// ostream support for Fp values, only for CPU
#ifdef CPU
inline std::ostream& operator<<(std::ostream& os, const Fp& x) {
  os << x.asUInt32();
  return os;
}
#endif

/// Raise an value to a power
DEVSPEC constexpr inline Fp pow(Fp x, size_t n) {
  Fp tot = 1;
  while (n != 0) {
    if (n % 2 == 1) {
      tot *= x;
    }
    n = n / 2;
    x *= x;
  }
  return tot;
}

/// Compute the multiplicative inverse of x, or `1/x` in finite field terms.  Since `x^(P-1) == 1
/// (mod P)` for any x != 0 (as a consequence of Fermat's little therorm), it follows that `x *
/// x^(P-2) == 1 (mod P)` for x != 0.  That is, `x^(P-2)` is the multiplicative inverse of x.
/// Computed this way, the 'inverse' of zero comes out as zero, which is convient in many cases, so
/// we leave it.
DEVSPEC constexpr inline Fp inv(Fp x) {
  return pow(x, Fp::P - 2);
}

} // namespace risc0