1use core::fmt::{self, Debug, Display, Formatter};
2use core::iter::{Product, Sum};
3use core::ops::{Add, AddAssign, Div, Mul, MulAssign, Neg, Sub, SubAssign};
4
5use num_bigint::BigUint;
6use p3_field::{
7 exp_1420470955, exp_u64_by_squaring, halve_u32, AbstractField, Field, Packable, PrimeField,
8 PrimeField31, PrimeField32, PrimeField64, TwoAdicField,
9};
10use rand::distributions::{Distribution, Standard};
11use rand::Rng;
12use serde::{Deserialize, Deserializer, Serialize};
13
14const P: u32 = 0x7f000001;
20
21const MONTY_BITS: u32 = 32;
22
23const MONTY_MU: u32 = 0x81000001;
26
27const MONTY_MASK: u32 = ((1u64 << MONTY_BITS) - 1) as u32;
29
30#[derive(Copy, Clone, Default, Eq, Hash, PartialEq)]
32#[repr(transparent)] pub struct KoalaBear {
34 pub(crate) value: u32,
37}
38
39impl KoalaBear {
40 #[inline]
42 pub(crate) const fn new(n: u32) -> Self {
43 Self { value: to_monty(n) }
44 }
45}
46
47impl Ord for KoalaBear {
48 #[inline]
49 fn cmp(&self, other: &Self) -> core::cmp::Ordering {
50 self.as_canonical_u32().cmp(&other.as_canonical_u32())
51 }
52}
53
54impl PartialOrd for KoalaBear {
55 #[inline]
56 fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
57 Some(self.cmp(other))
58 }
59}
60
61impl Display for KoalaBear {
62 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
63 Display::fmt(&self.as_canonical_u32(), f)
64 }
65}
66
67impl Debug for KoalaBear {
68 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
69 Debug::fmt(&self.as_canonical_u32(), f)
70 }
71}
72
73impl Distribution<KoalaBear> for Standard {
74 #[inline]
75 fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> KoalaBear {
76 loop {
77 let next_u31 = rng.next_u32() >> 1;
78 let is_canonical = next_u31 < P;
79 if is_canonical {
80 return KoalaBear { value: next_u31 };
81 }
82 }
83 }
84}
85
86impl Serialize for KoalaBear {
87 fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
88 serializer.serialize_u32(self.as_canonical_u32())
89 }
90}
91
92impl<'de> Deserialize<'de> for KoalaBear {
93 fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Self, D::Error> {
94 let val = u32::deserialize(d)?;
95 Ok(KoalaBear::from_canonical_u32(val))
96 }
97}
98
99const MONTY_ZERO: u32 = to_monty(0);
100const MONTY_ONE: u32 = to_monty(1);
101const MONTY_TWO: u32 = to_monty(2);
102const MONTY_NEG_ONE: u32 = to_monty(P - 1);
103
104impl Packable for KoalaBear {}
105
106impl AbstractField for KoalaBear {
107 type F = Self;
108
109 fn zero() -> Self {
110 Self { value: MONTY_ZERO }
111 }
112 fn one() -> Self {
113 Self { value: MONTY_ONE }
114 }
115 fn two() -> Self {
116 Self { value: MONTY_TWO }
117 }
118 fn neg_one() -> Self {
119 Self {
120 value: MONTY_NEG_ONE,
121 }
122 }
123
124 #[inline]
125 fn from_f(f: Self::F) -> Self {
126 f
127 }
128
129 #[inline]
130 fn from_bool(b: bool) -> Self {
131 Self::from_canonical_u32(b as u32)
132 }
133
134 #[inline]
135 fn from_canonical_u8(n: u8) -> Self {
136 Self::from_canonical_u32(n as u32)
137 }
138
139 #[inline]
140 fn from_canonical_u16(n: u16) -> Self {
141 Self::from_canonical_u32(n as u32)
142 }
143
144 #[inline]
145 fn from_canonical_u32(n: u32) -> Self {
146 debug_assert!(n < P);
147 Self::from_wrapped_u32(n)
148 }
149
150 #[inline]
151 fn from_canonical_u64(n: u64) -> Self {
152 debug_assert!(n < P as u64);
153 Self::from_canonical_u32(n as u32)
154 }
155
156 #[inline]
157 fn from_canonical_usize(n: usize) -> Self {
158 debug_assert!(n < P as usize);
159 Self::from_canonical_u32(n as u32)
160 }
161
162 #[inline]
163 fn from_wrapped_u32(n: u32) -> Self {
164 Self { value: to_monty(n) }
165 }
166
167 #[inline]
168 fn from_wrapped_u64(n: u64) -> Self {
169 Self {
170 value: to_monty_64(n),
171 }
172 }
173
174 #[inline]
175 fn generator() -> Self {
176 Self::from_canonical_u32(0x3)
177 }
178}
179
180impl Field for KoalaBear {
181 cfg_if::cfg_if! {
182 if #[cfg(all(target_arch = "aarch64", target_feature = "neon"))] {
183 type Packing = crate::PackedKoalaBearNeon;
184 } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx512f", rustc_version_1_89_or_later))] {
185 type Packing = crate::PackedKoalaBearAVX512;
186 } else if #[cfg(all(target_arch = "x86_64", target_feature = "avx2"))] {
187 type Packing = crate::PackedKoalaBearAVX2;
188 } else {
189 type Packing = Self;
190 }
191 }
192
193 #[inline]
194 fn mul_2exp_u64(&self, exp: u64) -> Self {
195 let product = (self.value as u64) << exp;
196 let value = (product % (P as u64)) as u32;
197 Self { value }
198 }
199
200 #[inline]
201 fn exp_u64_generic<AF: AbstractField<F = Self>>(val: AF, power: u64) -> AF {
202 match power {
203 1420470955 => exp_1420470955(val), _ => exp_u64_by_squaring(val, power),
205 }
206 }
207
208 fn try_inverse(&self) -> Option<Self> {
209 if self.is_zero() {
210 return None;
211 }
212
213 let p1 = *self;
218 let p10 = p1.square();
219 let p11 = p10 * p1;
220 let p1100 = p11.exp_power_of_2(2);
221 let p1111 = p1100 * p11;
222 let p110000 = p1100.exp_power_of_2(2);
223 let p111111 = p110000 * p1111;
224 let p1111110000 = p111111.exp_power_of_2(4);
225 let p1111111111 = p1111110000 * p1111;
226 let p11111101111 = p1111111111 * p1111110000;
227 let p111111011110000000000 = p11111101111.exp_power_of_2(10);
228 let p111111011111111111111 = p111111011110000000000 * p1111111111;
229 let p1111110111111111111110000000000 = p111111011111111111111.exp_power_of_2(10);
230 let p1111110111111111111111111111111 = p1111110111111111111110000000000 * p1111111111;
231
232 Some(p1111110111111111111111111111111)
233 }
234
235 #[inline]
236 fn halve(&self) -> Self {
237 KoalaBear {
238 value: halve_u32::<P>(self.value),
239 }
240 }
241
242 #[inline]
243 fn order() -> BigUint {
244 P.into()
245 }
246}
247
248impl PrimeField for KoalaBear {
249 fn as_canonical_biguint(&self) -> BigUint {
250 <Self as PrimeField32>::as_canonical_u32(self).into()
251 }
252}
253
254impl PrimeField64 for KoalaBear {
255 const ORDER_U64: u64 = <Self as PrimeField32>::ORDER_U32 as u64;
256
257 #[inline]
258 fn as_canonical_u64(&self) -> u64 {
259 u64::from(self.as_canonical_u32())
260 }
261}
262
263impl PrimeField32 for KoalaBear {
264 const ORDER_U32: u32 = P;
265
266 #[inline]
267 fn as_canonical_u32(&self) -> u32 {
268 from_monty(self.value)
269 }
270}
271
272impl PrimeField31 for KoalaBear {}
273
274impl TwoAdicField for KoalaBear {
275 const TWO_ADICITY: usize = 24;
276
277 fn two_adic_generator(bits: usize) -> Self {
278 assert!(bits <= Self::TWO_ADICITY);
279 match bits {
280 0 => Self::one(),
281 1 => Self::from_canonical_u32(0x7f000000),
282 2 => Self::from_canonical_u32(0x7e010002),
283 3 => Self::from_canonical_u32(0x6832fe4a),
284 4 => Self::from_canonical_u32(0x8dbd69c),
285 5 => Self::from_canonical_u32(0xa28f031),
286 6 => Self::from_canonical_u32(0x5c4a5b99),
287 7 => Self::from_canonical_u32(0x29b75a80),
288 8 => Self::from_canonical_u32(0x17668b8a),
289 9 => Self::from_canonical_u32(0x27ad539b),
290 10 => Self::from_canonical_u32(0x334d48c7),
291 11 => Self::from_canonical_u32(0x7744959c),
292 12 => Self::from_canonical_u32(0x768fc6fa),
293 13 => Self::from_canonical_u32(0x303964b2),
294 14 => Self::from_canonical_u32(0x3e687d4d),
295 15 => Self::from_canonical_u32(0x45a60e61),
296 16 => Self::from_canonical_u32(0x6e2f4d7a),
297 17 => Self::from_canonical_u32(0x163bd499),
298 18 => Self::from_canonical_u32(0x6c4a8a45),
299 19 => Self::from_canonical_u32(0x143ef899),
300 20 => Self::from_canonical_u32(0x514ddcad),
301 21 => Self::from_canonical_u32(0x484ef19b),
302 22 => Self::from_canonical_u32(0x205d63c3),
303 23 => Self::from_canonical_u32(0x68e7dd49),
304 24 => Self::from_canonical_u32(0x6ac49f88),
305 _ => unreachable!("Already asserted that bits <= Self::TWO_ADICITY"),
306 }
307 }
308}
309
310impl Add for KoalaBear {
311 type Output = Self;
312
313 #[inline]
314 fn add(self, rhs: Self) -> Self {
315 let mut sum = self.value + rhs.value;
316 let (corr_sum, over) = sum.overflowing_sub(P);
317 if !over {
318 sum = corr_sum;
319 }
320 Self { value: sum }
321 }
322}
323
324impl AddAssign for KoalaBear {
325 #[inline]
326 fn add_assign(&mut self, rhs: Self) {
327 *self = *self + rhs;
328 }
329}
330
331impl Sum for KoalaBear {
332 #[inline]
333 fn sum<I: Iterator<Item = Self>>(iter: I) -> Self {
334 let sum = iter.map(|x| (x.value as u64)).sum::<u64>();
339 Self {
340 value: (sum % P as u64) as u32,
341 }
342 }
343}
344
345impl Sub for KoalaBear {
346 type Output = Self;
347
348 #[inline]
349 fn sub(self, rhs: Self) -> Self {
350 let (mut diff, over) = self.value.overflowing_sub(rhs.value);
351 let corr = if over { P } else { 0 };
352 diff = diff.wrapping_add(corr);
353 Self { value: diff }
354 }
355}
356
357impl SubAssign for KoalaBear {
358 #[inline]
359 fn sub_assign(&mut self, rhs: Self) {
360 *self = *self - rhs;
361 }
362}
363
364impl Neg for KoalaBear {
365 type Output = Self;
366
367 #[inline]
368 fn neg(self) -> Self::Output {
369 Self::zero() - self
370 }
371}
372
373impl Mul for KoalaBear {
374 type Output = Self;
375
376 #[inline]
377 fn mul(self, rhs: Self) -> Self {
378 let long_prod = self.value as u64 * rhs.value as u64;
379 Self {
380 value: monty_reduce(long_prod),
381 }
382 }
383}
384
385impl MulAssign for KoalaBear {
386 #[inline]
387 fn mul_assign(&mut self, rhs: Self) {
388 *self = *self * rhs;
389 }
390}
391
392impl Product for KoalaBear {
393 #[inline]
394 fn product<I: Iterator<Item = Self>>(iter: I) -> Self {
395 iter.reduce(|x, y| x * y).unwrap_or(Self::one())
396 }
397}
398
399impl Div for KoalaBear {
400 type Output = Self;
401
402 #[allow(clippy::suspicious_arithmetic_impl)]
403 #[inline]
404 fn div(self, rhs: Self) -> Self {
405 self * rhs.inverse()
406 }
407}
408
409#[inline]
410#[must_use]
411const fn to_monty(x: u32) -> u32 {
412 (((x as u64) << MONTY_BITS) % P as u64) as u32
413}
414
415#[inline]
418#[must_use]
419pub(crate) const fn to_koalabear_array<const N: usize>(input: [u32; N]) -> [KoalaBear; N] {
420 let mut output = [KoalaBear { value: 0 }; N];
421 let mut i = 0;
422 loop {
423 if i == N {
424 break;
425 }
426 output[i].value = to_monty(input[i]);
427 i += 1;
428 }
429 output
430}
431
432#[inline]
433#[must_use]
434const fn to_monty_64(x: u64) -> u32 {
435 (((x as u128) << MONTY_BITS) % P as u128) as u32
436}
437
438#[inline]
439#[must_use]
440const fn from_monty(x: u32) -> u32 {
441 monty_reduce(x as u64)
442}
443
444#[inline]
446#[must_use]
447pub(crate) const fn monty_reduce(x: u64) -> u32 {
448 let t = x.wrapping_mul(MONTY_MU as u64) & (MONTY_MASK as u64);
449 let u = t * (P as u64);
450
451 let (x_sub_u, over) = x.overflowing_sub(u);
452 let x_sub_u_hi = (x_sub_u >> MONTY_BITS) as u32;
453 let corr = if over { P } else { 0 };
454 x_sub_u_hi.wrapping_add(corr)
455}
456
457#[cfg(test)]
458mod tests {
459 use p3_field_testing::{test_field, test_two_adic_field};
460
461 use super::*;
462
463 type F = KoalaBear;
464
465 #[test]
466 fn test_koala_bear_two_adicity_generators() {
467 let base = KoalaBear::from_canonical_u32(0x6ac49f88);
468 for bits in 0..=KoalaBear::TWO_ADICITY {
469 assert_eq!(
470 KoalaBear::two_adic_generator(bits),
471 base.exp_power_of_2(KoalaBear::TWO_ADICITY - bits)
472 );
473 }
474 }
475
476 #[test]
477 fn test_koala_bear() {
478 let f = F::from_canonical_u32(100);
479 assert_eq!(f.as_canonical_u64(), 100);
480
481 let f = F::from_canonical_u32(0);
482 assert!(f.is_zero());
483
484 let f = F::from_wrapped_u32(F::ORDER_U32);
485 assert!(f.is_zero());
486
487 let f_1 = F::one();
488 let f_1_copy = F::from_canonical_u32(1);
489
490 let expected_result = F::zero();
491 assert_eq!(f_1 - f_1_copy, expected_result);
492
493 let expected_result = F::two();
494 assert_eq!(f_1 + f_1_copy, expected_result);
495
496 let f_2 = F::from_canonical_u32(2);
497 let expected_result = F::from_canonical_u32(3);
498 assert_eq!(f_1 + f_1_copy * f_2, expected_result);
499
500 let expected_result = F::from_canonical_u32(5);
501 assert_eq!(f_1 + f_2 * f_2, expected_result);
502
503 let f_p_minus_1 = F::from_canonical_u32(F::ORDER_U32 - 1);
504 let expected_result = F::zero();
505 assert_eq!(f_1 + f_p_minus_1, expected_result);
506
507 let f_p_minus_2 = F::from_canonical_u32(F::ORDER_U32 - 2);
508 let expected_result = F::from_canonical_u32(F::ORDER_U32 - 3);
509 assert_eq!(f_p_minus_1 + f_p_minus_2, expected_result);
510
511 let expected_result = F::from_canonical_u32(1);
512 assert_eq!(f_p_minus_1 - f_p_minus_2, expected_result);
513
514 let expected_result = f_p_minus_1;
515 assert_eq!(f_p_minus_2 - f_p_minus_1, expected_result);
516
517 let expected_result = f_p_minus_2;
518 assert_eq!(f_p_minus_1 - f_1, expected_result);
519
520 let m1 = F::from_canonical_u32(0x34167c58);
521 let m2 = F::from_canonical_u32(0x61f3207b);
522 let expected_prod = F::from_canonical_u32(0x54b46b81);
523 assert_eq!(m1 * m2, expected_prod);
524
525 assert_eq!(m1.exp_u64(1420470955).exp_const_u64::<3>(), m1);
526 assert_eq!(m2.exp_u64(1420470955).exp_const_u64::<3>(), m2);
527 assert_eq!(f_2.exp_u64(1420470955).exp_const_u64::<3>(), f_2);
528
529 let f_serialized = serde_json::to_string(&f).unwrap();
530 let f_deserialized: F = serde_json::from_str(&f_serialized).unwrap();
531 assert_eq!(f, f_deserialized);
532
533 let f_1_serialized = serde_json::to_string(&f_1).unwrap();
534 let f_1_deserialized: F = serde_json::from_str(&f_1_serialized).unwrap();
535 let f_1_serialized_again = serde_json::to_string(&f_1_deserialized).unwrap();
536 let f_1_deserialized_again: F = serde_json::from_str(&f_1_serialized_again).unwrap();
537 assert_eq!(f_1, f_1_deserialized);
538 assert_eq!(f_1, f_1_deserialized_again);
539
540 let f_2_serialized = serde_json::to_string(&f_2).unwrap();
541 let f_2_deserialized: F = serde_json::from_str(&f_2_serialized).unwrap();
542 assert_eq!(f_2, f_2_deserialized);
543
544 let f_p_minus_1_serialized = serde_json::to_string(&f_p_minus_1).unwrap();
545 let f_p_minus_1_deserialized: F = serde_json::from_str(&f_p_minus_1_serialized).unwrap();
546 assert_eq!(f_p_minus_1, f_p_minus_1_deserialized);
547
548 let f_p_minus_2_serialized = serde_json::to_string(&f_p_minus_2).unwrap();
549 let f_p_minus_2_deserialized: F = serde_json::from_str(&f_p_minus_2_serialized).unwrap();
550 assert_eq!(f_p_minus_2, f_p_minus_2_deserialized);
551
552 let m1_serialized = serde_json::to_string(&m1).unwrap();
553 let m1_deserialized: F = serde_json::from_str(&m1_serialized).unwrap();
554 assert_eq!(m1, m1_deserialized);
555
556 let m2_serialized = serde_json::to_string(&m2).unwrap();
557 let m2_deserialized: F = serde_json::from_str(&m2_serialized).unwrap();
558 assert_eq!(m2, m2_deserialized);
559 }
560
561 test_field!(crate::KoalaBear);
562 test_two_adic_field!(crate::KoalaBear);
563}