Skip to main content

p3_koala_bear/
koala_bear.rs

1use core::fmt::{self, Debug, Display, Formatter};
2use core::iter::{Product, Sum};
3use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
4
5use num_bigint::BigUint;
6use p3_field::{
7    exp_1420470955, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField,
8    PrimeField31, PrimeField32, PrimeField64, TwoAdicField,
9};
10use rand::distributions::{Distribution, Standard};
11use rand::Rng;
12use serde::{Deserialize, Deserializer, Serialize};
13
14/// The KoalaBear prime: 2^31 - 2^24 + 1
15/// This is a 31-bit prime with the highest possible two adicity if we additionally demand that
16/// the cube map (x -> x^3) is an automorphism of the multiplicative group.
17/// Its not unique, as there is one other option with equal 2 adicity: 2^30 + 2^27 + 2^24 + 1.
18/// There is also one 29-bit prime with higher two adicity which might be appropriate for some applications: 2^29 - 2^26 + 1.
19const P: u32 = 0x7f000001;
20
21const MONTY_BITS: u32 = 32;
22
23// We are defining MU = P^-1 (mod 2^MONTY_BITS). This is different from the usual convention
24// (MU = -P^-1 (mod 2^MONTY_BITS)) but it avoids a carry.
25const MONTY_MU: u32 = 0x81000001;
26
27// This is derived from above.
28const MONTY_MASK: u32 = ((1u64 << MONTY_BITS) - 1) as u32;
29
30/// The prime field `2^31 - 2^24 + 1`, a.k.a. the Koala Bear field.
31#[derive(Copy, Clone, Default, Eq, Hash, PartialEq)]
32#[repr(transparent)] // `PackedKoalaBearNeon` relies on this!
33pub struct KoalaBear {
34    // This is `pub(crate)` for tests and delayed reduction strategies. If you're accessing `value` outside of those, you're
35    // likely doing something fishy.
36    pub(crate) value: u32,
37}
38
39impl KoalaBear {
40    /// create a new `KoalaBear` from a canonical `u32`.
41    #[inline]
42    pub(crate) const fn new(n: u32) -> Self {
43        Self { value: to_monty(n) }
44    }
45}
46
47impl Ord for KoalaBear {
48    #[inline]
49    fn cmp(&self, other: &Self) -> core::cmp::Ordering {
50        self.as_canonical_u32().cmp(&other.as_canonical_u32())
51    }
52}
53
54impl PartialOrd for KoalaBear {
55    #[inline]
56    fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
57        Some(self.cmp(other))
58    }
59}
60
61impl Display for KoalaBear {
62    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63        Display::fmt(&self.as_canonical_u32(), f)
64    }
65}
66
67impl Debug for KoalaBear {
68    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69        Debug::fmt(&self.as_canonical_u32(), f)
70    }
71}
72
73impl Distribution<KoalaBear> for Standard {
74    #[inline]
75    fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> KoalaBear {
76        loop {
77            let next_u31 = rng.next_u32() >> 1;
78            let is_canonical = next_u31 < P;
79            if is_canonical {
80                return KoalaBear { value: next_u31 };
81            }
82        }
83    }
84}
85
86impl Serialize for KoalaBear {
87    fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88        serializer.serialize_u32(self.as_canonical_u32())
89    }
90}
91
92impl<'de> Deserialize<'de> for KoalaBear {
93    fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
94        let val = u32::deserialize(d)?;
95        Ok(KoalaBear::from_canonical_u32(val))
96    }
97}
98
99const MONTY_ZERO: u32 = to_monty(0);
100const MONTY_ONE: u32 = to_monty(1);
101const MONTY_TWO: u32 = to_monty(2);
102const MONTY_NEG_ONE: u32 = to_monty(P - 1);
103
104impl Packable for KoalaBear {}
105
106impl AbstractField for KoalaBear {
107    type F = Self;
108
109    fn zero() -> Self {
110        Self { value: MONTY_ZERO }
111    }
112    fn one() -> Self {
113        Self { value: MONTY_ONE }
114    }
115    fn two() -> Self {
116        Self { value: MONTY_TWO }
117    }
118    fn neg_one() -> Self {
119        Self {
120            value: MONTY_NEG_ONE,
121        }
122    }
123
124    #[inline]
125    fn from_f(f: Self::F) -> Self {
126        f
127    }
128
129    #[inline]
130    fn from_bool(b: bool) -> Self {
131        Self::from_canonical_u32(b as u32)
132    }
133
134    #[inline]
135    fn from_canonical_u8(n: u8) -> Self {
136        Self::from_canonical_u32(n as u32)
137    }
138
139    #[inline]
140    fn from_canonical_u16(n: u16) -> Self {
141        Self::from_canonical_u32(n as u32)
142    }
143
144    #[inline]
145    fn from_canonical_u32(n: u32) -> Self {
146        debug_assert!(n < P);
147        Self::from_wrapped_u32(n)
148    }
149
150    #[inline]
151    fn from_canonical_u64(n: u64) -> Self {
152        debug_assert!(n < P as u64);
153        Self::from_canonical_u32(n as u32)
154    }
155
156    #[inline]
157    fn from_canonical_usize(n: usize) -> Self {
158        debug_assert!(n < P as usize);
159        Self::from_canonical_u32(n as u32)
160    }
161
162    #[inline]
163    fn from_wrapped_u32(n: u32) -> Self {
164        Self { value: to_monty(n) }
165    }
166
167    #[inline]
168    fn from_wrapped_u64(n: u64) -> Self {
169        Self {
170            value: to_monty_64(n),
171        }
172    }
173
174    #[inline]
175    fn generator() -> Self {
176        Self::from_canonical_u32(0x3)
177    }
178}
179
180impl Field for KoalaBear {
181    #[cfg(all(target_arch = "aarch64", target_feature = "neon"))]
182    type Packing = crate::PackedKoalaBearNeon;
183    #[cfg(all(
184        target_arch = "x86_64",
185        target_feature = "avx2",
186        not(target_feature = "avx512f")
187    ))]
188    type Packing = crate::PackedKoalaBearAVX2;
189    #[cfg(all(target_arch = "x86_64", target_feature = "avx512f"))]
190    type Packing = crate::PackedKoalaBearAVX512;
191    #[cfg(not(any(
192        all(target_arch = "aarch64", target_feature = "neon"),
193        all(
194            target_arch = "x86_64",
195            target_feature = "avx2",
196            not(target_feature = "avx512f")
197        ),
198        all(target_arch = "x86_64", target_feature = "avx512f"),
199    )))]
200    type Packing = Self;
201
202    #[inline]
203    fn mul_2exp_u64(&self, exp: u64) -> Self {
204        let product = (self.value as u64) << exp;
205        let value = (product % (P as u64)) as u32;
206        Self { value }
207    }
208
209    #[inline]
210    fn exp_u64_generic<AF: AbstractField<F = Self>>(val: AF, power: u64) -> AF {
211        match power {
212            1420470955 => exp_1420470955(val), // used to compute x^{1/3}
213            _ => exp_u64_by_squaring(val, power),
214        }
215    }
216
217    fn try_inverse(&self) -> Option<Self> {
218        if self.is_zero() {
219            return None;
220        }
221
222        // From Fermat's little theorem, in a prime field `F_p`, the inverse of `a` is `a^(p-2)`.
223        // Here p-2 = 2130706431 = 1111110111111111111111111111111_2
224        // Uses 29 Squares + 7 Multiplications => 36 Operations total.
225
226        let p1 = *self;
227        let p10 = p1.square();
228        let p11 = p10 * p1;
229        let p1100 = p11.exp_power_of_2(2);
230        let p1111 = p1100 * p11;
231        let p110000 = p1100.exp_power_of_2(2);
232        let p111111 = p110000 * p1111;
233        let p1111110000 = p111111.exp_power_of_2(4);
234        let p1111111111 = p1111110000 * p1111;
235        let p11111101111 = p1111111111 * p1111110000;
236        let p111111011110000000000 = p11111101111.exp_power_of_2(10);
237        let p111111011111111111111 = p111111011110000000000 * p1111111111;
238        let p1111110111111111111110000000000 = p111111011111111111111.exp_power_of_2(10);
239        let p1111110111111111111111111111111 = p1111110111111111111110000000000 * p1111111111;
240
241        Some(p1111110111111111111111111111111)
242    }
243
244    #[inline]
245    fn halve(&self) -> Self {
246        KoalaBear {
247            value: halve_u32::<P>(self.value),
248        }
249    }
250
251    #[inline]
252    fn order() -> BigUint {
253        P.into()
254    }
255}
256
257impl PrimeField for KoalaBear {
258    fn as_canonical_biguint(&self) -> BigUint {
259        <Self as PrimeField32>::as_canonical_u32(self).into()
260    }
261}
262
263impl PrimeField64 for KoalaBear {
264    const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
265
266    #[inline]
267    fn as_canonical_u64(&self) -> u64 {
268        u64::from(self.as_canonical_u32())
269    }
270}
271
272impl PrimeField32 for KoalaBear {
273    const ORDER_U32: u32 = P;
274
275    #[inline]
276    fn as_canonical_u32(&self) -> u32 {
277        from_monty(self.value)
278    }
279}
280
281impl PrimeField31 for KoalaBear {}
282
283impl TwoAdicField for KoalaBear {
284    const TWO_ADICITY: usize = 24;
285
286    fn two_adic_generator(bits: usize) -> Self {
287        assert!(bits <= Self::TWO_ADICITY);
288        match bits {
289            0 => Self::one(),
290            1 => Self::from_canonical_u32(0x7f000000),
291            2 => Self::from_canonical_u32(0x7e010002),
292            3 => Self::from_canonical_u32(0x6832fe4a),
293            4 => Self::from_canonical_u32(0x8dbd69c),
294            5 => Self::from_canonical_u32(0xa28f031),
295            6 => Self::from_canonical_u32(0x5c4a5b99),
296            7 => Self::from_canonical_u32(0x29b75a80),
297            8 => Self::from_canonical_u32(0x17668b8a),
298            9 => Self::from_canonical_u32(0x27ad539b),
299            10 => Self::from_canonical_u32(0x334d48c7),
300            11 => Self::from_canonical_u32(0x7744959c),
301            12 => Self::from_canonical_u32(0x768fc6fa),
302            13 => Self::from_canonical_u32(0x303964b2),
303            14 => Self::from_canonical_u32(0x3e687d4d),
304            15 => Self::from_canonical_u32(0x45a60e61),
305            16 => Self::from_canonical_u32(0x6e2f4d7a),
306            17 => Self::from_canonical_u32(0x163bd499),
307            18 => Self::from_canonical_u32(0x6c4a8a45),
308            19 => Self::from_canonical_u32(0x143ef899),
309            20 => Self::from_canonical_u32(0x514ddcad),
310            21 => Self::from_canonical_u32(0x484ef19b),
311            22 => Self::from_canonical_u32(0x205d63c3),
312            23 => Self::from_canonical_u32(0x68e7dd49),
313            24 => Self::from_canonical_u32(0x6ac49f88),
314            _ => unreachable!("Already asserted that bits <= Self::TWO_ADICITY"),
315        }
316    }
317}
318
319impl Add for KoalaBear {
320    type Output = Self;
321
322    #[inline]
323    fn add(self, rhs: Self) -> Self {
324        let mut sum = self.value + rhs.value;
325        let (corr_sum, over) = sum.overflowing_sub(P);
326        if !over {
327            sum = corr_sum;
328        }
329        Self { value: sum }
330    }
331}
332
333impl AddAssign for KoalaBear {
334    #[inline]
335    fn add_assign(&mut self, rhs: Self) {
336        *self = *self + rhs;
337    }
338}
339
340impl Sum for KoalaBear {
341    #[inline]
342    fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
343        // This is faster than iter.reduce(|x, y| x + y).unwrap_or(Self::zero()) for iterators of length > 2.
344        // There might be a faster reduction method possible for lengths <= 16 which avoids %.
345
346        // This sum will not overflow so long as iter.len() < 2^33.
347        let sum = iter.map(|x| (x.value as u64)).sum::<u64>();
348        Self {
349            value: (sum % P as u64) as u32,
350        }
351    }
352}
353
354impl Sub for KoalaBear {
355    type Output = Self;
356
357    #[inline]
358    fn sub(self, rhs: Self) -> Self {
359        let (mut diff, over) = self.value.overflowing_sub(rhs.value);
360        let corr = if over { P } else { 0 };
361        diff = diff.wrapping_add(corr);
362        Self { value: diff }
363    }
364}
365
366impl SubAssign for KoalaBear {
367    #[inline]
368    fn sub_assign(&mut self, rhs: Self) {
369        *self = *self - rhs;
370    }
371}
372
373impl Neg for KoalaBear {
374    type Output = Self;
375
376    #[inline]
377    fn neg(self) -> Self::Output {
378        Self::zero() - self
379    }
380}
381
382impl Mul for KoalaBear {
383    type Output = Self;
384
385    #[inline]
386    fn mul(self, rhs: Self) -> Self {
387        let long_prod = self.value as u64 * rhs.value as u64;
388        Self {
389            value: monty_reduce(long_prod),
390        }
391    }
392}
393
394impl MulAssign for KoalaBear {
395    #[inline]
396    fn mul_assign(&mut self, rhs: Self) {
397        *self = *self * rhs;
398    }
399}
400
401impl Product for KoalaBear {
402    #[inline]
403    fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
404        iter.reduce(|x, y| x * y).unwrap_or(Self::one())
405    }
406}
407
408impl Div for KoalaBear {
409    type Output = Self;
410
411    #[allow(clippy::suspicious_arithmetic_impl)]
412    #[inline]
413    fn div(self, rhs: Self) -> Self {
414        self * rhs.inverse()
415    }
416}
417
418#[inline]
419#[must_use]
420const fn to_monty(x: u32) -> u32 {
421    (((x as u64) << MONTY_BITS) % P as u64) as u32
422}
423
424/// Convert a constant u32 array into a constant KoalaBear array.
425/// Saves every element in Monty Form
426#[inline]
427#[must_use]
428pub(crate) const fn to_koalabear_array<const N: usize>(input: [u32; N]) -> [KoalaBear; N] {
429    let mut output = [KoalaBear { value: 0 }; N];
430    let mut i = 0;
431    loop {
432        if i == N {
433            break;
434        }
435        output[i].value = to_monty(input[i]);
436        i += 1;
437    }
438    output
439}
440
441#[inline]
442#[must_use]
443const fn to_monty_64(x: u64) -> u32 {
444    (((x as u128) << MONTY_BITS) % P as u128) as u32
445}
446
447#[inline]
448#[must_use]
449const fn from_monty(x: u32) -> u32 {
450    monty_reduce(x as u64)
451}
452
453/// Montgomery reduction of a value in `0..P << MONTY_BITS`.
454#[inline]
455#[must_use]
456pub(crate) const fn monty_reduce(x: u64) -> u32 {
457    let t = x.wrapping_mul(MONTY_MU as u64) & (MONTY_MASK as u64);
458    let u = t * (P as u64);
459
460    let (x_sub_u, over) = x.overflowing_sub(u);
461    let x_sub_u_hi = (x_sub_u >> MONTY_BITS) as u32;
462    let corr = if over { P } else { 0 };
463    x_sub_u_hi.wrapping_add(corr)
464}
465
466#[cfg(test)]
467mod tests {
468    use p3_field_testing::{test_field, test_two_adic_field};
469
470    use super::*;
471
472    type F = KoalaBear;
473
474    #[test]
475    fn test_koala_bear_two_adicity_generators() {
476        let base = KoalaBear::from_canonical_u32(0x6ac49f88);
477        for bits in 0..=KoalaBear::TWO_ADICITY {
478            assert_eq!(
479                KoalaBear::two_adic_generator(bits),
480                base.exp_power_of_2(KoalaBear::TWO_ADICITY - bits)
481            );
482        }
483    }
484
485    #[test]
486    fn test_koala_bear() {
487        let f = F::from_canonical_u32(100);
488        assert_eq!(f.as_canonical_u64(), 100);
489
490        let f = F::from_canonical_u32(0);
491        assert!(f.is_zero());
492
493        let f = F::from_wrapped_u32(F::ORDER_U32);
494        assert!(f.is_zero());
495
496        let f_1 = F::one();
497        let f_1_copy = F::from_canonical_u32(1);
498
499        let expected_result = F::zero();
500        assert_eq!(f_1 - f_1_copy, expected_result);
501
502        let expected_result = F::two();
503        assert_eq!(f_1 + f_1_copy, expected_result);
504
505        let f_2 = F::from_canonical_u32(2);
506        let expected_result = F::from_canonical_u32(3);
507        assert_eq!(f_1 + f_1_copy * f_2, expected_result);
508
509        let expected_result = F::from_canonical_u32(5);
510        assert_eq!(f_1 + f_2 * f_2, expected_result);
511
512        let f_p_minus_1 = F::from_canonical_u32(F::ORDER_U32 - 1);
513        let expected_result = F::zero();
514        assert_eq!(f_1 + f_p_minus_1, expected_result);
515
516        let f_p_minus_2 = F::from_canonical_u32(F::ORDER_U32 - 2);
517        let expected_result = F::from_canonical_u32(F::ORDER_U32 - 3);
518        assert_eq!(f_p_minus_1 + f_p_minus_2, expected_result);
519
520        let expected_result = F::from_canonical_u32(1);
521        assert_eq!(f_p_minus_1 - f_p_minus_2, expected_result);
522
523        let expected_result = f_p_minus_1;
524        assert_eq!(f_p_minus_2 - f_p_minus_1, expected_result);
525
526        let expected_result = f_p_minus_2;
527        assert_eq!(f_p_minus_1 - f_1, expected_result);
528
529        let m1 = F::from_canonical_u32(0x34167c58);
530        let m2 = F::from_canonical_u32(0x61f3207b);
531        let expected_prod = F::from_canonical_u32(0x54b46b81);
532        assert_eq!(m1 * m2, expected_prod);
533
534        assert_eq!(m1.exp_u64(1420470955).exp_const_u64::<3>(), m1);
535        assert_eq!(m2.exp_u64(1420470955).exp_const_u64::<3>(), m2);
536        assert_eq!(f_2.exp_u64(1420470955).exp_const_u64::<3>(), f_2);
537
538        let f_serialized = serde_json::to_string(&f).unwrap();
539        let f_deserialized: F = serde_json::from_str(&f_serialized).unwrap();
540        assert_eq!(f, f_deserialized);
541
542        let f_1_serialized = serde_json::to_string(&f_1).unwrap();
543        let f_1_deserialized: F = serde_json::from_str(&f_1_serialized).unwrap();
544        let f_1_serialized_again = serde_json::to_string(&f_1_deserialized).unwrap();
545        let f_1_deserialized_again: F = serde_json::from_str(&f_1_serialized_again).unwrap();
546        assert_eq!(f_1, f_1_deserialized);
547        assert_eq!(f_1, f_1_deserialized_again);
548
549        let f_2_serialized = serde_json::to_string(&f_2).unwrap();
550        let f_2_deserialized: F = serde_json::from_str(&f_2_serialized).unwrap();
551        assert_eq!(f_2, f_2_deserialized);
552
553        let f_p_minus_1_serialized = serde_json::to_string(&f_p_minus_1).unwrap();
554        let f_p_minus_1_deserialized: F = serde_json::from_str(&f_p_minus_1_serialized).unwrap();
555        assert_eq!(f_p_minus_1, f_p_minus_1_deserialized);
556
557        let f_p_minus_2_serialized = serde_json::to_string(&f_p_minus_2).unwrap();
558        let f_p_minus_2_deserialized: F = serde_json::from_str(&f_p_minus_2_serialized).unwrap();
559        assert_eq!(f_p_minus_2, f_p_minus_2_deserialized);
560
561        let m1_serialized = serde_json::to_string(&m1).unwrap();
562        let m1_deserialized: F = serde_json::from_str(&m1_serialized).unwrap();
563        assert_eq!(m1, m1_deserialized);
564
565        let m2_serialized = serde_json::to_string(&m2).unwrap();
566        let m2_deserialized: F = serde_json::from_str(&m2_serialized).unwrap();
567        assert_eq!(m2, m2_deserialized);
568    }
569
570    test_field!(crate::KoalaBear);
571    test_two_adic_field!(crate::KoalaBear);
572}