risc0_core/field/
baby_bear.rs

1// Copyright 2025 RISC Zero, Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Baby bear field.
16//!
17//! Support for the finite field of order `15 * 2^27 + 1`, and its degree 4
18//! extension field. This field choice allows for 32-bit addition without
19//! overflow.
20
21use alloc::{fmt, vec::Vec};
22use core::{
23    cmp::{Ordering, PartialEq},
24    ops,
25};
26
27use bytemuck::{CheckedBitPattern, NoUninit, Zeroable};
28
29use crate::field::{self, Elem as FieldElem};
30
31/// Definition of this field for operations that operate on the baby
32/// bear field and its 4th degree extension.
33pub struct BabyBear;
34
35impl field::Field for BabyBear {
36    type Elem = Elem;
37    type ExtElem = ExtElem;
38}
39
40// montgomery form constants
41const M: u32 = 0x88000001;
42const R2: u32 = 1172168163;
43
44/// The BabyBear class is an element of the finite field F_p, where P is the
45/// prime number 15*2^27 + 1. Put another way, Fp is basically integer
46/// arithmetic modulo P.
47///
48/// The `Fp` datatype is the core type of all of the operations done within the
49/// zero knowledge proofs, and is the smallest 'addressable' datatype, and the
50/// base type of which all composite types are built. In many ways, one can
51/// imagine it as the word size of a very strange architecture.
52///
53/// This specific prime P was chosen to:
54/// - Be less than 2^31 so that it fits within a 32 bit word and doesn't
55///   overflow on addition.
56/// - Otherwise have as large a power of 2 in the factors of P-1 as possible.
57///
58/// This last property is useful for number theoretical transforms (the fast
59/// fourier transform equivalent on finite fields). See risc0_zkp::core::ntt
60/// for details.
61///
62/// The Fp class wraps all the standard arithmetic operations to make the finite
63/// field elements look basically like ordinary numbers (which they mostly are).
64#[derive(Eq, Clone, Copy, NoUninit, Zeroable)]
65#[repr(transparent)]
66pub struct Elem(u32);
67
68/// Alias for the Baby Bear [Elem]
69pub type BabyBearElem = Elem;
70
71impl Default for Elem {
72    fn default() -> Self {
73        Self::ZERO
74    }
75}
76
77impl fmt::Debug for Elem {
78    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
79        write!(f, "0x{:08x}", decode(self.0))
80    }
81}
82
83/// The modulus of the field.
84pub const P: u32 = 15 * (1 << 27) + 1;
85
86/// The modulus of the field as a u64.
87const P_U64: u64 = P as u64;
88
89/// The amount of memory to store a field element, as number of u32 words
90const WORDS: usize = 1;
91
92impl field::Elem for Elem {
93    const INVALID: Self = Elem(0xffffffff);
94    const ZERO: Self = Elem::new(0);
95    const ONE: Self = Elem::new(1);
96    const WORDS: usize = WORDS;
97
98    /// Compute the multiplicative inverse of `x`, or `1 / x` in finite field
99    /// terms. Since we know by Fermat's Little Theorem that
100    /// `x ^ (P - 1) == 1 % P` for any `x != 0`,
101    /// it follows that `x * x ^ (P - 2) == 1 % P` for `x != 0`.
102    /// That is, `x ^ (P - 2)` is the multiplicative inverse of `x`.
103    /// Note that if computed this way, the *inverse* of zero comes out as zero,
104    /// which we allow because it is convenient in many cases.
105    fn inv(self) -> Self {
106        self.ensure_valid().pow((P - 2) as usize)
107    }
108
109    /// Generate a random value within the Baby Bear field
110    fn random(rng: &mut impl rand_core::RngCore) -> Self {
111        // Normally, we would use rejection sampling here, but our specialized
112        // verifier circuit really wants an O(1) solution to sampling.  So instead, we
113        // sample [0, 2^192) % P.  This is very close to uniform, as we have 2^192 / P
114        // full copies of P, with only 2^192%P left over elements in the 'partial' copy
115        // (which we would normally reject with rejection sampling).
116        //
117        // Even if we imagined that this failure to reject totally destroys soundness,
118        // the probability of it occurring even once during proving is vanishingly low
119        // (for the about 50 samples our current verifier pulls and at a probability of
120        // less than2^-161 per sample, this is less than 2^-155).  Even if we target
121        // a soundness of 128 bits, we are millions of times more likely to let an
122        // invalid proof by due to normal low probability events which are part of
123        // soundness analysis than due to imperfect sampling.
124        //
125        // Finally, from an implementation perspective, we can view generating a number
126        // in the [0, 2^192) range as using a linear combination of uniform u32s, r0,
127        // r1, etc and the following formula:
128        // u192 = r0 + 2^32 * r1 + 2^64 * r2 + ... + 2^160 * r5
129        // This is turn can be computed as:
130        // u192 = 2^32*(2^32*(2^32*(2^32*(2^32*(r5) + r4) + r3) + r2) + r1) + r0.
131        // Since we only need the final result modulo P, we can compute the entire
132        // expression above modulo P, and get the following implementation:
133        let mut val: u64 = 0;
134        for _ in 0..6 {
135            val <<= 32;
136            val += rng.next_u32() as u64;
137            val %= P as u64;
138        }
139        Elem::from(val as u32)
140    }
141
142    fn from_u64(val: u64) -> Self {
143        Elem::from(val)
144    }
145
146    fn to_u32_words(&self) -> Vec<u32> {
147        Vec::<u32>::from([self.0])
148    }
149
150    fn from_u32_words(val: &[u32]) -> Self {
151        Self(val[0])
152    }
153
154    fn is_valid(&self) -> bool {
155        self.0 != Self::INVALID.0
156    }
157
158    fn is_reduced(&self) -> bool {
159        self.0 < P
160    }
161}
162
163unsafe impl CheckedBitPattern for Elem {
164    type Bits = u32;
165
166    /// Checks that the u32 is less than the modulus.
167    fn is_valid_bit_pattern(bits: &u32) -> bool {
168        *bits < P
169    }
170}
171
172macro_rules! rou_array {
173    [$($x:literal),* $(,)?] => {
174        [$(Elem::new($x)),* ]
175    }
176}
177
178impl field::RootsOfUnity for Elem {
179    /// Maximum power of two for which we have a root of unity using Baby Bear
180    /// field
181    const MAX_ROU_PO2: usize = 27;
182
183    /// 'Forward' root of unity for each power of two.
184    const ROU_FWD: &'static [Elem] = &rou_array![
185        1, 2013265920, 284861408, 1801542727, 567209306, 740045640, 918899846, 1881002012,
186        1453957774, 65325759, 1538055801, 515192888, 483885487, 157393079, 1695124103, 2005211659,
187        1540072241, 88064245, 1542985445, 1269900459, 1461624142, 825701067, 682402162, 1311873874,
188        1164520853, 352275361, 18769, 137
189    ];
190
191    /// 'Reverse' root of unity for each power of two.
192    const ROU_REV: &'static [Elem] = &rou_array![
193        1, 2013265920, 1728404513, 1592366214, 196396260, 1253260071, 72041623, 1091445674,
194        145223211, 1446820157, 1030796471, 2010749425, 1827366325, 1239938613, 246299276,
195        596347512, 1893145354, 246074437, 1525739923, 1194341128, 1463599021, 704606912, 95395244,
196        15672543, 647517488, 584175179, 137728885, 749463956
197    ];
198}
199
200impl Elem {
201    /// Create a new [BabyBear] from a raw integer.
202    pub const fn new(x: u32) -> Self {
203        Self(encode(x % P))
204    }
205
206    /// Create a new [BabyBear] from a Montgomery form representation
207    ///
208    /// Requires that `x` comes pre-encoded in Montgomery form.
209    pub const fn new_raw(x: u32) -> Self {
210        Self(x)
211    }
212
213    /// Cast a [BabyBear] to an integer
214    pub const fn as_u32(&self) -> u32 {
215        decode(self.0)
216    }
217
218    /// Return the Montgomery form representation used for byte-based
219    /// hashes of slices of [BabyBear]s.
220    pub const fn as_u32_montgomery(&self) -> u32 {
221        self.0
222    }
223}
224
225impl ops::Add for Elem {
226    type Output = Self;
227
228    /// Addition for Baby Bear [Elem]
229    fn add(self, rhs: Self) -> Self {
230        Elem(add(self.ensure_valid().0, rhs.ensure_valid().0))
231    }
232}
233
234impl ops::AddAssign for Elem {
235    /// Simple addition case for Baby Bear [Elem]
236    fn add_assign(&mut self, rhs: Self) {
237        self.0 = add(self.ensure_valid().0, rhs.ensure_valid().0)
238    }
239}
240
241impl ops::Sub for Elem {
242    type Output = Self;
243
244    /// Subtraction for Baby Bear [Elem]
245    fn sub(self, rhs: Self) -> Self {
246        Elem(sub(self.ensure_valid().0, rhs.ensure_valid().0))
247    }
248}
249
250impl ops::SubAssign for Elem {
251    /// Simple subtraction case for Baby Bear [Elem]
252    fn sub_assign(&mut self, rhs: Self) {
253        self.0 = sub(self.ensure_valid().0, rhs.ensure_valid().0)
254    }
255}
256
257impl ops::Mul for Elem {
258    type Output = Self;
259
260    /// Multiplication for Baby Bear [Elem]
261    fn mul(self, rhs: Self) -> Self {
262        Elem(mul(self.ensure_valid().0, rhs.ensure_valid().0))
263    }
264}
265
266impl ops::MulAssign for Elem {
267    /// Simple multiplication case for Baby Bear [Elem]
268    fn mul_assign(&mut self, rhs: Self) {
269        self.0 = mul(self.ensure_valid().0, rhs.ensure_valid().0)
270    }
271}
272
273impl ops::Neg for Elem {
274    type Output = Self;
275
276    fn neg(self) -> Self {
277        Elem(0) - *self.ensure_valid()
278    }
279}
280
281impl PartialEq<Elem> for Elem {
282    fn eq(&self, rhs: &Self) -> bool {
283        self.ensure_valid().0 == rhs.ensure_valid().0
284    }
285}
286
287impl Ord for Elem {
288    fn cmp(&self, rhs: &Self) -> Ordering {
289        decode(self.ensure_valid().0).cmp(&decode(rhs.ensure_valid().0))
290    }
291}
292
293impl PartialOrd for Elem {
294    fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
295        Some(self.cmp(rhs))
296    }
297}
298
299impl From<Elem> for u32 {
300    fn from(x: Elem) -> Self {
301        decode(x.0)
302    }
303}
304
305impl From<Elem> for u64 {
306    fn from(x: Elem) -> Self {
307        decode(x.0).into()
308    }
309}
310
311impl From<u32> for Elem {
312    fn from(x: u32) -> Self {
313        Elem::new(x)
314    }
315}
316
317impl From<u64> for Elem {
318    fn from(x: u64) -> Self {
319        Elem::new((x % P_U64) as u32)
320    }
321}
322
323/// Wrapping addition of [Elem] using Baby Bear field modulus
324fn add(lhs: u32, rhs: u32) -> u32 {
325    let x = lhs.wrapping_add(rhs);
326    if x >= P {
327        x - P
328    } else {
329        x
330    }
331}
332
333/// Wrapping subtraction of [Elem] using Baby Bear field modulus
334fn sub(lhs: u32, rhs: u32) -> u32 {
335    let x = lhs.wrapping_sub(rhs);
336    if x > P {
337        x.wrapping_add(P)
338    } else {
339        x
340    }
341}
342
343/// Wrapping multiplication of [Elem]  using Baby Bear field modulus
344// Copied from the C++ implementation (fp.h)
345const fn mul(lhs: u32, rhs: u32) -> u32 {
346    // uint64_t o64 = uint64_t(a) * uint64_t(b);
347    let mut o64: u64 = (lhs as u64).wrapping_mul(rhs as u64);
348    // uint32_t low = -uint32_t(o64);
349    let low: u32 = 0u32.wrapping_sub(o64 as u32);
350    // uint32_t red = M * low;
351    let red = M.wrapping_mul(low);
352    // o64 += uint64_t(red) * uint64_t(P);
353    o64 += (red as u64).wrapping_mul(P_U64);
354    // uint32_t ret = o64 >> 32;
355    let ret = (o64 >> 32) as u32;
356    // return (ret >= P ? ret - P : ret);
357    if ret >= P {
358        ret - P
359    } else {
360        ret
361    }
362}
363
364/// Encode to Montgomery form from direct form.
365const fn encode(a: u32) -> u32 {
366    mul(R2, a)
367}
368
369/// Decode from Montgomery form to direct form.
370const fn decode(a: u32) -> u32 {
371    mul(1, a)
372}
373
374/// The size of the extension field in elements, 4 in this case.
375const EXT_SIZE: usize = 4;
376
377/// Instances of `ExtElem` are elements of a finite field `F_p^4`. They are
378/// represented as elements of `F_p[X] / (X^4 + 11)`. This large
379/// finite field (about `2^128` elements) is used when the security of
380/// operations depends on the size of the field. The field extension `ExtElem`
381/// has `Elem` as a subfield, so operations on elements of each are compatible.
382/// The irreducible polynomial `x^4 + 11` was chosen because `11` is
383/// the simplest choice of `BETA` for `x^4 + BETA` that makes this polynomial
384/// irreducible.
385#[derive(Eq, Clone, Copy, Zeroable)]
386#[repr(transparent)]
387pub struct ExtElem([Elem; EXT_SIZE]);
388
389// ExtElem has no padding bytes as Elem has none and is 32 bits in width.
390// See bytemuck docs for a full list of requirements.
391// https://docs.rs/bytemuck/latest/bytemuck/trait.NoUninit.html#safety
392unsafe impl NoUninit for ExtElem {}
393
394unsafe impl CheckedBitPattern for ExtElem {
395    type Bits = [u32; EXT_SIZE];
396
397    /// Checks that the u32 array elements are all less than the modulus.
398    fn is_valid_bit_pattern(bits: &[u32; EXT_SIZE]) -> bool {
399        let mut valid = true;
400        for x in bits {
401            valid &= *x < P;
402        }
403        valid
404    }
405}
406
407/// Alias for the Baby Bear [ExtElem]
408pub type BabyBearExtElem = ExtElem;
409
410impl Default for ExtElem {
411    fn default() -> Self {
412        Self::ZERO
413    }
414}
415
416impl fmt::Debug for ExtElem {
417    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
418        write!(
419            f,
420            "[{:?}, {:?}, {:?}, {:?}]",
421            self.0[0], self.0[1], self.0[2], self.0[3]
422        )
423    }
424}
425
426impl field::Elem for ExtElem {
427    const INVALID: Self = ExtElem([Elem::INVALID, Elem::INVALID, Elem::INVALID, Elem::INVALID]);
428    const ZERO: Self = ExtElem::zero();
429    const ONE: Self = ExtElem::one();
430    const WORDS: usize = WORDS * EXT_SIZE;
431
432    /// Generate a random field element uniformly.
433    fn random(rng: &mut impl rand_core::RngCore) -> Self {
434        // NOTE: It's possible this could be made more efficient since each field element uses 192
435        // random bits while the total entropy needed for a negligibly biased ExtElem is 288.
436        Self([
437            Elem::random(rng),
438            Elem::random(rng),
439            Elem::random(rng),
440            Elem::random(rng),
441        ])
442    }
443
444    /// Raise a [ExtElem] to a power of `n`.
445    fn pow(self, n: usize) -> Self {
446        let mut n = n;
447        let mut tot = ExtElem::ONE;
448        let mut x = *self.ensure_valid();
449        while n != 0 {
450            if n % 2 == 1 {
451                tot *= x;
452            }
453            n /= 2;
454            x *= x;
455        }
456        tot
457    }
458
459    /// Compute the multiplicative inverse of an `ExtElem`.
460    fn inv(self) -> Self {
461        let a = &self.ensure_valid().0;
462        // Compute the multiplicative inverse by looking at `ExtElem` as a composite
463        // field and using the same basic methods used to invert complex
464        // numbers. We imagine that initially we have a numerator of `1`, and a
465        // denominator of `a`. `out = 1 / a`; We set `a'` to be a with the first
466        // and third components negated. We then multiply the numerator and the
467        // denominator by `a'`, producing `out = a' / (a * a')`. By construction
468        // `(a * a')` has `0`s in its first and third elements. We call this
469        // number, `b` and compute it as follows.
470        let mut b0 = a[0] * a[0] + BETA * (a[1] * (a[3] + a[3]) - a[2] * a[2]);
471        let mut b2 = a[0] * (a[2] + a[2]) - a[1] * a[1] + BETA * (a[3] * a[3]);
472        // Now, we make `b'` by inverting `b2`. When we multiply both sizes by `b'`, we
473        // get `out = (a' * b') / (b * b')`.  But by construction `b * b'` is in
474        // fact an element of `Elem`, call it `c`.
475        let c = b0 * b0 + BETA * b2 * b2;
476        // But we can now invert `C` directly, and multiply by `a' * b'`:
477        // `out = a' * b' * inv(c)`
478        let ic = c.inv();
479        // Note: if c == 0 (really should only happen if in == 0), our
480        // 'safe' version of inverse results in ic == 0, and thus out
481        // = 0, so we have the same 'safe' behavior for ExtElem.  Oh,
482        // and since we want to multiply everything by ic, it's
483        // slightly faster to pre-multiply the two parts of b by ic (2
484        // multiplies instead of 4).
485        b0 *= ic;
486        b2 *= ic;
487        ExtElem([
488            a[0] * b0 + BETA * a[2] * b2,
489            -a[1] * b0 + NBETA * a[3] * b2,
490            -a[0] * b2 + a[2] * b0,
491            a[1] * b2 - a[3] * b0,
492        ])
493    }
494
495    /// Convert from a u64 to a base field elem, then cast to the extension field.
496    fn from_u64(val: u64) -> Self {
497        Self([Elem::from_u64(val), Elem::ZERO, Elem::ZERO, Elem::ZERO])
498    }
499
500    fn to_u32_words(&self) -> Vec<u32> {
501        self.elems()
502            .iter()
503            .flat_map(|elem| elem.to_u32_words())
504            .collect()
505    }
506
507    fn from_u32_words(val: &[u32]) -> Self {
508        field::ExtElem::from_subelems(val.iter().map(|word| Elem(*word)))
509    }
510
511    // So we're not checking every subfield element every time we do
512    // anything, assume that if our first subelement is valid, the
513    // whole thing is valid.  Any subfield elements will be double checked
514    // when we do operations on them anyways.
515    fn is_valid(&self) -> bool {
516        self.0[0].is_valid()
517    }
518
519    fn is_reduced(&self) -> bool {
520        self.0.iter().all(|x| x.is_reduced())
521    }
522}
523
524impl field::ExtElem for ExtElem {
525    const EXT_SIZE: usize = EXT_SIZE;
526
527    type SubElem = Elem;
528
529    fn from_subfield(elem: &Elem) -> Self {
530        Self::from([*elem.ensure_valid(), Elem::ZERO, Elem::ZERO, Elem::ZERO])
531    }
532
533    fn from_subelems(elems: impl IntoIterator<Item = Self::SubElem>) -> Self {
534        let mut iter = elems.into_iter();
535        let elem = Self::from([
536            *iter.next().unwrap().ensure_valid(),
537            *iter.next().unwrap().ensure_valid(),
538            *iter.next().unwrap().ensure_valid(),
539            *iter.next().unwrap().ensure_valid(),
540        ]);
541        assert!(
542            iter.next().is_none(),
543            "Extra elements passed to create element in extension field"
544        );
545        elem
546    }
547
548    /// Returns the subelements of a [Elem].
549    fn subelems(&self) -> &[Elem] {
550        &self.ensure_valid().0
551    }
552}
553
554impl PartialEq<ExtElem> for ExtElem {
555    fn eq(&self, rhs: &Self) -> bool {
556        self.ensure_valid().0 == rhs.ensure_valid().0
557    }
558}
559
560impl From<[Elem; EXT_SIZE]> for ExtElem {
561    fn from(val: [Elem; EXT_SIZE]) -> Self {
562        if cfg!(debug_assertions) {
563            for elem in val.iter() {
564                elem.ensure_valid();
565            }
566        }
567        ExtElem(val)
568    }
569}
570
571const BETA: Elem = Elem::new(11);
572const NBETA: Elem = Elem::new(P - 11);
573
574// TODO: refactor if rust gets const trait methods.
575const fn const_ensure_valid(x: Elem) -> Elem {
576    debug_assert!(x.0 != Elem::INVALID.0);
577    x
578}
579
580impl ExtElem {
581    /// Explicitly construct an ExtElem from parts.
582    pub const fn new(x0: Elem, x1: Elem, x2: Elem, x3: Elem) -> Self {
583        Self([
584            const_ensure_valid(x0),
585            const_ensure_valid(x1),
586            const_ensure_valid(x2),
587            const_ensure_valid(x3),
588        ])
589    }
590
591    /// Create an [ExtElem] from an [Elem].
592    pub fn from_fp(x: Elem) -> Self {
593        Self([x, Elem::new(0), Elem::new(0), Elem::new(0)])
594    }
595
596    /// Create an [ExtElem] from a raw integer.
597    pub const fn from_u32(x0: u32) -> Self {
598        Self([Elem::new(x0), Elem::new(0), Elem::new(0), Elem::new(0)])
599    }
600
601    /// Return the value zero.
602    const fn zero() -> Self {
603        Self::from_u32(0)
604    }
605
606    /// Return the value one.
607    const fn one() -> Self {
608        Self::from_u32(1)
609    }
610
611    /// Return the base field term of an [Elem].
612    pub fn const_part(self) -> Elem {
613        self.ensure_valid().0[0]
614    }
615
616    /// Return [Elem] as a vector of base field values.
617    pub fn elems(&self) -> &[Elem] {
618        &self.ensure_valid().0
619    }
620}
621
622impl ops::Add for ExtElem {
623    type Output = Self;
624
625    /// Addition for Baby Bear [ExtElem]
626    fn add(self, rhs: Self) -> Self {
627        let mut lhs = self;
628        lhs += rhs;
629        lhs
630    }
631}
632
633impl ops::AddAssign for ExtElem {
634    /// Simple addition case for Baby Bear [ExtElem]
635    fn add_assign(&mut self, rhs: Self) {
636        for i in 0..self.0.len() {
637            self.0[i] += rhs.0[i];
638        }
639    }
640}
641
642impl ops::Add<Elem> for ExtElem {
643    type Output = Self;
644
645    /// Addition for Baby Bear [Elem]
646    fn add(self, rhs: Elem) -> Self {
647        let mut lhs = self;
648        lhs += rhs;
649        lhs
650    }
651}
652
653impl ops::Add<ExtElem> for Elem {
654    type Output = ExtElem;
655
656    /// Addition for Baby Bear [Elem]
657    fn add(self, rhs: ExtElem) -> ExtElem {
658        let mut lhs = ExtElem::from(self);
659        lhs += rhs;
660        lhs
661    }
662}
663
664impl ops::AddAssign<Elem> for ExtElem {
665    /// Promoting addition case for BabyBear [Elem]
666    fn add_assign(&mut self, rhs: Elem) {
667        self.0[0] += rhs;
668    }
669}
670
671impl ops::Sub for ExtElem {
672    type Output = Self;
673
674    /// Subtraction for Baby Bear [ExtElem]
675    fn sub(self, rhs: Self) -> Self {
676        let mut lhs = self;
677        lhs -= rhs;
678        lhs
679    }
680}
681
682impl ops::SubAssign for ExtElem {
683    /// Simple subtraction case for Baby Bear [ExtElem]
684    fn sub_assign(&mut self, rhs: Self) {
685        for i in 0..self.0.len() {
686            self.0[i] -= rhs.0[i];
687        }
688    }
689}
690
691impl ops::Sub<Elem> for ExtElem {
692    type Output = Self;
693
694    /// Subtraction for Baby Bear [ExtElem]
695    fn sub(self, rhs: Elem) -> Self {
696        let mut lhs = self;
697        lhs -= rhs;
698        lhs
699    }
700}
701
702impl ops::Sub<ExtElem> for Elem {
703    type Output = ExtElem;
704
705    /// Subtraction for Baby Bear [ExtElem]
706    fn sub(self, rhs: ExtElem) -> ExtElem {
707        let mut lhs = ExtElem::from(self);
708        lhs -= rhs;
709        lhs
710    }
711}
712
713impl ops::SubAssign<Elem> for ExtElem {
714    /// Promoting subtraction case for BabyBear [Elem]
715    fn sub_assign(&mut self, rhs: Elem) {
716        self.0[0] -= rhs;
717    }
718}
719
720impl ops::MulAssign<Elem> for ExtElem {
721    /// Simple multiplication case by a
722    /// Baby Bear [Elem]
723    fn mul_assign(&mut self, rhs: Elem) {
724        for i in 0..self.0.len() {
725            self.0[i] *= rhs;
726        }
727    }
728}
729
730impl ops::Mul<Elem> for ExtElem {
731    type Output = Self;
732
733    /// Multiplication by a Baby Bear [Elem]
734    fn mul(self, rhs: Elem) -> Self {
735        let mut lhs = self;
736        lhs *= rhs;
737        lhs
738    }
739}
740
741impl ops::Mul<ExtElem> for Elem {
742    type Output = ExtElem;
743
744    /// Multiplication for a subfield [Elem] by an [ExtElem]
745    fn mul(self, rhs: ExtElem) -> ExtElem {
746        rhs * self
747    }
748}
749
750// Now we get to the interesting case of multiplication. Basically,
751// multiply out the polynomial representations, and then reduce module
752// `x^4 - B`, which means powers >= 4 get shifted back 4 and
753// multiplied by `-beta`. We could write this as a double loops with
754// some `if`s and hope it gets unrolled properly, but it's small
755// enough to just hand write.
756impl ops::MulAssign for ExtElem {
757    #[inline(always)]
758    fn mul_assign(&mut self, rhs: Self) {
759        // Rename the element arrays to something small for readability.
760        let a = &self.0;
761        let b = &rhs.0;
762        self.0 = [
763            a[0] * b[0] + NBETA * (a[1] * b[3] + a[2] * b[2] + a[3] * b[1]),
764            a[0] * b[1] + a[1] * b[0] + NBETA * (a[2] * b[3] + a[3] * b[2]),
765            a[0] * b[2] + a[1] * b[1] + a[2] * b[0] + NBETA * (a[3] * b[3]),
766            a[0] * b[3] + a[1] * b[2] + a[2] * b[1] + a[3] * b[0],
767        ];
768    }
769}
770
771impl ops::Mul for ExtElem {
772    type Output = ExtElem;
773
774    #[inline(always)]
775    fn mul(self, rhs: ExtElem) -> ExtElem {
776        let mut lhs = self;
777        lhs *= rhs;
778        lhs
779    }
780}
781
782impl ops::Neg for ExtElem {
783    type Output = Self;
784
785    fn neg(self) -> Self {
786        ExtElem::ZERO - self
787    }
788}
789
790impl From<u32> for ExtElem {
791    fn from(x: u32) -> Self {
792        Self([Elem::from(x), Elem::ZERO, Elem::ZERO, Elem::ZERO])
793    }
794}
795
796impl From<Elem> for ExtElem {
797    fn from(x: Elem) -> Self {
798        Self([x, Elem::ZERO, Elem::ZERO, Elem::ZERO])
799    }
800}
801
802#[cfg(test)]
803mod tests {
804    use alloc::{vec, vec::Vec};
805
806    use rand::{Rng, SeedableRng};
807
808    use super::{field, Elem, ExtElem, P, P_U64};
809    use crate::field::Elem as FieldElem;
810
811    #[test]
812    pub fn roots_of_unity() {
813        field::tests::test_roots_of_unity::<Elem>();
814    }
815
816    #[test]
817    pub fn field_ops() {
818        field::tests::test_field_ops::<Elem>(P_U64);
819    }
820
821    #[test]
822    pub fn ext_field_ops() {
823        field::tests::test_ext_field_ops::<ExtElem>();
824    }
825
826    #[test]
827    pub fn linear() {
828        let x = ExtElem::new(
829            Elem::new(1880084280),
830            Elem::new(1788985953),
831            Elem::new(1273325207),
832            Elem::new(277471107),
833        );
834        let c0 = ExtElem::new(
835            Elem::new(1582815482),
836            Elem::new(2011839994),
837            Elem::new(589901),
838            Elem::new(698998108),
839        );
840        let c1 = ExtElem::new(
841            Elem::new(1262573828),
842            Elem::new(1903841444),
843            Elem::new(1738307519),
844            Elem::new(100967278),
845        );
846
847        assert_eq!(
848            x * c1,
849            ExtElem::new(
850                Elem::new(876029217),
851                Elem::new(1948387849),
852                Elem::new(498773186),
853                Elem::new(1997003991)
854            )
855        );
856        assert_eq!(
857            c0 + x * c1,
858            ExtElem::new(
859                Elem::new(445578778),
860                Elem::new(1946961922),
861                Elem::new(499363087),
862                Elem::new(682736178)
863            )
864        );
865    }
866
867    #[test]
868    fn isa_field() {
869        let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
870        // Generate three field extension elements using randomly generated base field
871        // values, and verify they meet the requirements of a field.
872        for _ in 0..1_000 {
873            let a = ExtElem::random(&mut rng);
874            let b = ExtElem::random(&mut rng);
875            let c = ExtElem::random(&mut rng);
876            // Addition + multiplication commute
877            assert_eq!(a + b, b + a);
878            assert_eq!(a * b, b * a);
879            // Addition + multiplication are associative
880            assert_eq!(a + (b + c), (a + b) + c);
881            assert_eq!(a * (b * c), (a * b) * c);
882            // Distributive property
883            assert_eq!(a * (b + c), a * b + a * c);
884            // Inverses
885            if a != ExtElem::ZERO {
886                assert_eq!(a.inv() * a, ExtElem::from(1));
887            }
888            assert_eq!(ExtElem::ZERO - a, -a);
889            assert_eq!(a + (-a), ExtElem::ZERO);
890        }
891    }
892
893    #[test]
894    fn inv() {
895        // Smoke test for inv
896        assert_eq!(Elem::new(5).inv() * Elem::new(5), Elem::new(1));
897    }
898
899    #[test]
900    fn pow() {
901        // Smoke tests for pow
902        assert_eq!(Elem::new(5).pow(0), Elem::new(1));
903        assert_eq!(Elem::new(5).pow(1), Elem::new(5));
904        assert_eq!(Elem::new(5).pow(2), Elem::new(25));
905        // Mathematica says PowerMod[5, 1000, 15*2^27 + 1] == 589699054
906        assert_eq!(Elem::new(5).pow(1000), Elem::new(589699054));
907        assert_eq!(
908            Elem::new(5).pow((P - 2) as usize) * Elem::new(5),
909            Elem::new(1)
910        );
911        assert_eq!(Elem::new(5).pow((P - 1) as usize), Elem::new(1));
912    }
913
914    #[test]
915    fn compare_native() {
916        // Compare core operations against simple % P implementations
917        let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
918        for _ in 0..100_000 {
919            let fa = Elem::random(&mut rng);
920            let fb = Elem::random(&mut rng);
921            let a: u64 = fa.into();
922            let b: u64 = fb.into();
923            assert_eq!(fa + fb, Elem::from(a + b));
924            assert_eq!(fa - fb, Elem::from(a + (P_U64 - b)));
925            assert_eq!(fa * fb, Elem::from(a * b));
926        }
927    }
928
929    #[test]
930    #[cfg_attr(not(debug_assertions), ignore)]
931    #[should_panic(expected = "assertion failed: self.is_valid")]
932    fn compare_against_invalid() {
933        let _ = Elem::ZERO == Elem::INVALID;
934    }
935
936    #[test]
937    fn u32s_conversions() {
938        let mut rng = rand::rngs::SmallRng::seed_from_u64(2);
939        for _ in 0..100 {
940            let elem = Elem::random(&mut rng);
941            assert_eq!(elem, Elem::from_u32_words(&elem.to_u32_words()));
942
943            let val: u32 = rng.gen();
944            assert_eq!(val, Elem::from_u32_words(&[val]).to_u32_words()[0]);
945        }
946        for _ in 0..100 {
947            let elem = ExtElem::random(&mut rng);
948            assert_eq!(elem, ExtElem::from_u32_words(&elem.to_u32_words()));
949
950            let vec: Vec<u32> = vec![rng.gen(), rng.gen(), rng.gen(), rng.gen()];
951
952            assert_eq!(vec, ExtElem::from_u32_words(&vec).to_u32_words());
953        }
954    }
955}