1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
// Copyright 2022 Risc0, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
/// \file
/// Defines Fp4, a finite field F_p^4, based on Fp via the irreducable polynomial x^4 - 11.
#include "risc0/zkp/core/fp.h"
namespace risc0 {
// Defines instead of constexpr to appease CUDAs limitations around constants.
// undef'd at the end of this file.
#define BETA Fp(11)
#define NBETA Fp(Fp::P - 11)
/// Intstances of Fp4 are element of a finite field F_p^4. They are represented as elements of
/// F_p[X] / (X^4 - 11). Basically, this is a 'big' finite field (about 2^128 elements), which is
/// used when the security of various operations depends on the size of the field. It has the field
/// Fp as a subfield, which means operations by the two are compatable, which is important. The
/// irreducible polynomial was choosen to be the simpilest possible one, x^4 - B, where 11 is the
/// smallest B which makes the polynomial irreducable.
struct Fp4 {
/// The elements of Fp4, elems[0] + elems[1]*X + elems[2]*X^2 + elems[3]*x^4
Fp elems[4];
/// Default constructor makes the zero elements
DEVSPEC constexpr Fp4() {}
/// Initialize from uint32_t
DEVSPEC explicit constexpr Fp4(uint32_t x) {
elems[0] = x;
elems[1] = 0;
elems[2] = 0;
elems[3] = 0;
}
/// Convert from Fp to Fp4.
DEVSPEC explicit constexpr Fp4(Fp x) {
elems[0] = x;
elems[1] = 0;
elems[2] = 0;
elems[3] = 0;
}
/// Explicitly construct an Fp4 from parts
DEVSPEC constexpr Fp4(Fp a, Fp b, Fp c, Fp d) {
elems[0] = a;
elems[1] = b;
elems[2] = c;
elems[3] = d;
}
/// Generate a random field element uniformly
template <typename Rng> static Fp4 random(DEVADDR Rng& rng) {
return Fp4(Fp::random(rng), Fp::random(rng), Fp::random(rng), Fp::random(rng));
}
// Implement the addition/subtraction overloads
DEVSPEC constexpr Fp4 operator+=(Fp4 rhs) {
for (uint32_t i = 0; i < 4; i++) {
elems[i] += rhs.elems[i];
}
return *this;
}
DEVSPEC constexpr Fp4 operator-=(Fp4 rhs) {
for (uint32_t i = 0; i < 4; i++) {
elems[i] -= rhs.elems[i];
}
return *this;
}
DEVSPEC constexpr Fp4 operator+(Fp4 rhs) const {
Fp4 result = *this;
result += rhs;
return result;
}
DEVSPEC constexpr Fp4 operator-(Fp4 rhs) const {
Fp4 result = *this;
result -= rhs;
return result;
}
DEVSPEC constexpr Fp4 operator-() const { return Fp4() - *this; }
// Implement the simple multiplication case by the subfield Fp
// Fp * Fp4 is done as a free function due to C++'s operator overloading rules.
DEVSPEC constexpr Fp4 operator*=(Fp rhs) {
for (uint32_t i = 0; i < 4; i++) {
elems[i] *= rhs;
}
return *this;
}
DEVSPEC constexpr Fp4 operator*(Fp rhs) const {
Fp4 result = *this;
result *= rhs;
return result;
}
// Now we get to the interesting case of multiplication. Basically, multiply out the polynomial
// representations, and then reduce module x^4 - B, which means powers >= 4 get shifted back 4 and
// multiplied by -beta. We could write this as a double loops with some if's and hope it gets
// unrolled properly, but it'a small enough to just hand write.
DEVSPEC constexpr Fp4 operator*(Fp4 rhs) const {
// Rename the element arrays to something small for readability
#define a elems
#define b rhs.elems
return Fp4(a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]),
a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]),
a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]),
a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0]);
#undef a
#undef b
}
DEVSPEC constexpr Fp4 operator*=(Fp4 rhs) {
*this = *this * rhs;
return *this;
}
// Equality
DEVSPEC constexpr bool operator==(Fp4 rhs) const {
for (uint32_t i = 0; i < 4; i++) {
if (elems[i] != rhs.elems[i]) {
return false;
}
}
return true;
}
DEVSPEC constexpr bool operator!=(Fp4 rhs) const { return !(*this == rhs); }
DEVSPEC constexpr Fp constPart() const { return elems[0]; }
#ifdef METAL
constexpr Fp4 operator+(Fp4 rhs) device const { return Fp4(*this) + rhs; }
constexpr Fp4 operator-(Fp4 rhs) device const { return Fp4(*this) - rhs; }
constexpr Fp4 operator-() device const { return -Fp4(*this); }
constexpr Fp4 operator*(Fp4 rhs) device const { return Fp4(*this) * rhs; }
constexpr Fp4 operator*(Fp rhs) device const { return Fp4(*this) * rhs; }
constexpr bool operator==(Fp4 rhs) device const { return Fp4(*this) == rhs; }
constexpr bool operator!=(Fp4 rhs) device const { return Fp4(*this) != rhs; }
constexpr Fp constPart() device const { return Fp4(*this).constPart(); }
#endif
};
/// Overload for case where LHS is Fp (RHS case is handled as a method)
DEVSPEC constexpr inline Fp4 operator*(Fp a, Fp4 b) {
return b * a;
}
/// ostream support for Fp values
#ifdef CPU
inline std::ostream& operator<<(std::ostream& os, const Fp4& x) {
os << x.elems[0] << "+" << x.elems[1] << "x+" << x.elems[2] << "x^2+" << x.elems[3] << "x^3";
return os;
}
#endif
/// Raise an Fp4 to a power
DEVSPEC constexpr inline Fp4 pow(Fp4 x, size_t n) {
Fp4 tot(1);
while (n != 0) {
if (n % 2 == 1) {
tot *= x;
}
n = n / 2;
x *= x;
}
return tot;
}
/// Compute the multiplicative inverse of an Fp4.
DEVSPEC constexpr inline Fp4 inv(Fp4 in) {
#define a in.elems
// Compute the multiplicative inverse by basicly looking at Fp4 as a composite field and using the
// same basic methods used to invert complex numbers. We imagine that initially we have a
// numerator of 1, and an denominator of a. i.e out = 1 / a; We set a' to be a with the first and
// third components negated. We then multiply the numerator and the denominator by a', producing
// out = a' / (a * a'). By construction (a * a') has 0's in it's first and third elements. We
// call this number, 'b' and compute it as follows.
Fp b0 = a[0] * a[0] + BETA * (a[1] * (a[3] + a[3]) - a[2] * a[2]);
Fp b2 = a[0] * (a[2] + a[2]) - a[1] * a[1] + BETA * (a[3] * a[3]);
// Now, we make b' by inverting b2. When we muliply both sizes by b', we get out = (a' * b') /
// (b * b'). But by construcion b * b' is in fact an element of Fp, call it c.
Fp c = b0 * b0 + BETA * b2 * b2;
// But we can now invert C direcly, and multiply by a'*b', out = a'*b'*inv(c)
Fp ic = inv(c);
// Note: if c == 0 (really should only happen if in == 0), our 'safe' version of inverse results
// in ic == 0, and thus out = 0, so we have the same 'safe' behavior for Fp4. Oh, and since we
// want to multiply everything by ic, it's slightly faster to premultiply the two parts of b by ic
// (2 multiplies instead of 4)
b0 *= ic;
b2 *= ic;
return Fp4(a[0] * b0 + BETA * a[2] * b2,
-a[1] * b0 + NBETA * a[3] * b2,
-a[0] * b2 + a[2] * b0,
a[1] * b2 - a[3] * b0);
#undef a
}
#undef BETA
#undef NBETA
} // namespace risc0