Skip to main content

p3_koala_bear/
koala_bear.rs

1use core::fmt::{self, Debug, Display, Formatter};
2use core::iter::{Product, Sum};
3use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
4
5use num_bigint::BigUint;
6use p3_field::{
7    exp_1420470955, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField,
8    PrimeField31, PrimeField32, PrimeField64, TwoAdicField,
9};
10use rand::distributions::{Distribution, Standard};
11use rand::Rng;
12use serde::{Deserialize, Deserializer, Serialize};
13
14/// The KoalaBear prime: 2^31 - 2^24 + 1
15/// This is a 31-bit prime with the highest possible two adicity if we additionally demand that
16/// the cube map (x -> x^3) is an automorphism of the multiplicative group.
17/// Its not unique, as there is one other option with equal 2 adicity: 2^30 + 2^27 + 2^24 + 1.
18/// There is also one 29-bit prime with higher two adicity which might be appropriate for some applications: 2^29 - 2^26 + 1.
19const P: u32 = 0x7f000001;
20
21const MONTY_BITS: u32 = 32;
22
23// We are defining MU = P^-1 (mod 2^MONTY_BITS). This is different from the usual convention
24// (MU = -P^-1 (mod 2^MONTY_BITS)) but it avoids a carry.
25const MONTY_MU: u32 = 0x81000001;
26
27// This is derived from above.
28const MONTY_MASK: u32 = ((1u64 << MONTY_BITS) - 1) as u32;
29
30/// The prime field `2^31 - 2^24 + 1`, a.k.a. the Koala Bear field.
31#[derive(Copy, Clone, Default, Eq, Hash, PartialEq)]
32#[repr(transparent)] // `PackedKoalaBearNeon` relies on this!
33pub struct KoalaBear {
34    // This is `pub(crate)` for tests and delayed reduction strategies. If you're accessing `value` outside of those, you're
35    // likely doing something fishy.
36    pub(crate) value: u32,
37}
38
39impl KoalaBear {
40    /// create a new `KoalaBear` from a canonical `u32`.
41    #[inline]
42    pub(crate) const fn new(n: u32) -> Self {
43        Self { value: to_monty(n) }
44    }
45}
46
47impl Ord for KoalaBear {
48    #[inline]
49    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
50        self.as_canonical_u32().cmp(&other.as_canonical_u32())
51    }
52}
53
54impl PartialOrd for KoalaBear {
55    #[inline]
56    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
57        Some(self.cmp(other))
58    }
59}
60
61impl Display for KoalaBear {
62    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63        Display::fmt(&self.as_canonical_u32(), f)
64    }
65}
66
67impl Debug for KoalaBear {
68    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69        Debug::fmt(&self.as_canonical_u32(), f)
70    }
71}
72
73impl Distribution<KoalaBear> for Standard {
74    #[inline]
75    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> KoalaBear {
76        loop {
77            let next_u31 = rng.next_u32() >> 1;
78            let is_canonical = next_u31 < P;
79            if is_canonical {
80                return KoalaBear { value: next_u31 };
81            }
82        }
83    }
84}
85
86impl Serialize for KoalaBear {
87    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88        serializer.serialize_u32(self.as_canonical_u32())
89    }
90}
91
92impl<'de> Deserialize<'de> for KoalaBear {
93    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
94        let val = u32::deserialize(d)?;
95        Ok(KoalaBear::from_canonical_u32(val))
96    }
97}
98
99const MONTY_ZERO: u32 = to_monty(0);
100const MONTY_ONE: u32 = to_monty(1);
101const MONTY_TWO: u32 = to_monty(2);
102const MONTY_NEG_ONE: u32 = to_monty(P - 1);
103
104impl Packable for KoalaBear {}
105
106impl AbstractField for KoalaBear {
107    type F = Self;
108
109    fn zero() -> Self {
110        Self { value: MONTY_ZERO }
111    }
112    fn one() -> Self {
113        Self { value: MONTY_ONE }
114    }
115    fn two() -> Self {
116        Self { value: MONTY_TWO }
117    }
118    fn neg_one() -> Self {
119        Self {
120            value: MONTY_NEG_ONE,
121        }
122    }
123
124    #[inline]
125    fn from_f(f: Self::F) -> Self {
126        f
127    }
128
129    #[inline]
130    fn from_bool(b: bool) -> Self {
131        Self::from_canonical_u32(b as u32)
132    }
133
134    #[inline]
135    fn from_canonical_u8(n: u8) -> Self {
136        Self::from_canonical_u32(n as u32)
137    }
138
139    #[inline]
140    fn from_canonical_u16(n: u16) -> Self {
141        Self::from_canonical_u32(n as u32)
142    }
143
144    #[inline]
145    fn from_canonical_u32(n: u32) -> Self {
146        debug_assert!(n < P);
147        Self::from_wrapped_u32(n)
148    }
149
150    #[inline]
151    fn from_canonical_u64(n: u64) -> Self {
152        debug_assert!(n < P as u64);
153        Self::from_canonical_u32(n as u32)
154    }
155
156    #[inline]
157    fn from_canonical_usize(n: usize) -> Self {
158        debug_assert!(n < P as usize);
159        Self::from_canonical_u32(n as u32)
160    }
161
162    #[inline]
163    fn from_wrapped_u32(n: u32) -> Self {
164        Self { value: to_monty(n) }
165    }
166
167    #[inline]
168    fn from_wrapped_u64(n: u64) -> Self {
169        Self {
170            value: to_monty_64(n),
171        }
172    }
173
174    #[inline]
175    fn generator() -> Self {
176        Self::from_canonical_u32(0x3)
177    }
178}
179
180impl Field for KoalaBear {
181    cfg_if::cfg_if! {
182        if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] {
183            type Packing = crate::PackedKoalaBearNeon;
184        } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f", rustc_version_1_89_or_later))] {
185            type Packing = crate::PackedKoalaBearAVX512;
186        } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
187            type Packing = crate::PackedKoalaBearAVX2;
188        } else {
189            type Packing = Self;
190        }
191    }
192
193    #[inline]
194    fn mul_2exp_u64(&self, exp: u64) -> Self {
195        let product = (self.value as u64) << exp;
196        let value = (product % (P as u64)) as u32;
197        Self { value }
198    }
199
200    #[inline]
201    fn exp_u64_generic<AF: AbstractField<F = Self>>(val: AF, power: u64) -> AF {
202        match power {
203            1420470955 => exp_1420470955(val), // used to compute x^{1/3}
204            _ => exp_u64_by_squaring(val, power),
205        }
206    }
207
208    fn try_inverse(&self) -> Option<Self> {
209        if self.is_zero() {
210            return None;
211        }
212
213        // From Fermat's little theorem, in a prime field `F_p`, the inverse of `a` is `a^(p-2)`.
214        // Here p-2 = 2130706431 = 1111110111111111111111111111111_2
215        // Uses 29 Squares + 7 Multiplications => 36 Operations total.
216
217        let p1 = *self;
218        let p10 = p1.square();
219        let p11 = p10 * p1;
220        let p1100 = p11.exp_power_of_2(2);
221        let p1111 = p1100 * p11;
222        let p110000 = p1100.exp_power_of_2(2);
223        let p111111 = p110000 * p1111;
224        let p1111110000 = p111111.exp_power_of_2(4);
225        let p1111111111 = p1111110000 * p1111;
226        let p11111101111 = p1111111111 * p1111110000;
227        let p111111011110000000000 = p11111101111.exp_power_of_2(10);
228        let p111111011111111111111 = p111111011110000000000 * p1111111111;
229        let p1111110111111111111110000000000 = p111111011111111111111.exp_power_of_2(10);
230        let p1111110111111111111111111111111 = p1111110111111111111110000000000 * p1111111111;
231
232        Some(p1111110111111111111111111111111)
233    }
234
235    #[inline]
236    fn halve(&self) -> Self {
237        KoalaBear {
238            value: halve_u32::<P>(self.value),
239        }
240    }
241
242    #[inline]
243    fn order() -> BigUint {
244        P.into()
245    }
246}
247
248impl PrimeField for KoalaBear {
249    fn as_canonical_biguint(&self) -> BigUint {
250        <Self as PrimeField32>::as_canonical_u32(self).into()
251    }
252}
253
254impl PrimeField64 for KoalaBear {
255    const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
256
257    #[inline]
258    fn as_canonical_u64(&self) -> u64 {
259        u64::from(self.as_canonical_u32())
260    }
261}
262
263impl PrimeField32 for KoalaBear {
264    const ORDER_U32: u32 = P;
265
266    #[inline]
267    fn as_canonical_u32(&self) -> u32 {
268        from_monty(self.value)
269    }
270}
271
272impl PrimeField31 for KoalaBear {}
273
274impl TwoAdicField for KoalaBear {
275    const TWO_ADICITY: usize = 24;
276
277    fn two_adic_generator(bits: usize) -> Self {
278        assert!(bits <= Self::TWO_ADICITY);
279        match bits {
280            0 => Self::one(),
281            1 => Self::from_canonical_u32(0x7f000000),
282            2 => Self::from_canonical_u32(0x7e010002),
283            3 => Self::from_canonical_u32(0x6832fe4a),
284            4 => Self::from_canonical_u32(0x8dbd69c),
285            5 => Self::from_canonical_u32(0xa28f031),
286            6 => Self::from_canonical_u32(0x5c4a5b99),
287            7 => Self::from_canonical_u32(0x29b75a80),
288            8 => Self::from_canonical_u32(0x17668b8a),
289            9 => Self::from_canonical_u32(0x27ad539b),
290            10 => Self::from_canonical_u32(0x334d48c7),
291            11 => Self::from_canonical_u32(0x7744959c),
292            12 => Self::from_canonical_u32(0x768fc6fa),
293            13 => Self::from_canonical_u32(0x303964b2),
294            14 => Self::from_canonical_u32(0x3e687d4d),
295            15 => Self::from_canonical_u32(0x45a60e61),
296            16 => Self::from_canonical_u32(0x6e2f4d7a),
297            17 => Self::from_canonical_u32(0x163bd499),
298            18 => Self::from_canonical_u32(0x6c4a8a45),
299            19 => Self::from_canonical_u32(0x143ef899),
300            20 => Self::from_canonical_u32(0x514ddcad),
301            21 => Self::from_canonical_u32(0x484ef19b),
302            22 => Self::from_canonical_u32(0x205d63c3),
303            23 => Self::from_canonical_u32(0x68e7dd49),
304            24 => Self::from_canonical_u32(0x6ac49f88),
305            _ => unreachable!("Already asserted that bits <= Self::TWO_ADICITY"),
306        }
307    }
308}
309
310impl Add for KoalaBear {
311    type Output = Self;
312
313    #[inline]
314    fn add(self, rhs: Self) -> Self {
315        let mut sum = self.value + rhs.value;
316        let (corr_sum, over) = sum.overflowing_sub(P);
317        if !over {
318            sum = corr_sum;
319        }
320        Self { value: sum }
321    }
322}
323
324impl AddAssign for KoalaBear {
325    #[inline]
326    fn add_assign(&mut self, rhs: Self) {
327        *self = *self + rhs;
328    }
329}
330
331impl Sum for KoalaBear {
332    #[inline]
333    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
334        // This is faster than iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) for iterators of length > 2.
335        // There might be a faster reduction method possible for lengths <= 16 which avoids %.
336
337        // This sum will not overflow so long as iter.len() < 2^33.
338        let sum = iter.map(|x| (x.value as u64)).sum::<u64>();
339        Self {
340            value: (sum % P as u64) as u32,
341        }
342    }
343}
344
345impl Sub for KoalaBear {
346    type Output = Self;
347
348    #[inline]
349    fn sub(self, rhs: Self) -> Self {
350        let (mut diff, over) = self.value.overflowing_sub(rhs.value);
351        let corr = if over { P } else { 0 };
352        diff = diff.wrapping_add(corr);
353        Self { value: diff }
354    }
355}
356
357impl SubAssign for KoalaBear {
358    #[inline]
359    fn sub_assign(&mut self, rhs: Self) {
360        *self = *self - rhs;
361    }
362}
363
364impl Neg for KoalaBear {
365    type Output = Self;
366
367    #[inline]
368    fn neg(self) -> Self::Output {
369        Self::zero() - self
370    }
371}
372
373impl Mul for KoalaBear {
374    type Output = Self;
375
376    #[inline]
377    fn mul(self, rhs: Self) -> Self {
378        let long_prod = self.value as u64 * rhs.value as u64;
379        Self {
380            value: monty_reduce(long_prod),
381        }
382    }
383}
384
385impl MulAssign for KoalaBear {
386    #[inline]
387    fn mul_assign(&mut self, rhs: Self) {
388        *self = *self * rhs;
389    }
390}
391
392impl Product for KoalaBear {
393    #[inline]
394    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
395        iter.reduce(|x, y| x * y).unwrap_or(Self::one())
396    }
397}
398
399impl Div for KoalaBear {
400    type Output = Self;
401
402    #[allow(clippy::suspicious_arithmetic_impl)]
403    #[inline]
404    fn div(self, rhs: Self) -> Self {
405        self * rhs.inverse()
406    }
407}
408
409#[inline]
410#[must_use]
411const fn to_monty(x: u32) -> u32 {
412    (((x as u64) << MONTY_BITS) % P as u64) as u32
413}
414
415/// Convert a constant u32 array into a constant KoalaBear array.
416/// Saves every element in Monty Form
417#[inline]
418#[must_use]
419pub(crate) const fn to_koalabear_array<const N: usize>(input: [u32; N]) -> [KoalaBear; N] {
420    let mut output = [KoalaBear { value: 0 }; N];
421    let mut i = 0;
422    loop {
423        if i == N {
424            break;
425        }
426        output[i].value = to_monty(input[i]);
427        i += 1;
428    }
429    output
430}
431
432#[inline]
433#[must_use]
434const fn to_monty_64(x: u64) -> u32 {
435    (((x as u128) << MONTY_BITS) % P as u128) as u32
436}
437
438#[inline]
439#[must_use]
440const fn from_monty(x: u32) -> u32 {
441    monty_reduce(x as u64)
442}
443
444/// Montgomery reduction of a value in `0..P << MONTY_BITS`.
445#[inline]
446#[must_use]
447pub(crate) const fn monty_reduce(x: u64) -> u32 {
448    let t = x.wrapping_mul(MONTY_MU as u64) & (MONTY_MASK as u64);
449    let u = t * (P as u64);
450
451    let (x_sub_u, over) = x.overflowing_sub(u);
452    let x_sub_u_hi = (x_sub_u >> MONTY_BITS) as u32;
453    let corr = if over { P } else { 0 };
454    x_sub_u_hi.wrapping_add(corr)
455}
456
457#[cfg(test)]
458mod tests {
459    use p3_field_testing::{test_field, test_two_adic_field};
460
461    use super::*;
462
463    type F = KoalaBear;
464
465    #[test]
466    fn test_koala_bear_two_adicity_generators() {
467        let base = KoalaBear::from_canonical_u32(0x6ac49f88);
468        for bits in 0..=KoalaBear::TWO_ADICITY {
469            assert_eq!(
470                KoalaBear::two_adic_generator(bits),
471                base.exp_power_of_2(KoalaBear::TWO_ADICITY - bits)
472            );
473        }
474    }
475
476    #[test]
477    fn test_koala_bear() {
478        let f = F::from_canonical_u32(100);
479        assert_eq!(f.as_canonical_u64(), 100);
480
481        let f = F::from_canonical_u32(0);
482        assert!(f.is_zero());
483
484        let f = F::from_wrapped_u32(F::ORDER_U32);
485        assert!(f.is_zero());
486
487        let f_1 = F::one();
488        let f_1_copy = F::from_canonical_u32(1);
489
490        let expected_result = F::zero();
491        assert_eq!(f_1 - f_1_copy, expected_result);
492
493        let expected_result = F::two();
494        assert_eq!(f_1 + f_1_copy, expected_result);
495
496        let f_2 = F::from_canonical_u32(2);
497        let expected_result = F::from_canonical_u32(3);
498        assert_eq!(f_1 + f_1_copy * f_2, expected_result);
499
500        let expected_result = F::from_canonical_u32(5);
501        assert_eq!(f_1 + f_2 * f_2, expected_result);
502
503        let f_p_minus_1 = F::from_canonical_u32(F::ORDER_U32 - 1);
504        let expected_result = F::zero();
505        assert_eq!(f_1 + f_p_minus_1, expected_result);
506
507        let f_p_minus_2 = F::from_canonical_u32(F::ORDER_U32 - 2);
508        let expected_result = F::from_canonical_u32(F::ORDER_U32 - 3);
509        assert_eq!(f_p_minus_1 + f_p_minus_2, expected_result);
510
511        let expected_result = F::from_canonical_u32(1);
512        assert_eq!(f_p_minus_1 - f_p_minus_2, expected_result);
513
514        let expected_result = f_p_minus_1;
515        assert_eq!(f_p_minus_2 - f_p_minus_1, expected_result);
516
517        let expected_result = f_p_minus_2;
518        assert_eq!(f_p_minus_1 - f_1, expected_result);
519
520        let m1 = F::from_canonical_u32(0x34167c58);
521        let m2 = F::from_canonical_u32(0x61f3207b);
522        let expected_prod = F::from_canonical_u32(0x54b46b81);
523        assert_eq!(m1 * m2, expected_prod);
524
525        assert_eq!(m1.exp_u64(1420470955).exp_const_u64::<3>(), m1);
526        assert_eq!(m2.exp_u64(1420470955).exp_const_u64::<3>(), m2);
527        assert_eq!(f_2.exp_u64(1420470955).exp_const_u64::<3>(), f_2);
528
529        let f_serialized = serde_json::to_string(&f).unwrap();
530        let f_deserialized: F = serde_json::from_str(&f_serialized).unwrap();
531        assert_eq!(f, f_deserialized);
532
533        let f_1_serialized = serde_json::to_string(&f_1).unwrap();
534        let f_1_deserialized: F = serde_json::from_str(&f_1_serialized).unwrap();
535        let f_1_serialized_again = serde_json::to_string(&f_1_deserialized).unwrap();
536        let f_1_deserialized_again: F = serde_json::from_str(&f_1_serialized_again).unwrap();
537        assert_eq!(f_1, f_1_deserialized);
538        assert_eq!(f_1, f_1_deserialized_again);
539
540        let f_2_serialized = serde_json::to_string(&f_2).unwrap();
541        let f_2_deserialized: F = serde_json::from_str(&f_2_serialized).unwrap();
542        assert_eq!(f_2, f_2_deserialized);
543
544        let f_p_minus_1_serialized = serde_json::to_string(&f_p_minus_1).unwrap();
545        let f_p_minus_1_deserialized: F = serde_json::from_str(&f_p_minus_1_serialized).unwrap();
546        assert_eq!(f_p_minus_1, f_p_minus_1_deserialized);
547
548        let f_p_minus_2_serialized = serde_json::to_string(&f_p_minus_2).unwrap();
549        let f_p_minus_2_deserialized: F = serde_json::from_str(&f_p_minus_2_serialized).unwrap();
550        assert_eq!(f_p_minus_2, f_p_minus_2_deserialized);
551
552        let m1_serialized = serde_json::to_string(&m1).unwrap();
553        let m1_deserialized: F = serde_json::from_str(&m1_serialized).unwrap();
554        assert_eq!(m1, m1_deserialized);
555
556        let m2_serialized = serde_json::to_string(&m2).unwrap();
557        let m2_deserialized: F = serde_json::from_str(&m2_serialized).unwrap();
558        assert_eq!(m2, m2_deserialized);
559    }
560
561    test_field!(crate::KoalaBear);
562    test_two_adic_field!(crate::KoalaBear);
563}