risc0_core/field/
mod.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Finite field types and operations
16//!
17//! Defines base fields and extension fields used for finite field-based
18//! operations across the RISC Zero zkVM architecture
19
20use alloc::vec::Vec;
21use core::{cmp, fmt::Debug, ops};
22
23pub mod baby_bear;
24
25/// A pair of fields, one of which is an extension field of the other.
26pub trait Field {
27    /// An element of the base field
28    type Elem: Elem + RootsOfUnity;
29    /// An element of the extension field
30    type ExtElem: ExtElem<SubElem = Self::Elem>;
31}
32
33/// Subfield elements that can be compared, copied, and operated
34/// on via multiplication, addition, and subtraction
35pub trait Elem: 'static
36    + Clone
37    + Copy
38    + Send
39    + Sync
40    + Debug
41    + Sized
42    + ops::Neg<Output = Self>
43    + ops::SubAssign
44    + cmp::PartialEq
45    + cmp::Eq
46    + core::clone::Clone
47    + core::marker::Copy
48    + bytemuck::NoUninit
49    + bytemuck::CheckedBitPattern
50    + core::default::Default
51    // Operators for Elem (op) Elem -> Elem
52    + ops::Add<Self, Output = Self>
53    + ops::Sub<Self, Output = Self>
54    + ops::Mul<Self, Output = Self>
55    // Operators for Elem (op)= Elem
56    + ops::AddAssign<Self>
57    + ops::SubAssign<Self>
58    + ops::MulAssign<Self>
59{
60    /// Invalid, a value that is not a member of the field.  This
61    /// should only be used with the "is_valid" or "unwrap_or_zero"
62    /// methods.
63    const INVALID: Self;
64
65    /// Zero, the additive identity.
66    const ZERO: Self;
67
68    /// One, the multiplicative identity.
69    const ONE: Self;
70
71    /// How many u32 words are required to hold a single element
72    const WORDS: usize;
73
74    /// Compute the multiplicative inverse of `x` (or `1 / x` in finite field
75    /// terms).
76    fn inv(self) -> Self;
77
78    /// Return an element raised to the given power.
79    fn pow(self, exp: usize) -> Self {
80        debug_assert!(self.is_valid());
81        let mut n = exp;
82        let mut tot = Self::ONE;
83        let mut x = self;
84        while n != 0 {
85            if n % 2 == 1 {
86                tot *= x;
87            }
88            n /= 2;
89            x *= x;
90        }
91        tot
92    }
93
94    /// Returns a random valid field element.
95    fn random(rng: &mut impl rand_core::RngCore) -> Self;
96
97    /// Import a number into the field from the natural numbers.
98    fn from_u64(val: u64) -> Self;
99
100    /// Represent a field element as a sequence of u32s
101    fn to_u32_words(&self) -> Vec<u32>;
102
103    /// Interpret a sequence of u32s as a field element
104    fn from_u32_words(val: &[u32]) -> Self;
105
106    /// Returns true if this element is not INVALID.  Unlike most
107    /// methods, this may be called on an INVALID element.
108    fn is_valid(&self) -> bool;
109
110    /// Returns true if this element is represented in reduced/normalized form.
111    /// Every element has exactly one reduced form. For a field of prime order
112    /// P, this typically means the underlying data is < P, and for an extension
113    /// field, this typically means every component is in reduced form.
114    fn is_reduced(&self) -> bool;
115
116    /// Returns 0 if this element is INVALID, else the value of this
117    /// element.  Unlike most methods, this may be called on an
118    /// INVALID element.
119    fn valid_or_zero(&self) -> Self {
120        if self.is_valid() {
121            *self
122        } else {
123            Self::ZERO
124        }
125    }
126
127    /// Returns this element, but checks to make sure it's valid.
128    fn ensure_valid(&self) -> &Self {
129        debug_assert!(self.is_valid());
130        self
131    }
132
133    /// Returns this element, but checks to make sure it's in reduced form.
134    fn ensure_reduced(&self) -> &Self {
135        assert!(self.is_reduced());
136        self
137    }
138
139    /// Interprets a slice of these elements as u32s.  These elements
140    /// may not be INVALID.
141    fn as_u32_slice(elems: &[Self]) -> &[u32] {
142        if cfg!(debug_assertions) {
143            for elem in elems {
144                elem.ensure_valid();
145            }
146        }
147        Self::as_u32_slice_unchecked(elems)
148    }
149
150    /// Interprets a slice of these elements as u32s.  These elements
151    /// may potentially be INVALID.
152    fn as_u32_slice_unchecked(elems: &[Self]) -> &[u32] {
153        bytemuck::cast_slice(elems)
154    }
155
156    /// Interprets a slice of u32s as a slice of these elements.
157    /// These elements may not be INVALID.
158    fn from_u32_slice(u32s: &[u32]) -> &[Self] {
159        bytemuck::checked::cast_slice(u32s)
160    }
161}
162
163/// A field extension which can be constructed from a subfield element [Elem]
164///
165/// Represents an element of an extension field. This extension field is
166/// associated with a base field (sometimes called "subfield") whose element
167/// type is given by the generic type parameter.
168pub trait ExtElem : Elem
169    + From<Self::SubElem>
170    + ops::Neg<Output = Self>
171    + cmp::PartialEq
172    + cmp::Eq
173
174    // Operators for ExtElem (op) Elem -> ExtElem
175    + ops::Add<Self::SubElem, Output = Self>
176    + ops::Sub<Self::SubElem, Output = Self>
177    + ops::Mul<Self::SubElem, Output = Self>
178
179    // Operators for ExtElem (op)= Elem
180    + ops::AddAssign<Self::SubElem>
181    + ops::SubAssign<Self::SubElem>
182    + ops::MulAssign<Self::SubElem>
183{
184    /// An element of the base field
185    ///
186    /// This type represents an element of the base field (sometimes called
187    /// "subfield") of this extension field.
188    type SubElem: Elem
189        // Operators for Elem (op) ExtElem -> ExtElem
190        + ops::Add<Self, Output = Self>
191        + ops::Sub<Self, Output = Self>
192        + ops::Mul<Self, Output = Self>;
193
194    /// The degree of the field extension
195    ///
196    /// This the degree of the extension field when interpreted as a vector
197    /// space over the base field. Thus, an [ExtElem] can be represented as
198    /// `EXT_SIZE` [SubElem](ExtElem::SubElem)s.
199    const EXT_SIZE: usize;
200
201    /// Interpret a base field element as an extension field element
202    ///
203    /// Every [SubElem](ExtElem::SubElem) is (mathematically) an [ExtElem]. This
204    /// constructs the [ExtElem] equal to the given [SubElem](ExtElem::SubElem).
205    fn from_subfield(elem: &Self::SubElem) -> Self;
206
207    /// Construct an extension field element
208    ///
209    /// Construct an extension field element from a (mathematical) vector of
210    /// [SubElem](ExtElem::SubElem)s. This vector is length
211    /// [EXT_SIZE](ExtElem::EXT_SIZE).
212    fn from_subelems(elems: impl IntoIterator<Item = Self::SubElem>) -> Self;
213
214    /// Express an extension field element in terms of base field elements
215    ///
216    /// Returns the (mathematical) vector of [SubElem](ExtElem::SubElem)s equal
217    /// to the [ExtElem]. This vector is length [EXT_SIZE](ExtElem::EXT_SIZE).
218    fn subelems(&self) -> &[Self::SubElem];
219}
220
221/// Roots of unity for the field whose elements are represented by [ExtElem] and
222/// whose subfield elements are represented by [Elem]
223pub trait RootsOfUnity: Sized + 'static {
224    /// Maximum root of unity which is a power of 2 (i.e., there is a
225    /// 2^MAX_ROU_PO2th root of unity, but no 2^(MAX_ROU_PO2+1)th root.
226    const MAX_ROU_PO2: usize;
227
228    /// For each power of 2, the 'forward' root of unity for
229    /// the po2.  That is, this list satisfies ROU_FWD\[i+1\] ^ 2 =
230    /// ROU_FWD\[i\] in the prime field, which implies ROU_FWD\[i\] ^
231    /// (2 ^ i) = 1.
232    const ROU_FWD: &'static [Self];
233
234    /// For each power of 2, the 'reverse' root of unity for
235    /// the po2.  This list satisfies ROU_FWD\[i\] * ROU_REV\[i\] = 1
236    /// in the prime field F_2013265921.
237    const ROU_REV: &'static [Self];
238}
239
240/// Equivalent to exponents.map(|exponent|
241/// base.pow(exponent)).collect(), but optimized to execute fewer
242/// multiplies.  Exponents must be sorted and strictly increasing.
243pub fn map_pow<E: super::field::Elem>(base: E, exponents: &[usize]) -> Vec<E> {
244    let mut result = Vec::with_capacity(exponents.len());
245
246    let mut prev_exp: usize;
247    match exponents.first() {
248        None => return result,
249        Some(&exp) => {
250            result.push(base.pow(exp));
251            prev_exp = exp;
252        }
253    }
254
255    for exp in exponents.iter().skip(1).copied() {
256        assert!(
257            prev_exp < exp,
258            "Expecting exponents to be strictly increasing but {prev_exp} is not less than {exp}"
259        );
260        if exp == prev_exp + 1 {
261            result.push(*result.last().unwrap() * base);
262        } else {
263            result.push(*result.last().unwrap() * base.pow(exp - prev_exp));
264        }
265        prev_exp = exp;
266    }
267
268    result
269}
270
271#[cfg(test)]
272mod tests {
273    use core::fmt::Debug;
274
275    use rand::Rng;
276
277    use super::{Elem, ExtElem, RootsOfUnity};
278
279    pub fn test_roots_of_unity<F: Elem + RootsOfUnity + Debug>() {
280        let mut cur: Option<F> = None;
281
282        for &rou in F::ROU_FWD.iter().rev() {
283            if let Some(ref mut curval) = &mut cur {
284                *curval *= *curval;
285                assert_eq!(*curval, rou);
286            } else {
287                cur = Some(rou);
288            }
289        }
290        assert_eq!(cur, Some(F::ONE));
291
292        for (&fwd, &rev) in F::ROU_FWD.iter().zip(F::ROU_REV.iter()) {
293            assert_eq!(fwd * rev, F::ONE);
294        }
295    }
296
297    fn non_zero_rand<F: Elem>(r: &mut impl Rng) -> F {
298        loop {
299            let val = F::random(r);
300            if val != F::ZERO {
301                return val;
302            }
303        }
304    }
305    pub fn test_field_ops<F>(p_u64: u64)
306    where
307        F: Elem + Into<u64> + From<u64> + Debug,
308    {
309        // For testing, we do 128-bit arithmetic so we don't have to worry about
310        // overflows.
311        let p: u128 = p_u64 as _;
312
313        assert_eq!(F::from(0), F::ZERO);
314        assert_eq!(F::from(p_u64), F::ZERO);
315        assert_eq!(F::from(1), F::ONE);
316        assert_eq!(F::from(p_u64 - 1) + F::from(1), F::ZERO);
317
318        assert_eq!(F::ZERO.inv(), F::ZERO);
319        assert_eq!(F::ONE.inv(), F::ONE);
320
321        // Compare against many randomly generated numbers to make sure results match
322        // the expected results for regular modular arithmetic.
323        let mut rng = rand::rng();
324
325        for _ in 0..1000 {
326            let x: F = non_zero_rand(&mut rng);
327            let y: F = non_zero_rand(&mut rng);
328
329            let xi: u128 = x.into() as _;
330            let yi: u128 = y.into() as _;
331
332            assert_eq!((x + y).into() as u128, (xi + yi) % p);
333            assert_eq!((x * y).into() as u128, (xi * yi) % p);
334            assert_eq!((x - y).into() as u128, (xi + p - yi) % p);
335
336            let xinv = x.inv();
337            if x != F::ONE {
338                assert!(xinv != x);
339            }
340            assert_eq!(xinv * x, F::ONE);
341        }
342
343        // Make sure map_pow produces the same results as calling F::pow.
344        let base: F = non_zero_rand(&mut rng);
345        let map_pow_cases: &[&[usize]] = &[&[], &[0], &[0, 1, 2, 3], &[1, 18, 19, 1234, 5678]];
346        for exps in map_pow_cases.iter() {
347            let expected: alloc::vec::Vec<_> = exps.iter().map(|&exp| base.pow(exp)).collect();
348            let actual = super::map_pow(base, exps);
349            assert_eq!(expected, actual);
350        }
351    }
352
353    /// Make sure extension field operations are consistent, no matter
354    /// what order they use, and whether they promote from base
355    /// elements or not.
356    pub fn test_ext_field_ops<E: ExtElem>() {
357        let mut r = rand::rng();
358        let x = E::random(&mut r);
359        let y = E::random(&mut r);
360
361        let mut e = x;
362
363        // Promote E to the extended field.
364        let promote = |e| E::from(e);
365
366        // Addition and subtraction operations between two ExtElems
367        e += y;
368        assert_eq!(e, x + y);
369        assert_eq!(e, y + x);
370        assert_eq!(x, e - y);
371        assert_eq!(-x, y - e);
372        e -= y;
373        assert_eq!(e, x);
374
375        // Multiplication and inverse operations between two ExtElems
376        e *= y;
377        assert_eq!(e, x * y);
378        assert_eq!(e, y * x);
379        assert_eq!(x, e * y.inv());
380        assert_eq!(x.inv(), y * e.inv());
381        e *= y.inv();
382        assert_eq!(e, x);
383
384        // Addition and subtraction between an ExtElem and a base element
385        let b = E::SubElem::random(&mut r);
386        e += b;
387        assert_eq!(e, x + b);
388        assert_eq!(e, b + x);
389        assert_eq!(e, x + promote(b));
390        assert_eq!(x, e - promote(b));
391        assert_eq!(x, e - b);
392        assert_eq!(-x, b - e);
393        assert_eq!(-x, promote(b) - e);
394        e -= b;
395        assert_eq!(e, x);
396
397        // Multiplication and inverse operations between an ExtElem and a base element
398        e *= b;
399        assert_eq!(e, x * b);
400        assert_eq!(e, b * x);
401        assert_eq!(e, x * promote(b));
402        assert_eq!(x, e * b.inv());
403        assert_eq!(x, b.inv() * e);
404        assert_eq!(x, e * promote(b.inv()));
405        assert_eq!(x, e * promote(b).inv());
406        assert_eq!(x.inv(), b * e.inv());
407        e *= b.inv();
408        assert_eq!(e, x);
409    }
410}