risc0-circuit-rv32im 0.12.0

RISC Zero circuit for rv32im
Documentation
// Copyright 2023 RISC Zero, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

/// \file
/// Defines Fp4, a finite field F_p^4, based on Fp via the irreducable polynomial x^4 - 11.

#include "fp.h"

namespace risc0 {

// Defines instead of constexpr to appease CUDAs limitations around constants.
// undef'd at the end of this file.
#define BETA Fp(11)
#define NBETA Fp(Fp::P - 11)

/// Intstances of Fp4 are element of a finite field F_p^4.  They are represented as elements of
/// F_p[X] / (X^4 - 11). Basically, this is a 'big' finite field (about 2^128 elements), which is
/// used when the security of various operations depends on the size of the field.  It has the field
/// Fp as a subfield, which means operations by the two are compatable, which is important.  The
/// irreducible polynomial was choosen to be the simpilest possible one, x^4 - B, where 11 is the
/// smallest B which makes the polynomial irreducable.
struct Fp4 {
  /// The elements of Fp4, elems[0] + elems[1]*X + elems[2]*X^2 + elems[3]*x^4
  Fp elems[4];

  /// Default constructor makes the zero elements
  constexpr Fp4() {}

  /// Initialize from uint32_t
  explicit constexpr Fp4(uint32_t x) {
    elems[0] = x;
    elems[1] = 0;
    elems[2] = 0;
    elems[3] = 0;
  }

  /// Convert from Fp to Fp4.
  explicit constexpr Fp4(Fp x) {
    elems[0] = x;
    elems[1] = 0;
    elems[2] = 0;
    elems[3] = 0;
  }

  /// Explicitly construct an Fp4 from parts
  constexpr Fp4(Fp a, Fp b, Fp c, Fp d) {
    elems[0] = a;
    elems[1] = b;
    elems[2] = c;
    elems[3] = d;
  }

  /// Get an 'invalid' Fp4 value
  static constexpr inline Fp4 invalid() {
    return Fp4(Fp::invalid(), Fp::invalid(), Fp::invalid(), Fp::invalid());
  }

  // Implement the addition/subtraction overloads
  constexpr Fp4 operator+=(Fp4 rhs) {
    for (uint32_t i = 0; i < 4; i++) {
      elems[i] += rhs.elems[i];
    }
    return *this;
  }

  constexpr Fp4 operator-=(Fp4 rhs) {
    for (uint32_t i = 0; i < 4; i++) {
      elems[i] -= rhs.elems[i];
    }
    return *this;
  }

  constexpr Fp4 operator+(Fp4 rhs) const {
    Fp4 result = *this;
    result += rhs;
    return result;
  }

  constexpr Fp4 operator-(Fp4 rhs) const {
    Fp4 result = *this;
    result -= rhs;
    return result;
  }

  constexpr Fp4 operator-() const { return Fp4() - *this; }

  // Implement the simple multiplication case by the subfield Fp
  // Fp * Fp4 is done as a free function due to C++'s operator overloading rules.
  constexpr Fp4 operator*=(Fp rhs) {
    for (uint32_t i = 0; i < 4; i++) {
      elems[i] *= rhs;
    }
    return *this;
  }

  constexpr Fp4 operator*(Fp rhs) const {
    Fp4 result = *this;
    result *= rhs;
    return result;
  }

  // Now we get to the interesting case of multiplication.  Basically, multiply out the polynomial
  // representations, and then reduce module x^4 - B, which means powers >= 4 get shifted back 4 and
  // multiplied by -beta.  We could write this as a double loops with some if's and hope it gets
  // unrolled properly, but it'a small enough to just hand write.
  constexpr Fp4 operator*(Fp4 rhs) const {
    // Rename the element arrays to something small for readability
#define a elems
#define b rhs.elems
    return Fp4(a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]),
               a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]),
               a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]),
               a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0]);
#undef a
#undef b
  }
  constexpr Fp4 operator*=(Fp4 rhs) {
    *this = *this * rhs;
    return *this;
  }

  // Equality
  constexpr bool operator==(Fp4 rhs) const {
    for (uint32_t i = 0; i < 4; i++) {
      if (elems[i] != rhs.elems[i]) {
        return false;
      }
    }
    return true;
  }

  constexpr bool operator!=(Fp4 rhs) const { return !(*this == rhs); }

  constexpr Fp constPart() const { return elems[0]; }
};

/// Overload for case where LHS is Fp (RHS case is handled as a method)
constexpr inline Fp4 operator*(Fp a, Fp4 b) {
  return b * a;
}

/// Raise an Fp4 to a power
constexpr inline Fp4 pow(Fp4 x, size_t n) {
  Fp4 tot(1);
  while (n != 0) {
    if (n % 2 == 1) {
      tot *= x;
    }
    n = n / 2;
    x *= x;
  }
  return tot;
}

/// Compute the multiplicative inverse of an Fp4.
constexpr inline Fp4 inv(Fp4 in) {
#define a in.elems
  // Compute the multiplicative inverse by basicly looking at Fp4 as a composite field and using the
  // same basic methods used to invert complex numbers.  We imagine that initially we have a
  // numerator of 1, and an denominator of a. i.e out = 1 / a; We set a' to be a with the first and
  // third components negated.  We then multiply the numerator and the denominator by a', producing
  // out = a' / (a * a'). By construction (a * a') has 0's in it's first and third elements.  We
  // call this number, 'b' and compute it as follows.
  Fp b0 = a[0] * a[0] + BETA * (a[1] * (a[3] + a[3]) - a[2] * a[2]);
  Fp b2 = a[0] * (a[2] + a[2]) - a[1] * a[1] + BETA * (a[3] * a[3]);
  // Now, we make b' by inverting b2.  When we muliply both sizes by b', we get out = (a' * b') /
  // (b * b').  But by construcion b * b' is in fact an element of Fp, call it c.
  Fp c = b0 * b0 + BETA * b2 * b2;
  // But we can now invert C direcly, and multiply by a'*b', out = a'*b'*inv(c)
  Fp ic = inv(c);
  // Note: if c == 0 (really should only happen if in == 0), our 'safe' version of inverse results
  // in ic == 0, and thus out = 0, so we have the same 'safe' behavior for Fp4.  Oh, and since we
  // want to multiply everything by ic, it's slightly faster to premultiply the two parts of b by ic
  // (2 multiplies instead of 4)
  b0 *= ic;
  b2 *= ic;
  return Fp4(a[0] * b0 + BETA * a[2] * b2,
             -a[1] * b0 + NBETA * a[3] * b2,
             -a[0] * b2 + a[2] * b0,
             a[1] * b2 - a[3] * b0);
#undef a
}

#undef BETA
#undef NBETA

} // namespace risc0